-
Notifications
You must be signed in to change notification settings - Fork 10
/
5. recursive.py
29 lines (24 loc) · 1.31 KB
/
5. recursive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from tree import *
car_data = [['med', 'low', '3', '4', 'med', 'med'], ['med', 'vhigh', '4', 'more', 'small', 'high'], ['high', 'med', '3', '2', 'med', 'low'], ['med', 'low', '4', '4', 'med', 'low'], ['med', 'low', '5more', '2', 'big', 'med'], ['med', 'med', '2', 'more', 'big', 'high'], ['med', 'med', '2', 'more', 'med', 'med'], ['vhigh', 'vhigh', '2', '2', 'med', 'low'], ['high', 'med', '4', '2', 'big', 'low'], ['low', 'low', '2', '4', 'big', 'med']]
car_labels = ['acc', 'acc', 'unacc', 'unacc', 'unacc', 'vgood', 'acc', 'unacc', 'unacc', 'good']
def find_best_split(dataset, labels):
best_gain = 0
best_feature = 0
for feature in range(len(dataset[0])):
data_subsets, label_subsets = split(dataset, labels, feature)
gain = information_gain(labels, label_subsets)
if gain > best_gain:
best_gain, best_feature = gain, feature
return best_feature, best_gain
def build_tree(data, labels):
best_feature, best_gain = find_best_split(data, labels)
if best_gain == 0:
return Counter(labels)
data_subsets, label_subsets = split(data, labels, best_feature)
branches = []
for i in range(len(data_subsets)):
branch = build_tree(data_subsets[i], label_subsets[i])
branches.append(branch)
return branches
tree = build_tree(car_data, car_labels)
print_tree(tree)