Skip to content

Commit 02002ef

Browse files
committed
move iree LinalgExt::generateScalarImplementation to Linalg::generateScalarImplementation
1 parent 01cc1d1 commit 02002ef

File tree

2 files changed

+292
-1
lines changed

2 files changed

+292
-1
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,20 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
7777
/// with `inner_dims_pos` rather than the packed tensor.
7878
SmallVector<int64_t> getTiledOuterDims();
7979
}];
80-
80+
let extraClassDeclaration = commonExtraClassDeclaration # [{
81+
ShapedType getInputType() {
82+
return cast<ShapedType>(getInput().getType());
83+
}
84+
ShapedType getOutputType() {
85+
return cast<ShapedType>(getOutput().getType());
86+
}
87+
int64_t getInputRank() {
88+
return getInputType().getRank();
89+
}
90+
int64_t getOutputRank() {
91+
return getOutputType().getRank();
92+
}
93+
}];
8194
let hasVerifier = 1;
8295
}
8396

@@ -179,6 +192,28 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
179192
];
180193

181194
let extraClassDeclaration = commonExtraClassDeclaration # [{
195+
Value getOutput() {
196+
return getDpsInitOperand(0)->get();
197+
}
198+
199+
// Return the input operand.
200+
Value getInput() {
201+
return getDpsInputOperand(0)->get();
202+
}
203+
ShapedType getInputType() {
204+
return cast<ShapedType>(getInput().getType());
205+
}
206+
ShapedType getOutputType() {
207+
return cast<ShapedType>(getDest().getType()); // getDest() 사용
208+
}
209+
int64_t getInputRank() {
210+
return getInputType().getRank();
211+
}
212+
int64_t getOutputRank() {
213+
return getOutputType().getRank();
214+
}
215+
216+
LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
182217
// Method to get the shape of the result as `SmallVector<OpFoldResult>`.
183218
// This is a static method to allow getting the shape of the destination
184219
// expected while creating a `pack` op.
@@ -229,6 +264,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
229264
/// 2. pads the other ones, and
230265
/// 3. doesn't shuffle the dimensions
231266
bool isLikePad();
267+
232268
}];
233269

234270
let hasCanonicalizeMethod = 1;
@@ -303,6 +339,28 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
303339
];
304340

305341
let extraClassDeclaration = commonExtraClassDeclaration # [{
342+
Value getOutput() {
343+
return getDpsInitOperand(0)->get();
344+
}
345+
346+
// Return the input operand.
347+
Value getInput() {
348+
return getDpsInputOperand(0)->get();
349+
}
350+
ShapedType getInputType() {
351+
return cast<ShapedType>(getInput().getType());
352+
}
353+
ShapedType getOutputType() {
354+
return cast<ShapedType>(getDest().getType()); // getDest() 사용
355+
}
356+
int64_t getInputRank() {
357+
return getInputType().getRank();
358+
}
359+
int64_t getOutputRank() {
360+
return getOutputType().getRank();
361+
}
362+
LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
363+
306364
static Value createDestinationTensor(OpBuilder &b, Location loc,
307365
Value source, ArrayRef<OpFoldResult> innerTileSizes,
308366
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "mlir/Dialect/Affine/Utils.h"
1314
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1415

1516
#include "mlir/AsmParser/AsmParser.h"
@@ -55,6 +56,45 @@
5556
using namespace mlir;
5657
using namespace mlir::linalg;
5758

59+
60+
SmallVector<int64_t> computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos,
61+
int64_t rank) {
62+
SmallVector<int64_t> interchangeVector;
63+
interchangeVector.reserve(dimsPos.size());
64+
// First map dims and their position. For example, dims_pos = [2, 0] will map
65+
// to:
66+
// [
67+
// [ key: 2, value: 0]
68+
// [ key: 0, value: 1]
69+
// ]
70+
// where key is the idx in dims_pos while value its position in dims_pos.
71+
DenseMap<int64_t, int64_t> dimsAndPosMapping;
72+
for (int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++) {
73+
dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx;
74+
}
75+
76+
// Scan the position in order and insert the value in the map
77+
// to compute the interchange vector.
78+
for (int64_t dimsIdx = 0; dimsIdx < rank; dimsIdx++) {
79+
if (dimsAndPosMapping.count(dimsIdx)) {
80+
interchangeVector.push_back(dimsAndPosMapping[dimsIdx]);
81+
}
82+
}
83+
return interchangeVector;
84+
}
85+
86+
template <typename T>
87+
SmallVector<T> interchange(ArrayRef<T> elements,
88+
ArrayRef<int64_t> interchangeVector,
89+
int offset = 0) {
90+
SmallVector<T> vec = llvm::to_vector(elements);
91+
for (auto [idx, val] : llvm::enumerate(interchangeVector)) {
92+
vec[idx + offset] = elements[val + offset];
93+
}
94+
return vec;
95+
}
96+
97+
5898
/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
5999
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
60100
int64_t dim) {
@@ -4756,6 +4796,140 @@ RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
47564796
return RankedTensorType::get(resultShape, sourceType.getElementType());
47574797
}
47584798

4799+
/// Generate the body of the innermost loop of the scalar implementation
4800+
/// of `pack` operation.
4801+
static void generatePackOpScalarImplementationBody(PackOp packOp,
4802+
OpBuilder &builder,
4803+
Location loc,
4804+
ValueRange ivs) {
4805+
// Note: `ivs` are already in the correct order, possibly interchanged based
4806+
// on `dims_pos`. However, connecting the loops with the access patterns is
4807+
// difficult - What is the relation between the position of the tile loop and
4808+
// the point loop? However, if we interchange `ivs` once more to go to the
4809+
// canonical blocking format: ABCabc, this connection becomes trivial: Each
4810+
// point loop is pointLoopsOffset + inputRank away from the tiled loop.
4811+
ArrayRef<int64_t> dimsToInnerBlock = packOp.getInnerDimsPos();
4812+
ArrayRef<int64_t> dimsToOuterBlock = packOp.getOuterDimsPerm();
4813+
4814+
SmallVector<Value> interchangedIvs = ivs;
4815+
SmallVector<int64_t> interchangeVector =
4816+
computeInterchangeFromDimPos(dimsToInnerBlock, packOp.getInputRank());
4817+
interchangedIvs = interchange<Value>(interchangedIvs, interchangeVector,
4818+
/*offset=*/packOp.getInputRank());
4819+
if (!dimsToOuterBlock.empty()) {
4820+
interchangeVector =
4821+
computeInterchangeFromDimPos(dimsToOuterBlock, packOp.getInputRank());
4822+
interchangedIvs =
4823+
interchange<Value>(interchangedIvs, interchangeVector, /*offset=*/0);
4824+
}
4825+
4826+
SmallVector<OpFoldResult> tiles = packOp.getMixedTiles();
4827+
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
4828+
packOp.getDimAndTileMapping();
4829+
SmallVector<OpFoldResult> sourceIndices;
4830+
size_t pointLoopsOffset = 0;
4831+
int64_t inputRank = packOp.getInputRank();
4832+
for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
4833+
if (dimAndTileMapping.count(dim)) {
4834+
AffineExpr i, j, tile;
4835+
bindDims(builder.getContext(), i, j);
4836+
bindSymbols(builder.getContext(), tile);
4837+
OpFoldResult sourceIndex = affine::makeComposedFoldedAffineApply(
4838+
builder, loc, i * tile + j,
4839+
ArrayRef<OpFoldResult>{
4840+
interchangedIvs[dim],
4841+
interchangedIvs[pointLoopsOffset + packOp.getInputRank()],
4842+
dimAndTileMapping[dim]});
4843+
sourceIndices.push_back(sourceIndex);
4844+
++pointLoopsOffset;
4845+
} else {
4846+
sourceIndices.push_back(interchangedIvs[dim]);
4847+
}
4848+
}
4849+
4850+
auto createLoad = [&]() -> Value {
4851+
return builder.create<memref::LoadOp>(
4852+
loc, packOp.getInput(),
4853+
getValueOrCreateConstantIndexOp(builder, loc, sourceIndices));
4854+
};
4855+
Value scalar;
4856+
if (auto paddingValue = packOp.getPaddingValue()) {
4857+
ArithBuilder arithBuilder(builder, loc);
4858+
Value isInBounds;
4859+
for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
4860+
Value idx =
4861+
getValueOrCreateConstantIndexOp(builder, loc, sourceIndices[dim]);
4862+
Value dimValue = getValueOrCreateConstantIndexOp(
4863+
builder, loc, getDimValue(builder, loc, packOp.getInput(), dim));
4864+
Value cond = arithBuilder.slt(
4865+
idx, dimValue);
4866+
isInBounds = dim == 0 ? cond : arithBuilder._and(isInBounds, cond);
4867+
}
4868+
scalar = builder
4869+
.create<scf::IfOp>(
4870+
loc, isInBounds, /*thenBuilder=*/
4871+
[&](OpBuilder &b, Location l) {
4872+
b.create<scf::YieldOp>(l, createLoad());
4873+
},
4874+
/*elseBuilder=*/
4875+
[&](OpBuilder &b, Location l) {
4876+
b.create<scf::YieldOp>(l, paddingValue);
4877+
})
4878+
.getResult(0);
4879+
} else {
4880+
scalar = createLoad();
4881+
}
4882+
4883+
builder.create<memref::StoreOp>(loc, scalar, packOp.getOutput(), ivs);
4884+
}
4885+
4886+
LogicalResult PackOp::generateScalarImplementation(OpBuilder &builder,
4887+
Location loc,
4888+
ValueRange ivs) {
4889+
OpBuilder::InsertionGuard g(builder);
4890+
// The `ivs` already represent the position into the output tensor for the
4891+
// non data-tile dimensions.
4892+
SmallVector<Value> ivVec = llvm::to_vector(ivs);
4893+
ReifiedRankedShapedTypeDims outputShape;
4894+
if (failed(reifyResultShapes(builder, outputShape))) {
4895+
return getOperation()->emitOpError("failed to reify result shape");
4896+
}
4897+
if (outputShape.size() != 1 || outputShape[0].size() != getOutputRank()) {
4898+
return getOperation()->emitOpError(
4899+
"expected shape of one result value of rank")
4900+
<< getOutputRank();
4901+
}
4902+
4903+
// Generate the loops that iterate over the data tile.
4904+
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
4905+
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
4906+
4907+
// All loops except the innermost are simple loops that just iterate
4908+
// over the tile dimensions.
4909+
for (auto dataTileDim :
4910+
llvm::seq<unsigned>(getInputRank(), getOutputRank() - 1)) {
4911+
Value ub = getValueOrCreateConstantIndexOp(builder, loc,
4912+
outputShape[0][dataTileDim]);
4913+
scf::ForOp loop = builder.create<scf::ForOp>(loc, zero, ub, one);
4914+
builder.setInsertionPointToStart(loop.getBody());
4915+
ivVec.push_back(loop.getInductionVar());
4916+
}
4917+
// The body of the innermost loops does the actual data movement.
4918+
builder.create<scf::ForOp>(
4919+
loc, zero,
4920+
getValueOrCreateConstantIndexOp(builder, loc, outputShape[0].back()), one,
4921+
ValueRange{},
4922+
[&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
4923+
ValueRange regionIterArgs) {
4924+
ivVec.push_back(iv);
4925+
generatePackOpScalarImplementationBody(*this, bodyBuilder, bodyLoc,
4926+
ivVec);
4927+
bodyBuilder.create<scf::YieldOp>(bodyLoc);
4928+
});
4929+
return success();
4930+
}
4931+
4932+
47594933
Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
47604934
ArrayRef<OpFoldResult> innerTileSizes,
47614935
ArrayRef<int64_t> innerDimsPos,
@@ -5080,6 +5254,65 @@ void UnPackOp::getAsmResultNames(
50805254
setNameFn(getResult(), "unpack");
50815255
}
50825256

5257+
LogicalResult UnPackOp::generateScalarImplementation(OpBuilder &builder,
5258+
Location loc,
5259+
ValueRange ivs) {
5260+
return llvm::success();
5261+
OpBuilder::InsertionGuard g(builder);
5262+
ReifiedRankedShapedTypeDims outputShape;
5263+
5264+
if (failed(reifyResultShapes(builder, outputShape))) {
5265+
return getOperation()->emitError("failed to reify result shapes");
5266+
}
5267+
if (outputShape.size() != 1 || outputShape[0].size() != getOutputRank()) {
5268+
return getOperation()->emitError(
5269+
"expected shape of one result value of rank")
5270+
<< getOutputRank();
5271+
}
5272+
5273+
DenseMap<int64_t, OpFoldResult> dimAndTileMapping = getDimAndTileMapping();
5274+
// untiled loops and tile loops induction variables.
5275+
SmallVector<Value> inputIvs;
5276+
SmallVector<Value> inputIvsPointLoops;
5277+
inputIvs.reserve(getOutputRank());
5278+
inputIvsPointLoops.reserve(dimAndTileMapping.size());
5279+
for (auto dim : llvm::seq<int64_t>(0, getOutputRank())) {
5280+
if (dimAndTileMapping.count(dim)) {
5281+
affine::DivModValue divMod =
5282+
affine::getDivMod(builder, loc, ivs[dim],
5283+
getValueOrCreateConstantIndexOp(
5284+
builder, loc, dimAndTileMapping[dim]));
5285+
inputIvsPointLoops.push_back(divMod.remainder);
5286+
inputIvs.push_back(divMod.quotient);
5287+
} else {
5288+
inputIvs.push_back(ivs[dim]);
5289+
}
5290+
}
5291+
5292+
// TODO: (lorenzo) simplify the logic a bit. There is `ivs`,
5293+
// `inputIvsPointLoops` and `inputIvs`.
5294+
assert(inputIvsPointLoops.size() + inputIvs.size() == getInputRank() &&
5295+
"expect same number of iduction variables equals to input rank");
5296+
// interchange the point loops induction variables based on `inner_dim_pos`.
5297+
ArrayRef<int64_t> innerDims = getInnerDimsPos();
5298+
SmallVector<int64_t> interchangeVector =
5299+
computeInterchangeFromDimPos(innerDims, getOutputRank());
5300+
SmallVector<Value> interchangedInputIvsPointLoops = inputIvsPointLoops;
5301+
interchangedInputIvsPointLoops = interchange<Value>(
5302+
interchangedInputIvsPointLoops, interchangeVector, /*offset=*/0);
5303+
// interchange the tiled loops induction variables based on `outer_dims_perm`.
5304+
ArrayRef<int64_t> outerDims = getOuterDimsPerm();
5305+
if (!outerDims.empty()) {
5306+
inputIvs = interchange<Value>(inputIvs, outerDims, /*offset=*/0);
5307+
}
5308+
5309+
llvm::append_range(inputIvs, interchangedInputIvsPointLoops);
5310+
Value scalar = builder.create<memref::LoadOp>(loc, getInput(), inputIvs);
5311+
builder.create<memref::StoreOp>(loc, scalar, getOutput(), ivs);
5312+
return success();
5313+
}
5314+
5315+
50835316
LogicalResult
50845317
UnPackOp::reifyResultShapes(OpBuilder &builder,
50855318
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {

0 commit comments

Comments
 (0)