Skip to content

Commit c577f91

Browse files
authored
[mlir][vector] Add support for linearizing Extract, ExtractStridedSlice, Shuffle VectorOps in VectorLinearize (#88204)
This PR adds support for converting `vector.extract_strided_slice` and `vector.extract` operations to equivalent `vector.shuffle` operations that operates on linearized (1-D) vectors. `vector.shuffle` operations operating on n-D (n > 1) are also converted to equivalent shuffle operations working on linearized vectors.
1 parent 44713f1 commit c577f91

File tree

4 files changed

+371
-0
lines changed

4 files changed

+371
-0
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,13 @@ void populateVectorLinearizeTypeConversionsAndLegality(
389389
TypeConverter &typeConverter, RewritePatternSet &patterns,
390390
ConversionTarget &target, unsigned targetBitWidth);
391391

392+
/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
393+
/// vector shuffle operations.
394+
void populateVectorLinearizeShuffleLikeOpsPatterns(TypeConverter &typeConverter,
395+
RewritePatternSet &patterns,
396+
ConversionTarget &target,
397+
unsigned targetBitWidth);
398+
392399
} // namespace vector
393400
} // namespace mlir
394401

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,16 @@
1313
#include "mlir/Dialect/Arith/IR/Arith.h"
1414
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1515
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
16+
#include "mlir/IR/Attributes.h"
17+
#include "mlir/IR/BuiltinAttributes.h"
18+
#include "mlir/IR/Operation.h"
1619
#include "mlir/IR/PatternMatch.h"
1720
#include "mlir/IR/TypeUtilities.h"
21+
#include "mlir/Support/LogicalResult.h"
1822
#include "mlir/Transforms/DialectConversion.h"
23+
#include "llvm/ADT/ArrayRef.h"
24+
#include <cstdint>
25+
#include <numeric>
1926

2027
using namespace mlir;
2128

@@ -103,6 +110,251 @@ struct LinearizeVectorizable final
103110
return success();
104111
}
105112

113+
private:
114+
unsigned targetVectorBitWidth;
115+
};
116+
117+
/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
118+
/// on a linearized vector.
119+
/// Following,
120+
/// vector.extract_strided_slice %source
121+
/// { offsets = [..], strides = [..], sizes = [..] }
122+
/// is converted to :
123+
/// %source_1d = vector.shape_cast %source
124+
/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
125+
/// %out_nd = vector.shape_cast %out_1d
126+
/// `shuffle_indices_1d` is computed using the offsets and sizes of the
127+
/// extraction.
128+
struct LinearizeVectorExtractStridedSlice final
129+
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
130+
using OpConversionPattern::OpConversionPattern;
131+
LinearizeVectorExtractStridedSlice(
132+
const TypeConverter &typeConverter, MLIRContext *context,
133+
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
134+
PatternBenefit benefit = 1)
135+
: OpConversionPattern(typeConverter, context, benefit),
136+
targetVectorBitWidth(targetVectBitWidth) {}
137+
138+
LogicalResult
139+
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
140+
ConversionPatternRewriter &rewriter) const override {
141+
Type dstType = getTypeConverter()->convertType(extractOp.getType());
142+
assert(!(extractOp.getVector().getType().isScalable() ||
143+
dstType.cast<VectorType>().isScalable()) &&
144+
"scalable vectors are not supported.");
145+
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
146+
return rewriter.notifyMatchFailure(
147+
extractOp, "Can't flatten since targetBitWidth <= OpSize");
148+
149+
ArrayAttr offsets = extractOp.getOffsets();
150+
ArrayAttr sizes = extractOp.getSizes();
151+
ArrayAttr strides = extractOp.getStrides();
152+
if (!isConstantIntValue(strides[0], 1))
153+
return rewriter.notifyMatchFailure(
154+
extractOp, "Strided slice with stride != 1 is not supported.");
155+
Value srcVector = adaptor.getVector();
156+
// If kD offsets are specified for nD source vector (n > k), the granularity
157+
// of the extraction is greater than 1. In this case last (n-k) dimensions
158+
// form the extraction granularity.
159+
// Example :
160+
// vector.extract_strided_slice %src {
161+
// offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
162+
// vector<4x8x8xf32> to vector<2x2x8xf32>
163+
// Here, extraction granularity is 8.
164+
int64_t extractGranularitySize = 1;
165+
int64_t nD = extractOp.getSourceVectorType().getRank();
166+
int64_t kD = (int64_t)offsets.size();
167+
int64_t k = kD;
168+
while (k < nD) {
169+
extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
170+
++k;
171+
}
172+
// Get total number of extracted slices.
173+
int64_t nExtractedSlices = 1;
174+
for (Attribute size : sizes) {
175+
nExtractedSlices *= size.cast<IntegerAttr>().getInt();
176+
}
177+
// Compute the strides of the source vector considering first k dimensions.
178+
llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
179+
for (int i = kD - 2; i >= 0; --i) {
180+
sourceStrides[i] = sourceStrides[i + 1] *
181+
extractOp.getSourceVectorType().getShape()[i + 1];
182+
}
183+
// Final shuffle indices has nExtractedSlices * extractGranularitySize
184+
// elements.
185+
llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
186+
extractGranularitySize);
187+
// Compute the strides of the extracted kD vector.
188+
llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
189+
// Compute extractedStrides.
190+
for (int i = kD - 2; i >= 0; --i) {
191+
extractedStrides[i] =
192+
extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt();
193+
}
194+
// Iterate over all extracted slices from 0 to nExtractedSlices - 1
195+
// and compute the multi-dimensional index and the corresponding linearized
196+
// index within the source vector.
197+
for (int64_t i = 0; i < nExtractedSlices; ++i) {
198+
int64_t index = i;
199+
// Compute the corresponding multi-dimensional index.
200+
llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
201+
for (int64_t j = 0; j < kD; ++j) {
202+
multiDimIndex[j] = (index / extractedStrides[j]);
203+
index -= multiDimIndex[j] * extractedStrides[j];
204+
}
205+
// Compute the corresponding linearized index in the source vector
206+
// i.e. shift the multiDimIndex by the offsets.
207+
int64_t linearizedIndex = 0;
208+
for (int64_t j = 0; j < kD; ++j) {
209+
linearizedIndex +=
210+
(offsets[j].cast<IntegerAttr>().getInt() + multiDimIndex[j]) *
211+
sourceStrides[j];
212+
}
213+
// Fill the indices array form linearizedIndex to linearizedIndex +
214+
// extractGranularitySize.
215+
for (int64_t j = 0; j < extractGranularitySize; ++j) {
216+
indices[i * extractGranularitySize + j] = linearizedIndex + j;
217+
}
218+
}
219+
// Perform a shuffle to extract the kD vector.
220+
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
221+
extractOp, dstType, srcVector, srcVector,
222+
rewriter.getI64ArrayAttr(indices));
223+
return success();
224+
}
225+
226+
private:
227+
unsigned targetVectorBitWidth;
228+
};
229+
230+
/// This pattern converts the ShuffleOp that works on nD (n > 1)
231+
/// vectors to a ShuffleOp that works on linearized vectors.
232+
/// Following,
233+
/// vector.shuffle %v1, %v2 [ shuffle_indices ]
234+
/// is converted to :
235+
/// %v1_1d = vector.shape_cast %v1
236+
/// %v2_1d = vector.shape_cast %v2
237+
/// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
238+
/// %out_nd = vector.shape_cast %out_1d
239+
// `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
240+
/// of the original shuffle operation.
241+
struct LinearizeVectorShuffle final
242+
: public OpConversionPattern<vector::ShuffleOp> {
243+
using OpConversionPattern::OpConversionPattern;
244+
LinearizeVectorShuffle(
245+
const TypeConverter &typeConverter, MLIRContext *context,
246+
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
247+
PatternBenefit benefit = 1)
248+
: OpConversionPattern(typeConverter, context, benefit),
249+
targetVectorBitWidth(targetVectBitWidth) {}
250+
251+
LogicalResult
252+
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
253+
ConversionPatternRewriter &rewriter) const override {
254+
Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
255+
assert(!(shuffleOp.getV1VectorType().isScalable() ||
256+
shuffleOp.getV2VectorType().isScalable() ||
257+
dstType.cast<VectorType>().isScalable()) &&
258+
"scalable vectors are not supported.");
259+
if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
260+
return rewriter.notifyMatchFailure(
261+
shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
262+
263+
Value vec1 = adaptor.getV1();
264+
Value vec2 = adaptor.getV2();
265+
int shuffleSliceLen = 1;
266+
int rank = shuffleOp.getV1().getType().getRank();
267+
268+
// If rank > 1, we need to do the shuffle in the granularity of slices
269+
// instead of scalars. Size of the slice is equal to the rank-1 innermost
270+
// dims. Mask of the shuffle op specifies which slice to take from the
271+
// outermost dim.
272+
if (rank > 1) {
273+
llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
274+
for (unsigned i = 1; i < shape.size(); ++i) {
275+
shuffleSliceLen *= shape[i];
276+
}
277+
}
278+
279+
// For each value in the mask, we generate the indices of the source vectors
280+
// that needs to be shuffled to the destination vector. If shuffleSliceLen >
281+
// 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
282+
// elements) instead of scalars.
283+
ArrayAttr mask = shuffleOp.getMask();
284+
int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
285+
llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
286+
for (auto [i, value] :
287+
llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
288+
289+
int64_t v = value.getZExtValue();
290+
std::iota(indices.begin() + shuffleSliceLen * i,
291+
indices.begin() + shuffleSliceLen * (i + 1),
292+
shuffleSliceLen * v);
293+
}
294+
295+
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
296+
shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
297+
return success();
298+
}
299+
300+
private:
301+
unsigned targetVectorBitWidth;
302+
};
303+
304+
/// This pattern converts the ExtractOp to a ShuffleOp that works on a
305+
/// linearized vector.
306+
/// Following,
307+
/// vector.extract %source [ position ]
308+
/// is converted to :
309+
/// %source_1d = vector.shape_cast %source
310+
/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
311+
/// %out_nd = vector.shape_cast %out_1d
312+
/// `shuffle_indices_1d` is computed using the position of the original extract.
313+
struct LinearizeVectorExtract final
314+
: public OpConversionPattern<vector::ExtractOp> {
315+
using OpConversionPattern::OpConversionPattern;
316+
LinearizeVectorExtract(
317+
const TypeConverter &typeConverter, MLIRContext *context,
318+
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
319+
PatternBenefit benefit = 1)
320+
: OpConversionPattern(typeConverter, context, benefit),
321+
targetVectorBitWidth(targetVectBitWidth) {}
322+
LogicalResult
323+
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
324+
ConversionPatternRewriter &rewriter) const override {
325+
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
326+
assert(!(extractOp.getVector().getType().isScalable() ||
327+
dstTy.cast<VectorType>().isScalable()) &&
328+
"scalable vectors are not supported.");
329+
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
330+
return rewriter.notifyMatchFailure(
331+
extractOp, "Can't flatten since targetBitWidth <= OpSize");
332+
333+
// Dynamic position is not supported.
334+
if (extractOp.hasDynamicPosition())
335+
return rewriter.notifyMatchFailure(extractOp,
336+
"dynamic position is not supported.");
337+
338+
llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
339+
int64_t size = extractOp.getVector().getType().getNumElements();
340+
341+
// Compute linearized offset.
342+
int64_t linearizedOffset = 0;
343+
llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
344+
for (auto [i, off] : llvm::enumerate(offsets)) {
345+
size /= shape[i];
346+
linearizedOffset += offsets[i] * size;
347+
}
348+
349+
llvm::SmallVector<int64_t, 2> indices(size);
350+
std::iota(indices.begin(), indices.end(), linearizedOffset);
351+
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
352+
extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
353+
rewriter.getI64ArrayAttr(indices));
354+
355+
return success();
356+
}
357+
106358
private:
107359
unsigned targetVectorBitWidth;
108360
};
@@ -145,3 +397,21 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
145397
patterns.add<LinearizeConstant, LinearizeVectorizable>(
146398
typeConverter, patterns.getContext(), targetBitWidth);
147399
}
400+
401+
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
402+
TypeConverter &typeConverter, RewritePatternSet &patterns,
403+
ConversionTarget &target, unsigned int targetBitWidth) {
404+
target.addDynamicallyLegalOp<vector::ShuffleOp>(
405+
[=](vector::ShuffleOp shuffleOp) -> bool {
406+
return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
407+
? (typeConverter.isLegal(shuffleOp) &&
408+
shuffleOp.getResult()
409+
.getType()
410+
.cast<mlir::VectorType>()
411+
.getRank() == 1)
412+
: true;
413+
});
414+
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
415+
LinearizeVectorExtractStridedSlice>(
416+
typeConverter, patterns.getContext(), targetBitWidth);
417+
}

0 commit comments

Comments
 (0)