Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement at.eye using existing Ops #1217

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6617b69
added Ops for solve_discrete_lyapunov and solve_continuous_lyapunov
jessegrabowski Jun 29, 2022
34acf8f
ran pre-commit
jessegrabowski Jun 29, 2022
f28cee4
add `method` to `SolveDiscreteLyapunov` `__props__`
jessegrabowski Jun 30, 2022
10472ba
Remove shape checks from `SolveDiscreteLyapunov` and `SolveContinuous…
jessegrabowski Jun 30, 2022
0ec6c10
Merge branch 'solve_lyapunov' of github.com:jessegrabowski/aesara int…
jessegrabowski Jun 30, 2022
bf445ab
Update signature of `solve_discrete_lyapunov` and `solve_continuous_l…
jessegrabowski Jun 30, 2022
8a9da62
Update signature of `solve_discrete_lyapunov` and `solve_continuous_l…
jessegrabowski Jun 30, 2022
1738058
Delete a scratchpad file
jessegrabowski Jun 30, 2022
5c2442b
Add a direct aesara solution via `at.solve` for `solve_discrete_lyapu…
jessegrabowski Sep 27, 2022
e0b5d2e
add `method` to `SolveDiscreteLyapunov` `__props__`
jessegrabowski Jun 30, 2022
a14303a
Update signature of `solve_discrete_lyapunov` and `solve_continuous_l…
jessegrabowski Jun 30, 2022
4c095b6
Update signature of `solve_discrete_lyapunov` and `solve_continuous_l…
jessegrabowski Jun 30, 2022
810f5d9
Delete a scratchpad file
jessegrabowski Jun 30, 2022
837ec91
Add a direct aesara solution via `at.solve` for `solve_discrete_lyapu…
jessegrabowski Sep 27, 2022
2487597
Merge branch 'solve_lyapunov' of github.com:jessegrabowski/aesara int…
jessegrabowski Sep 27, 2022
2861abd
Rewrite function `eye` using `at.zeros` and `at.set_subtensor`
jessegrabowski Sep 27, 2022
4ac63b3
remove changes unrelated to eye
jessegrabowski Sep 27, 2022
af2dc7d
remove changes unrelated to eye
jessegrabowski Sep 27, 2022
94ce70d
remove changes unrelated to eye
jessegrabowski Sep 27, 2022
29991e3
remove changes unrelated to eye
jessegrabowski Sep 27, 2022
b4d9b7e
Merge branch 'main' into aesara_native_eye
jessegrabowski Sep 27, 2022
4c2a622
reverted unnecessary changes to `test_eye`
jessegrabowski Sep 27, 2022
b823eec
Merge branch 'aesara_native_eye' of github.com:jessegrabowski/aesara …
jessegrabowski Sep 27, 2022
680179e
Merge branch 'aesara-devs:main' into aesara_native_eye
jessegrabowski Oct 1, 2022
5ece747
Add typing information to arguments in function signature
jessegrabowski Oct 1, 2022
741a73c
Merge branch 'aesara_native_eye' of github.com:jessegrabowski/aesara …
jessegrabowski Oct 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,7 +1284,7 @@ def grad(self, inp, grads):
return [grad_undefined(self, i, inp[i]) for i in range(3)]


def eye(n, m=None, k=0, dtype=None):
def eye(n: int, m: int = None, k: int = 0, dtype=None) -> TensorVariable:
"""Return a 2-D array with ones on the diagonal and zeros elsewhere.

Parameters
Expand All @@ -1302,17 +1302,41 @@ def eye(n, m=None, k=0, dtype=None):

Returns
-------
ndarray of shape (N,M)
An array where all elements are equal to zero, except for the `k`-th
diagonal, whose values are equal to one.

aesara tensor of shape (N,M)
A symbolic tensor representing a matrix where all elements are equal to zero,
except for the `k`-th diagonal, whose values are equal to one.
brandonwillard marked this conversation as resolved.
Show resolved Hide resolved
"""
if dtype is None:
dtype = config.floatX

if m is None:
m = n
localop = Eye(dtype)
return localop(n, m, k)
if dtype is None:
dtype = aesara.config.floatX

n = aesara.scalar.as_scalar(n)
m = aesara.scalar.as_scalar(m)
k = aesara.scalar.as_scalar(k)

i = aesara.scalar.switch(k >= 0, k, -k * m)
i_comp_op = aesara.scalar.Composite([n, m, k], [i])
i_comp = i_comp_op(n, m, k)

mkm = (m - k) * m
mkm_comp_op = aesara.scalar.Composite([m, k], [mkm])
mkm_comp = mkm_comp_op(m, k)

last_row = aesara.scalar.switch(m - k > 0, m - k, 0)
last_row_op = aesara.scalar.Composite([m, k], [last_row])
last_valid_row = last_row_op(m, k)

eye = zeros(n * m, dtype=dtype)

ones_slice = slice(i_comp, mkm_comp, m + 1)
overflow_rows = slice(last_valid_row, None, None)

eye = aesara.tensor.subtensor.set_subtensor(eye[ones_slice], 1).reshape((n, m))
eye = aesara.tensor.subtensor.set_subtensor(eye[overflow_rows, :], 0)

return eye


def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
Expand Down
3 changes: 3 additions & 0 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,11 +822,14 @@ def check(dtype, N, M_=None, k=0):
# allowed.
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
M = N

N_symb = iscalar()
M_symb = iscalar()
k_symb = iscalar()

f = function([N_symb, M_symb, k_symb], eye(N_symb, M_symb, k_symb, dtype=dtype))
result = f(N, M, k)

assert np.allclose(result, np.eye(N, M_, k, dtype=dtype))
assert result.dtype == np.dtype(dtype)

Expand Down