-
Notifications
You must be signed in to change notification settings - Fork 0
/
postprocess.py
142 lines (109 loc) · 5.3 KB
/
postprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import copy
from collections import defaultdict
from preferences import Paths
import conlluplus as cplus
#=============================================================================
## TODO: tee lista yleisimmistä virheistä
## TODO: kokeile sanavektoreita monitulkintaisimpien logogrammien
## disambiguointiin
class Postprocessor:
def __init__(self, predictions, model_name):
self.model_name = model_name
self.train = os.path.join(
Paths.models, self.model_name,
'conllu', 'train.conllu')
self.override = os.path.join(
Paths.models, self.model_name,
'override', 'override.conllu')
""" Container for the last post-processing step """
if isinstance(predictions, str):
predictions = cplus.ConlluPlus(predictions, validate=False)
self.predictions = predictions
self.train_data = None
def _generate_lemmadict(self, fields, threshold):
""" Creates naive disambiguation dictionary based on
FORM + an arbitrary tag mapped to a lemma """
## TODO: koita parantaa postägäystä, esim
## xlit + left context + right context --> lemma + POS
lems = defaultdict(dict)
if self.train_data is None:
self.train_data = cplus.ConlluPlus(self.train, validate=False)
for data in self.train_data.get_contents():
key1 = data[fields[0]]
key2 = data[fields[1]]
lemma = data[2]
lems[(key1, key2)].setdefault(lemma, 0)
lems[(key1, key2)][lemma] += 1
""" Collect lemmas that have been given to xlit + pos
more often than the given threshold """
for xlit_pos, lemmata in lems.items():
for lemma, count in lemmata.items():
score = count / sum(lemmata.values())
if score >= threshold:
yield xlit_pos, lemma, 1.0 #score; now flat not
def initialize_scores(self):
""" Initialize confidence scores """
invocab = set()
with open(os.path.join(Paths.models, self.model_name,
'lex', 'train-types.xlit')) as f:
for line in f:
line = line.rstrip()
invocab.add(line.split('\t')[0])
def get_scores():
for sentence in self.predictions.get_contents():
form = sentence[cplus.FORM]
if form not in invocab:
if form.lower() == form:
score = 2.0
elif form.upper() == form:
score = 0.0
else:
score = 1.0
else:
score = 3.0
yield score
self.predictions.update_value(
field = 'score', values = get_scores())
def fill_unambiguous(self, threshold=0.9):
""" First post-processing step: calculate close-to
unambiguous word forms + pos tags from the training
data and overwrite lemmatizations
:param threshold unambiguity threshold
:type threshhold float """
print(f'> Post-processor ({self.model_name}): '\
f'filling in unambiguous words (t ≥ {threshold})')
""" Reform dictionary in format
{(input fields):
{output_index: output, score: score}, ...} """
unambiguous = {xlit_pos : {cplus.LEMMA: lemma, 'score': score}
for xlit_pos, lemma, score in self._generate_lemmadict(
fields=(cplus.FORM, cplus.XPOS), threshold=threshold)}
""" Populate CoNLL-U with substitutions """
self.predictions.conditional_update_value(
unambiguous, fields = ('form', 'xpos'))
def disambiguate_by_pos_context(self, threshold=0.9):
""" Disambiguate lemmata by their XPOS context; works
just as fill_unambiguous() but uses XPOS context instead of
XPOS as second part of the key. """
## TODO: tee tämä steppi ensin ja koita vaihtaa pos-tägi.
## Ei kyllä toimi jos konteksti on vituillaan
## Markovin ketju?
unambiguous = {xlit_pos : {cplus.LEMMA: lemma, 'score': score}
for xlit_pos, lemma, score in self._generate_lemmadict(
fields=(cplus.FORM, cplus.XPOSCTX), threshold=threshold)}
self.predictions.conditional_update_value(
unambiguous, fields = ('form', 'xposctx'))
def apply_override(self):
""" Make override dictionary """
override = cplus.ConlluPlus(self.override, validate=False)
_dict = {}
for form, lemma, xpos in override.get_contents('form', 'lemma', 'xpos'):
#form = form.strip('*') # remove stars
_dict[form] = {'lemma': lemma, 'xpos': xpos}
# aa override ja ylikirjoita jokainen form overriden lemma + pos kombolla
self.predictions.override_form(_dict)
if __name__ == "__main__":
P = Postprocessor('input/example_nn.conllu', 'lbtest2')
#P.disambiguate_by_pos_context()
P.apply_override()