@@ -44,6 +44,22 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
44
44
return true ;
45
45
}
46
46
47
+ static bool isLessThanOrEqualTargetBitWidth (mlir::Type t,
48
+ unsigned targetBitWidth) {
49
+ VectorType vecType = dyn_cast<VectorType>(t);
50
+ // Reject index since getElementTypeBitWidth will abort for Index types.
51
+ if (!vecType || vecType.getElementType ().isIndex ())
52
+ return false ;
53
+ // There are no dimension to fold if it is a 0-D vector.
54
+ if (vecType.getRank () == 0 )
55
+ return false ;
56
+ unsigned trailingVecDimBitWidth =
57
+ vecType.getShape ().back () * vecType.getElementTypeBitWidth ();
58
+ if (trailingVecDimBitWidth > targetBitWidth)
59
+ return false ;
60
+ return true ;
61
+ }
62
+
47
63
namespace {
48
64
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
49
65
using OpConversionPattern::OpConversionPattern;
@@ -355,6 +371,88 @@ struct LinearizeVectorExtract final
355
371
return success ();
356
372
}
357
373
374
+ private:
375
+ unsigned targetVectorBitWidth;
376
+ };
377
+
378
+ // / This pattern converts the InsertOp to a ShuffleOp that works on a
379
+ // / linearized vector.
380
+ // / Following,
381
+ // / vector.insert %source %destination [ position ]
382
+ // / is converted to :
383
+ // / %source_1d = vector.shape_cast %source
384
+ // / %destination_1d = vector.shape_cast %destination
385
+ // / %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
386
+ // / ] %out_nd = vector.shape_cast %out_1d
387
+ // / `shuffle_indices_1d` is computed using the position of the original insert.
388
+ struct LinearizeVectorInsert final
389
+ : public mlir::OpConversionPattern<mlir::vector::InsertOp> {
390
+ using OpConversionPattern::OpConversionPattern;
391
+ LinearizeVectorInsert (
392
+ const TypeConverter &typeConverter, MLIRContext *context,
393
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned >::max(),
394
+ PatternBenefit benefit = 1 )
395
+ : OpConversionPattern(typeConverter, context, benefit),
396
+ targetVectorBitWidth (targetVectBitWidth) {}
397
+ mlir::LogicalResult
398
+ matchAndRewrite (mlir::vector::InsertOp insertOp, OpAdaptor adaptor,
399
+ mlir::ConversionPatternRewriter &rewriter) const override {
400
+ Type dstTy = getTypeConverter ()->convertType (insertOp.getDestVectorType ());
401
+ assert (!(insertOp.getDestVectorType ().isScalable () ||
402
+ cast<VectorType>(dstTy).isScalable ()) &&
403
+ " scalable vectors are not supported." );
404
+
405
+ if (!isLessThanOrEqualTargetBitWidth (insertOp.getSourceType (),
406
+ targetVectorBitWidth))
407
+ return rewriter.notifyMatchFailure (
408
+ insertOp, " Can't flatten since targetBitWidth < OpSize" );
409
+
410
+ // dynamic position is not supported
411
+ if (insertOp.hasDynamicPosition ())
412
+ return rewriter.notifyMatchFailure (insertOp,
413
+ " dynamic position is not supported." );
414
+ auto srcTy = insertOp.getSourceType ();
415
+ auto srcAsVec = mlir::dyn_cast<mlir::VectorType>(srcTy);
416
+ uint64_t srcSize = 0 ;
417
+ if (srcAsVec) {
418
+ srcSize = srcAsVec.getNumElements ();
419
+ } else {
420
+ return rewriter.notifyMatchFailure (insertOp,
421
+ " scalars are not supported." );
422
+ }
423
+
424
+ auto dstShape = insertOp.getDestVectorType ().getShape ();
425
+ const auto dstSize = insertOp.getDestVectorType ().getNumElements ();
426
+ auto dstSizeForOffsets = dstSize;
427
+
428
+ // compute linearized offset
429
+ int64_t linearizedOffset = 0 ;
430
+ auto offsetsNd = insertOp.getStaticPosition ();
431
+ for (auto [dim, offset] : llvm::enumerate (offsetsNd)) {
432
+ dstSizeForOffsets /= dstShape[dim];
433
+ linearizedOffset += offset * dstSizeForOffsets;
434
+ }
435
+
436
+ llvm::SmallVector<int64_t , 2 > indices (dstSize);
437
+ auto origValsUntil = indices.begin ();
438
+ std::advance (origValsUntil, linearizedOffset);
439
+ std::iota (indices.begin (), origValsUntil,
440
+ 0 ); // original values that remain [0, offset)
441
+ auto newValsUntil = origValsUntil;
442
+ std::advance (newValsUntil, srcSize);
443
+ std::iota (origValsUntil, newValsUntil,
444
+ dstSize); // new values [offset, offset+srcNumElements)
445
+ std::iota (newValsUntil, indices.end (),
446
+ linearizedOffset + srcSize); // the rest of original values
447
+ // [offset+srcNumElements, end)
448
+
449
+ rewriter.replaceOpWithNewOp <mlir::vector::ShuffleOp>(
450
+ insertOp, dstTy, adaptor.getDest (), adaptor.getSource (),
451
+ rewriter.getI64ArrayAttr (indices));
452
+
453
+ return mlir::success ();
454
+ }
455
+
358
456
private:
359
457
unsigned targetVectorBitWidth;
360
458
};
@@ -410,6 +508,6 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
410
508
: true ;
411
509
});
412
510
patterns.add <LinearizeVectorShuffle, LinearizeVectorExtract,
413
- LinearizeVectorExtractStridedSlice>(
511
+ LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
414
512
typeConverter, patterns.getContext (), targetBitWidth);
415
513
}
0 commit comments