-
Notifications
You must be signed in to change notification settings - Fork 0
/
cursed_cfg.py
36 lines (27 loc) · 1 KB
/
cursed_cfg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
class CFG:
def __init__(self):
self.structures = {}
def add_structure(self, pattern):
structure_id = "".join(pattern)
if pattern == []:
return
if structure_id in self.structures:
self.structures[structure_id].count += 1
else:
self.structures[structure_id] = Structure(pattern)
def get_sample_grammar(self):
patterns = list(self.structures.values())
probs = [p.probability for p in patterns]
return np.random.choice(patterns, p=probs)
def compute_probability(self):
for pattern in self.structures.values():
pattern.compute_probability(len(self.structures))
class Structure:
def __init__(self, pattern):
print(" ".join(pattern))
self.pattern = pattern # this should be a list of POS tags
self.count = 1
self.probability = 0
def compute_probability(self, total):
self.probability = float(self.count) / float(total)