Skip to content

Commit cfc0377

Browse files
committed
[mlir][vector] Allow lowering multi-dim scatters to LLVM
1 parent 24a8e18 commit cfc0377

File tree

8 files changed

+74
-22
lines changed

8 files changed

+74
-22
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,12 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
244244
/// [UnrollGather]
245245
/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
246246
/// outermost dimension.
247-
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
248-
PatternBenefit benefit = 1);
247+
///
248+
/// [UnrollScatter]
249+
/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
250+
/// outermost dimension.
251+
void populateVectorGatherScatterLoweringPatterns(RewritePatternSet &patterns,
252+
PatternBenefit benefit = 1);
249253

250254
/// Populate the pattern set with the following patterns:
251255
///

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,9 @@ class VectorGatherOpConversion
286286
// Resolve address.
287287
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
288288
adaptor.getIndices(), rewriter);
289-
Value base = adaptor.getBase();
290289
Value ptrs =
291290
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
292-
base, ptr, adaptor.getIndexVec(), vType);
291+
adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
293292

294293
// Replace with the gather intrinsic.
295294
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
@@ -308,7 +307,7 @@ class VectorScatterOpConversion
308307
LogicalResult
309308
matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
310309
ConversionPatternRewriter &rewriter) const override {
311-
auto loc = scatter->getLoc();
310+
Location loc = scatter->getLoc();
312311
MemRefType memRefType = scatter.getMemRefType();
313312

314313
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
8181
populateVectorInsertExtractStridedSliceTransforms(patterns);
8282
populateVectorStepLoweringPatterns(patterns);
8383
populateVectorRankReducingFMAPattern(patterns);
84-
populateVectorGatherLoweringPatterns(patterns);
84+
populateVectorGatherScatterLoweringPatterns(patterns);
8585
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
8686
}
8787

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
138138

139139
void transform::ApplyLowerGatherPatternsOp::populatePatterns(
140140
RewritePatternSet &patterns) {
141-
vector::populateVectorGatherLoweringPatterns(patterns);
141+
vector::populateVectorGatherScatterLoweringPatterns(patterns);
142142
}
143143

144144
void transform::ApplyLowerScanPatternsOp::populatePatterns(

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
33
LowerVectorBitCast.cpp
44
LowerVectorBroadcast.cpp
55
LowerVectorContract.cpp
6-
LowerVectorGather.cpp
6+
LowerVectorGatherScatter.cpp
77
LowerVectorInterleave.cpp
88
LowerVectorMask.cpp
99
LowerVectorMultiReduction.cpp

mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp renamed to mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ using namespace mlir;
3838
using namespace mlir::vector;
3939

4040
namespace {
41+
4142
/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
4243
/// outermost dimension. For example:
4344
/// ```
@@ -81,26 +82,72 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
8182
VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
8283

8384
for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
84-
int64_t thisIdx[1] = {i};
85-
86-
Value indexSubVec =
87-
rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
88-
Value maskSubVec =
89-
rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
85+
Value indexSubVec = rewriter.create<vector::ExtractOp>(loc, indexVec, i);
86+
Value maskSubVec = rewriter.create<vector::ExtractOp>(loc, maskVec, i);
9087
Value passThruSubVec =
91-
rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
88+
rewriter.create<vector::ExtractOp>(loc, passThruVec, i);
9289
Value subGather = rewriter.create<vector::GatherOp>(
9390
loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
9491
passThruSubVec);
95-
result =
96-
rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
92+
result = rewriter.create<vector::InsertOp>(loc, subGather, result, i);
9793
}
9894

9995
rewriter.replaceOp(op, result);
10096
return success();
10197
}
10298
};
10399

100+
/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
101+
/// outermost dimension. For example:
102+
/// ```
103+
/// %g = vector.scatter %base[%c0][%v], %mask, %valueToStore : ...
104+
/// vector<2x3xf32>
105+
///
106+
/// ==>
107+
///
108+
/// %g0 = vector.extract %valueToStore[0] : vector<3xf32> from vector<2x3xf32>
109+
/// vector.scatter %base[%c0][%v0], %mask0, %g0
110+
/// %g1 = vector.extract %valueToStore[1] : vector<3xf32> from vector<2x3xf32>
111+
/// vector.scatter %base[%c0][%v0], %mask0, %g1
112+
/// ```
113+
///
114+
/// When applied exhaustively, this will produce a sequence of 1-d scatter ops.
115+
///
116+
/// Supports vector types with a fixed leading dimension.
117+
struct UnrollScatter : OpRewritePattern<vector::ScatterOp> {
118+
using OpRewritePattern::OpRewritePattern;
119+
120+
LogicalResult matchAndRewrite(vector::ScatterOp op,
121+
PatternRewriter &rewriter) const override {
122+
VectorType vectorTy = op.getVectorType();
123+
if (vectorTy.getRank() < 2)
124+
return rewriter.notifyMatchFailure(op, "already 1-D");
125+
126+
// Unrolling doesn't take vscale into account. Pattern is disabled for
127+
// vectors with leading scalable dim(s).
128+
if (vectorTy.getScalableDims().front())
129+
return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
130+
131+
Location loc = op.getLoc();
132+
Value indexVec = op.getIndexVec();
133+
Value maskVec = op.getMask();
134+
Value valueToStoreVec = op.getValueToStore();
135+
136+
for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) {
137+
Value indexSubVec = rewriter.create<vector::ExtractOp>(loc, indexVec, i);
138+
Value maskSubVec = rewriter.create<vector::ExtractOp>(loc, maskVec, i);
139+
Value valueToStoreSubVec =
140+
rewriter.create<vector::ExtractOp>(loc, valueToStoreVec, i);
141+
rewriter.create<vector::ScatterOp>(loc, op.getBase(), op.getIndices(),
142+
indexSubVec, maskSubVec,
143+
valueToStoreSubVec);
144+
}
145+
146+
rewriter.eraseOp(op);
147+
return success();
148+
}
149+
};
150+
104151
/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
105152
/// MemRef with updated indices that model the strided access.
106153
///
@@ -268,9 +315,9 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
268315
};
269316
} // namespace
270317

271-
void mlir::vector::populateVectorGatherLoweringPatterns(
318+
void mlir::vector::populateVectorGatherScatterLoweringPatterns(
272319
RewritePatternSet &patterns, PatternBenefit benefit) {
273-
patterns.add<UnrollGather>(patterns.getContext(), benefit);
320+
patterns.add<UnrollGather, UnrollScatter>(patterns.getContext(), benefit);
274321
}
275322

276323
void mlir::vector::populateVectorGatherToConditionalLoadPatterns(

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,7 +1734,8 @@ func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2
17341734
}
17351735

17361736
// CHECK-LABEL: func @scatter_with_mask
1737-
// CHECK: vector.scatter
1737+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr>
1738+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr>
17381739

17391740
// -----
17401741

@@ -1749,7 +1750,8 @@ func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]x
17491750
}
17501751

17511752
// CHECK-LABEL: func @scatter_with_mask_scalable
1752-
// CHECK: vector.scatter
1753+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
1754+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
17531755

17541756
// -----
17551757

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@ struct TestVectorGatherLowering
781781

782782
void runOnOperation() override {
783783
RewritePatternSet patterns(&getContext());
784-
populateVectorGatherLoweringPatterns(patterns);
784+
populateVectorGatherScatterLoweringPatterns(patterns);
785785
populateVectorGatherToConditionalLoadPatterns(patterns);
786786
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
787787
}

0 commit comments

Comments
 (0)