Skip to content

Commit

Permalink
Merge branch 'main' into python_types
Browse files Browse the repository at this point in the history
  • Loading branch information
markusschmaus authored Sep 23, 2022
2 parents 2db7b43 + 9a85dbc commit 4e5add1
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 30 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ exclude: |
)$
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.1.0
rev: v4.3.0
hooks:
- id: debug-statements
exclude: |
Expand All @@ -20,16 +20,16 @@ repos:
)$
- id: check-merge-conflict
- repo: https://github.com/psf/black
rev: 22.3.0
rev: 22.8.0
hooks:
- id: black
language_version: python3
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.4
rev: 3.9.2
hooks:
- id: flake8
- repo: https://github.com/pycqa/isort
rev: 5.6.4
rev: 5.10.1
hooks:
- id: isort
- repo: https://github.com/humitos/mirrors-autoflake.git
Expand All @@ -47,7 +47,7 @@ repos:
)$
args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable']
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.961
rev: v0.971
hooks:
- id: mypy
additional_dependencies:
Expand Down
43 changes: 19 additions & 24 deletions doc/library/scan.rst
Original file line number Diff line number Diff line change
Expand Up @@ -254,40 +254,35 @@ Another useful feature of scan, is that it can handle shared variables.
For example, if we want to implement a Gibbs chain of length 10 we would do
the following:

.. testsetup:: scan1

import aesara
import numpy
W_values = numpy.random.random((2, 2))
bvis_values = numpy.random.random((2,))
bhid_values = numpy.random.random((2,))

.. testcode:: scan1

import aesara
from aesara import tensor as at
import aesara
import aesara.tensor as at
import numpy as np

W = aesara.shared(W_values) # we assume that ``W_values`` contains the
# initial values of your weight matrix
rng = np.random.default_rng(203940)
W_values = rng.uniform(size=(2, 2))
bvis_values = rng.uniform(size=(2,))
bhid_values = rng.uniform(size=(2,))

bvis = aesara.shared(bvis_values)
bhid = aesara.shared(bhid_values)
W = aesara.shared(W_values)
bvis = aesara.shared(bvis_values)
bhid = aesara.shared(bhid_values)

trng = aesara.tensor.random.utils.RandomStream(1234)
srng = at.random.RandomStream(1234)

def OneStep(vsample) :
hmean = at.sigmoid(aesara.dot(vsample, W) + bhid)
hsample = trng.binomial(size=hmean.shape, n=1, p=hmean)
vmean = at.sigmoid(aesara.dot(hsample, W.T) + bvis)
return trng.binomial(size=vsample.shape, n=1, p=vmean,
dtype=aesara.config.floatX)
def one_step(vsample):
hmean = at.sigmoid(at.dot(vsample, W) + bhid)
hsample = srng.binomial(1, hmean, size=hmean.shape)
vmean = at.sigmoid(at.dot(hsample, W.T) + bvis)

sample = aesara.tensor.vector()
return srng.binomial(1, vmean, size=vsample.shape)

values, updates = aesara.scan(OneStep, outputs_info=sample, n_steps=10)
sample = at.lvector()

gibbs10 = aesara.function([sample], values[-1], updates=updates)
values, updates = aesara.scan(one_step, outputs_info=sample, n_steps=10)

gibbs10 = aesara.function([sample], values[-1], updates=updates)

The first, and probably most crucial observation is that the updates
dictionary becomes important in this case. It links a shared variable
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ lines_after_imports = 2
lines_between_sections = 1
honor_noqa = True
skip_gitignore = True
skip = aesara/version.py, **/__init__.py
skip = aesara/version.py
skip_glob = **/*.pyx

[mypy]
ignore_missing_imports = True
Expand Down

0 comments on commit 4e5add1

Please sign in to comment.