Skip to content

Commit 2ce8ae9

Browse files
committed
Add cast in numba elemwise between func type and output type
1 parent 02b616c commit 2ce8ae9

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

pytensor/link/numba/dispatch/elemwise_codegen.py

+3
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,10 @@ def extract_array(aryty, obj):
188188

189189
if isinstance(scalar_signature.return_type, (types.Tuple, types.UniTuple)):
190190
output_values = cgutils.unpack_tuple(builder, output_values)
191+
func_output_types = scalar_signature.return_type.types
191192
else:
192193
output_values = [output_values]
194+
func_output_types = [scalar_signature.return_type]
193195

194196
# Update output value or accumulators respectively
195197
for i, ((accu, _), value) in enumerate(zip(output_accumulator, output_values)):
@@ -206,6 +208,7 @@ def extract_array(aryty, obj):
206208
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, output_bc[i])]
207209
ptr = cgutils.get_item_pointer2(context, builder, *outputs[i], idxs_bc)
208210
# store = builder.store(value, ptr)
211+
value = context.cast(builder, value, func_output_types[i], output_types[i].dtype)
209212
arrayobj.store_item(context, builder, output_types[i], value, ptr)
210213
# store.set_metadata("alias.scope", output_scope_set)
211214
# store.set_metadata("noalias", input_scope_set)

0 commit comments

Comments
 (0)