Skip to content

Commit 9cc11b9

Browse files
authored
[mlir] [linalg] Add pattern to swap transpose with broadcast (llvm#97063)
Add a pattern that implement: transpose(broadcast(input)) -> broadcast(transpose(input))
1 parent d7e8a74 commit 9cc11b9

File tree

4 files changed

+168
-2
lines changed

4 files changed

+168
-2
lines changed

mlir/include/mlir/Dialect/Utils/IndexingUtils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,14 @@ SmallVector<int64_t>
243243
computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
244244
ArrayRef<int64_t> desiredPositions);
245245

246+
/// Returns a permutation vector that drop the input dims in
247+
/// dropPositions from inputPerm.
248+
///
249+
/// For example, inputPerm = {2, 4, 0, 1, 3} and dropPositions= {1, 2} would
250+
/// result in a {2, 0, 1} permutation vector.
251+
SmallVector<int64_t> dropDims(ArrayRef<int64_t> inputPerm,
252+
ArrayRef<int64_t> dropPositions);
253+
246254
/// Helper to return a subset of `arrayAttr` as a vector of int64_t.
247255
// TODO: Port everything relevant to DenseArrayAttr and drop this util.
248256
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1895,9 +1895,68 @@ struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
18951895
}
18961896
};
18971897

1898+
/// This pattern canonicalize transpose by swapping the order of
1899+
/// broadcast and transpose:
1900+
/// transpose(broadcast(input)) -> broadcast(transpose(input))
1901+
struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
1902+
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1903+
1904+
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1905+
PatternRewriter &rewriter) const override {
1906+
Value input = transposeOp.getInput();
1907+
BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
1908+
if (!input.hasOneUse() || !broadcastOp)
1909+
return failure();
1910+
1911+
ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
1912+
ArrayRef<int64_t> perms = transposeOp.getPermutation();
1913+
1914+
// Get new perms and new dimensions.
1915+
SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
1916+
SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
1917+
SmallVector<int64_t> resultDimensions;
1918+
unsigned dimensionSize = dimensions.size();
1919+
for (unsigned i = 0; i < dimensionSize; ++i)
1920+
resultDimensions.push_back(invertPerm[dimensions[i]]);
1921+
1922+
// Create transpose result.
1923+
Value broadcastInput = broadcastOp.getInput();
1924+
Location loc = transposeOp.getLoc();
1925+
MLIRContext *ctx = transposeOp.getContext();
1926+
SmallVector<OpFoldResult> dims;
1927+
auto broadcastInputTy =
1928+
mlir::cast<RankedTensorType>(broadcastInput.getType());
1929+
unsigned inputRank = broadcastInputTy.getRank();
1930+
for (unsigned i = 0; i < inputRank; ++i) {
1931+
if (broadcastInputTy.isDynamicDim(i)) {
1932+
dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
1933+
->getResult(0));
1934+
} else {
1935+
dims.push_back(IntegerAttr::get(IndexType::get(ctx),
1936+
broadcastInputTy.getDimSize(i)));
1937+
}
1938+
}
1939+
SmallVector<OpFoldResult> transposeResultShapes =
1940+
applyPermutation(dims, resultPerms);
1941+
Value transposeInit = rewriter.create<tensor::EmptyOp>(
1942+
transposeOp.getLoc(), transposeResultShapes,
1943+
broadcastInputTy.getElementType());
1944+
1945+
// Create broadcast(transpose(input)).
1946+
Value transposeResult =
1947+
rewriter
1948+
.create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
1949+
resultPerms)
1950+
->getResult(0);
1951+
rewriter.replaceOpWithNewOp<BroadcastOp>(
1952+
transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
1953+
return success();
1954+
}
1955+
};
1956+
18981957
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
18991958
MLIRContext *context) {
1900-
results.add<FoldTransposeWithTranspose>(context);
1959+
results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
19011960
}
19021961

19031962
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Utils/IndexingUtils.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,32 @@ mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
252252
return res;
253253
}
254254

255+
SmallVector<int64_t> mlir::dropDims(ArrayRef<int64_t> inputPerm,
256+
ArrayRef<int64_t> dropPositions) {
257+
assert(inputPerm.size() >= dropPositions.size() &&
258+
"expect inputPerm size large than position to drop");
259+
SmallVector<int64_t> res;
260+
unsigned permSize = inputPerm.size();
261+
for (unsigned inputIndex = 0; inputIndex < permSize; ++inputIndex) {
262+
int64_t targetIndex = inputPerm[inputIndex];
263+
bool shouldDrop = false;
264+
unsigned dropSize = dropPositions.size();
265+
for (unsigned dropIndex = 0; dropIndex < dropSize; dropIndex++) {
266+
if (dropPositions[dropIndex] == inputPerm[inputIndex]) {
267+
shouldDrop = true;
268+
break;
269+
}
270+
if (dropPositions[dropIndex] < inputPerm[inputIndex]) {
271+
targetIndex--;
272+
}
273+
}
274+
if (!shouldDrop) {
275+
res.push_back(targetIndex);
276+
}
277+
}
278+
return res;
279+
}
280+
255281
SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
256282
unsigned dropFront,
257283
unsigned dropBack) {

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
10171017
return %0 : tensor<2x3xf32>
10181018
}
10191019

1020-
// ----
1020+
// -----
10211021

10221022
func.func @transpose_1d(%input: tensor<16xf32>,
10231023
%init: tensor<16xf32>) -> tensor<16xf32> {
@@ -1096,3 +1096,76 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
10961096
func.return %transpose2 : tensor<3x4x5xf32>
10971097
}
10981098

1099+
// -----
1100+
1101+
func.func @broadcast_transpose_fold(%input: tensor<2x4x5xf32>,
1102+
%init1: tensor<1x2x3x4x5x6xf32>,
1103+
%init2: tensor<1x6x2x3x5x4xf32>) -> tensor<1x6x2x3x5x4xf32> {
1104+
// CHECK-LABEL: @broadcast_transpose_fold
1105+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2x4x5xf32>
1106+
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x2x3x4x5x6xf32>
1107+
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x6x2x3x5x4xf32>
1108+
// CHECK: %[[TMP_INIT:.+]] = tensor.empty() : tensor<2x5x4xf32>
1109+
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<2x4x5xf32>) outs(%[[TMP_INIT]] : tensor<2x5x4xf32>) permutation = [0, 2, 1]
1110+
// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<1x6x2x3x5x4xf32>) dimensions = [0, 3, 1]
1111+
// CHECK: return %[[BROADCAST]] : tensor<1x6x2x3x5x4xf32>
1112+
%broadcast = linalg.broadcast
1113+
ins(%input : tensor<2x4x5xf32>)
1114+
outs(%init1 : tensor<1x2x3x4x5x6xf32>)
1115+
dimensions = [0, 2, 5]
1116+
%transpose = linalg.transpose
1117+
ins(%broadcast : tensor<1x2x3x4x5x6xf32>)
1118+
outs(%init2 : tensor<1x6x2x3x5x4xf32>)
1119+
permutation = [0, 5, 1, 2, 4, 3]
1120+
func.return %transpose : tensor<1x6x2x3x5x4xf32>
1121+
}
1122+
1123+
// -----
1124+
1125+
func.func @broadcast_transpose_fold_dynamic(%input: tensor<?x?x5xf32>,
1126+
%init1: tensor<1x?x3x?x5x6xf32>,
1127+
%init2: tensor<1x3x?x6x5x?xf32>) -> tensor<1x3x?x6x5x?xf32> {
1128+
// CHECK-LABEL: @broadcast_transpose_fold_dynamic
1129+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x5xf32>
1130+
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x?x3x?x5x6xf32>
1131+
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x3x?x6x5x?xf32>
1132+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1133+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1134+
// CHECK: %[[DIM0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x5xf32>
1135+
// CHECK: %[[DIM1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x5xf32>
1136+
// CHECK: %[[TMP_INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]]) : tensor<?x5x?xf32>
1137+
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<?x?x5xf32>) outs(%[[TMP_INIT]] : tensor<?x5x?xf32>) permutation = [1, 2, 0]
1138+
// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<?x5x?xf32>) outs(%[[INIT2]] : tensor<1x3x?x6x5x?xf32>) dimensions = [0, 1, 3]
1139+
// CHECK: return %[[BROADCAST]] : tensor<1x3x?x6x5x?xf32>
1140+
%broadcast = linalg.broadcast
1141+
ins(%input : tensor<?x?x5xf32>)
1142+
outs(%init1 : tensor<1x?x3x?x5x6xf32>)
1143+
dimensions = [0, 2, 5]
1144+
%transpose = linalg.transpose
1145+
ins(%broadcast : tensor<1x?x3x?x5x6xf32>)
1146+
outs(%init2 : tensor<1x3x?x6x5x?xf32>)
1147+
permutation = [0, 2, 3, 5, 4, 1]
1148+
func.return %transpose : tensor<1x3x?x6x5x?xf32>
1149+
}
1150+
1151+
// -----
1152+
1153+
func.func @broadcast_transpose_fold_2dim(%input: tensor<2xf32>,
1154+
%init1: tensor<2x4xf32>,
1155+
%init2: tensor<4x2xf32>) -> tensor<4x2xf32> {
1156+
// CHECK-LABEL: @broadcast_transpose_fold_2dim
1157+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
1158+
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
1159+
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<4x2xf32>
1160+
// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<4x2xf32>) dimensions = [0]
1161+
// CHECK: return %[[BROADCAST]] : tensor<4x2xf32>
1162+
%broadcast = linalg.broadcast
1163+
ins(%input : tensor<2xf32>)
1164+
outs(%init1 : tensor<2x4xf32>)
1165+
dimensions = [1]
1166+
%transpose = linalg.transpose
1167+
ins(%broadcast : tensor<2x4xf32>)
1168+
outs(%init2 : tensor<4x2xf32>)
1169+
permutation = [1, 0]
1170+
func.return %transpose : tensor<4x2xf32>
1171+
}

0 commit comments

Comments
 (0)