-
Notifications
You must be signed in to change notification settings - Fork 8
/
nmf_gibbs.py
87 lines (64 loc) · 2.6 KB
/
nmf_gibbs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Recover the toy dataset using Gibbs.
Measure the convergence over iterations and time.
We run the algorithm 10 times with the same seed, and take the average timestamps.
"""
import sys, os
project_location = os.path.dirname(__file__)+"/../../../../"
sys.path.append(project_location)
from BNMTF_ARD.code.models.bnmf_gibbs import bnmf_gibbs
import numpy
import scipy
import random
import matplotlib.pyplot as plt
''' Location of toy data, and where to store the performances. '''
input_folder = project_location+"BNMTF_ARD/data/toy/bnmf/"
output_folder = project_location+"BNMTF_ARD/experiments/experiments_toy/convergence/results/"
output_file_performances = output_folder+'nmf_gibbs_all_performances.txt'
output_file_times = output_folder+'nmf_gibbs_all_times.txt'
''' Model settings. '''
iterations = 500
init_UV = 'random'
I, J, K = 100, 80, 10
ARD = False
repeats = 10
lambdaU = 0.1
lambdaV = 0.1
alphatau, betatau = 1., 1.
alpha0, beta0 = 1., 1.
hyperparams = { 'alphatau':alphatau, 'betatau':betatau, 'alpha0':alpha0, 'beta0':beta0, 'lambdaU':lambdaU, 'lambdaV':lambdaV }
''' Load in data. '''
R = numpy.loadtxt(input_folder+"R.txt")
M = numpy.ones((I,J))
''' Run the algorithm, :repeats times, and average the timestamps. '''
times_repeats = []
performances_repeats = []
for i in range(0,repeats):
# Set all the seeds
numpy.random.seed(0), random.seed(0), scipy.random.seed(0)
# Run the classifier
BNMF = bnmf_gibbs(R,M,K,ARD,hyperparams)
BNMF.initialise(init_UV)
BNMF.run(iterations)
# Extract the performances and timestamps across all iterations
times_repeats.append(BNMF.all_times)
performances_repeats.append(BNMF.all_performances)
''' Check whether seed worked: all performances should be the same. '''
assert all([numpy.array_equal(p, performances_repeats[0]) for p in performances_repeats]), \
"Seed went wrong - performances not the same across repeats!"
''' Print out the performances, and the average times, and store them in a file. '''
all_times_average = list(numpy.average(times_repeats, axis=0))
all_performances = performances_repeats[0]
print "all_times_average = %s" % all_times_average
print "all_performances = %s" % all_performances
open(output_file_times,'w').write("%s" % all_times_average)
open(output_file_performances,'w').write("%s" % all_performances)
''' Plot the average time plot, and performance vs iterations. '''
plt.figure()
plt.title("Performance against average time")
plt.plot(all_times_average, all_performances['MSE'])
plt.ylim(0,10)
plt.figure()
plt.title("Performance against iteration")
plt.plot(all_performances['MSE'])
plt.ylim(0,10)