|
10 | 10 | //
|
11 | 11 | //===----------------------------------------------------------------------===//
|
12 | 12 |
|
| 13 | +#include "mlir/Dialect/Affine/Utils.h" |
13 | 14 | #include "mlir/Dialect/Linalg/IR/Linalg.h"
|
14 | 15 |
|
15 | 16 | #include "mlir/AsmParser/AsmParser.h"
|
|
55 | 56 | using namespace mlir;
|
56 | 57 | using namespace mlir::linalg;
|
57 | 58 |
|
| 59 | + |
| 60 | +SmallVector<int64_t> computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos, |
| 61 | + int64_t rank) { |
| 62 | + SmallVector<int64_t> interchangeVector; |
| 63 | + interchangeVector.reserve(dimsPos.size()); |
| 64 | + // First map dims and their position. For example, dims_pos = [2, 0] will map |
| 65 | + // to: |
| 66 | + // [ |
| 67 | + // [ key: 2, value: 0] |
| 68 | + // [ key: 0, value: 1] |
| 69 | + // ] |
| 70 | + // where key is the idx in dims_pos while value its position in dims_pos. |
| 71 | + DenseMap<int64_t, int64_t> dimsAndPosMapping; |
| 72 | + for (int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++) { |
| 73 | + dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx; |
| 74 | + } |
| 75 | + |
| 76 | + // Scan the position in order and insert the value in the map |
| 77 | + // to compute the interchange vector. |
| 78 | + for (int64_t dimsIdx = 0; dimsIdx < rank; dimsIdx++) { |
| 79 | + if (dimsAndPosMapping.count(dimsIdx)) { |
| 80 | + interchangeVector.push_back(dimsAndPosMapping[dimsIdx]); |
| 81 | + } |
| 82 | + } |
| 83 | + return interchangeVector; |
| 84 | +} |
| 85 | + |
| 86 | +template <typename T> |
| 87 | +SmallVector<T> interchange(ArrayRef<T> elements, |
| 88 | + ArrayRef<int64_t> interchangeVector, |
| 89 | + int offset = 0) { |
| 90 | + SmallVector<T> vec = llvm::to_vector(elements); |
| 91 | + for (auto [idx, val] : llvm::enumerate(interchangeVector)) { |
| 92 | + vec[idx + offset] = elements[val + offset]; |
| 93 | + } |
| 94 | + return vec; |
| 95 | +} |
| 96 | + |
| 97 | + |
58 | 98 | /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
|
59 | 99 | static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
|
60 | 100 | int64_t dim) {
|
@@ -4756,6 +4796,140 @@ RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
|
4756 | 4796 | return RankedTensorType::get(resultShape, sourceType.getElementType());
|
4757 | 4797 | }
|
4758 | 4798 |
|
| 4799 | +/// Generate the body of the innermost loop of the scalar implementation |
| 4800 | +/// of `pack` operation. |
| 4801 | +static void generatePackOpScalarImplementationBody(PackOp packOp, |
| 4802 | + OpBuilder &builder, |
| 4803 | + Location loc, |
| 4804 | + ValueRange ivs) { |
| 4805 | + // Note: `ivs` are already in the correct order, possibly interchanged based |
| 4806 | + // on `dims_pos`. However, connecting the loops with the access patterns is |
| 4807 | + // difficult - What is the relation between the position of the tile loop and |
| 4808 | + // the point loop? However, if we interchange `ivs` once more to go to the |
| 4809 | + // canonical blocking format: ABCabc, this connection becomes trivial: Each |
| 4810 | + // point loop is pointLoopsOffset + inputRank away from the tiled loop. |
| 4811 | + ArrayRef<int64_t> dimsToInnerBlock = packOp.getInnerDimsPos(); |
| 4812 | + ArrayRef<int64_t> dimsToOuterBlock = packOp.getOuterDimsPerm(); |
| 4813 | + |
| 4814 | + SmallVector<Value> interchangedIvs = ivs; |
| 4815 | + SmallVector<int64_t> interchangeVector = |
| 4816 | + computeInterchangeFromDimPos(dimsToInnerBlock, packOp.getInputRank()); |
| 4817 | + interchangedIvs = interchange<Value>(interchangedIvs, interchangeVector, |
| 4818 | + /*offset=*/packOp.getInputRank()); |
| 4819 | + if (!dimsToOuterBlock.empty()) { |
| 4820 | + interchangeVector = |
| 4821 | + computeInterchangeFromDimPos(dimsToOuterBlock, packOp.getInputRank()); |
| 4822 | + interchangedIvs = |
| 4823 | + interchange<Value>(interchangedIvs, interchangeVector, /*offset=*/0); |
| 4824 | + } |
| 4825 | + |
| 4826 | + SmallVector<OpFoldResult> tiles = packOp.getMixedTiles(); |
| 4827 | + DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
| 4828 | + packOp.getDimAndTileMapping(); |
| 4829 | + SmallVector<OpFoldResult> sourceIndices; |
| 4830 | + size_t pointLoopsOffset = 0; |
| 4831 | + int64_t inputRank = packOp.getInputRank(); |
| 4832 | + for (auto dim : llvm::seq<int64_t>(0, inputRank)) { |
| 4833 | + if (dimAndTileMapping.count(dim)) { |
| 4834 | + AffineExpr i, j, tile; |
| 4835 | + bindDims(builder.getContext(), i, j); |
| 4836 | + bindSymbols(builder.getContext(), tile); |
| 4837 | + OpFoldResult sourceIndex = affine::makeComposedFoldedAffineApply( |
| 4838 | + builder, loc, i * tile + j, |
| 4839 | + ArrayRef<OpFoldResult>{ |
| 4840 | + interchangedIvs[dim], |
| 4841 | + interchangedIvs[pointLoopsOffset + packOp.getInputRank()], |
| 4842 | + dimAndTileMapping[dim]}); |
| 4843 | + sourceIndices.push_back(sourceIndex); |
| 4844 | + ++pointLoopsOffset; |
| 4845 | + } else { |
| 4846 | + sourceIndices.push_back(interchangedIvs[dim]); |
| 4847 | + } |
| 4848 | + } |
| 4849 | + |
| 4850 | + auto createLoad = [&]() -> Value { |
| 4851 | + return builder.create<memref::LoadOp>( |
| 4852 | + loc, packOp.getInput(), |
| 4853 | + getValueOrCreateConstantIndexOp(builder, loc, sourceIndices)); |
| 4854 | + }; |
| 4855 | + Value scalar; |
| 4856 | + if (auto paddingValue = packOp.getPaddingValue()) { |
| 4857 | + ArithBuilder arithBuilder(builder, loc); |
| 4858 | + Value isInBounds; |
| 4859 | + for (auto dim : llvm::seq<int64_t>(0, inputRank)) { |
| 4860 | + Value idx = |
| 4861 | + getValueOrCreateConstantIndexOp(builder, loc, sourceIndices[dim]); |
| 4862 | + Value dimValue = getValueOrCreateConstantIndexOp( |
| 4863 | + builder, loc, getDimValue(builder, loc, packOp.getInput(), dim)); |
| 4864 | + Value cond = arithBuilder.slt( |
| 4865 | + idx, dimValue); |
| 4866 | + isInBounds = dim == 0 ? cond : arithBuilder._and(isInBounds, cond); |
| 4867 | + } |
| 4868 | + scalar = builder |
| 4869 | + .create<scf::IfOp>( |
| 4870 | + loc, isInBounds, /*thenBuilder=*/ |
| 4871 | + [&](OpBuilder &b, Location l) { |
| 4872 | + b.create<scf::YieldOp>(l, createLoad()); |
| 4873 | + }, |
| 4874 | + /*elseBuilder=*/ |
| 4875 | + [&](OpBuilder &b, Location l) { |
| 4876 | + b.create<scf::YieldOp>(l, paddingValue); |
| 4877 | + }) |
| 4878 | + .getResult(0); |
| 4879 | + } else { |
| 4880 | + scalar = createLoad(); |
| 4881 | + } |
| 4882 | + |
| 4883 | + builder.create<memref::StoreOp>(loc, scalar, packOp.getOutput(), ivs); |
| 4884 | +} |
| 4885 | + |
| 4886 | +LogicalResult PackOp::generateScalarImplementation(OpBuilder &builder, |
| 4887 | + Location loc, |
| 4888 | + ValueRange ivs) { |
| 4889 | + OpBuilder::InsertionGuard g(builder); |
| 4890 | + // The `ivs` already represent the position into the output tensor for the |
| 4891 | + // non data-tile dimensions. |
| 4892 | + SmallVector<Value> ivVec = llvm::to_vector(ivs); |
| 4893 | + ReifiedRankedShapedTypeDims outputShape; |
| 4894 | + if (failed(reifyResultShapes(builder, outputShape))) { |
| 4895 | + return getOperation()->emitOpError("failed to reify result shape"); |
| 4896 | + } |
| 4897 | + if (outputShape.size() != 1 || outputShape[0].size() != getOutputRank()) { |
| 4898 | + return getOperation()->emitOpError( |
| 4899 | + "expected shape of one result value of rank") |
| 4900 | + << getOutputRank(); |
| 4901 | + } |
| 4902 | + |
| 4903 | + // Generate the loops that iterate over the data tile. |
| 4904 | + Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); |
| 4905 | + Value one = builder.create<arith::ConstantIndexOp>(loc, 1); |
| 4906 | + |
| 4907 | + // All loops except the innermost are simple loops that just iterate |
| 4908 | + // over the tile dimensions. |
| 4909 | + for (auto dataTileDim : |
| 4910 | + llvm::seq<unsigned>(getInputRank(), getOutputRank() - 1)) { |
| 4911 | + Value ub = getValueOrCreateConstantIndexOp(builder, loc, |
| 4912 | + outputShape[0][dataTileDim]); |
| 4913 | + scf::ForOp loop = builder.create<scf::ForOp>(loc, zero, ub, one); |
| 4914 | + builder.setInsertionPointToStart(loop.getBody()); |
| 4915 | + ivVec.push_back(loop.getInductionVar()); |
| 4916 | + } |
| 4917 | + // The body of the innermost loops does the actual data movement. |
| 4918 | + builder.create<scf::ForOp>( |
| 4919 | + loc, zero, |
| 4920 | + getValueOrCreateConstantIndexOp(builder, loc, outputShape[0].back()), one, |
| 4921 | + ValueRange{}, |
| 4922 | + [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, |
| 4923 | + ValueRange regionIterArgs) { |
| 4924 | + ivVec.push_back(iv); |
| 4925 | + generatePackOpScalarImplementationBody(*this, bodyBuilder, bodyLoc, |
| 4926 | + ivVec); |
| 4927 | + bodyBuilder.create<scf::YieldOp>(bodyLoc); |
| 4928 | + }); |
| 4929 | + return success(); |
| 4930 | +} |
| 4931 | + |
| 4932 | + |
4759 | 4933 | Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
|
4760 | 4934 | ArrayRef<OpFoldResult> innerTileSizes,
|
4761 | 4935 | ArrayRef<int64_t> innerDimsPos,
|
@@ -5080,6 +5254,65 @@ void UnPackOp::getAsmResultNames(
|
5080 | 5254 | setNameFn(getResult(), "unpack");
|
5081 | 5255 | }
|
5082 | 5256 |
|
| 5257 | +LogicalResult UnPackOp::generateScalarImplementation(OpBuilder &builder, |
| 5258 | + Location loc, |
| 5259 | + ValueRange ivs) { |
| 5260 | + return llvm::success(); |
| 5261 | + OpBuilder::InsertionGuard g(builder); |
| 5262 | + ReifiedRankedShapedTypeDims outputShape; |
| 5263 | + |
| 5264 | + if (failed(reifyResultShapes(builder, outputShape))) { |
| 5265 | + return getOperation()->emitError("failed to reify result shapes"); |
| 5266 | + } |
| 5267 | + if (outputShape.size() != 1 || outputShape[0].size() != getOutputRank()) { |
| 5268 | + return getOperation()->emitError( |
| 5269 | + "expected shape of one result value of rank") |
| 5270 | + << getOutputRank(); |
| 5271 | + } |
| 5272 | + |
| 5273 | + DenseMap<int64_t, OpFoldResult> dimAndTileMapping = getDimAndTileMapping(); |
| 5274 | + // untiled loops and tile loops induction variables. |
| 5275 | + SmallVector<Value> inputIvs; |
| 5276 | + SmallVector<Value> inputIvsPointLoops; |
| 5277 | + inputIvs.reserve(getOutputRank()); |
| 5278 | + inputIvsPointLoops.reserve(dimAndTileMapping.size()); |
| 5279 | + for (auto dim : llvm::seq<int64_t>(0, getOutputRank())) { |
| 5280 | + if (dimAndTileMapping.count(dim)) { |
| 5281 | + affine::DivModValue divMod = |
| 5282 | + affine::getDivMod(builder, loc, ivs[dim], |
| 5283 | + getValueOrCreateConstantIndexOp( |
| 5284 | + builder, loc, dimAndTileMapping[dim])); |
| 5285 | + inputIvsPointLoops.push_back(divMod.remainder); |
| 5286 | + inputIvs.push_back(divMod.quotient); |
| 5287 | + } else { |
| 5288 | + inputIvs.push_back(ivs[dim]); |
| 5289 | + } |
| 5290 | + } |
| 5291 | + |
| 5292 | + // TODO: (lorenzo) simplify the logic a bit. There is `ivs`, |
| 5293 | + // `inputIvsPointLoops` and `inputIvs`. |
| 5294 | + assert(inputIvsPointLoops.size() + inputIvs.size() == getInputRank() && |
| 5295 | + "expect same number of iduction variables equals to input rank"); |
| 5296 | + // interchange the point loops induction variables based on `inner_dim_pos`. |
| 5297 | + ArrayRef<int64_t> innerDims = getInnerDimsPos(); |
| 5298 | + SmallVector<int64_t> interchangeVector = |
| 5299 | + computeInterchangeFromDimPos(innerDims, getOutputRank()); |
| 5300 | + SmallVector<Value> interchangedInputIvsPointLoops = inputIvsPointLoops; |
| 5301 | + interchangedInputIvsPointLoops = interchange<Value>( |
| 5302 | + interchangedInputIvsPointLoops, interchangeVector, /*offset=*/0); |
| 5303 | + // interchange the tiled loops induction variables based on `outer_dims_perm`. |
| 5304 | + ArrayRef<int64_t> outerDims = getOuterDimsPerm(); |
| 5305 | + if (!outerDims.empty()) { |
| 5306 | + inputIvs = interchange<Value>(inputIvs, outerDims, /*offset=*/0); |
| 5307 | + } |
| 5308 | + |
| 5309 | + llvm::append_range(inputIvs, interchangedInputIvsPointLoops); |
| 5310 | + Value scalar = builder.create<memref::LoadOp>(loc, getInput(), inputIvs); |
| 5311 | + builder.create<memref::StoreOp>(loc, scalar, getOutput(), ivs); |
| 5312 | + return success(); |
| 5313 | +} |
| 5314 | + |
| 5315 | + |
5083 | 5316 | LogicalResult
|
5084 | 5317 | UnPackOp::reifyResultShapes(OpBuilder &builder,
|
5085 | 5318 | ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
|
0 commit comments