Skip to content

Commit 6af0428

Browse files
mcremon-metakirklandsign
authored andcommitted
Add cadence.where_Scalar op
Differential Revision: D70539497 Pull Request resolved: #9764
1 parent d3687fb commit 6af0428

File tree

3 files changed

+146
-0
lines changed

3 files changed

+146
-0
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@
162162
"quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
163163
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
164164
)
165+
lib.define("where_Scalar(Tensor condition, float self, float other) -> (Tensor Z)")
166+
lib.define(
167+
"where_Scalar.out(Tensor condition, float self, float other, *, Tensor(a!) out) -> Tensor(a!)"
168+
)
165169

166170
# ------------------------------------ #
167171
# Migrated from custom_ops.yaml #
@@ -935,3 +939,12 @@ def transposed_im2row_meta(
935939
output_size = torch.Size((batch_size, output_length, n_output_plane))
936940

937941
return input.new_empty(output_size, dtype=input.dtype)
942+
943+
944+
@register_fake("cadence::where_Scalar")
945+
def where_Scalar_meta(
946+
condition: torch.Tensor,
947+
self: float,
948+
other: float,
949+
) -> torch.Tensor:
950+
return condition.new_empty(condition.size(), dtype=torch.float32)

backends/cadence/aot/replace_ops.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2062,6 +2062,54 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
20622062
return PassResult(ret.graph_module, modified)
20632063

20642064

2065+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
2066+
class ReplaceWhereWithFullArgsWithWhereScalar(ExportPass):
2067+
"""Replaces where ops using two full ops as tensors with a scalar
2068+
version.
2069+
"""
2070+
2071+
def call_operator(
2072+
self,
2073+
op,
2074+
args: Tuple[Argument, ...],
2075+
kwargs: Dict[str, Argument],
2076+
meta: NodeMetadata,
2077+
) -> ProxyValue:
2078+
if op not in {
2079+
exir_ops.edge.aten.where.self,
2080+
}:
2081+
return super().call_operator(op, args, kwargs, meta)
2082+
2083+
# If the args are not full ops, bail
2084+
# pyre-ignore[16]: `ProxyValue` has no attribute `node`.
2085+
if (args[1].node.target != exir_ops.edge.aten.full.default) or (
2086+
args[2].node.target != exir_ops.edge.aten.full.default
2087+
):
2088+
return super().call_operator(op, args, kwargs, meta)
2089+
2090+
# If one of the full ops is a different size than than the cond tensor, we need to broadcast. Bail.
2091+
if (
2092+
# pyre-ignore[16]: `ProxyValue` has no attribute `node`.
2093+
list(args[0].to_tensor().shape) != args[1].node.args[0]
2094+
or list(args[0].to_tensor().shape) != args[2].node.args[0]
2095+
):
2096+
return super().call_operator(op, args, kwargs, meta)
2097+
2098+
# Get the scalar values from the full ops
2099+
scalar_value_1 = args[1].node.args[1]
2100+
scalar_value_2 = args[2].node.args[1]
2101+
2102+
# Replace the where op with a scalar where op
2103+
return super().call_operator(
2104+
exir_ops.edge.cadence.where_Scalar.default,
2105+
(args[0], scalar_value_1, scalar_value_2),
2106+
kwargs,
2107+
meta,
2108+
)
2109+
2110+
return super().call_operator(op, args, kwargs, meta)
2111+
2112+
20652113
# This class encapsulates all the functions that replace/switch one op in the
20662114
# graph with another.
20672115
class CadenceReplaceOpsInGraph:
@@ -2100,4 +2148,5 @@ class CadenceReplaceOpsInGraph:
21002148
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
21012149
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
21022150
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
2151+
ReplaceWhereWithFullArgsWithWhereScalar,
21032152
]

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
ReplaceTCopyWithTransposePass,
4545
ReplaceTransposedConvWithLinearPass,
4646
ReplaceTrivialConvWithLinear,
47+
ReplaceWhereWithFullArgsWithWhereScalar,
4748
)
4849
from executorch.exir.dialects._ops import ops as exir_ops
4950
from executorch.exir.pass_base import ExportPass
@@ -1217,6 +1218,89 @@ def forward(self, x: torch.Tensor):
12171218
1,
12181219
)
12191220

1221+
def test_replace_aten_where_with_cadence_where_Scalar(self):
1222+
class WhereScalarModel(torch.nn.Module):
1223+
def forward(self, cond: torch.Tensor):
1224+
a = torch.ops.aten.full.default(a_shape, val1)
1225+
b = torch.ops.aten.full.default(b_shape, val2)
1226+
return torch.where(cond > 0, a, b)
1227+
1228+
cond_shape, a_shape, b_shape, val1, val2 = [(4, 8), (4, 8), (4, 8), 0.0, 1.0]
1229+
cond = torch.randn(cond_shape)
1230+
1231+
graph_module = (
1232+
export_to_edge(WhereScalarModel(), (cond,)).exported_program().graph_module
1233+
)
1234+
1235+
p = ReplaceWhereWithFullArgsWithWhereScalar()
1236+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1237+
1238+
# Assert that aten.where op was replaced by a
1239+
# cadence.where_Scalar op
1240+
self.assertEqual(
1241+
count_node(
1242+
graph_after_passes,
1243+
exir_ops.edge.aten.where.self,
1244+
),
1245+
0,
1246+
)
1247+
self.assertEqual(
1248+
count_node(graph_after_passes, exir_ops.edge.cadence.where_Scalar.default),
1249+
1,
1250+
)
1251+
1252+
class WhereBroadcastModel(torch.nn.Module):
1253+
def forward(self, cond: torch.Tensor):
1254+
a = torch.ops.aten.full.default(a_shape, val1)
1255+
b = torch.ops.aten.full.default(b_shape, val2)
1256+
return torch.where(cond > 0, a, b)
1257+
1258+
# a tensor bigger than cond and b
1259+
cond_shape, a_shape, b_shape, val1, val2 = [(8,), (4, 8), (8,), 0.0, 1.0]
1260+
cond = torch.randn(cond_shape)
1261+
1262+
graph_module = (
1263+
export_to_edge(WhereBroadcastModel(), (cond,))
1264+
.exported_program()
1265+
.graph_module
1266+
)
1267+
1268+
p = ReplaceWhereWithFullArgsWithWhereScalar()
1269+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1270+
1271+
# Assert that aten.where op is still in the graph since where_Scalar does not
1272+
# support broadcast
1273+
self.assertEqual(
1274+
count_node(
1275+
graph_after_passes,
1276+
exir_ops.edge.aten.where.self,
1277+
),
1278+
1,
1279+
)
1280+
1281+
# cond tensor bigger than a and b
1282+
cond_shape, a_shape, b_shape, val1, val2 = [(4, 8), (8,), (8,), 0.0, 1.0]
1283+
cond = torch.randn(cond_shape)
1284+
1285+
graph_module = (
1286+
export_to_edge(WhereBroadcastModel(), (cond,))
1287+
.exported_program()
1288+
.graph_module
1289+
)
1290+
1291+
p = ReplaceWhereWithFullArgsWithWhereScalar()
1292+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1293+
1294+
# Assert that aten.where op is still in the graph since where_Scalar does not
1295+
# support broadcast
1296+
self.assertEqual(
1297+
count_node(
1298+
graph_after_passes,
1299+
exir_ops.edge.aten.where.self,
1300+
),
1301+
1,
1302+
)
1303+
12201304

12211305
class TestReplaceIm2rowWithViewPass(unittest.TestCase):
12221306
def test_no_replacement_for_conv(self):

0 commit comments

Comments
 (0)