-
Notifications
You must be signed in to change notification settings - Fork 0
/
Trainer.py
43 lines (37 loc) · 1.49 KB
/
Trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from NeuralNet import NeuralNet
from DataLoader import LoadFilesData, DataLoader
if __name__ == "__main__":
# initialize the data loader
datasetDir = "/home/sanche/Datasets/IMDB-WIKI"
csvPath = "./dataset.csv"
indicesPath = "./indices.p"
csvdata, indices = LoadFilesData(datasetDir, csvPath, indicesPath)
saveSteps = 10
image_size = 64
numPerBin = 4
batch_size = numPerBin * 8 * 2
noise_size = 100
loader = DataLoader(indices, csvdata, numPerBin=numPerBin, imageSize=image_size, numWorkerThreads=10, bufferMax=20, debugLogs=False)
loader.start()
# start training
network = NeuralNet(batch_size=batch_size, image_size=image_size, noise_size=noise_size, learningRate=5e-4)
printInterval = 100
saveInterval = 1000
loadedCheckpoint = network.checkpoint_num
i=0
while True:
batchDict = loader.getData()
batchImage = batchDict["image"]
batchAge = batchDict["age"]
batchSex = batchDict["sex"]
if i % printInterval == 0:
if (i != 0 or loadedCheckpoint == 0):
#if we are repeating a previous one, skip logging to csv
saveFile = "./logs.tsv"
else:
saveFile = None
network.printStatus(i+loadedCheckpoint, batchImage, batchSex, batchAge, logFilePath=saveFile)
network.train(batchImage, batchSex, batchAge)
if i % saveInterval == 0 and i != 0:
network.saveCheckpoint(saveInterval)
i = i + 1