Skip to content

Commit c6d79c1

Browse files
pianpwkpytorchmergebot
authored andcommitted
[dynamic shapes] allow duck typing for 0/1 (pytorch#150222)
Fixes pytorch#150184 e.g. for config.backed_size_oblivious=True and compile Pull Request resolved: pytorch#150222 Approved by: https://github.com/laithsakka
1 parent 7df6f93 commit c6d79c1

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

test/dynamo/test_misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10211,8 +10211,8 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self):
1021110211
> Left: {44, 93}
1021210212
> Right: {}
1021310213
==> val_to_var: values don't match.
10214-
> Left: {0: 0, 1: 1, 2: s44, 3: s93}
10215-
> Right: {0: 0, 1: 1}
10214+
> Left: {2: s44, 3: s93}
10215+
> Right: {}
1021610216
==> var_to_range: values don't match.
1021710217
> Left: {s44: VR[2, int_oo], s93: VR[2, int_oo]}
1021810218
> Right: {}

test/test_dynamic_shapes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,6 +1329,22 @@ def test_tensor_factory_with_symint(self):
13291329
res = Tensor(sym_args)
13301330
self.assertEqual(res, expected, exact_dtype=False)
13311331

1332+
def test_backed_size_oblivious_01_spec(self):
1333+
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1334+
1335+
@torch.compile(dynamic=True, fullgraph=True)
1336+
def f(a, b):
1337+
if guard_size_oblivious(a.size(0) == 1):
1338+
return b * 10
1339+
else:
1340+
return b * 20
1341+
1342+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
1343+
# always go to the >= 2 branch.
1344+
self.assertEqual(
1345+
f(torch.tensor([1]), torch.tensor([1])), torch.tensor([20])
1346+
)
1347+
13321348

13331349
@skipIfTorchDynamo(
13341350
"Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"

torch/fx/experimental/symbolic_shapes.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3330,8 +3330,6 @@ def _init(
33303330
# Duck-shaping says that if two input tensors have the same size,
33313331
# they get assigned the same symbolic variable
33323332
self.val_to_var: dict[int, sympy.Symbol] = {}
3333-
if specialize_zero_one:
3334-
self.val_to_var = {0: sympy.S.Zero, 1: sympy.S.One}
33353333
self.unbacked_symfloat_counter = itertools.count()
33363334
self.unbacked_symint_counter = itertools.count()
33373335
# Similar to guards, but these MUST evaluate to true and can
@@ -4541,7 +4539,10 @@ def create_symbol(
45414539
sloc = self._get_sloc()
45424540

45434541
if val in (0, 1) and specialize_zero_one:
4544-
r = self.val_to_var[val]
4542+
if val == 0:
4543+
return sympy.S.Zero
4544+
else:
4545+
return sympy.S.One
45454546
elif not duck or val not in self.val_to_var:
45464547
# If we're not duck shaping, we always create a new symbol
45474548
# Even if we're duck shaping, if we haven't seen this particular

0 commit comments

Comments
 (0)