|
| 1 | +//===-------- SplitReduction.cpp - Split reduction dimesion ---------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// This file implements linalg transformation to break a reduction dimension |
| 10 | +// between a parallel and a reduction dimension. |
| 11 | +// |
| 12 | +//===----------------------------------------------------------------------===// |
| 13 | + |
| 14 | +#include "mlir/Analysis/SliceAnalysis.h" |
| 15 | +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" |
| 16 | +#include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 17 | +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 18 | +#include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 19 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 20 | +#include "mlir/IR/PatternMatch.h" |
| 21 | + |
| 22 | +using namespace mlir; |
| 23 | +using namespace mlir::linalg; |
| 24 | + |
| 25 | +/// Return the identity numeric value associated to the give op. |
| 26 | +static Optional<Attribute> getIdentity(Operation *op) { |
| 27 | + // Builder only used as helper for attribute creation. |
| 28 | + OpBuilder b(op->getContext()); |
| 29 | + Type resultType = op->getResult(0).getType(); |
| 30 | + if (auto floatType = resultType.dyn_cast<FloatType>()) { |
| 31 | + const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); |
| 32 | + if (isa<arith::AddFOp>(op)) |
| 33 | + return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); |
| 34 | + if (isa<arith::MulFOp>(op)) |
| 35 | + return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); |
| 36 | + if (isa<arith::MaxFOp>(op)) |
| 37 | + return b.getFloatAttr(resultType, |
| 38 | + llvm::APFloat::getLargest(semantic, true)); |
| 39 | + if (isa<arith::MinFOp>(op)) |
| 40 | + return b.getFloatAttr(resultType, |
| 41 | + llvm::APFloat::getLargest(semantic, true)); |
| 42 | + return llvm::None; |
| 43 | + } |
| 44 | + if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op)) |
| 45 | + return b.getIntegerAttr(resultType, 0); |
| 46 | + if (isa<arith::AndIOp>(op)) |
| 47 | + return b.getIntegerAttr(resultType, -1); |
| 48 | + if (isa<arith::MaxSIOp>(op)) |
| 49 | + return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min()); |
| 50 | + if (isa<arith::MinSIOp>(op)) |
| 51 | + return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max()); |
| 52 | + if (isa<arith::MulIOp>(op)) |
| 53 | + return b.getIntegerAttr(resultType, 1); |
| 54 | + return llvm::None; |
| 55 | +} |
| 56 | + |
| 57 | +FailureOr<LinalgOp> |
| 58 | +mlir::linalg::splitReduction(PatternRewriter &b, LinalgOp op, |
| 59 | + ControlSplitReductionFn controlSplitReductionFn, |
| 60 | + LinalgTransformationFilter filter) { |
| 61 | + if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() || |
| 62 | + op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 || |
| 63 | + !op.hasOnlyProjectedPermutations()) |
| 64 | + return b.notifyMatchFailure(op, "precondition not met"); |
| 65 | + std::pair<int64_t, unsigned> control = controlSplitReductionFn(op); |
| 66 | + int64_t ratio = control.first; |
| 67 | + unsigned insertDimIndex = control.second; |
| 68 | + if (ratio <= 1) |
| 69 | + return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); |
| 70 | + SmallVector<unsigned> dims; |
| 71 | + op.getReductionDims(dims); |
| 72 | + assert(dims.size() == 1); |
| 73 | + unsigned reductionDim = dims[0]; |
| 74 | + Optional<SmallVector<int64_t, 4>> loopRanges = op.getStaticLoopRanges(); |
| 75 | + if (!loopRanges) |
| 76 | + return b.notifyMatchFailure(op, "Cannot analyze loops"); |
| 77 | + int64_t reductionDimSize = (*loopRanges)[reductionDim]; |
| 78 | + if (reductionDimSize == ShapedType::kDynamicSize || |
| 79 | + reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges->size()) |
| 80 | + return b.notifyMatchFailure( |
| 81 | + op, "Reduction dimension not divisible by split ratio"); |
| 82 | + SmallVector<Operation *, 4> combinerOps; |
| 83 | + if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || |
| 84 | + combinerOps.size() != 1) |
| 85 | + return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); |
| 86 | + Operation *reductionOp = combinerOps[0]; |
| 87 | + Optional<Attribute> identity = getIdentity(reductionOp); |
| 88 | + if (!identity) |
| 89 | + return b.notifyMatchFailure(op, "Unknown identity value for the redution"); |
| 90 | + |
| 91 | + Location loc = op->getLoc(); |
| 92 | + SmallVector<Value> newInputs; |
| 93 | + SmallVector<AffineMap> newMaps; |
| 94 | + // Calculate the new shapes and indexing maps of the input operands. |
| 95 | + for (OpOperand *operand : op.getInputOperands()) { |
| 96 | + AffineMap map = op.getTiedIndexingMap(operand); |
| 97 | + SmallVector<int64_t> newShape; |
| 98 | + SmallVector<AffineExpr> exprs; |
| 99 | + SmallVector<ReassociationIndices> reassociation; |
| 100 | + unsigned index = 0; |
| 101 | + for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) { |
| 102 | + unsigned dim = map.getDimPosition(idx); |
| 103 | + if (reductionDim == dim) { |
| 104 | + newShape.push_back(ratio); |
| 105 | + newShape.push_back(op.getShape(operand)[idx] / ratio); |
| 106 | + reassociation.push_back({index++, index++}); |
| 107 | + exprs.push_back(b.getAffineDimExpr(insertDimIndex)); |
| 108 | + exprs.push_back( |
| 109 | + b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); |
| 110 | + continue; |
| 111 | + } |
| 112 | + newShape.push_back(op.getShape(operand)[idx]); |
| 113 | + exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); |
| 114 | + reassociation.push_back({index++}); |
| 115 | + } |
| 116 | + newMaps.push_back( |
| 117 | + AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext())); |
| 118 | + // If the shape is unchanged the input doesn't change. |
| 119 | + if (newShape == op.getShape(operand)) { |
| 120 | + newInputs.push_back(operand->get()); |
| 121 | + continue; |
| 122 | + } |
| 123 | + Type newType = RankedTensorType::get( |
| 124 | + newShape, |
| 125 | + operand->get().getType().cast<RankedTensorType>().getElementType()); |
| 126 | + Value newInput = b.create<tensor::ExpandShapeOp>( |
| 127 | + loc, newType, operand->get(), reassociation); |
| 128 | + newInputs.push_back(newInput); |
| 129 | + } |
| 130 | + // Calculate the new output map and shape, we insert the new dimension based |
| 131 | + // on the index returned by `controlSplitReductionFn`. |
| 132 | + SmallVector<int64_t> newOutputShape; |
| 133 | + AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0)); |
| 134 | + ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0)); |
| 135 | + SmallVector<AffineExpr> outputExpr; |
| 136 | + for (unsigned idx : |
| 137 | + llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) { |
| 138 | + if (idx == insertDimIndex) { |
| 139 | + newOutputShape.push_back(ratio); |
| 140 | + outputExpr.push_back(b.getAffineDimExpr(insertDimIndex)); |
| 141 | + continue; |
| 142 | + } |
| 143 | + unsigned oldDim = idx < insertDimIndex ? idx : idx - 1; |
| 144 | + newOutputShape.push_back(oldShape[oldDim]); |
| 145 | + unsigned dim = oldOutputMap.getDimPosition(oldDim); |
| 146 | + outputExpr.push_back( |
| 147 | + b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); |
| 148 | + } |
| 149 | + Value initTensor = b.create<linalg::InitTensorOp>( |
| 150 | + loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); |
| 151 | + Value constantOp = b.create<arith::ConstantOp>(loc, *identity); |
| 152 | + Value identityTensor = |
| 153 | + b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor) |
| 154 | + .getResult(0); |
| 155 | + |
| 156 | + newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, |
| 157 | + op.getContext())); |
| 158 | + SmallVector<StringRef> newIteratorTypes; |
| 159 | + for (auto &it : llvm::enumerate(op.iterator_types())) { |
| 160 | + if (insertDimIndex == it.index()) |
| 161 | + newIteratorTypes.push_back(getParallelIteratorTypeName()); |
| 162 | + newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue()); |
| 163 | + } |
| 164 | + // Create the new op matching the original op with an extra parallel |
| 165 | + // dimension. |
| 166 | + GenericOp genericOp = b.create<GenericOp>( |
| 167 | + loc, TypeRange({initTensor.getType()}), newInputs, |
| 168 | + ValueRange({identityTensor}), newMaps, newIteratorTypes); |
| 169 | + b.inlineRegionBefore(op->getRegion(0), genericOp.region(), |
| 170 | + genericOp.region().begin()); |
| 171 | + |
| 172 | + // Then create a new reduction that only reduce the newly added dimension from |
| 173 | + // the previous op. |
| 174 | + unsigned intermRank = newOutputShape.size(); |
| 175 | + AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); |
| 176 | + SmallVector<Value> outputOperands = op.getOutputOperands(); |
| 177 | + SmallVector<StringRef> reductionIteratorTypes; |
| 178 | + SmallVector<AffineExpr> exprs; |
| 179 | + for (unsigned i : llvm::seq<unsigned>(0, intermRank)) { |
| 180 | + if (insertDimIndex == i) { |
| 181 | + reductionIteratorTypes.push_back(getReductionIteratorTypeName()); |
| 182 | + } else { |
| 183 | + exprs.push_back(b.getAffineDimExpr(i)); |
| 184 | + reductionIteratorTypes.push_back(getParallelIteratorTypeName()); |
| 185 | + } |
| 186 | + } |
| 187 | + AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); |
| 188 | + SmallVector<AffineMap> reductionMaps = {inputMap, outputMap}; |
| 189 | + |
| 190 | + auto reduction = b.create<GenericOp>( |
| 191 | + loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), |
| 192 | + outputOperands, reductionMaps, reductionIteratorTypes, |
| 193 | + [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { |
| 194 | + Operation *clonedReductionOp = b.clone(*reductionOp); |
| 195 | + clonedReductionOp->setOperand(0, inputs[0]); |
| 196 | + clonedReductionOp->setOperand(1, inputs[1]); |
| 197 | + b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); |
| 198 | + }); |
| 199 | + b.replaceOp(op, reduction.getResults()); |
| 200 | + filter.replaceLinalgTransformationFilter(b, genericOp); |
| 201 | + filter.replaceLinalgTransformationFilter(b, reduction); |
| 202 | + return cast<LinalgOp>(genericOp.getOperation()); |
| 203 | +} |
| 204 | + |
| 205 | +namespace { |
| 206 | + |
| 207 | +struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> { |
| 208 | + /// Construct a generic pattern applied to all LinalgOp that verify `filter`. |
| 209 | + LinalgSplitReduction(MLIRContext *context, |
| 210 | + ControlSplitReductionFn controlSplitReductionFn, |
| 211 | + LinalgTransformationFilter f, PatternBenefit benefit = 1) |
| 212 | + : OpInterfaceRewritePattern<LinalgOp>(context, benefit), |
| 213 | + controlSplitReductionFn(controlSplitReductionFn), filter(std::move(f)) { |
| 214 | + } |
| 215 | + |
| 216 | + LogicalResult matchAndRewrite(LinalgOp op, |
| 217 | + PatternRewriter &rewriter) const override { |
| 218 | + return splitReduction(rewriter, op, controlSplitReductionFn, filter); |
| 219 | + } |
| 220 | + |
| 221 | +private: |
| 222 | + ControlSplitReductionFn controlSplitReductionFn; |
| 223 | + LinalgTransformationFilter filter; |
| 224 | +}; |
| 225 | + |
| 226 | +} // namespace |
| 227 | + |
| 228 | +void linalg::populateSplitReductionPattern( |
| 229 | + RewritePatternSet &patterns, |
| 230 | + ControlSplitReductionFn controlSplitReductionFn, |
| 231 | + LinalgTransformationFilter f) { |
| 232 | + patterns.add<LinalgSplitReduction>(patterns.getContext(), |
| 233 | + controlSplitReductionFn, f); |
| 234 | +} |
0 commit comments