Skip to content

Commit 33d2a78

Browse files
committed
[mlir][linalg] Add pattern to split reduction dimension in a linalg op
This transformation allow to break up a reduction dimension in a parallel and a reduction dimension. This is followed by a separate reduction op. This allows to generate tree reduction which is beneficial on target allowing to take advantage parallelism. Differential Revision: https://reviews.llvm.org/D122045
1 parent 9951578 commit 33d2a78

File tree

6 files changed

+445
-0
lines changed

6 files changed

+445
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
10901090
/*methodName=*/"getRegionBuilder",
10911091
(ins),
10921092
[{ return ConcreteOp::getRegionBuilder(); }]
1093+
>,
1094+
InterfaceMethod<
1095+
/*desc=*/[{
1096+
Return true if all the indexing maps are projected permutations.
1097+
Otherwise return false.
1098+
}],
1099+
/*retTy=*/"bool",
1100+
/*methodName=*/"hasOnlyProjectedPermutations",
1101+
(ins),
1102+
[{
1103+
return llvm::all_of($_op.getIndexingMaps(),
1104+
[](AffineMap map) { return map.isProjectedPermutation(); });
1105+
}]
10931106
>
10941107
];
10951108

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,6 +1447,64 @@ class TilingPatterns<OpTy, OpTypes...> {
14471447
}
14481448
};
14491449

1450+
/// Function signature to control reduction splitting. This returns a pair
1451+
/// containing a ratio and a dimension index. The ratio is used to split the
1452+
/// reduction dimension. The dimension index is used to control where the extra
1453+
/// dimension is added to the intermediate tensor shape. If the ratio value is
1454+
/// less or equal to 1 then nothing will be done.
1455+
using ControlSplitReductionFn =
1456+
std::function<std::pair<int64_t, unsigned>(LinalgOp op)>;
1457+
1458+
/// Patterns to apply `splitReduction` below.
1459+
void populateSplitReductionPattern(
1460+
RewritePatternSet &patterns,
1461+
ControlSplitReductionFn controlSplitReductionFn,
1462+
LinalgTransformationFilter f = LinalgTransformationFilter());
1463+
1464+
/// Apply transformation to split the single linalg op reduction into a parallel
1465+
/// and reduction dimension. Then create a new linalg.generic op doing the rest
1466+
/// of the reduction. Return the new linalg op with an extra parallel dimension
1467+
/// or failure if the transformation didn't happen.
1468+
/// Example:
1469+
/// ```
1470+
/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
1471+
/// affine_map<(d0) -> ()>],
1472+
/// iterator_types = ["reduction"]}
1473+
/// ins(%in : tensor<32xf32>)
1474+
/// outs(%out : tensor<f32>) {
1475+
/// ^bb0(%arg1: f32, %arg2: f32):
1476+
/// %y = arith.addf %arg1, %arg2 : f32
1477+
/// linalg.yield %y : f32
1478+
/// } -> tensor<f32>
1479+
/// ```
1480+
/// To:
1481+
/// ```
1482+
/// %cst = arith.constant 0.000000e+00 : f32
1483+
/// %0 = tensor.expand_shape %in [[0, 1]] : tensor<32xf32> into tensor<4x8xf32>
1484+
/// %1 = linalg.init_tensor [4] : tensor<4xf32>
1485+
/// %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32>
1486+
/// %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
1487+
/// affine_map<(d0, d1) -> (d0)>],
1488+
/// iterator_types = ["parallel", "reduction"]}
1489+
/// ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) {
1490+
/// ^bb0(%arg3: f32, %arg5: f32):
1491+
/// %5 = arith.addf %arg3, %arg4 : f32
1492+
/// linalg.yield %5 : f32
1493+
/// } -> tensor<4xf32>
1494+
/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
1495+
/// affine_map<(d0) -> ()>],
1496+
/// iterator_types = ["reduction"]}
1497+
/// ins(%3 : tensor<4xf32>) outs(%out : tensor<f32>) {
1498+
/// ^bb0(%arg3: f32, %arg4: f32):
1499+
/// %5 = arith.addf %arg3, %arg4 : f32
1500+
/// linalg.yield %5 : f32
1501+
/// } -> tensor<f32>
1502+
/// ```
1503+
FailureOr<LinalgOp>
1504+
splitReduction(PatternRewriter &b, LinalgOp op,
1505+
ControlSplitReductionFn controlSplitReductionFn,
1506+
LinalgTransformationFilter f);
1507+
14501508
} // namespace linalg
14511509
} // namespace mlir
14521510

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
2020
PadOpInterchange.cpp
2121
Promotion.cpp
2222
SparseTensorRewriting.cpp
23+
SplitReduction.cpp
2324
Tiling.cpp
2425
Transforms.cpp
2526
Vectorization.cpp
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
//===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements linalg transformation to break a reduction dimension
10+
// between a parallel and a reduction dimension.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Analysis/SliceAnalysis.h"
15+
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
17+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18+
#include "mlir/Dialect/Linalg/Utils/Utils.h"
19+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
20+
#include "mlir/IR/PatternMatch.h"
21+
22+
using namespace mlir;
23+
using namespace mlir::linalg;
24+
25+
/// Return the identity numeric value associated to the give op.
26+
static Optional<Attribute> getIdentity(Operation *op) {
27+
// Builder only used as helper for attribute creation.
28+
OpBuilder b(op->getContext());
29+
Type resultType = op->getResult(0).getType();
30+
if (auto floatType = resultType.dyn_cast<FloatType>()) {
31+
const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
32+
if (isa<arith::AddFOp>(op))
33+
return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
34+
if (isa<arith::MulFOp>(op))
35+
return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
36+
if (isa<arith::MaxFOp>(op))
37+
return b.getFloatAttr(resultType,
38+
llvm::APFloat::getLargest(semantic, true));
39+
if (isa<arith::MinFOp>(op))
40+
return b.getFloatAttr(resultType,
41+
llvm::APFloat::getLargest(semantic, true));
42+
return llvm::None;
43+
}
44+
if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
45+
return b.getIntegerAttr(resultType, 0);
46+
if (isa<arith::AndIOp>(op))
47+
return b.getIntegerAttr(resultType, -1);
48+
if (isa<arith::MaxSIOp>(op))
49+
return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
50+
if (isa<arith::MinSIOp>(op))
51+
return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
52+
if (isa<arith::MulIOp>(op))
53+
return b.getIntegerAttr(resultType, 1);
54+
return llvm::None;
55+
}
56+
57+
FailureOr<LinalgOp>
58+
mlir::linalg::splitReduction(PatternRewriter &b, LinalgOp op,
59+
ControlSplitReductionFn controlSplitReductionFn,
60+
LinalgTransformationFilter filter) {
61+
if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
62+
op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
63+
!op.hasOnlyProjectedPermutations())
64+
return b.notifyMatchFailure(op, "precondition not met");
65+
std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
66+
int64_t ratio = control.first;
67+
unsigned insertDimIndex = control.second;
68+
if (ratio <= 1)
69+
return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
70+
SmallVector<unsigned> dims;
71+
op.getReductionDims(dims);
72+
assert(dims.size() == 1);
73+
unsigned reductionDim = dims[0];
74+
Optional<SmallVector<int64_t, 4>> loopRanges = op.getStaticLoopRanges();
75+
if (!loopRanges)
76+
return b.notifyMatchFailure(op, "Cannot analyze loops");
77+
int64_t reductionDimSize = (*loopRanges)[reductionDim];
78+
if (reductionDimSize == ShapedType::kDynamicSize ||
79+
reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges->size())
80+
return b.notifyMatchFailure(
81+
op, "Reduction dimension not divisible by split ratio");
82+
SmallVector<Operation *, 4> combinerOps;
83+
if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
84+
combinerOps.size() != 1)
85+
return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
86+
Operation *reductionOp = combinerOps[0];
87+
Optional<Attribute> identity = getIdentity(reductionOp);
88+
if (!identity)
89+
return b.notifyMatchFailure(op, "Unknown identity value for the redution");
90+
91+
Location loc = op->getLoc();
92+
SmallVector<Value> newInputs;
93+
SmallVector<AffineMap> newMaps;
94+
// Calculate the new shapes and indexing maps of the input operands.
95+
for (OpOperand *operand : op.getInputOperands()) {
96+
AffineMap map = op.getTiedIndexingMap(operand);
97+
SmallVector<int64_t> newShape;
98+
SmallVector<AffineExpr> exprs;
99+
SmallVector<ReassociationIndices> reassociation;
100+
unsigned index = 0;
101+
for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
102+
unsigned dim = map.getDimPosition(idx);
103+
if (reductionDim == dim) {
104+
newShape.push_back(ratio);
105+
newShape.push_back(op.getShape(operand)[idx] / ratio);
106+
reassociation.push_back({index++, index++});
107+
exprs.push_back(b.getAffineDimExpr(insertDimIndex));
108+
exprs.push_back(
109+
b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
110+
continue;
111+
}
112+
newShape.push_back(op.getShape(operand)[idx]);
113+
exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
114+
reassociation.push_back({index++});
115+
}
116+
newMaps.push_back(
117+
AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
118+
// If the shape is unchanged the input doesn't change.
119+
if (newShape == op.getShape(operand)) {
120+
newInputs.push_back(operand->get());
121+
continue;
122+
}
123+
Type newType = RankedTensorType::get(
124+
newShape,
125+
operand->get().getType().cast<RankedTensorType>().getElementType());
126+
Value newInput = b.create<tensor::ExpandShapeOp>(
127+
loc, newType, operand->get(), reassociation);
128+
newInputs.push_back(newInput);
129+
}
130+
// Calculate the new output map and shape, we insert the new dimension based
131+
// on the index returned by `controlSplitReductionFn`.
132+
SmallVector<int64_t> newOutputShape;
133+
AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0));
134+
ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0));
135+
SmallVector<AffineExpr> outputExpr;
136+
for (unsigned idx :
137+
llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
138+
if (idx == insertDimIndex) {
139+
newOutputShape.push_back(ratio);
140+
outputExpr.push_back(b.getAffineDimExpr(insertDimIndex));
141+
continue;
142+
}
143+
unsigned oldDim = idx < insertDimIndex ? idx : idx - 1;
144+
newOutputShape.push_back(oldShape[oldDim]);
145+
unsigned dim = oldOutputMap.getDimPosition(oldDim);
146+
outputExpr.push_back(
147+
b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
148+
}
149+
Value initTensor = b.create<linalg::InitTensorOp>(
150+
loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
151+
Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
152+
Value identityTensor =
153+
b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)
154+
.getResult(0);
155+
156+
newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
157+
op.getContext()));
158+
SmallVector<StringRef> newIteratorTypes;
159+
for (auto &it : llvm::enumerate(op.iterator_types())) {
160+
if (insertDimIndex == it.index())
161+
newIteratorTypes.push_back(getParallelIteratorTypeName());
162+
newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
163+
}
164+
// Create the new op matching the original op with an extra parallel
165+
// dimension.
166+
GenericOp genericOp = b.create<GenericOp>(
167+
loc, TypeRange({initTensor.getType()}), newInputs,
168+
ValueRange({identityTensor}), newMaps, newIteratorTypes);
169+
b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
170+
genericOp.region().begin());
171+
172+
// Then create a new reduction that only reduce the newly added dimension from
173+
// the previous op.
174+
unsigned intermRank = newOutputShape.size();
175+
AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
176+
SmallVector<Value> outputOperands = op.getOutputOperands();
177+
SmallVector<StringRef> reductionIteratorTypes;
178+
SmallVector<AffineExpr> exprs;
179+
for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
180+
if (insertDimIndex == i) {
181+
reductionIteratorTypes.push_back(getReductionIteratorTypeName());
182+
} else {
183+
exprs.push_back(b.getAffineDimExpr(i));
184+
reductionIteratorTypes.push_back(getParallelIteratorTypeName());
185+
}
186+
}
187+
AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
188+
SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
189+
190+
auto reduction = b.create<GenericOp>(
191+
loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
192+
outputOperands, reductionMaps, reductionIteratorTypes,
193+
[reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
194+
Operation *clonedReductionOp = b.clone(*reductionOp);
195+
clonedReductionOp->setOperand(0, inputs[0]);
196+
clonedReductionOp->setOperand(1, inputs[1]);
197+
b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
198+
});
199+
b.replaceOp(op, reduction.getResults());
200+
filter.replaceLinalgTransformationFilter(b, genericOp);
201+
filter.replaceLinalgTransformationFilter(b, reduction);
202+
return cast<LinalgOp>(genericOp.getOperation());
203+
}
204+
205+
namespace {
206+
207+
struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
208+
/// Construct a generic pattern applied to all LinalgOp that verify `filter`.
209+
LinalgSplitReduction(MLIRContext *context,
210+
ControlSplitReductionFn controlSplitReductionFn,
211+
LinalgTransformationFilter f, PatternBenefit benefit = 1)
212+
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
213+
controlSplitReductionFn(controlSplitReductionFn), filter(std::move(f)) {
214+
}
215+
216+
LogicalResult matchAndRewrite(LinalgOp op,
217+
PatternRewriter &rewriter) const override {
218+
return splitReduction(rewriter, op, controlSplitReductionFn, filter);
219+
}
220+
221+
private:
222+
ControlSplitReductionFn controlSplitReductionFn;
223+
LinalgTransformationFilter filter;
224+
};
225+
226+
} // namespace
227+
228+
void linalg::populateSplitReductionPattern(
229+
RewritePatternSet &patterns,
230+
ControlSplitReductionFn controlSplitReductionFn,
231+
LinalgTransformationFilter f) {
232+
patterns.add<LinalgSplitReduction>(patterns.getContext(),
233+
controlSplitReductionFn, f);
234+
}

0 commit comments

Comments
 (0)