Skip to content

[mlir][linalg] Add Linalg::generateScalarImplementation #128816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

ita9naiwa
Copy link
Contributor

@ita9naiwa ita9naiwa commented Feb 26, 2025

We need the memref version for backends that do not vectorize them on tensors. E.g., there are some pack/unpack ops that VMVX backend can not vectorize.

Add generateScalarImplementation so that linalg.pack/unpack ops on memrefs can be lowered to scalar codes.

It's my first time to write new interface code, I'd like to get checked before I fix further.

Co-authored-by: Han-Chung Wang [email protected]
Co-authored-by: lorenzo chelini [email protected]

@llvmbot
Copy link
Member

llvmbot commented Feb 26, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Hyunsung Lee (ita9naiwa)

Changes

> We need the memref version for backends that do not vectorize them on tensors. E.g., there are some pack/unpack ops that VMVX backend can not vectorize.

Add generateScalarImplementation so that linalg.pack/unpack ops on memrefs can be lowered to scalar codes.

It's my first time to write new interface code, I'd like to get checked before I fix further.


Full diff: https://github.com/llvm/llvm-project/pull/128816.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td (+59-1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+228)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 1e48a5e3a20ee..7123d7112f9ac 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -77,7 +77,20 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
     /// with `inner_dims_pos` rather than the packed tensor.
     SmallVector<int64_t> getTiledOuterDims();
   }];
-
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    ShapedType getInputType() {
+        return cast<ShapedType>(getInput().getType());
+      }
+      ShapedType getOutputType() {
+        return cast<ShapedType>(getOutput().getType());
+      }
+      int64_t getInputRank() {
+        return getInputType().getRank();
+      }
+      int64_t getOutputRank() {
+        return getOutputType().getRank();
+      }
+    }];
   let hasVerifier = 1;
 }
 
@@ -179,6 +192,28 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
+    Value getOutput() {
+      return getDpsInitOperand(0)->get();
+    }
+
+    // Return the input operand.
+    Value getInput() {
+      return getDpsInputOperand(0)->get();
+    }
+    ShapedType getInputType() {
+      return cast<ShapedType>(getInput().getType());
+    }
+    ShapedType getOutputType() {
+      return cast<ShapedType>(getDest().getType()); // getDest() 사용
+    }
+    int64_t getInputRank() {
+      return getInputType().getRank();
+    }
+    int64_t getOutputRank() {
+      return getOutputType().getRank();
+    }
+
+    LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
     // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
     // This is a static method to allow getting the shape of the destination
     // expected while creating a `pack` op.
@@ -229,6 +264,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     /// 2. pads the other ones, and
     /// 3. doesn't shuffle the dimensions
     bool isLikePad();
+
   }];
 
   let hasCanonicalizeMethod = 1;
@@ -303,6 +339,28 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
+    Value getOutput() {
+      return getDpsInitOperand(0)->get();
+    }
+
+    // Return the input operand.
+    Value getInput() {
+      return getDpsInputOperand(0)->get();
+    }
+    ShapedType getInputType() {
+      return cast<ShapedType>(getInput().getType());
+    }
+    ShapedType getOutputType() {
+      return cast<ShapedType>(getDest().getType()); // getDest() 사용
+    }
+    int64_t getInputRank() {
+      return getInputType().getRank();
+    }
+    int64_t getOutputRank() {
+      return getOutputType().getRank();
+    }
+    LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
+
     static Value createDestinationTensor(OpBuilder &b, Location loc,
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
         ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..1d4833e06c776 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 
 #include "mlir/AsmParser/AsmParser.h"
@@ -55,6 +56,43 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
+SmallVector<int64_t> computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos,
+                                                  int64_t rank) {
+  SmallVector<int64_t> interchangeVector;
+  interchangeVector.reserve(dimsPos.size());
+  // First map dims and their position. For example, dims_pos = [2, 0] will map
+  // to:
+  // [
+  //  [ key: 2, value: 0]
+  //  [ key: 0, value: 1]
+  // ]
+  // where key is the idx in dims_pos while value its position in dims_pos.
+  DenseMap<int64_t, int64_t> dimsAndPosMapping;
+  for (int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++) {
+    dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx;
+  }
+
+  // Scan the position in order and insert the value in the map
+  // to compute the interchange vector.
+  for (int64_t dimsIdx = 0; dimsIdx < rank; dimsIdx++) {
+    if (dimsAndPosMapping.count(dimsIdx)) {
+      interchangeVector.push_back(dimsAndPosMapping[dimsIdx]);
+    }
+  }
+  return interchangeVector;
+}
+
+template <typename T>
+SmallVector<T> interchange(ArrayRef<T> elements,
+                           ArrayRef<int64_t> interchangeVector,
+                           int offset = 0) {
+  SmallVector<T> vec = llvm::to_vector(elements);
+  for (auto [idx, val] : llvm::enumerate(interchangeVector)) {
+    vec[idx + offset] = elements[val + offset];
+  }
+  return vec;
+}
+
 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
 static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
                                 int64_t dim) {
@@ -4756,6 +4794,138 @@ RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
   return RankedTensorType::get(resultShape, sourceType.getElementType());
 }
 
+/// Generate the body of the innermost loop of the scalar implementation
+/// of `pack` operation.
+static void generatePackOpScalarImplementationBody(PackOp packOp,
+                                                   OpBuilder &builder,
+                                                   Location loc,
+                                                   ValueRange ivs) {
+  // Note: `ivs` are already in the correct order, possibly interchanged based
+  // on `dims_pos`. However, connecting the loops with the access patterns is
+  // difficult - What is the relation between the position of the tile loop and
+  // the point loop? However, if we interchange `ivs` once more to go to the
+  // canonical blocking format: ABCabc, this connection becomes trivial: Each
+  // point loop is pointLoopsOffset + inputRank away from the tiled loop.
+  ArrayRef<int64_t> dimsToInnerBlock = packOp.getInnerDimsPos();
+  ArrayRef<int64_t> dimsToOuterBlock = packOp.getOuterDimsPerm();
+
+  SmallVector<Value> interchangedIvs = ivs;
+  SmallVector<int64_t> interchangeVector =
+      computeInterchangeFromDimPos(dimsToInnerBlock, packOp.getInputRank());
+  interchangedIvs = interchange<Value>(interchangedIvs, interchangeVector,
+                                       /*offset=*/packOp.getInputRank());
+  if (!dimsToOuterBlock.empty()) {
+    interchangeVector =
+        computeInterchangeFromDimPos(dimsToOuterBlock, packOp.getInputRank());
+    interchangedIvs =
+        interchange<Value>(interchangedIvs, interchangeVector, /*offset=*/0);
+  }
+
+  SmallVector<OpFoldResult> tiles = packOp.getMixedTiles();
+  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+      packOp.getDimAndTileMapping();
+  SmallVector<OpFoldResult> sourceIndices;
+  size_t pointLoopsOffset = 0;
+  int64_t inputRank = packOp.getInputRank();
+  for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
+    if (dimAndTileMapping.count(dim)) {
+      AffineExpr i, j, tile;
+      bindDims(builder.getContext(), i, j);
+      bindSymbols(builder.getContext(), tile);
+      OpFoldResult sourceIndex = affine::makeComposedFoldedAffineApply(
+          builder, loc, i * tile + j,
+          ArrayRef<OpFoldResult>{
+              interchangedIvs[dim],
+              interchangedIvs[pointLoopsOffset + packOp.getInputRank()],
+              dimAndTileMapping[dim]});
+      sourceIndices.push_back(sourceIndex);
+      ++pointLoopsOffset;
+    } else {
+      sourceIndices.push_back(interchangedIvs[dim]);
+    }
+  }
+
+  auto createLoad = [&]() -> Value {
+    return builder.create<memref::LoadOp>(
+        loc, packOp.getInput(),
+        getValueOrCreateConstantIndexOp(builder, loc, sourceIndices));
+  };
+  Value scalar;
+  if (auto paddingValue = packOp.getPaddingValue()) {
+    ArithBuilder arithBuilder(builder, loc);
+    Value isInBounds;
+    for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
+      Value idx =
+          getValueOrCreateConstantIndexOp(builder, loc, sourceIndices[dim]);
+      Value dimValue = getValueOrCreateConstantIndexOp(
+          builder, loc, getDimValue(builder, loc, packOp.getInput(), dim));
+      Value cond = arithBuilder.slt(idx, dimValue);
+      isInBounds = dim == 0 ? cond : arithBuilder._and(isInBounds, cond);
+    }
+    scalar = builder
+                 .create<scf::IfOp>(
+                     loc, isInBounds, /*thenBuilder=*/
+                     [&](OpBuilder &b, Location l) {
+                       b.create<scf::YieldOp>(l, createLoad());
+                     },
+                     /*elseBuilder=*/
+                     [&](OpBuilder &b, Location l) {
+                       b.create<scf::YieldOp>(l, paddingValue);
+                     })
+                 .getResult(0);
+  } else {
+    scalar = createLoad();
+  }
+
+  builder.create<memref::StoreOp>(loc, scalar, packOp.getOutput(), ivs);
+}
+
+LogicalResult PackOp::generateScalarImplementation(OpBuilder &builder,
+                                                   Location loc,
+                                                   ValueRange ivs) {
+  OpBuilder::InsertionGuard g(builder);
+  // The `ivs` already represent the position into the output tensor for the
+  // non data-tile dimensions.
+  SmallVector<Value> ivVec = llvm::to_vector(ivs);
+  ReifiedRankedShapedTypeDims outputShape;
+  if (failed(reifyResultShapes(builder, outputShape))) {
+    return getOperation()->emitOpError("failed to reify result shape");
+  }
+  if (outputShape.size() != 1 || outputShape[0].size() != getOutputRank()) {
+    return getOperation()->emitOpError(
+               "expected shape of one result value of rank")
+           << getOutputRank();
+  }
+
+  // Generate the loops that iterate over the data tile.
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+
+  // All loops except the innermost are simple loops that just iterate
+  // over the tile dimensions.
+  for (auto dataTileDim :
+       llvm::seq<unsigned>(getInputRank(), getOutputRank() - 1)) {
+    Value ub = getValueOrCreateConstantIndexOp(builder, loc,
+                                               outputShape[0][dataTileDim]);
+    scf::ForOp loop = builder.create<scf::ForOp>(loc, zero, ub, one);
+    builder.setInsertionPointToStart(loop.getBody());
+    ivVec.push_back(loop.getInductionVar());
+  }
+  // The body of the innermost loops does the actual data movement.
+  builder.create<scf::ForOp>(
+      loc, zero,
+      getValueOrCreateConstantIndexOp(builder, loc, outputShape[0].back()), one,
+      ValueRange{},
+      [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
+          ValueRange regionIterArgs) {
+        ivVec.push_back(iv);
+        generatePackOpScalarImplementationBody(*this, bodyBuilder, bodyLoc,
+                                               ivVec);
+        bodyBuilder.create<scf::YieldOp>(bodyLoc);
+      });
+  return success();
+}
+
 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
                                       ArrayRef<OpFoldResult> innerTileSizes,
                                       ArrayRef<int64_t> innerDimsPos,
@@ -5080,6 +5250,64 @@ void UnPackOp::getAsmResultNames(
   setNameFn(getResult(), "unpack");
 }
 
+LogicalResult UnPackOp::generateScalarImplementation(OpBuilder &builder,
+                                                     Location loc,
+                                                     ValueRange ivs) {
+  return llvm::success();
+  OpBuilder::InsertionGuard g(builder);
+  ReifiedRankedShapedTypeDims outputShape;
+
+  if (failed(reifyResultShapes(builder, outputShape))) {
+    return getOperation()->emitError("failed to reify result shapes");
+  }
+  if (outputShape.size() != 1 || outputShape[0].size() != getOutputRank()) {
+    return getOperation()->emitError(
+               "expected shape of one result value of rank")
+           << getOutputRank();
+  }
+
+  DenseMap<int64_t, OpFoldResult> dimAndTileMapping = getDimAndTileMapping();
+  // untiled loops and tile loops induction variables.
+  SmallVector<Value> inputIvs;
+  SmallVector<Value> inputIvsPointLoops;
+  inputIvs.reserve(getOutputRank());
+  inputIvsPointLoops.reserve(dimAndTileMapping.size());
+  for (auto dim : llvm::seq<int64_t>(0, getOutputRank())) {
+    if (dimAndTileMapping.count(dim)) {
+      affine::DivModValue divMod =
+          affine::getDivMod(builder, loc, ivs[dim],
+                            getValueOrCreateConstantIndexOp(
+                                builder, loc, dimAndTileMapping[dim]));
+      inputIvsPointLoops.push_back(divMod.remainder);
+      inputIvs.push_back(divMod.quotient);
+    } else {
+      inputIvs.push_back(ivs[dim]);
+    }
+  }
+
+  // TODO: (lorenzo) simplify the logic a bit. There is `ivs`,
+  // `inputIvsPointLoops` and `inputIvs`.
+  assert(inputIvsPointLoops.size() + inputIvs.size() == getInputRank() &&
+         "expect same number of iduction variables equals to input rank");
+  // interchange the point loops induction variables based on `inner_dim_pos`.
+  ArrayRef<int64_t> innerDims = getInnerDimsPos();
+  SmallVector<int64_t> interchangeVector =
+      computeInterchangeFromDimPos(innerDims, getOutputRank());
+  SmallVector<Value> interchangedInputIvsPointLoops = inputIvsPointLoops;
+  interchangedInputIvsPointLoops = interchange<Value>(
+      interchangedInputIvsPointLoops, interchangeVector, /*offset=*/0);
+  // interchange the tiled loops induction variables based on `outer_dims_perm`.
+  ArrayRef<int64_t> outerDims = getOuterDimsPerm();
+  if (!outerDims.empty()) {
+    inputIvs = interchange<Value>(inputIvs, outerDims, /*offset=*/0);
+  }
+
+  llvm::append_range(inputIvs, interchangedInputIvsPointLoops);
+  Value scalar = builder.create<memref::LoadOp>(loc, getInput(), inputIvs);
+  builder.create<memref::StoreOp>(loc, scalar, getOutput(), ivs);
+  return success();
+}
+
 LogicalResult
 UnPackOp::reifyResultShapes(OpBuilder &builder,
                             ReifiedRankedShapedTypeDims &reifiedReturnShapes) {

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite follow the context here and the bigger picture.

Add generateScalarImplementation so that linalg.pack/unpack ops on memrefs can be lowered to scalar codes.

With no documentation and tests, generateScalarImplementation feels like some arbitrary ad-hoc hook.

If this some initial step to make linalg.pack and linalg.unpack support MemRef, then could we start with a higher-level GitHub issue (or Discourse thread) outlining the high level steps?

Thanks!

}
}

// TODO: (lorenzo) simplify the logic a bit. There is `ivs`,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this TODO assigned to (presumably) @chelini ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we developed it in IREE, and the author help upstream the code. I think @chelini put himself to the TODO when he wrote the code in IREE. :)

Copy link
Contributor

@chelini chelini Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I did not get any notifications, I just saw it. Yes, I remember. Feel free to remove the TODO and refactor the code a bit. That said, why do we want to upstream the scalar implementation? It serves as a good reference, but it’s terribly slow.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some tests as well.

@hanhanW
Copy link
Contributor

hanhanW commented Feb 26, 2025

If this some initial step to make linalg.pack and linalg.unpack support MemRef, then could we start with a higher-level GitHub issue (or Discourse thread) outlining the high level steps?

I missed this part, so linalg.pack and linalg.unpack op do not support MemRef now? I thought that all the linalg ops support both tensor and memref. So I asked the author help upstream the rest of implementation (about memref version) from IREE to the upstream.

@ita9naiwa Re tests, we need to check the memref support first. Then you can follow the IREE tests and add the tests to loops.mlir.

@MaheshRavishankar
Copy link
Contributor

I don't quite follow the context here and the bigger picture.

Add generateScalarImplementation so that linalg.pack/unpack ops on memrefs can be lowered to scalar codes.

With no documentation and tests, generateScalarImplementation feels like some arbitrary ad-hoc hook.

It is the implementation of the method in TilingInterface (

).

If this some initial step to make linalg.pack and linalg.unpack support MemRef, then could we start with a higher-level GitHub issue (or Discourse thread) outlining the high level steps?

Thanks!

If the op is moved to linalg I think it should be made to support memref also now. There is no reason not to.

@ita9naiwa
Copy link
Contributor Author

I have created this issue #129004 so that we can discuss it here. If we believe that this work is necessary, I will proceed with implementing memref support.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants