Skip to content

Commit 13e8181

Browse files
avikchaudhuripytorchmergebot
authored andcommitted
relax assertion on fake shape (#121599)
Summary: Seems like if you use `capture_pre_autograd_graph` fake tensor shapes can be ints instead of symints. Test Plan: fixes the AssertionError in N5057219 Differential Revision: D54729142 Pull Request resolved: #121599 Approved by: https://github.com/angelayi, https://github.com/BoyuanFeng
1 parent 660ec3d commit 13e8181

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

torch/fx/experimental/symbolic_shapes.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -2930,8 +2930,11 @@ def is_dim(src):
29302930
def get_expression(tensor_dim_src):
29312931
fake = placeholders[source_index[tensor_dim_src.base.name()]]
29322932
symint = fake.shape[tensor_dim_src.idx]
2933-
assert isinstance(symint, torch.SymInt)
2934-
return symint.node.expr
2933+
if isinstance(symint, torch.SymInt):
2934+
return symint.node.expr
2935+
else:
2936+
assert type(symint) is int, f"Expected int, got {type(symint)}"
2937+
return symint
29352938

29362939
for src1, src2 in equalities_inputs.source_pairs:
29372940
expr1, expr2 = get_expression(src1), get_expression(src2)

0 commit comments

Comments
 (0)