We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 007abae commit 0905becCopy full SHA for 0905bec
pytensor/link/pytorch/dispatch/scalar.py
@@ -55,7 +55,7 @@ def cast(x):
55
56
@pytorch_funcify.register(ScalarLoop)
57
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
58
- update = pytorch_funcify(op.fgraph)
+ update = pytorch_funcify(op.fgraph, **kwargs)
59
state_length = op.nout
60
if op.is_while:
61
@@ -84,4 +84,4 @@ def scalar_loop(steps, *start_and_constants):
84
else:
85
return carry
86
87
- return torch.compiler.disable(scalar_loop, recursive=False)
+ return scalar_loop
0 commit comments