-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] transpose(broadcast) -> broadcast canonicalization #135096
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
2498d7d
d3fe38a
af69672
7df6355
384a5ca
4f49639
81101c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6151,12 +6151,106 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> { | |
} | ||
}; | ||
|
||
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is | ||
/// 'order preserving', where 'order preserving' means the flattened | ||
/// inputs and outputs of the transpose have identical (numerical) values. | ||
/// | ||
/// Example: | ||
/// ``` | ||
/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32> | ||
/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32> | ||
/// to vector<8x1xi32> | ||
/// ``` | ||
/// can be rewritten as the equivalent | ||
/// ``` | ||
/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>. | ||
/// ``` | ||
/// The algorithm works by partitioning dimensions into groups that can be | ||
/// locally permuted while preserving order, and checks that the transpose | ||
/// only permutes within these groups. | ||
/// | ||
/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups). | ||
/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to | ||
/// broadcasting from 1x1x4x1x1x7. | ||
/// ^^^ ^ ^^^ ^ | ||
/// groups: 0 1 2 3 | ||
/// Order preserving permutations for this example are ones that only permute | ||
/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6). | ||
class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1) | ||
: OpRewritePattern<vector::TransposeOp>(context, benefit) {} | ||
|
||
LogicalResult matchAndRewrite(vector::TransposeOp transpose, | ||
PatternRewriter &rewriter) const override { | ||
|
||
vector::BroadcastOp broadcast = | ||
transpose.getVector().getDefiningOp<vector::BroadcastOp>(); | ||
if (!broadcast) { | ||
return rewriter.notifyMatchFailure(transpose, | ||
"not preceded by a broadcast"); | ||
} | ||
|
||
auto inputType = dyn_cast<VectorType>(broadcast.getSourceType()); | ||
|
||
// transpose(broadcast(scalar)) -> broadcast(scalar) is always valid, and | ||
// transpose(broadcast(all ones)) -> broadcast(all ones) is always valid | ||
bool inputIsScalar = !inputType; | ||
bool inputIsSizeOneVector = inputType.getNumElements() == 1; | ||
if (inputIsScalar || inputIsSizeOneVector) { | ||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>( | ||
transpose, transpose.getResultVectorType(), transpose.getVector()); | ||
return success(); | ||
} | ||
|
||
ArrayRef<int64_t> permutation = transpose.getPermutation(); | ||
ArrayRef<int64_t> inputShape = inputType.getShape(); | ||
int64_t inputRank = inputType.getRank(); | ||
int64_t outputRank = transpose.getType().getRank(); | ||
int64_t deltaRank = outputRank - inputRank; | ||
|
||
int low = 0; | ||
for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) { | ||
bool notOne = inputShape[inputIndex] != 1; | ||
bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1); | ||
bool groupEndFound = notOne || prevNotOne; | ||
if (groupEndFound) { | ||
int high = inputIndex + deltaRank; | ||
// Return failure if not all permutation destinations for indices in | ||
// [low, high) are in [low, high), i.e. the permutation is not local to | ||
// the group. | ||
for (int i = low; i < high; ++i) { | ||
if (permutation[i] < low || permutation[i] >= high) { | ||
return rewriter.notifyMatchFailure( | ||
transpose, "permutation not local to group"); | ||
} | ||
} | ||
} | ||
} | ||
|
||
// We don't need to check the final group [low, outputRank) because if it is | ||
// not locally bound, there must be a preceding group that already failed | ||
// the check (impossible to have just 1 non-locally bound group). | ||
|
||
// The preceding logic also ensures that at this point, the output of the | ||
// transpose is definitely broadcastable from the input shape, so we don't | ||
// need to check vector::isBroadcastableTo now. | ||
|
||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>( | ||
transpose, transpose.getResultVectorType(), transpose.getVector()); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void vector::TransposeOp::getCanonicalizationPatterns( | ||
RewritePatternSet &results, MLIRContext *context) { | ||
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast, | ||
TransposeFolder, FoldTransposeSplat>(context); | ||
TransposeFolder, FoldTransposeSplat, FoldTransposeBroadcast>( | ||
context); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am wondering whether this qualifies as canonicalisation. I am not an expert, so merely raising my concerns. From https://mlir.llvm.org/docs/Canonicalization/#general-design
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so. The new broadcast is the same rank and 'volume' as the old one, so i can't think of sense in which it'll be more complex. So the removal of the transpose clinches it in my mind! As an aside: It would be nice if MLIR/dialects defined an 'energy function' defining what classifies as a canonicalization (i.e. something to guarantee every rewrite takes us closer to a fixed point = energy minimum). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, thanks for the discussion!
So nice :) |
||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s | ||
|
||
// This file contains some canonicalizations tests involving vector.transpose. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note, it's totally valid (and something I personally encourage) to document what pattern specifically is being tested: |
||
|
||
// CHECK-LABEL: broadcast_transpose_scalar_to_broadcast | ||
// CHECK-SAME: %[[ARG:.*]]: i8) -> vector<2x3x4xi8> { | ||
func.func @broadcast_transpose_scalar_to_broadcast(%arg0 : i8) -> vector<2x3x4xi8> { | ||
// CHECK: %[[BC:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8> | ||
%0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8> | ||
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8> | ||
// CHECK: return %[[BC]] : vector<2x3x4xi8> | ||
return %1 : vector<2x3x4xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: broadcast_transpose_ones_to_broadcast | ||
// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> { | ||
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8> | ||
// CHECK: return %[[RES]] : vector<2x3x4xi8> | ||
func.func @broadcast_transpose_ones_to_broadcast(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> { | ||
%0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8> | ||
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8> | ||
return %1 : vector<2x3x4xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: broadcast_transpose_partial_ones_to_broadcast | ||
// CHECK-SAME: %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> { | ||
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8> | ||
// CHECK: return %[[RES]] : vector<8x1xi8> | ||
func.func @broadcast_transpose_partial_ones_to_broadcast(%arg0 : vector<1xi8>) -> vector<8x1xi8> { | ||
%0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8> | ||
%1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8> | ||
return %1 : vector<8x1xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: broadcast_transpose_mixed_example | ||
// CHECK-SAME: %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> { | ||
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8> | ||
// CHECK: return %[[RES]] : vector<3x2x4x5x6x7xi8> | ||
func.func @broadcast_transpose_mixed_example(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> { | ||
%0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8> | ||
%1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8> | ||
return %1 : vector<3x2x4x5x6x7xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: broadcast_transpose_final_group | ||
// CHECK-SAME: %[[ARG:.*]]: vector<4x7x1x1xi8>) -> vector<4x7x2x3xi8> { | ||
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x7x1x1xi8> to vector<4x7x2x3xi8> | ||
// CHECK: return %[[RES]] : vector<4x7x2x3xi8> | ||
func.func @broadcast_transpose_final_group(%arg0 : vector<4x7x1x1xi8>) -> vector<4x7x2x3xi8> { | ||
%0 = vector.broadcast %arg0 : vector<4x7x1x1xi8> to vector<4x7x3x2xi8> | ||
%1 = vector.transpose %0, [0, 1, 3, 2] : vector<4x7x3x2xi8> to vector<4x7x2x3xi8> | ||
return %1 : vector<4x7x2x3xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: negative_broadcast_transpose_square | ||
// CHECK-SAME: %[[ARG:.*]]: | ||
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]] | ||
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0] | ||
// CHECK: return %[[TRP]] : vector<4x4xi8> | ||
func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> { | ||
%0 = vector.broadcast %arg0 : vector<4x1xi8> to vector<4x4xi8> | ||
%1 = vector.transpose %0, [1, 0] : vector<4x4xi8> to vector<4x4xi8> | ||
return %1 : vector<4x4xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: negative_broadcast_transpose_hypercube | ||
// CHECK-SAME: %[[ARG:.*]]: | ||
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]] | ||
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 3, 2] | ||
// CHECK: return %[[TRP]] : vector<4x4x4x4xi8> | ||
func.func @negative_broadcast_transpose_hypercube(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> { | ||
%0 = vector.broadcast %arg0 : vector<1x1x4xi8> to vector<4x4x4x4xi8> | ||
%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x4x4x4xi8> to vector<4x4x4x4xi8> | ||
return %1 : vector<4x4x4x4xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: negative_broadcast_transpose_102 | ||
// CHECK-SAME: %[[ARG:.*]]: | ||
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]] | ||
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2] | ||
// CHECK: return %[[TRP]] : vector<3x3x3xi8> | ||
func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> { | ||
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8> | ||
%1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8> | ||
return %1 : vector<3x3x3xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: negative_broadcast_transpose_021 | ||
// CHECK-SAME: %[[ARG:.*]]: | ||
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]] | ||
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1] | ||
// CHECK: return %[[TRP]] : vector<3x3x3xi8> | ||
func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> { | ||
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8> | ||
%1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8> | ||
return %1 : vector<3x3x3xi8> | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.