-
Notifications
You must be signed in to change notification settings - Fork 3
/
mnli_extract.py
108 lines (81 loc) · 2.99 KB
/
mnli_extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import ujson
import sys
import argparse
import re
import spacy
spacy_nlp = spacy.load('en_core_web_sm')
# tokenize and tag pos
def tokenize_spacy(text):
tokenized = spacy_nlp(text)
# use universal pos tags
toks = [tok.text for tok in tokenized if not tok.is_space]
pos = [tok.pos_ for tok in tokenized if not tok.is_space]
lemma = [tok.lemma_.replace(' ','') for tok in tokenized if not tok.is_space]
lemma = [l if l != '' else t for l, t in zip(lemma, toks)]
return toks, pos, lemma
def write_to(ls, out_file):
print('writing to {0}'.format(out_file))
with open(out_file, 'w+') as f:
for l in ls:
f.write((l + '\n'))
def extract(opt, csv_file):
all_sent1 = []
all_sent2 = []
all_label = []
all_sent1_pos = []
all_sent2_pos = []
all_sent1_lemma = []
all_sent2_lemma = []
max_sent_l = 0
skip_cnt = 0
headers = {}
with open(csv_file, 'r') as f:
line_idx = 0
for l in f:
line_idx += 1
if line_idx == 1:
hds = l.rstrip().split('\t')
for k, h in enumerate(hds):
headers[h] = k
continue
if l.strip() == '':
continue
cells = l.rstrip().split('\t')
label = cells[headers['gold_label']]
sent1 = cells[headers['sentence1']]
sent2 = cells[headers['sentence2']]
if label == '-':
print('skipping label {0}'.format(label))
skip_cnt += 1
continue
else:
print(label)
assert(label in ['entailment', 'neutral', 'contradiction'])
sent1_toks, sent1_pos, sent1_lemma = tokenize_spacy(sent1)
sent2_toks, sent2_pos, sent2_lemma = tokenize_spacy(sent2)
assert(len(sent1_toks) == len(sent1_pos))
assert(len(sent2_toks) == len(sent2_pos))
assert(len(sent1_toks) == len(sent1_lemma))
max_sent_l = max(max_sent_l, len(sent1_toks), len(sent2_toks))
all_sent1.append(' '.join(sent1_toks))
all_sent2.append(' '.join(sent2_toks))
all_sent1_pos.append(' '.join(sent1_pos))
all_sent2_pos.append(' '.join(sent2_pos))
all_sent1_lemma.append(' '.join(sent1_lemma))
all_sent2_lemma.append(' '.join(sent2_lemma))
all_label.append(label)
print('skipped {0} examples'.format(skip_cnt))
return (all_sent1, all_sent2, all_sent1_pos, all_sent2_pos, all_sent1_lemma, all_sent2_lemma, all_label)
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--data', help="Path to SNLI txt file", default="data/bert_nli/multinli_1.0_dev_mismatched.txt")
parser.add_argument('--output', help="Prefix to the path of output", default="data/bert_nli/mnli.dev")
parser.add_argument('--filter', help="List of pos tags to filter out", default="")
def main(args):
opt = parser.parse_args(args)
all_sent1, all_sent2, all_sent1_pos, all_sent2_pos, all_sent1_lemma, all_sent2_lemma, all_label = extract(opt, opt.data)
print('{0} examples processed.'.format(len(all_sent1)))
write_to(all_sent1, opt.output + '.raw.sent1.txt')
write_to(all_sent2, opt.output + '.raw.sent2.txt')
write_to(all_label, opt.output + '.label.txt')
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))