Skip to content

Commit b1b46ee

Browse files
: constant fold None
Differential Revision: D74350331 Pull Request resolved: #10762
1 parent 277c39d commit b1b46ee

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

exir/passes/constant_prop_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def is_const(
6666
)
6767
elif isinstance(arg, _PRIMITIVE_TYPES):
6868
return True
69+
elif arg is None:
70+
return True
6971
elif not isinstance(arg, torch.fx.Node):
7072
return False
7173
elif arg in const_node_to_tensor:

exir/tests/test_passes.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,3 +1823,34 @@ def _do_checks(
18231823
self.assertTrue(
18241824
torch.allclose(output_no_dim_order[0], output_no_dim_order_revert[0])
18251825
)
1826+
1827+
def test_constant_prop_pass_none(self) -> None:
1828+
"""
1829+
This checks that None arguments are treated as constants in constant_prop_pass.
1830+
"""
1831+
1832+
class M(torch.nn.Module):
1833+
def __init__(self):
1834+
super().__init__()
1835+
self.cst = torch.ones(3, 3, 3, dtype=torch.int8)
1836+
self.w = torch.ones(3, 3, 3, dtype=torch.int8)
1837+
1838+
def forward(self, x):
1839+
# Note: using e.g aten.linear would not work as None is not in the graph
1840+
a = torch.ops.aten.convolution.default(
1841+
self.cst, self.w, None, [1], [0], [1], False, [0], 1
1842+
)
1843+
return a + x
1844+
1845+
mod = M()
1846+
x = torch.randn([3, 3, 3])
1847+
mod(x)
1848+
edge = to_edge(
1849+
export(mod, (x,), strict=True),
1850+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1851+
)
1852+
# 2 constants: self.w and self.cst
1853+
self.assertEqual(2, len(edge.exported_program().constants))
1854+
pass_result = constant_prop_pass(edge.exported_program())
1855+
# 1 constant: a (= self.w @ self.cst)
1856+
self.assertEqual(1, len(pass_result.constants))

0 commit comments

Comments
 (0)