-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_catastrophic_forgetting.py
107 lines (92 loc) · 3.79 KB
/
test_catastrophic_forgetting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from SequentialGroup import Column, Group, Node, Link
from utils.dataset import generate_dataset as _generate_dataset
from pathlib import Path
import pickle
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import matplotlib.animation as animation
from tqdm import tqdm
from utils.Profiler import Profiler
from typing import Set
from utils.draw_group import draw_group
np.random.seed(137)
sparse_rate = 1.00 # 0.05
mu = 0.2
sigma = 0.01
def generate_dataset(length: int=4, number: int=2, n_train=1000, n_test=1000, n_chars=26):
'''
Args:
length (int): the length of each subsequence
number (int): the number of subsequences
'''
from ordered_set import OrderedSet
length = int(length)
number = int(number)
assert (number >= 1) and (length >= 4)
seqs = [] # subsequences
for _ in range(number):
samples = OrderedSet([chr(i) for i in range(65, n_chars+65)])
seq = list(np.random.choice(samples, length))
seqs.append(seq)
D_train, D_test, _dictionary = _generate_dataset(seqs, n_train, n_test, n_chars=n_chars)
# _bias = ord('A')
# D_train = [ord(c)-_bias for c in D_train]
# D_test = [ord(c)-_bias for c in D_test]
return D_train, D_test, _dictionary
def test_catastrophic_forgetting(length, number, n_nodes=16, n_chars=26, observe_all=True, n_patterns=3, n_periods=5):
''''''
# Create the model
chars = [chr(i) for i in range(65,65+n_chars)]
g = Group(len(chars), n_nodes)
g.thresh = 0.8
g.p_plus = 2#1.95
g.p_minus = 0.99#0.9
g.p_minus2 = 0.001
for idx, column in enumerate(g.columns):
column.mark = idx
accuracies = []
n_anticipations = []
profiler = Profiler(100)
datasets = {}
for _ in range(n_periods):
for i_pattern in range(n_patterns):
# Prepare the dataset
root_cache = Path('./cache')
root_cache.mkdir(parents=True, exist_ok=True)
file_cache = root_cache/f"dataset_capacity-test_{length}_{number}_{n_chars}_{i_pattern}.pkl"
n_data = length*number*50
if file_cache not in datasets:
if not file_cache.exists():
D_train, D_test, _dictionary = generate_dataset(length, number, n_chars=n_chars, n_train=length*number*200)
with open(file_cache, 'wb') as f:
pickle.dump((D_train, D_test, _dictionary), f)
else:
with open(file_cache, "rb") as f:
D_train, D_test, _dictionary = pickle.load(f)
datasets[file_cache] = [D_train, n_data]
data = D_train[:n_data]
else:
D_train, cnt = datasets[file_cache]
data = D_train[cnt:cnt+n_data]
duration = length*2-2
for i, (idx, idx_next) in enumerate((tqdm(zip(data[:-1], data[1:]), total=len(data)-1))):
candidates: Set[Node] = g.activate(idx)
anticipations = set()
for candidate in candidates:
anticipations.add(candidate.column.mark)
# if len(candidates) > 0:
# candidate = max(candidates, key=lambda node: node.activity_pred)
# anticipations.add(candidate.column.mark)
if not observe_all and not (((remainder := i%duration) > 0) and (remainder < length-1)):
continue
acc = profiler.observe(anticipations, idx_next)
accuracies.append(acc)
n_anticipations.append(len(anticipations))
return accuracies, g
pair = (10, 10)
accuracies, g = test_catastrophic_forgetting(*pair, 16, observe_all=True, n_patterns=10)
plt.figure()
plt.plot(accuracies)
# draw_group(g)
plt.show()