-
Notifications
You must be signed in to change notification settings - Fork 5
/
preprocess_truecase.py
90 lines (59 loc) · 2.02 KB
/
preprocess_truecase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# coding=utf-8
# Copyright 2016 Ottokar Tilk and Tanel Alumäe
"""
Preprocess data for punctuation model training
Code from Punctuator2: https://github.com/ottokart/punctuator2
"""
import os
import codecs
import re
import sys
# The following lines are only needed once
# import nltk
# nltk.download('punkt')
from nltk.tokenize import word_tokenize
NUM = "<NUM>"
EOS_PUNCTS = {".": ".PERIOD", "?": "?QUESTIONMARK", "!": "!EXCLAMATIONMARK"}
INS_PUNCTS = {",": ",COMMA", ";": ";SEMICOLON", ":": ":COLON", "-": "-DASH"}
forbidden_symbols = re.compile(r"[\[\]\(\)\\\>\<\=\+\_\*]")
numbers = re.compile(r"\d")
multiple_punct = re.compile(r"([\.\?\!\,\:\;\-])(?:[\.\?\!\,\:\;\-]){1,}")
def is_number(x):
return len(numbers.sub("", x)) / len(x) < 0.6
def untokenize(line):
# return line.replace(" '", "'").replace(" n't", "n't").replace("can not", "cannot")
return line
def skip(line):
if line.strip() == "":
return True
last_symbol = line[-1]
if not last_symbol in EOS_PUNCTS:
return True
if forbidden_symbols.search(line) is not None:
return True
return False
def process_line(line):
tokens = word_tokenize(line)
output_tokens = []
for token in tokens:
if token in INS_PUNCTS:
output_tokens.append(INS_PUNCTS[token])
elif token in EOS_PUNCTS:
output_tokens.append(EOS_PUNCTS[token])
elif is_number(token):
output_tokens.append(NUM)
else:
output_tokens.append(token)
return untokenize(" ".join(output_tokens) + " ")
skipped = 0
with codecs.open(sys.argv[2], "w", encoding="utf-8") as out_txt:
with codecs.open(sys.argv[1], "r", encoding="utf-8") as text:
for line in text:
line = line.replace('"', "").strip()
line = multiple_punct.sub(r"\g<1>", line)
if skip(line):
skipped += 1
continue
line = process_line(line)
out_txt.write(line + "\n")
print("Skipped {} lines".format(skipped))