Skip to content

Allow removing permute pairs in addition to transpose pairs #10501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 30 additions & 37 deletions backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import operator
from collections import deque
from numbers import Number
from typing import cast, Sequence
from typing import cast

# Import these for the cadence function signatures.
import executorch.backends.cadence.aot.ops_registrations # noqa: F401
Expand Down Expand Up @@ -881,9 +881,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class FuseTransposeOpPairsPass(FuseOpPairsAcrossBranchesPass):
class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
"""
Fuse transpose op pairs to a single view op.
Fuse transpose or permute op pairs to a single view op.
"""

# A list of ops that can be bypassed when looking for a
Expand All @@ -907,42 +907,28 @@ def can_fuse_for_chain(
if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets):
return False

def get_dims(node: torch.fx.Node) -> tuple[int, int]:
def canonicalize(dim: int) -> int:
if dim < 0:
dim += len(node.meta["val"].shape)
return dim

return tuple(canonicalize(cast(int, d)) for d in node.args[1:3])

def is_equivalent(
shape: Sequence[int],
transpose0: tuple[int, int],
transpose1: tuple[int, int],
) -> bool:
def permute_order(
order: Sequence[int], dims: tuple[int, int]
) -> Sequence[int]:
new_order = list(order)
new_order[dims[0]], new_order[dims[1]] = (
new_order[dims[1]],
new_order[dims[0]],
)
return new_order
input_shape = list(cast(torch.fx.Node, producer.args[0]).meta["val"].shape)

order = permute_order(range(len(shape)), transpose0)
order = permute_order(order, transpose1)
intermediate_shape = (
get_transposed_dims(producer, input_shape)
if producer.target == exir_ops.edge.aten.transpose_copy.int
else get_permuted_dims(producer, input_shape)
)

non_unit_dims = [dim for dim in range(len(shape)) if shape[dim] != 1]
non_unit_dims_permuted = [dim for dim in order if shape[dim] != 1]
final_shape = (
get_transposed_dims(consumer, intermediate_shape)
if consumer.target == exir_ops.edge.aten.transpose_copy.int
else get_permuted_dims(consumer, intermediate_shape)
)

return non_unit_dims == non_unit_dims_permuted
non_unit_dims = [
input_shape[dim] for dim in range(len(input_shape)) if input_shape[dim] != 1
]
non_unit_dims_permuted = [
final_shape[dim] for dim in range(len(final_shape)) if final_shape[dim] != 1
]

return is_equivalent(
cast(torch.fx.Node, producer.args[0]).meta["val"].shape,
get_dims(producer),
get_dims(consumer),
)
return non_unit_dims == non_unit_dims_permuted

def get_fused_node(
self,
Expand All @@ -960,13 +946,20 @@ def get_fused_node(
return view

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
# Remove any dequantize op that has only quantize ops as its users.
# Remove any transpose op pair that cancel each other.
self.find_and_fuse(
graph_module,
producer_op_packets={exir_ops.edge.aten.transpose_copy},
consumer_op_packets={exir_ops.edge.aten.transpose_copy},
bypass_ops=self.bypass_ops,
)
# Remove any permute op pair that cancel each other.
self.find_and_fuse(
graph_module,
producer_op_packets={exir_ops.edge.aten.permute_copy},
consumer_op_packets={exir_ops.edge.aten.permute_copy},
bypass_ops=self.bypass_ops,
)
result = super().call(graph_module)
return result

Expand Down Expand Up @@ -1028,5 +1021,5 @@ class CadenceFuseOpsInGraph:
FuseQuantDequantToRequantizePass,
FuseMulIntoDequantPass,
FuseFullThenReshapePass,
FuseTransposeOpPairsPass,
FuseTransposeOrPermuteOpPairsPass,
]
4 changes: 2 additions & 2 deletions backends/cadence/aot/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from executorch.backends.cadence.aot.fuse_ops import (
CadenceFuseOpsInGraph,
FuseFullThenReshapePass,
FuseTransposeOpPairsPass,
FuseTransposeOrPermuteOpPairsPass,
)
from executorch.backends.cadence.aot.pass_utils import (
CadencePassAttribute,
Expand Down Expand Up @@ -83,7 +83,7 @@ def get_passes_in_default_order() -> List[ExportPass]:
CadenceSimplifyOpsInGraph.passes,
FinalizePipeline,
FuseFullThenReshapePass,
FuseTransposeOpPairsPass,
FuseTransposeOrPermuteOpPairsPass,
RemoveNopSliceOrViewOpPass,
]
return pytree.tree_flatten(passes)[0]
Expand Down
99 changes: 88 additions & 11 deletions backends/cadence/aot/tests/test_fusion_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
FuseFullThenReshapePass,
FuseMulIntoDequantPass,
FuseQuantDequantToRequantizePass,
FuseTransposeOpPairsPass,
FuseTransposeOrPermuteOpPairsPass,
)
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
Expand Down Expand Up @@ -510,7 +510,7 @@ def test_fuse_then_transpose_pass(self):
)


class TestFuseTransposeOpPairsPass(TestFusionPassesBase):
class TestFuseTransposeOrPermuteOpPairsPass(TestFusionPassesBase):
def _create_operator(
self, builder: GraphBuilder, op: torch._ops.OpOverload, x: ProxyValue
) -> ProxyValue:
Expand All @@ -536,17 +536,17 @@ def _create_operator(
def test_fuse_transpose_pairs(self, op: torch._ops.OpOverload):
# Create a graph with transpose -> quant -> transpose.
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 3))
transpose_node = builder.call_operator(
x = builder.placeholder("x", torch.randn(2, 3, 4))
transpose_node0 = builder.call_operator(
op=exir_ops.edge.aten.transpose_copy.int,
args=(x, 0, 1),
)
quant_node = self._create_operator(builder, op, transpose_node)
transpose_node = builder.call_operator(
quant_node = self._create_operator(builder, op, transpose_node0)
transpose_node1 = builder.call_operator(
op=exir_ops.edge.aten.transpose_copy.int,
args=(quant_node, 0, 1),
args=(quant_node, 1, 2),
)
builder.output([transpose_node])
builder.output([transpose_node1])
gm = builder.get_graph_module()
self.check_op_counts(
gm,
Expand All @@ -557,7 +557,7 @@ def test_fuse_transpose_pairs(self, op: torch._ops.OpOverload):
)

# Check that the pass fuses the two transpose ops.
fusion_pass_result = FuseTransposeOpPairsPass()(gm)
fusion_pass_result = FuseTransposeOrPermuteOpPairsPass()(gm)
self.assertIsNotNone(fusion_pass_result)
gm_after_pass = fusion_pass_result.graph_module
self.check_op_counts(
Expand All @@ -568,6 +568,47 @@ def test_fuse_transpose_pairs(self, op: torch._ops.OpOverload):
},
)

@parameterized.expand(
[
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.cadence.quantized_relu.per_tensor,
],
)
def test_fuse_permute_pairs(self, op: torch._ops.OpOverload):
# Create a graph with permute -> quant -> permute.
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(8, 2, 3, 4))
permute_node0 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(x, [0, 3, 1, 2]),
)
quant_node = self._create_operator(builder, op, permute_node0)
permute_node1 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(quant_node, [0, 2, 3, 1]),
)
builder.output([permute_node1])
gm = builder.get_graph_module()
self.check_op_counts(
gm,
expected_op_counts={
exir_ops.edge.aten.permute_copy.default: 2,
op: 1,
},
)

# Check that the pass fuses the two transpose ops.
fusion_pass_result = FuseTransposeOrPermuteOpPairsPass()(gm)
self.assertIsNotNone(fusion_pass_result)
gm_after_pass = fusion_pass_result.graph_module
self.check_op_counts(
gm_after_pass,
expected_op_counts={
exir_ops.edge.aten.permute_copy.default: 0,
op: 1,
},
)

def test_no_fusion_for_transpose_pairs(self):
# Create a graph with transpose -> quant -> transpose.
builder = GraphBuilder()
Expand Down Expand Up @@ -595,7 +636,7 @@ def test_no_fusion_for_transpose_pairs(self):
)

# No fusion.
gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module
gm_after_pass = FuseTransposeOrPermuteOpPairsPass()(gm).graph_module
self.check_op_counts(
gm_after_pass,
expected_op_counts={
Expand All @@ -604,6 +645,42 @@ def test_no_fusion_for_transpose_pairs(self):
},
)

def test_no_fusion_for_permute_pairs(self):
# Create a graph with permute -> quant -> permute.
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 3, 4))
permute_node = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(x, [2, 0, 1]),
)
quant_node = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(permute_node, 1.2, 3, 0, 127, torch.int8),
)
permute_node = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(quant_node, [2, 0, 1]),
)
builder.output(permute_node)
gm = builder.get_graph_module()
self.check_op_counts(
gm,
expected_op_counts={
exir_ops.edge.aten.permute_copy.default: 2,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
},
)

# No fusion.
gm_after_pass = FuseTransposeOrPermuteOpPairsPass()(gm).graph_module
self.check_op_counts(
gm_after_pass,
expected_op_counts={
exir_ops.edge.aten.permute_copy.default: 2,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
},
)

def test_fusion_for_forked_transposes(self):
# Create a graph with transpose -> quant -> transpose.
builder = GraphBuilder()
Expand Down Expand Up @@ -636,7 +713,7 @@ def test_fusion_for_forked_transposes(self):
)

# Fuse the all the transpose ops.
gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module
gm_after_pass = FuseTransposeOrPermuteOpPairsPass()(gm).graph_module
self.check_op_counts(
gm_after_pass,
expected_op_counts={
Expand Down