Skip to content

Commit 51115c2

Browse files
committed
[mlir][vector] Add support for linearizing Insert VectorOp in VectorLinearize
Building on top of llvm#88204, this commit adds support for InsertOp.
1 parent 90d2f8c commit 51115c2

File tree

2 files changed

+128
-1
lines changed

2 files changed

+128
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,22 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
4444
return true;
4545
}
4646

47+
static bool isLessThanOrEqualTargetBitWidth(mlir::Type t,
48+
unsigned targetBitWidth) {
49+
VectorType vecType = dyn_cast<VectorType>(t);
50+
// Reject index since getElementTypeBitWidth will abort for Index types.
51+
if (!vecType || vecType.getElementType().isIndex())
52+
return false;
53+
// There are no dimension to fold if it is a 0-D vector.
54+
if (vecType.getRank() == 0)
55+
return false;
56+
unsigned trailingVecDimBitWidth =
57+
vecType.getShape().back() * vecType.getElementTypeBitWidth();
58+
if (trailingVecDimBitWidth > targetBitWidth)
59+
return false;
60+
return true;
61+
}
62+
4763
namespace {
4864
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
4965
using OpConversionPattern::OpConversionPattern;
@@ -355,6 +371,88 @@ struct LinearizeVectorExtract final
355371
return success();
356372
}
357373

374+
private:
375+
unsigned targetVectorBitWidth;
376+
};
377+
378+
/// This pattern converts the InsertOp to a ShuffleOp that works on a
379+
/// linearized vector.
380+
/// Following,
381+
/// vector.insert %source %destination [ position ]
382+
/// is converted to :
383+
/// %source_1d = vector.shape_cast %source
384+
/// %destination_1d = vector.shape_cast %destination
385+
/// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
386+
/// ] %out_nd = vector.shape_cast %out_1d
387+
/// `shuffle_indices_1d` is computed using the position of the original insert.
388+
struct LinearizeVectorInsert final
389+
: public mlir::OpConversionPattern<mlir::vector::InsertOp> {
390+
using OpConversionPattern::OpConversionPattern;
391+
LinearizeVectorInsert(
392+
const TypeConverter &typeConverter, MLIRContext *context,
393+
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
394+
PatternBenefit benefit = 1)
395+
: OpConversionPattern(typeConverter, context, benefit),
396+
targetVectorBitWidth(targetVectBitWidth) {}
397+
mlir::LogicalResult
398+
matchAndRewrite(mlir::vector::InsertOp insertOp, OpAdaptor adaptor,
399+
mlir::ConversionPatternRewriter &rewriter) const override {
400+
Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
401+
assert(!(insertOp.getDestVectorType().isScalable() ||
402+
cast<VectorType>(dstTy).isScalable()) &&
403+
"scalable vectors are not supported.");
404+
405+
if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
406+
targetVectorBitWidth))
407+
return rewriter.notifyMatchFailure(
408+
insertOp, "Can't flatten since targetBitWidth < OpSize");
409+
410+
// dynamic position is not supported
411+
if (insertOp.hasDynamicPosition())
412+
return rewriter.notifyMatchFailure(insertOp,
413+
"dynamic position is not supported.");
414+
auto srcTy = insertOp.getSourceType();
415+
auto srcAsVec = mlir::dyn_cast<mlir::VectorType>(srcTy);
416+
uint64_t srcSize = 0;
417+
if (srcAsVec) {
418+
srcSize = srcAsVec.getNumElements();
419+
} else {
420+
return rewriter.notifyMatchFailure(insertOp,
421+
"scalars are not supported.");
422+
}
423+
424+
auto dstShape = insertOp.getDestVectorType().getShape();
425+
const auto dstSize = insertOp.getDestVectorType().getNumElements();
426+
auto dstSizeForOffsets = dstSize;
427+
428+
// compute linearized offset
429+
int64_t linearizedOffset = 0;
430+
auto offsetsNd = insertOp.getStaticPosition();
431+
for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
432+
dstSizeForOffsets /= dstShape[dim];
433+
linearizedOffset += offset * dstSizeForOffsets;
434+
}
435+
436+
llvm::SmallVector<int64_t, 2> indices(dstSize);
437+
auto origValsUntil = indices.begin();
438+
std::advance(origValsUntil, linearizedOffset);
439+
std::iota(indices.begin(), origValsUntil,
440+
0); // original values that remain [0, offset)
441+
auto newValsUntil = origValsUntil;
442+
std::advance(newValsUntil, srcSize);
443+
std::iota(origValsUntil, newValsUntil,
444+
dstSize); // new values [offset, offset+srcNumElements)
445+
std::iota(newValsUntil, indices.end(),
446+
linearizedOffset + srcSize); // the rest of original values
447+
// [offset+srcNumElements, end)
448+
449+
rewriter.replaceOpWithNewOp<mlir::vector::ShuffleOp>(
450+
insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
451+
rewriter.getI64ArrayAttr(indices));
452+
453+
return mlir::success();
454+
}
455+
358456
private:
359457
unsigned targetVectorBitWidth;
360458
};
@@ -410,6 +508,6 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
410508
: true;
411509
});
412510
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
413-
LinearizeVectorExtractStridedSlice>(
511+
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
414512
typeConverter, patterns.getContext(), targetBitWidth);
415513
}

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,32 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
245245
%0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
246246
return %0 : vector<8x2xf32>
247247
}
248+
249+
// -----
250+
// ALL-LABEL: test_vector_insert
251+
// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
252+
func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
253+
// DEFAULT: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
254+
// DEFAULT: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
255+
// DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
256+
// DEFAULT-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
257+
// DEFAULT-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
258+
// DEFAULT-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
259+
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
260+
// DEFAULT: return %[[RES]] : vector<2x8x4xf32>
261+
262+
// BW-128: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
263+
// BW-128: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
264+
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
265+
// BW-128-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
266+
// BW-128-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
267+
// BW-128-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
268+
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
269+
// BW-128: return %[[RES]] : vector<2x8x4xf32>
270+
271+
// BW-0: %[[RES:.*]] = vector.insert %[[SRC]], %[[DEST]] [0] : vector<8x4xf32> into vector<2x8x4xf32>
272+
// BW-0: return %[[RES]] : vector<2x8x4xf32>
273+
274+
%0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
275+
return %0 : vector<2x8x4xf32>
276+
}

0 commit comments

Comments
 (0)