Skip to content

Commit

Permalink
Merge pull request #1701 from PrincetonUniversity/devel
Browse files Browse the repository at this point in the history
Composition: set variable of TD PredictionErrorMechanism as in System
  • Loading branch information
dillontsmith authored Jul 9, 2020
2 parents eb2420c + 4a96b42 commit 627974b
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 24 deletions.
18 changes: 16 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ arch:
- amd64
- arm64
- ppc64le
- s390x

stages:
- precache
Expand Down Expand Up @@ -51,8 +52,21 @@ jobs:
stage: precache
env: PYTHON=3.6
arch: ppc64le
allow_failures:
- arch: ppc64le
- script: true
after_script: true
stage: precache
env: PYTHON=3.8
arch: s390x
- script: true
after_script: true
stage: precache
env: PYTHON=3.7
arch: s390x
- script: true
after_script: true
stage: precache
env: PYTHON=3.6
arch: s390x

env:
jobs:
Expand Down
4 changes: 2 additions & 2 deletions psyneulink/core/compositions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6409,9 +6409,9 @@ def _create_td_related_mechanisms(self,

objective_mechanism = PredictionErrorMechanism(name='PredictionError',
sample={NAME: SAMPLE,
VARIABLE: output_source.defaults.value},
VARIABLE: np.zeros_like(output_source.output_ports[0].defaults.value)},
target={NAME: TARGET,
VARIABLE: output_source.defaults.value},
VARIABLE: np.zeros_like(output_source.output_ports[0].defaults.value)},
function=PredictionErrorDeltaFunction(gamma=1.0))

learning_mechanism = LearningMechanism(function=TDLearning(learning_rate=learning_rate),
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pytest_plugins = ['pytest_profiling', 'helpers_namespace', 'benchmark']
xfail_strict = True

filterwarnings =
error::numpy.VisibleDeprecationWarning
error:Creating an ndarray from ragged nested sequences \(which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes\) is deprecated.*:numpy.VisibleDeprecationWarning

[pycodestyle]
# for code explanation see https://pep8.readthedocs.io/en/latest/intro.html#error-codes
Expand Down
48 changes: 29 additions & 19 deletions tests/composition/test_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,18 +1379,24 @@ def test_td_montague_et_al_figure_a(self):
delta_vals = comparator_mechanism.log.nparray_dictionary()['TD_Learning'][pnl.VALUE]

trial_1_expected = [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.003, 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., -0.003, 0.]
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]


trial_30_expected = [0.] * 40
trial_30_expected +=[.0682143186, .0640966042, .0994344173, .133236921, .152270799, .145592903, .113949692,
.0734420009, .0450652924, .0357386468, .0330810871, .0238007805, .0102892090, -.998098988,
-.0000773996815, -.0000277845011, -.00000720338916, -.00000120056486, -.0000000965971727, 0.]
trial_30_expected += [
0.06521536244675225, 0.0640993870383315, 0.09944290863181729, 0.13325956499595726, 0.15232363406006394,
0.14570077419644378, 0.11414216814982991, 0.07374140787058237, 0.04546975436471501, 0.036210519138262454,
0.03355295938927161, 0.024201157062338496, 0.010573534379529015, -0.9979331317238949, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0
]
trial_50_expected = [0.] * 40
trial_50_expected += [.717416347, .0816522429, .0595516548, .0379308899, .0193587853, .00686581694,
.00351883747, .00902310583, .0149133617, .000263272179, -.0407611997, -.0360124387,
.0539085146, .0723714910, -.000000550934336, -.000000111783778, -.0000000166486478,
-.00000000161861854, -.0000000000770770722, 0.]
trial_50_expected += [
0.7149863408177357, 0.08193033235388536, 0.05988592388364977, 0.03829793050401187, 0.01972582584273075,
0.007198872281648616, 0.0037918828476545263, 0.009224297157983563, 0.015045769646998886,
0.00034051016062952577, -0.040721638768680624, -0.03599485605332753, 0.0539151932684796,
0.07237361605659998, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
]

assert np.allclose(trial_1_expected, delta_vals[0][0])
assert np.allclose(trial_30_expected, delta_vals[29][0])
Expand Down Expand Up @@ -1502,14 +1508,16 @@ def test_td_enabled_learning_false(self):
delta_vals = comparator_mechanism.log.nparray_dictionary()['TD_Learning'][pnl.VALUE]

trial_1_expected = [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.003, 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., -0.003, 0.]
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]

trial_30_expected = [0.] * 40
trial_30_expected +=[.0682143186, .0640966042, .0994344173, .133236921, .152270799, .145592903, .113949692,
.0734420009, .0450652924, .0357386468, .0330810871, .0238007805, .0102892090, -.998098988,
-.0000773996815, -.0000277845011, -.00000720338916, -.00000120056486, -.0000000965971727, 0.]

trial_30_expected += [
0.06521536244675225, 0.0640993870383315, 0.09944290863181729, 0.13325956499595726, 0.15232363406006394,
0.14570077419644378, 0.11414216814982991, 0.07374140787058237, 0.04546975436471501, 0.036210519138262454,
0.03355295938927161, 0.024201157062338496, 0.010573534379529015, -0.9979331317238949, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0
]

assert np.allclose(trial_1_expected, delta_vals[0][0])
assert np.allclose(trial_30_expected, delta_vals[29][0])
Expand All @@ -1522,10 +1530,12 @@ def test_td_enabled_learning_false(self):
delta_vals = comparator_mechanism.log.nparray_dictionary()['TD_Learning'][pnl.VALUE]

trial_50_expected = [0.] * 40
trial_50_expected += [.717416347, .0816522429, .0595516548, .0379308899, .0193587853, .00686581694,
.00351883747, .00902310583, .0149133617, .000263272179, -.0407611997, -.0360124387,
.0539085146, .0723714910, -.000000550934336, -.000000111783778, -.0000000166486478,
-.00000000161861854, -.0000000000770770722, 0.]
trial_50_expected += [
0.7149863408177357, 0.08193033235388536, 0.05988592388364977, 0.03829793050401187, 0.01972582584273075,
0.007198872281648616, 0.0037918828476545263, 0.009224297157983563, 0.015045769646998886,
0.00034051016062952577, -0.040721638768680624, -0.03599485605332753, 0.0539151932684796,
0.07237361605659998, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
]

assert np.allclose(trial_50_expected, delta_vals[49][0])

Expand Down

0 comments on commit 627974b

Please sign in to comment.