Skip to content

Commit a480d75

Browse files
committed
[mlir][vector] Fold transpose(broadcast(<scalar>))
For such cases, the transpose op can be elided. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D122903
1 parent 4cf98f9 commit a480d75

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4212,11 +4212,33 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
42124212
}
42134213
};
42144214

4215+
// Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
4216+
struct FoldTransposedScalarBroadcast final
4217+
: public OpRewritePattern<vector::TransposeOp> {
4218+
using OpRewritePattern::OpRewritePattern;
4219+
4220+
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
4221+
PatternRewriter &rewriter) const override {
4222+
auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
4223+
if (!bcastOp)
4224+
return failure();
4225+
4226+
auto srcVectorType = bcastOp.getSourceType().dyn_cast<VectorType>();
4227+
if (!srcVectorType || srcVectorType.getNumElements() == 1) {
4228+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
4229+
transposeOp, transposeOp.getResultType(), bcastOp.getSource());
4230+
return success();
4231+
}
4232+
4233+
return failure();
4234+
}
4235+
};
4236+
42154237
} // namespace
42164238

42174239
void vector::TransposeOp::getCanonicalizationPatterns(
42184240
RewritePatternSet &results, MLIRContext *context) {
4219-
results.add<TransposeFolder>(context);
4241+
results.add<FoldTransposedScalarBroadcast, TransposeFolder>(context);
42204242
}
42214243

42224244
void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,3 +1348,27 @@ func @shuffle_nofold2(%v0 : vector<[4]xi32>, %v1 : vector<[2]xi32>) -> vector<4x
13481348
%shuffle = vector.shuffle %v0, %v1 [0, 1, 2, 3] : vector<[4]xi32>, vector<[2]xi32>
13491349
return %shuffle : vector<4xi32>
13501350
}
1351+
1352+
// -----
1353+
1354+
// CHECK-LABEL: func @transpose_scalar_broadcast1
1355+
// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
1356+
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<1xf32> to vector<1x8xf32>
1357+
// CHECK: return %[[V]] : vector<1x8xf32>
1358+
func @transpose_scalar_broadcast1(%value: vector<1xf32>) -> vector<1x8xf32> {
1359+
%bcast = vector.broadcast %value : vector<1xf32> to vector<8x1xf32>
1360+
%t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
1361+
return %t : vector<1x8xf32>
1362+
}
1363+
1364+
// -----
1365+
1366+
// CHECK-LABEL: func @transpose_scalar_broadcast2
1367+
// CHECK-SAME: (%[[ARG:.+]]: f32)
1368+
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x8xf32>
1369+
// CHECK: return %[[V]] : vector<1x8xf32>
1370+
func @transpose_scalar_broadcast2(%value: f32) -> vector<1x8xf32> {
1371+
%bcast = vector.broadcast %value : f32 to vector<8x1xf32>
1372+
%t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
1373+
return %t : vector<1x8xf32>
1374+
}

0 commit comments

Comments
 (0)