-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
125 lines (103 loc) · 4.86 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from __future__ import division
from __future__ import print_function
import time
import tensorflow as tf
from gcn.utils import *
from gcn.models import GCN, MLP
def run_training(adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, model_type, model=None):
# Set random seed
seed = 123
np.random.seed(seed)
tf.set_random_seed(seed)
# Settings
try:
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('model', 'gcn', 'Model string.') # 'gcn', 'gcn_cheby', 'dense'
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
flags.DEFINE_integer('epochs', 500, 'Number of epochs to train.')
flags.DEFINE_integer('hidden1', 16, 'Number of units in hidden layer 1.')
flags.DEFINE_float('dropout', 0.5, 'Dropout rate (1 - keep probability).')
flags.DEFINE_float('weight_decay', 5e-4, 'Weight for L2 loss on embedding matrix.')
flags.DEFINE_integer('early_stopping', 10, 'Tolerance for early stopping (# of epochs).')
flags.DEFINE_integer('max_degree', 3, 'Maximum Chebyshev polynomial degree.')
except:
pass
# Load data
# adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(FLAGS.dataset)
# print(type(adj), adj.shape)
# print(type(features), features.shape)
# print(type(y_train), y_train.shape)
# print(type(y_val), y_val.shape)
# print(type(y_test), y_test.shape)
# print(type(train_mask), train_mask.shape)
# print(type(val_mask), val_mask.shape)
# print(type(test_mask), test_mask.shape)
# Some preprocessing
features = preprocess_features(features)
if model_type == 'gcn':
support = [preprocess_adj(adj)]
num_supports = 1
model_func = GCN
elif model_type == 'gcn_cheby':
support = chebyshev_polynomials(adj, FLAGS.max_degree)
num_supports = 1 + FLAGS.max_degree
model_func = GCN
elif model_type == 'dense':
support = [preprocess_adj(adj)] # Not used
num_supports = 1
model_func = MLP
else:
raise ValueError('Invalid argument for model: ' + str(FLAGS.model))
# Define placeholders
placeholders = {
'support': [tf.sparse_placeholder(tf.float32) for _ in range(num_supports)],
'features': tf.sparse_placeholder(tf.float32, shape=tf.constant(features[2], dtype=tf.int64)),
'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
'labels_mask': tf.placeholder(tf.int32),
'dropout': tf.placeholder_with_default(0., shape=()),
'num_features_nonzero': tf.placeholder(tf.int32) # helper variable for sparse dropout
}
# Create model
if model is None:
model = model_func(placeholders, input_dim=features[2][1], logging=True)
else:
model.placeholders = placeholders
model.inputs = placeholders['features']
model.output_dim = placeholders['labels'].get_shape().as_list()[1]
model.input_dim = features[2][1]
# Initialize session
sess = tf.Session()
# Define model evaluation function
def evaluate(features, support, labels, mask, placeholders):
t_test = time.time()
feed_dict_val = construct_feed_dict(features, support, labels, mask, placeholders)
outs_val = sess.run([model.loss, model.accuracy, model.outputs], feed_dict=feed_dict_val)
return outs_val[0], outs_val[1], outs_val[2], (time.time() - t_test)
# Init variables
sess.run(tf.global_variables_initializer())
cost_val = []
# Train model
for epoch in range(FLAGS.epochs):
t = time.time()
# Construct feed dictionary
feed_dict = construct_feed_dict(features, support, y_train, train_mask, placeholders)
feed_dict.update({placeholders['dropout']: FLAGS.dropout})
# Training step
outs = sess.run([model.opt_op, model.loss, model.accuracy], feed_dict=feed_dict)
# # Validation
# cost, acc, _, duration = evaluate(features, support, y_val, val_mask, placeholders)
# cost_val.append(cost)
# # Print results
# print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(outs[1]),
# "train_acc=", "{:.5f}".format(outs[2]), "val_loss=", "{:.5f}".format(cost),
# "val_acc=", "{:.5f}".format(acc), "time=", "{:.5f}".format(time.time() - t))
# if epoch > FLAGS.early_stopping and cost_val[-1] > np.mean(cost_val[-(FLAGS.early_stopping+1):-1]):
# print("Early stopping...")
# break
# print("Optimization Finished!")
# Testing
test_cost, test_acc, test_outputs, test_duration = evaluate(features, support, y_test, test_mask, placeholders)
print("Training results:", "cost=", "{:.5f}".format(test_cost),
"time=", "{:.5f}".format(test_duration)) #"accuracy=", "{:.5f}".format(test_acc)
return model, test_outputs