-
Notifications
You must be signed in to change notification settings - Fork 2
/
s3-4_IZH.py
107 lines (86 loc) · 2.76 KB
/
s3-4_IZH.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
ゼロから学ぶスパイキングニューラルネットワーク
- Spiking Neural Networks from Scratch
Copyright (c) 2020 HiroshiARAKI. All Rights Reserved.
"""
import numpy as np
import matplotlib.pyplot as plt
class Izhikevich:
def __init__(self, a, b, c, d):
"""
Izhikevich neuron model
:param a: uのスケーリング係数
:param b: vに対するuの感受性度合い
:param c: 静止膜電位
:param d: 発火後の膜電位が落ち着くまでを司る係数
"""
self.a = a
self.b = b
self.c = c
self.d = d
def calc(self, inputs, time=300, dt=0.5, tci=10):
"""
膜電位(Membrane potential) v と回復変数(Recovery variable) u を計算する
:param inputs:
:param weights:
:param time:
:param dt:
:param tci:
:return:
"""
v = self.c
u = self.d
i = 0
monitor = {'v': [], 'u': []}
for t in range(int(time / dt)):
# uを計算
du = self.a * (self.b * v - u)
u += du * dt
monitor['u'].append(u)
# vを計算
dv = 0.04 * v ** 2 + 5 * v + 140 - u + inputs[t]
v += dv * dt
monitor['v'].append(v)
# 発火処理
if v >= 30:
v = self.c
u += self.d
return monitor
if __name__ == '__main__':
time = 300 # 実験時間 (観測時間)
dt = 0.125 # 時間分解能
pre = 50 # 前ニューロンの数
t = np.arange(0, time, dt)
# 入力データ (面倒臭いので適当な矩形波とノイズを合成して作った)
input_data = np.sin(0.5 * np.arange(0, time, dt))
input_data = np.where(input_data > 0, 20, 0) + 10 * np.random.rand(int(time/dt))
input_data_2 = np.cos(0.4 * np.arange(0, time, dt) + 0.5)
input_data_2 = np.where(input_data_2 > 0, 10, 0)
input_data += input_data_2
# Izhikevichニューロンの生成 (今回はRegular Spiking Neuronのパラメータ)
neuron = Izhikevich(
a=0.02,
b=0.2,
c=-65,
d=8
)
history = neuron.calc(input_data, time=time, dt=dt)
# 結果の描画
plt.figure(figsize=(10, 4))
# 入力データ
plt.subplot(3, 1, 1)
plt.plot(t, input_data)
plt.xlim(0, time)
plt.ylim(-1, pre)
plt.ylabel('Input current')
# 膜電位
# plt.subplot(3, 1, 2)
plt.plot(t, history['v'], label=f'a=0.2, b=2, c=-56, d=-16, I(t)=-99')
plt.ylabel('Membrane potential $v$ [mV]')
# 膜電位
plt.subplot(3, 1, 3)
plt.plot(t, history['u'], c='tab:orange')
plt.xlabel('time [ms]')
plt.ylabel('Recovery variable $u$')
plt.legend()
plt.show()