Skip to content

Commit 200214a

Browse files
javedabsar1frederik-h
authored andcommitted
[mlir][linalg][elementwise] Fold transpose into new elementwise (llvm#130207)
Fold transpose into new elementwise Op which has affine-map attached. Will add broadcast folding in next diff.
1 parent 4f4e718 commit 200214a

File tree

6 files changed

+154
-1
lines changed

6 files changed

+154
-1
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,18 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
601601
[{
602602
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
603603
attributes, ElementwiseOp::getRegionBuilder());
604-
}]>
604+
}]>,
605+
606+
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
607+
"ElementwiseKindAttr":$kind,
608+
"ArrayAttr":$indexingMaps,
609+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
610+
[{
611+
$_state.addAttribute("kind", kind);
612+
$_state.addAttribute("indexing_maps", indexingMaps);
613+
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
614+
attributes, ElementwiseOp::getRegionBuilder());
615+
}]>
605616
];
606617

607618
let hasCustomAssemblyFormat = 1;

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
9999
let dependentDialects = ["linalg::LinalgDialect"];
100100
}
101101

102+
def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
103+
let summary = "Fold transform, broadcast and other ops into elementwise";
104+
let dependentDialects = ["linalg::LinalgDialect"];
105+
}
106+
102107
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
103108
let summary = "Detensorize linalg ops";
104109
let dependentDialects = [];

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,6 +1710,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
17101710
void populateLinalgGenericOpsSpecializationPatterns(
17111711
RewritePatternSet &patterns);
17121712

1713+
/// Populates `patterns` with patterns that fold operations like
1714+
/// `linalg.transform` into elementwise op map.
1715+
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
1716+
17131717
/// Linalg decompose convolutions patterns
17141718

17151719
/// Populates patterns to decompose high-D convolution ops into low-D ones.

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
1414
EliminateEmptyTensors.cpp
1515
EraseUnusedOperandsAndResults.cpp
1616
FoldAddIntoDest.cpp
17+
FoldIntoElementwise.cpp
1718
FusePadOpWithLinalgProducer.cpp
1819
Fusion.cpp
1920
Generalization.cpp
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//===- FoldIntoElementwise.cpp - Fold Ops into elementwise if possible ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements folding ops such as transpose and broadcast into the
10+
// affine maps of the elementwise op.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
15+
#include "mlir/Dialect/Linalg/Passes.h"
16+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
#include "llvm/ADT/SmallVector.h"
20+
#include "llvm/ADT/TypeSwitch.h"
21+
22+
namespace mlir {
23+
#define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS
24+
#include "mlir/Dialect/Linalg/Passes.h.inc"
25+
} // namespace mlir
26+
27+
using namespace mlir;
28+
using namespace mlir::linalg;
29+
30+
#define DEBUG_TYPE "linalg-fold-into-elementwise"
31+
32+
namespace {
33+
struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
34+
using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
35+
36+
LogicalResult matchAndRewrite(ElementwiseOp op,
37+
PatternRewriter &rewriter) const override {
38+
bool changed = false;
39+
SmallVector<Value> newIns;
40+
SmallVector<AffineMap> newMaps;
41+
for (OpOperand *operand : op.getDpsInputOperands()) {
42+
AffineMap map = op.getMatchingIndexingMap(operand);
43+
auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
44+
45+
if (!map.isIdentity() || !transposeOp) {
46+
// push in original operand and its map.
47+
newIns.push_back(operand->get());
48+
newMaps.push_back(map);
49+
continue;
50+
}
51+
newIns.push_back(transposeOp.getInput());
52+
// push in transposeOp's inverse permutation map.
53+
newMaps.push_back(transposeOp.getMatchingIndexingMap(
54+
transposeOp.getDpsInputOperand(0)));
55+
changed = true;
56+
}
57+
if (!changed)
58+
return failure();
59+
newMaps.push_back(op.getIndexingMapsArray().back());
60+
61+
rewriter.replaceOpWithNewOp<ElementwiseOp>(
62+
op, newIns, op.getDpsInits()[0], op.getKindAttr(),
63+
rewriter.getAffineMapArrayAttr(newMaps));
64+
return success();
65+
}
66+
};
67+
68+
struct LinalgFoldIntoElementwisePass
69+
: public impl::LinalgFoldIntoElementwisePassBase<
70+
LinalgFoldIntoElementwisePass> {
71+
using impl::LinalgFoldIntoElementwisePassBase<
72+
LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase;
73+
74+
void runOnOperation() override {
75+
llvm::outs() << "Hellow from fold into elemenwise \n";
76+
Operation *op = getOperation();
77+
RewritePatternSet patterns(op->getContext());
78+
populateLinalgFoldIntoElementwisePatterns(patterns);
79+
80+
if (failed(applyPatternsGreedily(op, std::move(patterns))))
81+
return signalPassFailure();
82+
}
83+
};
84+
} // namespace
85+
86+
void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
87+
RewritePatternSet &patterns) {
88+
patterns.add<FoldTransposePattern>(patterns.getContext());
89+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: mlir-opt %s -linalg-fold-into-elementwise -split-input-file | FileCheck %s
2+
3+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
4+
// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
5+
//
6+
// CHECK: func.func @unary_transpose(%[[A:.+]]: tensor<16x8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
7+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
8+
// CHECK-SAME: indexing_maps = [#[[TRANSPOSED]], #[[IDENTITY]]]
9+
// CHECK-SAME: ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
10+
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
11+
//
12+
func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
13+
%empty = tensor.empty() : tensor<8x16x32xf32>
14+
%transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
15+
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
16+
ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
17+
return %result : tensor<8x16x32xf32>
18+
}
19+
20+
// -----
21+
22+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
23+
// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1) -> (d1, d0)>
24+
//
25+
// CHECK: func.func @binary_transposed(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
26+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
27+
// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[TRANSPOSED]], #[[IDENTITY]]]
28+
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
29+
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
30+
//
31+
func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
32+
%c0 = arith.constant 0 : index
33+
%c1 = arith.constant 1 : index
34+
%dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
35+
%dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
36+
37+
%empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
38+
%transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
39+
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
40+
ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
41+
outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
42+
return %result : tensor<?x?xf32>
43+
}

0 commit comments

Comments
 (0)