Skip to content

Commit b59a0c8

Browse files
committed
Fix some floatX issues
1 parent 0c2a2b0 commit b59a0c8

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

pytensor/link/numba/dispatch/elemwise.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -708,17 +708,19 @@ def numba_funcify_Sum(op, node, **kwargs):
708708

709709
np_acc_dtype = np.dtype(acc_dtype)
710710

711+
out_dtype = np.dtype(node.outputs[0].dtype)
712+
711713
if ndim_input == len(axes):
712714

713715
@numba_njit(fastmath=True)
714716
def impl_sum(array):
715-
return np.asarray(array.sum(), dtype=np_acc_dtype)
717+
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
716718

717719
elif len(axes) == 0:
718720

719721
@numba_njit(fastmath=True)
720722
def impl_sum(array):
721-
return array
723+
return np.asarray(array, dtype=out_dtype)
722724

723725
else:
724726
impl_sum = numba_funcify_CAReduce(op, node, **kwargs)

tests/link/numba/test_scalar.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_Clip(v, min, max):
9797
],
9898
)
9999
def test_Composite(inputs, input_values, scalar_fn):
100-
composite_inputs = [aes.float64(i.name) for i in inputs]
100+
composite_inputs = [aes.ScalarType(config.floatX)(name=i.name) for i in inputs]
101101
comp_op = Elemwise(Composite(composite_inputs, [scalar_fn(*composite_inputs)]))
102102
out_fg = FunctionGraph(inputs, [comp_op(*inputs)])
103103
compare_numba_and_py(out_fg, input_values)

0 commit comments

Comments
 (0)