Skip to content

Commit e889d2c

Browse files
committed
Temporarily exclude fusion rewrite from Numba Scan tests
Otherwise they fail due to lack of support for multi-output Elemwises in the Numba backend
1 parent 5f07241 commit e889d2c

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

tests/link/numba/test_scan.py

+8
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
211211
sequences=[at_C, at_D],
212212
outputs_info=[st0, et0, it0, logp_c, logp_d],
213213
non_sequences=[beta, gamma, delta],
214+
# multi-output Elemwise not supported in NUMBA
215+
mode=get_mode("NUMBA").excluding("fusion"),
214216
)
215217
st.name = "S_t"
216218
et.name = "E_t"
@@ -321,6 +323,8 @@ def power_of_2(previous_power, max_value):
321323
outputs_info=at.constant(1.0),
322324
non_sequences=max_value,
323325
n_steps=1024,
326+
# multi-output Elemwise not supported in NUMBA
327+
mode=get_mode("NUMBA").excluding("fusion"),
324328
)
325329

326330
out_fg = FunctionGraph([max_value], [values])
@@ -370,6 +374,8 @@ def f_pow2(x_tm2, x_tm1):
370374
state_val = np.array([1.0, 2.0])
371375

372376
numba_mode = get_mode("NUMBA").including("scan_save_mem")
377+
# multi-output Elemwise not supported in NUMBA
378+
numba_mode = numba_mode.excluding("fusion")
373379
py_mode = Mode("py").including("scan_save_mem")
374380

375381
out_fg = FunctionGraph([init_x, n_steps], [output])
@@ -409,6 +415,8 @@ def inner_fct(seq, state_old, state_current):
409415
g_outs = grad(out.sum(), [seq, init_x])
410416

411417
numba_mode = get_mode("NUMBA").including("scan_save_mem")
418+
# multi-output Elemwise not supported in NUMBA
419+
numba_mode = numba_mode.excluding("fusion")
412420
py_mode = Mode("py").including("scan_save_mem")
413421

414422
out_fg = FunctionGraph([seq, init_x], g_outs)

0 commit comments

Comments
 (0)