-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][vector] Allow lowering multi-dim scatters to LLVM #132227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Depends on #132217 |
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesThis patch adds a UnrollScatter pattern for vector.scatter, exactly same as UnrollGather for vector.gather, allowing us to lower multi dimensional vector.scatter by unrolling to 1-D vectors. Patch is 27.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132227.diff 13 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..5fab2ee1194e8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2034,9 +2034,9 @@ def Vector_ScatterOp :
Vector_Op<"scatter">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
- VectorOfRankAndType<[1], [I1]>:$mask,
- VectorOfRank<[1]>:$valueToStore)> {
+ VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+ VectorOfNonZeroRankOf<[I1]>:$mask,
+ AnyVectorOfNonZeroRank:$valueToStore)> {
let summary = [{
scatters elements from a vector into memory as defined by an index vector
@@ -2044,9 +2044,9 @@ def Vector_ScatterOp :
}];
let description = [{
- The scatter operation stores elements from a 1-D vector into memory as
- defined by a base with indices and an additional 1-D index vector, but
- only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
+ The scatter operation stores elements from a n-D vector into memory as
+ defined by a base with indices and an additional n-D index vector, but
+ only if the corresponding bit in a n-D mask vector is set. Otherwise, no
action is taken for that element. Informally the semantics are:
```
if (mask[0]) base[index[0]] = value[0]
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 601a65333d026..528de2340f7b7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -241,16 +241,24 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
/// Populate the pattern set with the following patterns:
///
-/// [FlattenGather]
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// [UnrollGather]
+/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
/// outermost dimension.
///
+/// [UnrollScatter]
+/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
+/// outermost dimension.
+void populateVectorGatherScatterLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
/// [Gather1DToConditionalLoads]
/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
/// loads/extracts are made conditional using `scf.if` ops.
-void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Populates instances of `MaskOpRewritePattern` to lower masked operations
/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 94efec61a466c..4127f5b065bc8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -269,6 +269,10 @@ class VectorGatherOpConversion
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return failure();
+ VectorType vType = gather.getVectorType();
+ if (vType.getRank() > 1)
+ return failure();
+
auto loc = gather->getLoc();
// Resolve alignment.
@@ -276,42 +280,21 @@ class VectorGatherOpConversion
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
+ // Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value base = adaptor.getBase();
- auto llvmNDVectorTy = adaptor.getIndexVec().getType();
// Handle the simple case of 1-D vector.
- if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
- auto vType = gather.getVectorType();
- // Resolve address.
- Value ptrs =
- getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
- base, ptr, adaptor.getIndexVec(), vType);
- // Replace with the gather intrinsic.
- rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
- gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
- adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
- return success();
- }
-
- const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
- auto callback = [align, memRefType, base, ptr, loc, &rewriter,
- &typeConverter](Type llvm1DVectorTy,
- ValueRange vectorOperands) {
- // Resolve address.
- Value ptrs = getIndexedPtrs(
- rewriter, loc, typeConverter, memRefType, base, ptr,
- /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
- // Create the gather intrinsic.
- return rewriter.create<LLVM::masked_gather>(
- loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
- /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
- };
- SmallVector<Value> vectorOperands = {
- adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
- return LLVM::detail::handleMultidimensionalVectors(
- gather, vectorOperands, *getTypeConverter(), callback, rewriter);
+ // Resolve address.
+ Value ptrs =
+ getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+ base, ptr, adaptor.getIndexVec(), vType);
+ // Replace with the gather intrinsic.
+ rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+ gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+ adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+ return success();
}
};
@@ -330,13 +313,16 @@ class VectorScatterOpConversion
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return failure();
+ VectorType vType = scatter.getVectorType();
+ if (vType.getRank() > 1)
+ return failure();
+
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
// Resolve address.
- VectorType vType = scatter.getVectorType();
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptrs =
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index eb1555df5d574..dfa188bdfc5cc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -81,6 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorInsertExtractStridedSliceTransforms(patterns);
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
+ populateVectorGatherScatterLoweringPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..59da2ebe4aae0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5340,9 +5340,9 @@ LogicalResult ScatterOp::verify() {
return emitOpError("base and valueToStore element type should match");
if (llvm::size(getIndices()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
- if (valueVType.getDimSize(0) != indVType.getDimSize(0))
+ if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
- if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
+ if (valueVType.getShape() != maskVType.getShape())
return emitOpError("expected valueToStore dim to match mask dim");
return success();
}
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 20c577273d786..623b9aa83fff3 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -138,7 +138,7 @@ void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
void transform::ApplyLowerGatherPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- vector::populateVectorGatherLoweringPatterns(patterns);
+ vector::populateVectorGatherScatterLoweringPatterns(patterns);
}
void transform::ApplyLowerScanPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..8abaa6ac527eb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -3,7 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorBitCast.cpp
LowerVectorBroadcast.cpp
LowerVectorContract.cpp
- LowerVectorGather.cpp
+ LowerVectorGatherScatter.cpp
LowerVectorInterleave.cpp
LowerVectorMask.cpp
LowerVectorMultiReduction.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp
similarity index 76%
rename from mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
rename to mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp
index 3b38505becd18..72892859df200 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp
@@ -38,7 +38,8 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+
+/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
/// outermost dimension. For example:
/// ```
/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
@@ -56,14 +57,14 @@ namespace {
/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
///
/// Supports vector types with a fixed leading dimension.
-struct FlattenGather : OpRewritePattern<vector::GatherOp> {
+struct UnrollGather : OpRewritePattern<vector::GatherOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
VectorType resultTy = op.getType();
if (resultTy.getRank() < 2)
- return rewriter.notifyMatchFailure(op, "already flat");
+ return rewriter.notifyMatchFailure(op, "already 1-D");
// Unrolling doesn't take vscale into account. Pattern is disabled for
// vectors with leading scalable dim(s).
@@ -81,19 +82,14 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
- int64_t thisIdx[1] = {i};
-
- Value indexSubVec =
- rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
- Value maskSubVec =
- rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
+ Value indexSubVec = rewriter.create<vector::ExtractOp>(loc, indexVec, i);
+ Value maskSubVec = rewriter.create<vector::ExtractOp>(loc, maskVec, i);
Value passThruSubVec =
- rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
+ rewriter.create<vector::ExtractOp>(loc, passThruVec, i);
Value subGather = rewriter.create<vector::GatherOp>(
loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
passThruSubVec);
- result =
- rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
+ result = rewriter.create<vector::InsertOp>(loc, subGather, result, i);
}
rewriter.replaceOp(op, result);
@@ -101,13 +97,65 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
}
};
+/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %g = vector.scatter %base[%c0][%v], %mask, %valueToStore : ...
+/// vector<2x3xf32>
+///
+/// ==>
+///
+/// %g0 = vector.extract %valueToStore[0] : vector<3xf32> from vector<2x3xf32>
+/// vector.scatter %base[%c0][%v0], %mask0, %g0
+/// %g1 = vector.extract %valueToStore[1] : vector<3xf32> from vector<2x3xf32>
+/// vector.scatter %base[%c0][%v0], %mask0, %g1
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d scatter ops.
+///
+/// Supports vector types with a fixed leading dimension.
+struct UnrollScatter : OpRewritePattern<vector::ScatterOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ScatterOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType vectorTy = op.getVectorType();
+ if (vectorTy.getRank() < 2)
+ return rewriter.notifyMatchFailure(op, "already 1-D");
+
+ // Unrolling doesn't take vscale into account. Pattern is disabled for
+ // vectors with leading scalable dim(s).
+ if (vectorTy.getScalableDims().front())
+ return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
+
+ Location loc = op.getLoc();
+ Value indexVec = op.getIndexVec();
+ Value maskVec = op.getMask();
+ Value valueToStoreVec = op.getValueToStore();
+
+ for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) {
+ Value indexSubVec = rewriter.create<vector::ExtractOp>(loc, indexVec, i);
+ Value maskSubVec = rewriter.create<vector::ExtractOp>(loc, maskVec, i);
+ Value valueToStoreSubVec =
+ rewriter.create<vector::ExtractOp>(loc, valueToStoreVec, i);
+ rewriter.create<vector::ScatterOp>(loc, op.getBase(), op.getIndices(),
+ indexSubVec, maskSubVec,
+ valueToStoreSubVec);
+ }
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
/// MemRef with updated indices that model the strided access.
///
/// ```mlir
/// %subview = memref.subview %M (...)
/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
+/// strided<[3]>>
/// ```
/// ==>
/// ```mlir
@@ -267,8 +315,13 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
};
} // namespace
-void mlir::vector::populateVectorGatherLoweringPatterns(
+void mlir::vector::populateVectorGatherScatterLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollGather, UnrollScatter>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<FlattenGather, RemoveStrideFromGatherSource,
- Gather1DToConditionalLoads>(patterns.getContext(), benefit);
+ patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
+ patterns.getContext(), benefit);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..44b4a25a051f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2074,52 +2074,6 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
// -----
-func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
- %0 = arith.constant 0: index
- %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
- return %1 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-
-// -----
-
-func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
- %0 = arith.constant 0: index
- %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
- return %1 : vector<2x[3]xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d_scalable
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-
-// -----
-
func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
%0 = arith.constant 3 : index
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1ab28b9df2d19..e8171e3f4853f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-ll...
[truncated]
|
Please only review the top commit, other commits are from depending prs. |
9cf5167
to
cfc0377
Compare
All dependencies have landed, this is ready for review now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the addition! Could you add tests for the new pattern?
Basically, if we are reaching feature parity between vector.gather
and vector.scatter
, we should also make sure that the testing coverage is similar. I see two test files that exercise unrolling/lowering of vector.gather
:
I thought about this, and my reasoning about this is:
Hopefully this makes it clear why the test coverage in this patch is enough. |
Note that I do plan to add support for vector-gather-lowering and vector-transfer-unroll but I'm waiting till the work on passing index_vecs per dimension lands (https://discourse.llvm.org/t/rfc-improving-gather-codegen-for-vector-dialect/85011/12) because it makes the lowering not require linearization which makes it more robust. Just mentioning this in case you are wondering why this patch only adds one pattern out of all the available ones that we should. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! LGTM % minor comments. Please, wait for other discussions to resolve. Thanks!
for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { | ||
int64_t thisIdx[1] = {i}; | ||
|
||
Value indexSubVec = | ||
rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx); | ||
Value maskSubVec = | ||
rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx); | ||
Value indexSubVec = rewriter.create<vector::ExtractOp>(loc, indexVec, i); | ||
Value maskSubVec = rewriter.create<vector::ExtractOp>(loc, maskVec, i); | ||
Value passThruSubVec = | ||
rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx); | ||
rewriter.create<vector::ExtractOp>(loc, passThruVec, i); | ||
Value subGather = rewriter.create<vector::GatherOp>( | ||
loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec, | ||
passThruSubVec); | ||
result = | ||
rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx); | ||
result = rewriter.create<vector::InsertOp>(loc, subGather, result, i); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonder if we could have a generic utility for this?
/// %g0 = vector.extract %valueToStore[0] : vector<3xf32> from vector<2x3xf32> | ||
/// vector.scatter %base[%c0][%v0], %mask0, %g0 | ||
/// %g1 = vector.extract %valueToStore[1] : vector<3xf32> from vector<2x3xf32> | ||
/// vector.scatter %base[%c0][%v0], %mask0, %g1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
%c0
-> %c1
?
Thanks for the explanation and the thoughtful analysis - that makes sense! That said, I think it would still be valuable to include some finer-grained testing, for a couple of reasons:
Looking at your patch, it seems like this TD op is a good fit for testing the new Now, even though this patch is only about scatter, I think it would be great and valuable to add symmetric tests for both gather and scatter - probably in a new dedicated test file. I realise this goes beyond the immediate scope though, so this is just a kind request. Also, this could be a good opportunity to rename the op to something like:
Thanks again for all the work here — it’s coming together nicely! -Andrzej (*) Just an assumption based on experience. |
This patch adds a UnrollScatter pattern for vector.scatter, exactly same as UnrollGather for vector.gather, allowing us to lower multi dimensional vector.scatter by unrolling to 1-D vectors.
Discourse Discussion: https://discourse.llvm.org/t/rfc-improving-gather-codegen-for-vector-dialect/85011/13