Skip to content

Commit

Permalink
Temporarily exclude fusion rewrite from Numba Scan tests
Browse files Browse the repository at this point in the history
Otherwise they fail due to lack of support for multi-output Elemwises in the Numba backend
  • Loading branch information
Ricardo Vieira committed Oct 27, 2022
1 parent d981df8 commit 1302b49
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tests/link/numba/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ def f_pow2(x_tm2, x_tm1):
state_val = np.array([1.0, 2.0])

numba_mode = get_mode("NUMBA").including("scan_save_mem")
# multi-output Elemwise not supported in NUMBA
numba_mode = numba_mode.excluding("fusion")
py_mode = Mode("py").including("scan_save_mem")

out_fg = FunctionGraph([init_x, n_steps], [output])
Expand Down Expand Up @@ -406,6 +408,8 @@ def inner_fct(seq, state_old, state_current):
g_outs = grad(out.sum(), [seq, init_x])

numba_mode = get_mode("NUMBA").including("scan_save_mem")
# multi-output Elemwise not supported in NUMBA
numba_mode = numba_mode.excluding("fusion")
py_mode = Mode("py").including("scan_save_mem")

out_fg = FunctionGraph([seq, init_x], g_outs)
Expand Down

0 comments on commit 1302b49

Please sign in to comment.