-
Notifications
You must be signed in to change notification settings - Fork 0
/
fix_square_brackets.py
49 lines (39 loc) · 1.52 KB
/
fix_square_brackets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import pickle
import os
from normalize_data2 import skip_brackets
import transformer.vocabulary
def update_vocabulary():
vocab_fname = "vocabulary.pkl"
dir = "/Users/balazs/token_trace_normalized"
vocab = pickle.load(open(os.path.join(dir, vocab_fname), 'rb'))
a = len(vocab)
vocab['\\['] = a + 1
vocab['\\]'] = a + 2
dir_out = "/Users/balazs/token_trace_normalized_square"
pickle.dump(vocab, open(os.path.join(dir_out, vocab_fname), 'wb'))
def main():
update_vocabulary()
return
token_file = "testing_data.pkl"
input_tokens_dir = "/Users/balazs/token_trace_normalized"
input_tokens_path = os.path.join(input_tokens_dir, token_file)
tokens = pickle.load(open(input_tokens_path, 'rb'))
for index, (formula, input) in enumerate(tokens):
for j, (char, box) in enumerate(input):
if char == '[':
char = '\\['
elif char == ']':
char = '\\]'
input[j] = (char, box)
for j, char in enumerate(formula):
if char == '[' and j > 0:
if formula[j - 1] != '\\sqrt':
k = skip_brackets(formula, j, ('[', ']')) - 1
formula[j] = '\\['
formula[k] = '\\]'
tokens[index] = (formula, input)
output_tokens_dir = "/Users/balazs/token_trace_normalized_square"
output_tokens_path = os.path.join(output_tokens_dir, token_file)
# pickle.dump(tokens, open(output_tokens_path, 'wb'))
if __name__ == "__main__":
main()