Skip to content

[mlir][vector] Allow lowering multi-dim scatters to LLVM #132227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Groverkss
Copy link
Member

@Groverkss Groverkss commented Mar 20, 2025

This patch adds a UnrollScatter pattern for vector.scatter, exactly same as UnrollGather for vector.gather, allowing us to lower multi dimensional vector.scatter by unrolling to 1-D vectors.

Discourse Discussion: https://discourse.llvm.org/t/rfc-improving-gather-codegen-for-vector-dialect/85011/13

@Groverkss
Copy link
Member Author

Depends on #132217

@llvmbot
Copy link
Member

llvmbot commented Mar 20, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

Changes

This patch adds a UnrollScatter pattern for vector.scatter, exactly same as UnrollGather for vector.gather, allowing us to lower multi dimensional vector.scatter by unrolling to 1-D vectors.


Patch is 27.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132227.diff

13 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+6-6)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+12-4)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+18-32)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1-1)
  • (renamed) mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp (+69-16)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (-46)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+39-3)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+10-8)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+2-1)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..5fab2ee1194e8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2034,9 +2034,9 @@ def Vector_ScatterOp :
   Vector_Op<"scatter">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$valueToStore)> {
+               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+               VectorOfNonZeroRankOf<[I1]>:$mask,
+               AnyVectorOfNonZeroRank:$valueToStore)> {
 
   let summary = [{
     scatters elements from a vector into memory as defined by an index vector
@@ -2044,9 +2044,9 @@ def Vector_ScatterOp :
   }];
 
   let description = [{
-    The scatter operation stores elements from a 1-D vector into memory as
-    defined by a base with indices and an additional 1-D index vector, but
-    only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
+    The scatter operation stores elements from a n-D vector into memory as
+    defined by a base with indices and an additional n-D index vector, but
+    only if the corresponding bit in a n-D mask vector is set. Otherwise, no
     action is taken for that element. Informally the semantics are:
     ```
     if (mask[0]) base[index[0]] = value[0]
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 601a65333d026..528de2340f7b7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -241,16 +241,24 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
 
 /// Populate the pattern set with the following patterns:
 ///
-/// [FlattenGather]
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// [UnrollGather]
+/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
 /// outermost dimension.
 ///
+/// [UnrollScatter]
+/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
+/// outermost dimension.
+void populateVectorGatherScatterLoweringPatterns(RewritePatternSet &patterns,
+                                                 PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
 /// [Gather1DToConditionalLoads]
 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
 /// loads/extracts are made conditional using `scf.if` ops.
-void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
+void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns,
+                                                   PatternBenefit benefit = 1);
 
 /// Populates instances of `MaskOpRewritePattern` to lower masked operations
 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 94efec61a466c..4127f5b065bc8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -269,6 +269,10 @@ class VectorGatherOpConversion
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return failure();
 
+    VectorType vType = gather.getVectorType();
+    if (vType.getRank() > 1)
+      return failure();
+
     auto loc = gather->getLoc();
 
     // Resolve alignment.
@@ -276,42 +280,21 @@ class VectorGatherOpConversion
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
+    // Resolve address.
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
     Value base = adaptor.getBase();
 
-    auto llvmNDVectorTy = adaptor.getIndexVec().getType();
     // Handle the simple case of 1-D vector.
-    if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
-      auto vType = gather.getVectorType();
-      // Resolve address.
-      Value ptrs =
-          getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
-                         base, ptr, adaptor.getIndexVec(), vType);
-      // Replace with the gather intrinsic.
-      rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
-          gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
-          adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
-      return success();
-    }
-
-    const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
-    auto callback = [align, memRefType, base, ptr, loc, &rewriter,
-                     &typeConverter](Type llvm1DVectorTy,
-                                     ValueRange vectorOperands) {
-      // Resolve address.
-      Value ptrs = getIndexedPtrs(
-          rewriter, loc, typeConverter, memRefType, base, ptr,
-          /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
-      // Create the gather intrinsic.
-      return rewriter.create<LLVM::masked_gather>(
-          loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
-          /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
-    };
-    SmallVector<Value> vectorOperands = {
-        adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
-    return LLVM::detail::handleMultidimensionalVectors(
-        gather, vectorOperands, *getTypeConverter(), callback, rewriter);
+    // Resolve address.
+    Value ptrs =
+        getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+                       base, ptr, adaptor.getIndexVec(), vType);
+    // Replace with the gather intrinsic.
+    rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+        gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+        adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+    return success();
   }
 };
 
@@ -330,13 +313,16 @@ class VectorScatterOpConversion
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return failure();
 
+    VectorType vType = scatter.getVectorType();
+    if (vType.getRank() > 1)
+      return failure();
+
     // Resolve alignment.
     unsigned align;
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
     // Resolve address.
-    VectorType vType = scatter.getVectorType();
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
     Value ptrs =
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index eb1555df5d574..dfa188bdfc5cc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -81,6 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorInsertExtractStridedSliceTransforms(patterns);
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
+    populateVectorGatherScatterLoweringPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..59da2ebe4aae0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5340,9 +5340,9 @@ LogicalResult ScatterOp::verify() {
     return emitOpError("base and valueToStore element type should match");
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
-  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
+  if (valueVType.getShape() != indVType.getShape())
     return emitOpError("expected valueToStore dim to match indices dim");
-  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
+  if (valueVType.getShape() != maskVType.getShape())
     return emitOpError("expected valueToStore dim to match mask dim");
   return success();
 }
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 20c577273d786..623b9aa83fff3 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -138,7 +138,7 @@ void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
 
 void transform::ApplyLowerGatherPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::populateVectorGatherLoweringPatterns(patterns);
+  vector::populateVectorGatherScatterLoweringPatterns(patterns);
 }
 
 void transform::ApplyLowerScanPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..8abaa6ac527eb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -3,7 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorBitCast.cpp
   LowerVectorBroadcast.cpp
   LowerVectorContract.cpp
-  LowerVectorGather.cpp
+  LowerVectorGatherScatter.cpp
   LowerVectorInterleave.cpp
   LowerVectorMask.cpp
   LowerVectorMultiReduction.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp
similarity index 76%
rename from mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
rename to mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp
index 3b38505becd18..72892859df200 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp
@@ -38,7 +38,8 @@ using namespace mlir;
 using namespace mlir::vector;
 
 namespace {
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+
+/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
 /// outermost dimension. For example:
 /// ```
 /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
@@ -56,14 +57,14 @@ namespace {
 /// When applied exhaustively, this will produce a sequence of 1-d gather ops.
 ///
 /// Supports vector types with a fixed leading dimension.
-struct FlattenGather : OpRewritePattern<vector::GatherOp> {
+struct UnrollGather : OpRewritePattern<vector::GatherOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::GatherOp op,
                                 PatternRewriter &rewriter) const override {
     VectorType resultTy = op.getType();
     if (resultTy.getRank() < 2)
-      return rewriter.notifyMatchFailure(op, "already flat");
+      return rewriter.notifyMatchFailure(op, "already 1-D");
 
     // Unrolling doesn't take vscale into account. Pattern is disabled for
     // vectors with leading scalable dim(s).
@@ -81,19 +82,14 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
     VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
 
     for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
-      int64_t thisIdx[1] = {i};
-
-      Value indexSubVec =
-          rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
-      Value maskSubVec =
-          rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
+      Value indexSubVec = rewriter.create<vector::ExtractOp>(loc, indexVec, i);
+      Value maskSubVec = rewriter.create<vector::ExtractOp>(loc, maskVec, i);
       Value passThruSubVec =
-          rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
+          rewriter.create<vector::ExtractOp>(loc, passThruVec, i);
       Value subGather = rewriter.create<vector::GatherOp>(
           loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
           passThruSubVec);
-      result =
-          rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
+      result = rewriter.create<vector::InsertOp>(loc, subGather, result, i);
     }
 
     rewriter.replaceOp(op, result);
@@ -101,13 +97,65 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
   }
 };
 
+/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %g = vector.scatter %base[%c0][%v], %mask, %valueToStore : ...
+/// vector<2x3xf32>
+///
+/// ==>
+///
+/// %g0  = vector.extract %valueToStore[0] : vector<3xf32> from vector<2x3xf32>
+///        vector.scatter %base[%c0][%v0], %mask0, %g0
+/// %g1  = vector.extract %valueToStore[1] : vector<3xf32> from vector<2x3xf32>
+///        vector.scatter %base[%c0][%v0], %mask0, %g1
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d scatter ops.
+///
+/// Supports vector types with a fixed leading dimension.
+struct UnrollScatter : OpRewritePattern<vector::ScatterOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ScatterOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType vectorTy = op.getVectorType();
+    if (vectorTy.getRank() < 2)
+      return rewriter.notifyMatchFailure(op, "already 1-D");
+
+    // Unrolling doesn't take vscale into account. Pattern is disabled for
+    // vectors with leading scalable dim(s).
+    if (vectorTy.getScalableDims().front())
+      return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
+
+    Location loc = op.getLoc();
+    Value indexVec = op.getIndexVec();
+    Value maskVec = op.getMask();
+    Value valueToStoreVec = op.getValueToStore();
+
+    for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) {
+      Value indexSubVec = rewriter.create<vector::ExtractOp>(loc, indexVec, i);
+      Value maskSubVec = rewriter.create<vector::ExtractOp>(loc, maskVec, i);
+      Value valueToStoreSubVec =
+          rewriter.create<vector::ExtractOp>(loc, valueToStoreVec, i);
+      rewriter.create<vector::ScatterOp>(loc, op.getBase(), op.getIndices(),
+                                         indexSubVec, maskSubVec,
+                                         valueToStoreSubVec);
+    }
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 /// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
 /// MemRef with updated indices that model the strided access.
 ///
 /// ```mlir
 ///   %subview = memref.subview %M (...)
 ///     : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
+///   strided<[3]>>
 /// ```
 /// ==>
 /// ```mlir
@@ -267,8 +315,13 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
 };
 } // namespace
 
-void mlir::vector::populateVectorGatherLoweringPatterns(
+void mlir::vector::populateVectorGatherScatterLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<UnrollGather, UnrollScatter>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<FlattenGather, RemoveStrideFromGatherSource,
-               Gather1DToConditionalLoads>(patterns.getContext(), benefit);
+  patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
+      patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..44b4a25a051f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2074,52 +2074,6 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
 
 // -----
 
-func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
-  return %1 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-
-// -----
-
-func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
-  return %1 : vector<2x[3]xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d_scalable
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-
-// -----
-
 
 func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.constant 3 : index
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1ab28b9df2d19..e8171e3f4853f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-ll...
[truncated]

@Groverkss
Copy link
Member Author

Please only review the top commit, other commits are from depending prs.

@Groverkss
Copy link
Member Author

All dependencies have landed, this is ready for review now

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the addition! Could you add tests for the new pattern?

Basically, if we are reaching feature parity between vector.gather and vector.scatter, we should also make sure that the testing coverage is similar. I see two test files that exercise unrolling/lowering of vector.gather:

@Groverkss
Copy link
Member Author

Thanks for the addition! Could you add tests for the new pattern?

Basically, if we are reaching feature parity between vector.gather and vector.scatter, we should also make sure that the testing coverage is similar. I see two test files that exercise unrolling/lowering of vector.gather:

I thought about this, and my reasoning about this is:

  • This patch only adds a single pattern, UnrollScatter, which is used by ConvertToLLVM to lower multi-dimensional scatters. We have the same tests for gather for gather in test/Conversions/VectorToLLVM/vector-to-llvm.mlir . Previous patches added a negative test to lower multi-dimensional scatters, which this patch makes positive (i.e. we can lower them now).

  • test/Dialect/Vector/vector-gather-lowering.mlir tests are tests for unroll + linearize + lower to conditional loads for gather operations. Currently, the scatter op only implements unroll here, so adding tests here would not be correct.

  • test/Dialect/Vector/vector-transfer-unroll contains tests for a seperate style of unrolling, where a single dimension is unrolled into smaller chunks. This pattern unrolls the outer dimension completly. So these tests are unrelated.

Hopefully this makes it clear why the test coverage in this patch is enough.

@Groverkss
Copy link
Member Author

Groverkss commented Mar 27, 2025

Thanks for the addition! Could you add tests for the new pattern?
Basically, if we are reaching feature parity between vector.gather and vector.scatter, we should also make sure that the testing coverage is similar. I see two test files that exercise unrolling/lowering of vector.gather:

I thought about this, and my reasoning about this is:

  • This patch only adds a single pattern, UnrollScatter, which is used by ConvertToLLVM to lower multi-dimensional scatters. We have the same tests for gather for gather in test/Conversions/VectorToLLVM/vector-to-llvm.mlir . Previous patches added a negative test to lower multi-dimensional scatters, which this patch makes positive (i.e. we can lower them now).
  • test/Dialect/Vector/vector-gather-lowering.mlir tests are tests for unroll + linearize + lower to conditional loads for gather operations. Currently, the scatter op only implements unroll here, so adding tests here would not be correct.
  • test/Dialect/Vector/vector-transfer-unroll contains tests for a seperate style of unrolling, where a single dimension is unrolled into smaller chunks. This pattern unrolls the outer dimension completly. So these tests are unrelated.

Hopefully this makes it clear why the test coverage in this patch is enough.

Note that I do plan to add support for vector-gather-lowering and vector-transfer-unroll but I'm waiting till the work on passing index_vecs per dimension lands (https://discourse.llvm.org/t/rfc-improving-gather-codegen-for-vector-dialect/85011/12) because it makes the lowering not require linearization which makes it more robust. Just mentioning this in case you are wondering why this patch only adds one pattern out of all the available ones that we should.

@Groverkss Groverkss requested a review from banach-space March 27, 2025 12:02
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LGTM % minor comments. Please, wait for other discussions to resolve. Thanks!

Comment on lines 84 to 93
for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
int64_t thisIdx[1] = {i};

Value indexSubVec =
rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
Value maskSubVec =
rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
Value indexSubVec = rewriter.create<vector::ExtractOp>(loc, indexVec, i);
Value maskSubVec = rewriter.create<vector::ExtractOp>(loc, maskVec, i);
Value passThruSubVec =
rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
rewriter.create<vector::ExtractOp>(loc, passThruVec, i);
Value subGather = rewriter.create<vector::GatherOp>(
loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
passThruSubVec);
result =
rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
result = rewriter.create<vector::InsertOp>(loc, subGather, result, i);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if we could have a generic utility for this?

/// %g0 = vector.extract %valueToStore[0] : vector<3xf32> from vector<2x3xf32>
/// vector.scatter %base[%c0][%v0], %mask0, %g0
/// %g1 = vector.extract %valueToStore[1] : vector<3xf32> from vector<2x3xf32>
/// vector.scatter %base[%c0][%v0], %mask0, %g1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

%c0 -> %c1?

@banach-space
Copy link
Contributor

I thought about this, and my reasoning about this is:

  • This patch only adds a single pattern, UnrollScatter, which is used by ConvertToLLVM to lower multi-dimensional scatters. We have the same tests for gather for gather in test/Conversions/VectorToLLVM/vector-to-llvm.mlir . Previous patches added a negative test to lower multi-dimensional scatters, which this patch makes positive (i.e. we can lower them now).
  • test/Dialect/Vector/vector-gather-lowering.mlir tests are tests for unroll + linearize + lower to conditional loads for gather operations. Currently, the scatter op only implements unroll here, so adding tests here would not be correct.
  • test/Dialect/Vector/vector-transfer-unroll contains tests for a seperate style of unrolling, where a single dimension is unrolled into smaller chunks. This pattern unrolls the outer dimension completly. So these tests are unrelated.

Hopefully this makes it clear why the test coverage in this patch is enough.

Thanks for the explanation and the thoughtful analysis - that makes sense!

That said, I think it would still be valuable to include some finer-grained testing, for a couple of reasons:

  • While LLVM is probably the most widely used and well-tested target for Vector (*), it's not the only one — so it's important that we have coverage across other parts of the stack.
  • ConvertVectorToLLVM is a large integration pipeline — it exercises many transformations together. It's helpful to complement that with more focused unit tests to ensure the individual pieces are working in isolation.

Looking at your patch, it seems like this TD op is a good fit for testing the new UnrollScatter pattern directly:

Now, even though this patch is only about scatter, I think it would be great and valuable to add symmetric tests for both gather and scatter - probably in a new dedicated test file. I realise this goes beyond the immediate scope though, so this is just a kind request.

Also, this could be a good opportunity to rename the op to something like:

  • transform.apply_patterns.vector.lower_gather_scatter

Thanks again for all the work here — it’s coming together nicely!

-Andrzej

(*) Just an assumption based on experience.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants