-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
134 lines (112 loc) · 5.33 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from ngrammer.Ngrammer import *
from ngrammer import CorpusHandler
def generate(trie, sos="[Start]", eos="[End]"):
import random
word = sos
sentence = [sos]
while word != eos:
temp = sentence[-(trie.n - 1):]
node = trie.get(temp)
children = list(node.children.keys())
word = random.choice(children)
sentence.append(word)
return sentence
def print_progress_bar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='█', print_end="\r"):
"""
Call in a loop to create terminal progress bar
@params:
iteration - Required : current iteration (Int)
total - Required : total iterations (Int)
prefix - Optional : prefix string (Str)
suffix - Optional : suffix string (Str)
decimals - Optional : positive number of decimals in percent complete (Int)
length - Optional : character length of bar (Int)
fill - Optional : bar fill character (Str)
printEnd - Optional : end character (e.g. "\r", "\r\n") (Str)
"""
percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
filledLength = int(length * iteration // total)
bar = fill * filledLength + '-' * (length - filledLength)
print(f'\r{prefix} |{bar}| {percent}% {suffix}', end=print_end)
# Print New Line on Complete
if iteration == total:
print()
def extract_coefficients(n, corpus):
temp = CachedPrefixTree(n)
for sentence in corpus:
temp.store(sentence)
temp.train()
interpolation = temp._deleted_interpolation()
return interpolation
def test_trees(trees_to_test, test_data):
print("Start testing")
n_test_data = len(test_data)
perplexities = defaultdict(int)
print_progress_bar(0, n_test_data, prefix='Progress:', suffix='Complete', length=50)
for j in range(n_test_data):
for n, t in trees_to_test.items():
perplexities[t] += calc_perplexity(t, test_data[j])
print_progress_bar(j + 1, n_test_data, prefix='Progress:', suffix='Complete', length=50)
print("")
for t, perplexity in perplexities.items():
perplexities[t] /= n_test_data
for n, t in trees.items():
print("Mean Perplexity {}: {}".format(n, perplexities[t]))
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
print("Start")
annotations = True
train = "dataset/ptbdataset/ptb.train.txt"
validation = "dataset/ptbdataset/ptb.valid.txt"
test = "dataset/ptbdataset/ptb.test.txt"
print("Loading datasets")
train_corpus = CorpusHandler.load_corpus(train, annotations)
validation_corpus = CorpusHandler.load_corpus(validation, annotations)
test_corpus = CorpusHandler.load_corpus(test, annotations)
N = 3
print("Calculating interpolation coefficients")
coefficients = extract_coefficients(N, validation_corpus)
turing = CorpusHandler.good_turing_estimate(test_corpus)
print("Instantiating trees")
trees = dict()
trees["Standard prefix tree"] = PrefixTree(N)
trees["Single cache(400) prefix tree"] = CachedPrefixTree(N, 400)
trees["Single cache(200) prefix tree"] = CachedPrefixTree(N, 200)
trees["Single cache(100) prefix tree"] = CachedPrefixTree(N, 100)
trees["Single cache(50) prefix tree"] = CachedPrefixTree(N, 50)
trees["Single cache(25) prefix tree"] = CachedPrefixTree(N, 25)
trees["Multi ngram cached(400) prefix tree"] = CachedMultiNgramPrefixTree(N, 400)
trees["Multi ngram cached(200) prefix tree"] = CachedMultiNgramPrefixTree(N, 200)
trees["Multi ngram cached(100) prefix tree"] = CachedMultiNgramPrefixTree(N, 100)
trees["Multi ngram cached(50) prefix tree"] = CachedMultiNgramPrefixTree(N, 50)
trees["Multi ngram cached(25) prefix tree"] = CachedMultiNgramPrefixTree(N, 25)
trees["Multi cache(200,100,50) prefix tree"] = CachedPrefixTree(N, [200, 100, 50])
trees["Multi ngram cached(200,100,50) prefix tree"] = CachedMultiNgramPrefixTree(N, [200, 100, 50])
# trees["Pos prefix tree"] = PosTree(N) # can't make it work
print("Loading train data")
print("")
n_samples = len(train_corpus)
"""
The pos tree takes a lot of time to store all the data, even more than 5 minutes
this is because it uses spacy to parse every single sentence one by one singularly
The batch loading doesn't seem to work properly on my machine
In the case one uses pycharm to run this program,
it is necessary to enable the terminal simulation for the debug console
in the setting for the run window to correctly view the progress bar
"""
for name, tree in trees.items():
print("Loading data into: {}".format(name))
print_progress_bar(0, n_samples, prefix='Progress:', suffix='Complete', length=50)
for i in range(n_samples):
tree.store(train_corpus[i])
print_progress_bar(i + 1, n_samples, prefix='Progress:', suffix='Complete', length=50)
print("")
logs = True
smoothing = True
for name, tree in trees.items():
print("Training: ", name)
tree.train(logs, smoothing)
for name, tree in trees.items():
if hasattr(tree.__class__, 'set_coefficients') and callable(getattr(tree.__class__, 'set_coefficients')):
tree.set_coefficients(coefficients)
test_trees(trees, test_corpus)