-
Notifications
You must be signed in to change notification settings - Fork 22
/
Run.py
executable file
·85 lines (70 loc) · 2.99 KB
/
Run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import sys
from Module.Train import train
from Module.Eval import eval
import zipfile
import os
from FilesManager.FilesManager import FilesManager
from Utils.Logger import Logger
import urllib
if __name__ == "__main__":
version_filename_flag = '.data_ver2'
# create logger
logger = Logger()
application = None
name = None
gpu = None
## get input parameters
# application
if len(sys.argv) > 1:
application = sys.argv[1]
# module name
if len(sys.argv) > 2:
name = sys.argv[2]
# gpu number
if len(sys.argv) > 3:
gpu = sys.argv[3]
if application == "train":
# check if requried data version downloaded
if not os.path.isfile(version_filename_flag):
print("Error: Data wasn't downloaded. Type python Run.py for instructions how to download\n\n")
exit()
logger.log("Command: Train(module_name=%s, gpu=%s" % (name, str(gpu)))
train(name=name, gpu=gpu)
elif application == "eval":
# check if requried data version downloaded
if not os.path.isfile(version_filename_flag):
print("Error: Data wasn't downloaded. Type python Run.py for instructions how to download\n\n")
exit()
logger.log("Command: Eval(module_name=%s, gpu=%s" % (name, str(gpu)))
eval(load_module_name=name, gpu=gpu)
elif application == "download":
logger.log("Command: Download()")
filesmanager = FilesManager()
path = filesmanager.get_file_path("data.visual_genome.data")
file_name = os.path.join(path, "data.zip")
# Download Data
logger.log("Download Data ...")
url = "http://www.cs.tau.ac.il/~taunlp/scene_graph/data.zip"
urllib.urlretrieve(url, file_name)
# Extract data
logger.log("Extract ZIP file ...")
zip_ref = zipfile.ZipFile(file_name, 'r')
zip_ref.extractall(path)
zip_ref.close()
# mark data version downloaded
open(version_filename_flag, "wb").close()
else:
# print usage
print("Error: unexpected usage\n\n")
print("SGP Runner")
print("----------")
print("Download data: \"python Run.py download\"")
print(" Should be run just once, on the first time the module used")
print("Train Module: \"python Run.py train <module_name> <gpu_number>\"")
print(" Train lingustic SGP")
print(" Module weights with the highest score over the validation set will be saved as \"<module_name>_best\"")
print(" Module weights of the last epoch will be saved as \"<module_name>\"")
print("Eval Module: \"python Run.py eval <module_name> <gpu_number>\"")
print(" Scene graph classification (recall@100) evaluation for the trained module.")
print(" Use 'gpi_ling_orig_best' for a pre-trained module")
print(" Use \"<module_name>_best\" for a self-trained module")