Skip to content

Commit 18c92bb

Browse files
committed
[mlir][vector] Add support for linearizing Insert VectorOp in VectorLinearize
Building on top of llvm#88204, this commit adds support for InsertOp.
1 parent 90d2f8c commit 18c92bb

File tree

2 files changed

+126
-1
lines changed

2 files changed

+126
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,20 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
4444
return true;
4545
}
4646

47+
static bool isLessThanOrEqualTargetBitWidth(Type t,
48+
unsigned targetBitWidth) {
49+
VectorType vecType = dyn_cast<VectorType>(t);
50+
// Reject index since getElementTypeBitWidth will abort for Index types.
51+
if (!vecType || vecType.getElementType().isIndex())
52+
return false;
53+
// There are no dimension to fold if it is a 0-D vector.
54+
if (vecType.getRank() == 0)
55+
return false;
56+
unsigned trailingVecDimBitWidth =
57+
vecType.getShape().back() * vecType.getElementTypeBitWidth();
58+
return trailingVecDimBitWidth <= targetBitWidth
59+
}
60+
4761
namespace {
4862
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
4963
using OpConversionPattern::OpConversionPattern;
@@ -355,6 +369,88 @@ struct LinearizeVectorExtract final
355369
return success();
356370
}
357371

372+
private:
373+
unsigned targetVectorBitWidth;
374+
};
375+
376+
/// This pattern converts the InsertOp to a ShuffleOp that works on a
377+
/// linearized vector.
378+
/// Following,
379+
/// vector.insert %source %destination [ position ]
380+
/// is converted to :
381+
/// %source_1d = vector.shape_cast %source
382+
/// %destination_1d = vector.shape_cast %destination
383+
/// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
384+
/// ] %out_nd = vector.shape_cast %out_1d
385+
/// `shuffle_indices_1d` is computed using the position of the original insert.
386+
struct LinearizeVectorInsert final
387+
: public OpConversionPattern<vector::InsertOp> {
388+
using OpConversionPattern::OpConversionPattern;
389+
LinearizeVectorInsert(
390+
const TypeConverter &typeConverter, MLIRContext *context,
391+
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
392+
PatternBenefit benefit = 1)
393+
: OpConversionPattern(typeConverter, context, benefit),
394+
targetVectorBitWidth(targetVectBitWidth) {}
395+
LogicalResult
396+
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
397+
ConversionPatternRewriter &rewriter) const override {
398+
Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
399+
assert(!(insertOp.getDestVectorType().isScalable() ||
400+
cast<VectorType>(dstTy).isScalable()) &&
401+
"scalable vectors are not supported.");
402+
403+
if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
404+
targetVectorBitWidth))
405+
return rewriter.notifyMatchFailure(
406+
insertOp, "Can't flatten since targetBitWidth < OpSize");
407+
408+
// dynamic position is not supported
409+
if (insertOp.hasDynamicPosition())
410+
return rewriter.notifyMatchFailure(insertOp,
411+
"dynamic position is not supported.");
412+
auto srcTy = insertOp.getSourceType();
413+
auto srcAsVec = dyn_cast<VectorType>(srcTy);
414+
uint64_t srcSize = 0;
415+
if (srcAsVec) {
416+
srcSize = srcAsVec.getNumElements();
417+
} else {
418+
return rewriter.notifyMatchFailure(insertOp,
419+
"scalars are not supported.");
420+
}
421+
422+
auto dstShape = insertOp.getDestVectorType().getShape();
423+
const auto dstSize = insertOp.getDestVectorType().getNumElements();
424+
auto dstSizeForOffsets = dstSize;
425+
426+
// compute linearized offset
427+
int64_t linearizedOffset = 0;
428+
auto offsetsNd = insertOp.getStaticPosition();
429+
for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
430+
dstSizeForOffsets /= dstShape[dim];
431+
linearizedOffset += offset * dstSizeForOffsets;
432+
}
433+
434+
llvm::SmallVector<int64_t, 2> indices(dstSize);
435+
auto origValsUntil = indices.begin();
436+
std::advance(origValsUntil, linearizedOffset);
437+
std::iota(indices.begin(), origValsUntil,
438+
0); // original values that remain [0, offset)
439+
auto newValsUntil = origValsUntil;
440+
std::advance(newValsUntil, srcSize);
441+
std::iota(origValsUntil, newValsUntil,
442+
dstSize); // new values [offset, offset+srcNumElements)
443+
std::iota(newValsUntil, indices.end(),
444+
linearizedOffset + srcSize); // the rest of original values
445+
// [offset+srcNumElements, end)
446+
447+
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
448+
insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
449+
rewriter.getI64ArrayAttr(indices));
450+
451+
return success();
452+
}
453+
358454
private:
359455
unsigned targetVectorBitWidth;
360456
};
@@ -410,6 +506,6 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
410506
: true;
411507
});
412508
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
413-
LinearizeVectorExtractStridedSlice>(
509+
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
414510
typeConverter, patterns.getContext(), targetBitWidth);
415511
}

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,32 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
245245
%0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
246246
return %0 : vector<8x2xf32>
247247
}
248+
249+
// -----
250+
// ALL-LABEL: test_vector_insert
251+
// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
252+
func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
253+
// DEFAULT: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
254+
// DEFAULT: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
255+
// DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
256+
// DEFAULT-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
257+
// DEFAULT-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
258+
// DEFAULT-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
259+
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
260+
// DEFAULT: return %[[RES]] : vector<2x8x4xf32>
261+
262+
// BW-128: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
263+
// BW-128: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
264+
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
265+
// BW-128-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
266+
// BW-128-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
267+
// BW-128-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
268+
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
269+
// BW-128: return %[[RES]] : vector<2x8x4xf32>
270+
271+
// BW-0: %[[RES:.*]] = vector.insert %[[SRC]], %[[DEST]] [0] : vector<8x4xf32> into vector<2x8x4xf32>
272+
// BW-0: return %[[RES]] : vector<2x8x4xf32>
273+
274+
%0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
275+
return %0 : vector<2x8x4xf32>
276+
}

0 commit comments

Comments
 (0)