Skip to content

Commit 1dd6322

Browse files
committed
Address pr comments
1 parent 8baf721 commit 1dd6322

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,11 @@ def elemwise_fn(*inputs):
2121

2222
def elemwise_fn(*inputs):
2323
Elemwise._check_runtime_broadcast(node, inputs)
24-
shaped_inputs = torch.broadcast_tensors(*inputs)
24+
broadcast_inputs = torch.broadcast_tensors(*inputs)
2525
ufunc = base_fn
26-
for _ in range(shaped_inputs[0].dim()):
26+
for _ in range(broadcast_inputs[0].dim()):
2727
ufunc = torch.vmap(ufunc)
28-
# @todo: This will fail for anything that calls
29-
# `.item()`
30-
return ufunc(*shaped_inputs)
28+
return ufunc(*broadcast_inputs)
3129

3230
return elemwise_fn
3331

0 commit comments

Comments
 (0)