-
Notifications
You must be signed in to change notification settings - Fork 38
/
run_split.py
96 lines (78 loc) · 3.14 KB
/
run_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
import tensorflow as tf
import gzip
import cPickle
import sys
sys.path.extend(['alg/'])
import vcl
import coreset
import utils
from copy import deepcopy
class SplitMnistGenerator():
def __init__(self):
f = gzip.open('data/mnist.pkl.gz', 'rb')
train_set, valid_set, test_set = cPickle.load(f)
f.close()
self.X_train = np.vstack((train_set[0], valid_set[0]))
self.X_test = test_set[0]
self.train_label = np.hstack((train_set[1], valid_set[1]))
self.test_label = test_set[1]
self.sets_0 = [0, 2, 4, 6, 8]
self.sets_1 = [1, 3, 5, 7, 9]
self.max_iter = len(self.sets_0)
self.cur_iter = 0
def get_dims(self):
# Get data input and output dimensions
return self.X_train.shape[1], 2
def next_task(self):
if self.cur_iter >= self.max_iter:
raise Exception('Number of tasks exceeded!')
else:
# Retrieve train data
train_0_id = np.where(self.train_label == self.sets_0[self.cur_iter])[0]
train_1_id = np.where(self.train_label == self.sets_1[self.cur_iter])[0]
next_x_train = np.vstack((self.X_train[train_0_id], self.X_train[train_1_id]))
next_y_train = np.vstack((np.ones((train_0_id.shape[0], 1)), np.zeros((train_1_id.shape[0], 1))))
next_y_train = np.hstack((next_y_train, 1-next_y_train))
# Retrieve test data
test_0_id = np.where(self.test_label == self.sets_0[self.cur_iter])[0]
test_1_id = np.where(self.test_label == self.sets_1[self.cur_iter])[0]
next_x_test = np.vstack((self.X_test[test_0_id], self.X_test[test_1_id]))
next_y_test = np.vstack((np.ones((test_0_id.shape[0], 1)), np.zeros((test_1_id.shape[0], 1))))
next_y_test = np.hstack((next_y_test, 1-next_y_test))
self.cur_iter += 1
return next_x_train, next_y_train, next_x_test, next_y_test
hidden_size = [256, 256]
batch_size = None
no_epochs = 120
single_head = False
# Run vanilla VCL
tf.set_random_seed(12)
np.random.seed(1)
coreset_size = 0
data_gen = SplitMnistGenerator()
vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen,
coreset.rand_from_batch, coreset_size, batch_size, single_head)
print vcl_result
# Run random coreset VCL
tf.reset_default_graph()
tf.set_random_seed(12)
np.random.seed(1)
coreset_size = 40
data_gen = SplitMnistGenerator()
rand_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen,
coreset.rand_from_batch, coreset_size, batch_size, single_head)
print rand_vcl_result
# Run k-center coreset VCL
tf.reset_default_graph()
tf.set_random_seed(12)
np.random.seed(1)
data_gen = SplitMnistGenerator()
kcen_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen,
coreset.k_center, coreset_size, batch_size, single_head)
print kcen_vcl_result
# Plot average accuracy
vcl_avg = np.nanmean(vcl_result, 1)
rand_vcl_avg = np.nanmean(rand_vcl_result, 1)
kcen_vcl_avg = np.nanmean(kcen_vcl_result, 1)
utils.plot('results/split.jpg', vcl_avg, rand_vcl_avg, kcen_vcl_avg)