Skip to content

[mlir][vector] transpose(broadcast) -> broadcast canonicalization #135096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6151,12 +6151,106 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};

/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
/// 'order preserving', where 'order preserving' means the flattened
/// inputs and outputs of the transpose have identical (numerical) values.
///
/// Example:
/// ```
/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
/// to vector<8x1xi32>
/// ```
/// can be rewritten as the equivalent
/// ```
/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
/// ```
/// The algorithm works by partitioning dimensions into groups that can be
/// locally permuted while preserving order, and checks that the transpose
/// only permutes within these groups.
///
/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
/// broadcasting from 1x1x4x1x1x7.
/// ^^^ ^ ^^^ ^
/// groups: 0 1 2 3
/// Order preserving permutations for this example are ones that only permute
/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}

LogicalResult matchAndRewrite(vector::TransposeOp transpose,
PatternRewriter &rewriter) const override {

vector::BroadcastOp broadcast =
transpose.getVector().getDefiningOp<vector::BroadcastOp>();
if (!broadcast) {
return rewriter.notifyMatchFailure(transpose,
"not preceded by a broadcast");
}

auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());

// transpose(broadcast(scalar)) -> broadcast(scalar) is always valid, and
// transpose(broadcast(all ones)) -> broadcast(all ones) is always valid
bool inputIsScalar = !inputType;
bool inputIsSizeOneVector = inputType.getNumElements() == 1;
if (inputIsScalar || inputIsSizeOneVector) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
transpose, transpose.getResultVectorType(), transpose.getVector());
return success();
}

ArrayRef<int64_t> permutation = transpose.getPermutation();
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t inputRank = inputType.getRank();
int64_t outputRank = transpose.getType().getRank();
int64_t deltaRank = outputRank - inputRank;

int low = 0;
for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
bool notOne = inputShape[inputIndex] != 1;
bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
bool groupEndFound = notOne || prevNotOne;
if (groupEndFound) {
int high = inputIndex + deltaRank;
// Return failure if not all permutation destinations for indices in
// [low, high) are in [low, high), i.e. the permutation is not local to
// the group.
for (int i = low; i < high; ++i) {
if (permutation[i] < low || permutation[i] >= high) {
return rewriter.notifyMatchFailure(
transpose, "permutation not local to group");
}
}
}
}

// We don't need to check the final group [low, outputRank) because if it is
// not locally bound, there must be a preceding group that already failed
// the check (impossible to have just 1 non-locally bound group).

// The preceding logic also ensures that at this point, the output of the
// transpose is definitely broadcastable from the input shape, so we don't
// need to check vector::isBroadcastableTo now.

rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
transpose, transpose.getResultVectorType(), transpose.getVector());

return success();
}
};

} // namespace

void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
TransposeFolder, FoldTransposeSplat>(context);
TransposeFolder, FoldTransposeSplat, FoldTransposeBroadcast>(
context);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering whether this qualifies as canonicalisation. I am not an expert, so merely raising my concerns. From https://mlir.llvm.org/docs/Canonicalization/#general-design

Canonicalize shouldn’t lose the semantic of original operation: the original information should always be recoverable from the transformed IR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. The new broadcast is the same rank and 'volume' as the old one, so i can't think of sense in which it'll be more complex. So the removal of the transpose clinches it in my mind!

As an aside: It would be nice if MLIR/dialects defined an 'energy function' defining what classifies as a canonicalization (i.e. something to guarantee every rewrite takes us closer to a fixed point = energy minimum).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for the discussion!

It would be nice if MLIR/dialects defined an 'energy function' defining what classifies as a canonicalization

So nice :)

}

//===----------------------------------------------------------------------===//
Expand Down
114 changes: 114 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s

// This file contains some canonicalizations tests involving vector.transpose.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, it's totally valid (and something I personally encourage) to document what pattern specifically is being tested:


// CHECK-LABEL: broadcast_transpose_scalar_to_broadcast
// CHECK-SAME: %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
func.func @broadcast_transpose_scalar_to_broadcast(%arg0 : i8) -> vector<2x3x4xi8> {
// CHECK: %[[BC:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
%0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
// CHECK: return %[[BC]] : vector<2x3x4xi8>
return %1 : vector<2x3x4xi8>
}

// -----

// CHECK-LABEL: broadcast_transpose_ones_to_broadcast
// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
// CHECK: return %[[RES]] : vector<2x3x4xi8>
func.func @broadcast_transpose_ones_to_broadcast(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
%0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
return %1 : vector<2x3x4xi8>
}

// -----

// CHECK-LABEL: broadcast_transpose_partial_ones_to_broadcast
// CHECK-SAME: %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
// CHECK: return %[[RES]] : vector<8x1xi8>
func.func @broadcast_transpose_partial_ones_to_broadcast(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
%0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
%1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
return %1 : vector<8x1xi8>
}

// -----

// CHECK-LABEL: broadcast_transpose_mixed_example
// CHECK-SAME: %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
// CHECK: return %[[RES]] : vector<3x2x4x5x6x7xi8>
func.func @broadcast_transpose_mixed_example(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
%0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
%1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
return %1 : vector<3x2x4x5x6x7xi8>
}

// -----

// CHECK-LABEL: broadcast_transpose_final_group
// CHECK-SAME: %[[ARG:.*]]: vector<4x7x1x1xi8>) -> vector<4x7x2x3xi8> {
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x7x1x1xi8> to vector<4x7x2x3xi8>
// CHECK: return %[[RES]] : vector<4x7x2x3xi8>
func.func @broadcast_transpose_final_group(%arg0 : vector<4x7x1x1xi8>) -> vector<4x7x2x3xi8> {
%0 = vector.broadcast %arg0 : vector<4x7x1x1xi8> to vector<4x7x3x2xi8>
%1 = vector.transpose %0, [0, 1, 3, 2] : vector<4x7x3x2xi8> to vector<4x7x2x3xi8>
return %1 : vector<4x7x2x3xi8>
}

// -----

// CHECK-LABEL: negative_broadcast_transpose_square
// CHECK-SAME: %[[ARG:.*]]:
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0]
// CHECK: return %[[TRP]] : vector<4x4xi8>
func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
%0 = vector.broadcast %arg0 : vector<4x1xi8> to vector<4x4xi8>
%1 = vector.transpose %0, [1, 0] : vector<4x4xi8> to vector<4x4xi8>
return %1 : vector<4x4xi8>
}

// -----

// CHECK-LABEL: negative_broadcast_transpose_hypercube
// CHECK-SAME: %[[ARG:.*]]:
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 3, 2]
// CHECK: return %[[TRP]] : vector<4x4x4x4xi8>
func.func @negative_broadcast_transpose_hypercube(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
%0 = vector.broadcast %arg0 : vector<1x1x4xi8> to vector<4x4x4x4xi8>
%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x4x4x4xi8> to vector<4x4x4x4xi8>
return %1 : vector<4x4x4x4xi8>
}

// -----

// CHECK-LABEL: negative_broadcast_transpose_102
// CHECK-SAME: %[[ARG:.*]]:
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
// CHECK: return %[[TRP]] : vector<3x3x3xi8>
func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
%1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
return %1 : vector<3x3x3xi8>
}

// -----

// CHECK-LABEL: negative_broadcast_transpose_021
// CHECK-SAME: %[[ARG:.*]]:
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
// CHECK: return %[[TRP]] : vector<3x3x3xi8>
func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
%1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
return %1 : vector<3x3x3xi8>
}