We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 8baf721 commit 1dd6322Copy full SHA for 1dd6322
pytensor/link/pytorch/dispatch/elemwise.py
@@ -21,13 +21,11 @@ def elemwise_fn(*inputs):
21
22
def elemwise_fn(*inputs):
23
Elemwise._check_runtime_broadcast(node, inputs)
24
- shaped_inputs = torch.broadcast_tensors(*inputs)
+ broadcast_inputs = torch.broadcast_tensors(*inputs)
25
ufunc = base_fn
26
- for _ in range(shaped_inputs[0].dim()):
+ for _ in range(broadcast_inputs[0].dim()):
27
ufunc = torch.vmap(ufunc)
28
- # @todo: This will fail for anything that calls
29
- # `.item()`
30
- return ufunc(*shaped_inputs)
+ return ufunc(*broadcast_inputs)
31
32
return elemwise_fn
33
0 commit comments