@@ -211,6 +211,8 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
211
211
sequences = [at_C , at_D ],
212
212
outputs_info = [st0 , et0 , it0 , logp_c , logp_d ],
213
213
non_sequences = [beta , gamma , delta ],
214
+ # multi-output Elemwise not supported in NUMBA
215
+ mode = get_mode ("NUMBA" ).excluding ("fusion" ),
214
216
)
215
217
st .name = "S_t"
216
218
et .name = "E_t"
@@ -321,6 +323,8 @@ def power_of_2(previous_power, max_value):
321
323
outputs_info = at .constant (1.0 ),
322
324
non_sequences = max_value ,
323
325
n_steps = 1024 ,
326
+ # multi-output Elemwise not supported in NUMBA
327
+ mode = get_mode ("NUMBA" ).excluding ("fusion" ),
324
328
)
325
329
326
330
out_fg = FunctionGraph ([max_value ], [values ])
@@ -370,6 +374,8 @@ def f_pow2(x_tm2, x_tm1):
370
374
state_val = np .array ([1.0 , 2.0 ])
371
375
372
376
numba_mode = get_mode ("NUMBA" ).including ("scan_save_mem" )
377
+ # multi-output Elemwise not supported in NUMBA
378
+ numba_mode = numba_mode .excluding ("fusion" )
373
379
py_mode = Mode ("py" ).including ("scan_save_mem" )
374
380
375
381
out_fg = FunctionGraph ([init_x , n_steps ], [output ])
@@ -409,6 +415,8 @@ def inner_fct(seq, state_old, state_current):
409
415
g_outs = grad (out .sum (), [seq , init_x ])
410
416
411
417
numba_mode = get_mode ("NUMBA" ).including ("scan_save_mem" )
418
+ # multi-output Elemwise not supported in NUMBA
419
+ numba_mode = numba_mode .excluding ("fusion" )
412
420
py_mode = Mode ("py" ).including ("scan_save_mem" )
413
421
414
422
out_fg = FunctionGraph ([seq , init_x ], g_outs )
0 commit comments