-
Notifications
You must be signed in to change notification settings - Fork 0
/
merge_conll_to_oneline_add_jp_name.py
130 lines (123 loc) · 4.78 KB
/
merge_conll_to_oneline_add_jp_name.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# ----------------------------------------------------------------------------
from __future__ import print_function
from __future__ import division
import argparse
import random
parser = argparse.ArgumentParser(description='merge_conll_to_oneline.py')
##
## **Preprocess Options**
##
parser.add_argument('-input_file', required=True,
help="Path to the training data")
parser.add_argument('-output_file', required=True,
help="Path to the output")
parser.add_argument('-augment_jp_name', action='store_true', default=False,
help="Augment Japanese person names")
# parser.add_argument('-want_type', default='per',
# help="What kind of label to show")
opt = parser.parse_args()
def makeData(filename):
src, tgt = [], []
pos_instances = dict()
outputs = []
sent_count = 0
line_num = 0
print('Processing file %s...' % (filename))
with open(filename) as inputFile:
for oneline in inputFile.readlines():
line_num += 1
oneline = oneline.rstrip()
if len(oneline) < 1:
if len(src) > 0:
outputs.append((src, tgt))
sent_count += 1
if sent_count % 10000 == 0:
print('... %d sentences read' % sent_count)
src = []
tgt = []
continue
oneline = oneline.split('\t')
if len(oneline) != 2:
print("Error on line %d, length not 2." % (line_num))
continue
srcWords = oneline[0]
tgtWords = oneline[1]
src.append(srcWords)
tgt.append(tgtWords)
pos_instance = pos_instances.get(tgtWords, [])
pos_instance.append(srcWords)
pos_instances[tgtWords] = pos_instance
return outputs, pos_instances
def join_words_pos(words, poss):
joined_words = "".join(words)
joined_pos = ""
for w_id, word in enumerate(words):
pos_tag = poss[w_id]
if pos_tag.lower() not in ["per", "loc", "org"]:
pos_tag = "N-{}".format(pos_tag)
if len(word) == 1:
joined_pos += 'S-{} '.format(pos_tag)
continue
if len(word) == 2:
joined_pos += 'B-{0} E-{0} '.format(pos_tag)
continue
joined_pos += 'B-{} '.format(pos_tag)
for i in range(len(word)-2):
joined_pos += 'I-{} '.format(pos_tag)
joined_pos += 'E-{} '.format(pos_tag)
joined_pos = joined_pos.strip()
return joined_words, joined_pos
def replace_with_jp_name(words, poss):
ret_words = []
for w_id, word in enumerate(words):
newword = word
if poss[w_id].lower() == "per":
newword = u"{}{}".format(random.choice(jp_last_names), random.choice(jp_first_names))
ret_words.append(newword)
return ret_words
def main():
sents, pos_instances = makeData(opt.input_file)
sent_count = 0
with open(opt.output_file, 'w') as ofile:
for sent in sents:
sent_count += 1
# if sent_count > 3: exit()
words, poss = sent
joined_words, joined_pos = join_words_pos(words, poss)
ofile.write(joined_words)
ofile.write('\t')
ofile.write(joined_pos)
ofile.write('\n')
# augment JP names randomly
if opt.augment_jp_name and \
any(x.lower() == "per" for x in poss):
new_words = replace_with_jp_name(words, poss)
joined_words, joined_pos = join_words_pos(new_words, poss)
ofile.write(joined_words)
ofile.write('\t')
ofile.write(joined_pos)
ofile.write('\n')
sent_count += 1
if sent_count % 10000 == 0:
print("Output sentence {}".format(sent_count))
print("Output sentence {}".format(sent_count))
# want_type = opt.want_type
# all_data = {}
# all_data['src'], all_data['tgt'], all_data['sizes'] = makeData(opt.input_file)
# found_words = set()
# linenum = 0
# for i in range(len(all_data['tgt'])):
# linenum += 1
# assert len(all_data['tgt'][i]) == len(all_data['src'][i]), "Line %s length mismatch" % linenum
# for c in range(len(all_data['tgt'][i])):
# if want_type in (all_data['tgt'][i][c]).lower():
# # print (linenum)
# found_words.add(all_data['src'][i][c])
# for w in found_words:
# print(w, end=' ')
if __name__ == "__main__":
# jp_first_names = [n.strip() for n in open('JP_first_name.txt').readlines()]
# jp_last_names = [n.strip() for n in open('JP_last_name.txt').readlines()]
main()