-
Notifications
You must be signed in to change notification settings - Fork 22
/
Quantisations.py
116 lines (96 loc) · 5.41 KB
/
Quantisations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
from Helper_Functions import mean
from sklearn import svm
from copy import deepcopy
class SVMDecisionTreeNode:
def __init__(self,id):
self.id = id
self.has_children = False
def get_node(self,vector):
if self.has_children:
return self._choose_child(vector).get_node(vector)
return self
def _choose_child(self,vector):
if self.is_dim_split:
return self._dim_choose_child(vector)
childnum = self.clf.predict([vector]).tolist()[0]
if childnum == 0:
return self.zero_child
return self.one_child
def _dim_choose_child(self,vector):
if vector[self.split_dim] > self.split_val:
return self.high
return self.low
def _dim_split_aux(self, split_tuples, split_vals, new_id, split_depth):
if split_depth == 0 or len(split_tuples)==0:
return new_id
split_tuples = deepcopy(split_tuples) #else they're all popping from the same place and run out of dimensions to split, not to mention split asymmetrically!
margin, self.split_dim = split_tuples.pop(0)
self.split_val = split_vals[self.split_dim]
self.high = SVMDecisionTreeNode(self.id)
self.low = SVMDecisionTreeNode(new_id)
new_id += 1 # next node will need to have the next id now that low has taken the one we had
self.has_children = True
self.is_dim_split = True
new_id = self.high._dim_split_aux(split_tuples,split_vals,new_id,split_depth-1)
new_id = self.low._dim_split_aux(split_tuples,split_vals,new_id,split_depth-1)
return new_id
# reminder: this function (dim_split) and the next (split) are called only by the overall SVMDecisionTree class,
# and won't be called on internal nodes of the tree (only on leaves, which represent actual clusters in the partitioning)
def dim_split(self,agreeing_continuous_visitors,conflicted_continuous_visitor,new_id,split_depth):
# print("making initial split of depth " + str(split_depth))
mean_agreeing_vector = []
for i in range(len(conflicted_continuous_visitor)):
mean_agreeing_vector.append(mean([visitor[i] for visitor in agreeing_continuous_visitors]))
margins = [abs(m-v) for m,v in zip(mean_agreeing_vector,conflicted_continuous_visitor)]
numbered_margins_by_largest = sorted([(margin,i) for i,margin in enumerate(margins)],reverse=True)
split_vals = [(a+b)/2.0 for a,b in zip(mean_agreeing_vector,conflicted_continuous_visitor)]
return self._dim_split_aux(numbered_margins_by_largest,split_vals,new_id,split_depth)
def split(self,agreeing_continuous_visitors,conflicted_continuous_visitor,new_id):
# print("trying regular svm split")
x = agreeing_continuous_visitors + [conflicted_continuous_visitor]
y = [0]*len(agreeing_continuous_visitors) + [1]
self.clf = svm.SVC(C=10000)
self.clf.fit(x,y)
# print("clf used this many support vectors: " + str(self.clf.n_support_))
self.zero_child = SVMDecisionTreeNode(self.id)
self.one_child = SVMDecisionTreeNode(new_id)
new_id += 1
self.has_children = True
self.is_dim_split = False
if not self.clf.predict(x).tolist() == y:
print("svm classifier failed to obtain perfect split :(")
return new_id
class SVMDecisionTreeQuantisation:
def __init__(self,num_dims_initial_split):
self.num_dims_initial_split = num_dims_initial_split
self.top_id = 1 #1-index so it's also a neat count of how many id's we have in general
self.head = SVMDecisionTreeNode(self.top_id)
self.had_initial_refine = False
self.initiated_with_all_rnn_states_to_some_depth = False
self.refinement_doesnt_hurt_other_clusters = True
# this is a trait of Decision Tree refinements: they affect only the cluster being refined,
# all the rest remain exactly the same (as opposed to, for instance, splitting a dimension
# across the board). If you wish to implement a different quantisation, think about whether
# yours satisfies this quality and fill this field appropriately)
pass
def _get_node(self,vector):
if not self.had_initial_refine:
return self.head
if self.initiated_with_all_rnn_states_to_some_depth:
return self.nodes[self.clf.predict([vector])[0]].get_node(vector)
return self.head.get_node(vector)
def get_partition(self, vector):
return self._get_node(vector).id
def refine(self,agreeing_continuous_visitors,conflicted_continuous_visitor):
# print("refining, H size is " + str(len(agreeing_continuous_visitors)))
relevant_node = self._get_node(conflicted_continuous_visitor)
next_id = relevant_node.split(agreeing_continuous_visitors,conflicted_continuous_visitor,self.top_id+1) if \
self.had_initial_refine else relevant_node.dim_split(agreeing_continuous_visitors,
conflicted_continuous_visitor,
self.top_id+1,
self.num_dims_initial_split)
self.refined_something = True
# print("refining - added "+str((next_id-1)-self.top_id)+" states")
self.top_id = next_id-1
self.had_initial_refine = True