@@ -44,6 +44,20 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
44
44
return true ;
45
45
}
46
46
47
+ static bool isLessThanOrEqualTargetBitWidth (Type t,
48
+ unsigned targetBitWidth) {
49
+ VectorType vecType = dyn_cast<VectorType>(t);
50
+ // Reject index since getElementTypeBitWidth will abort for Index types.
51
+ if (!vecType || vecType.getElementType ().isIndex ())
52
+ return false ;
53
+ // There are no dimension to fold if it is a 0-D vector.
54
+ if (vecType.getRank () == 0 )
55
+ return false ;
56
+ unsigned trailingVecDimBitWidth =
57
+ vecType.getShape ().back () * vecType.getElementTypeBitWidth ();
58
+ return trailingVecDimBitWidth <= targetBitWidth
59
+ }
60
+
47
61
namespace {
48
62
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
49
63
using OpConversionPattern::OpConversionPattern;
@@ -355,6 +369,88 @@ struct LinearizeVectorExtract final
355
369
return success ();
356
370
}
357
371
372
+ private:
373
+ unsigned targetVectorBitWidth;
374
+ };
375
+
376
+ // / This pattern converts the InsertOp to a ShuffleOp that works on a
377
+ // / linearized vector.
378
+ // / Following,
379
+ // / vector.insert %source %destination [ position ]
380
+ // / is converted to :
381
+ // / %source_1d = vector.shape_cast %source
382
+ // / %destination_1d = vector.shape_cast %destination
383
+ // / %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
384
+ // / ] %out_nd = vector.shape_cast %out_1d
385
+ // / `shuffle_indices_1d` is computed using the position of the original insert.
386
+ struct LinearizeVectorInsert final
387
+ : public OpConversionPattern<vector::InsertOp> {
388
+ using OpConversionPattern::OpConversionPattern;
389
+ LinearizeVectorInsert (
390
+ const TypeConverter &typeConverter, MLIRContext *context,
391
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned >::max(),
392
+ PatternBenefit benefit = 1 )
393
+ : OpConversionPattern(typeConverter, context, benefit),
394
+ targetVectorBitWidth (targetVectBitWidth) {}
395
+ LogicalResult
396
+ matchAndRewrite (vector::InsertOp insertOp, OpAdaptor adaptor,
397
+ ConversionPatternRewriter &rewriter) const override {
398
+ Type dstTy = getTypeConverter ()->convertType (insertOp.getDestVectorType ());
399
+ assert (!(insertOp.getDestVectorType ().isScalable () ||
400
+ cast<VectorType>(dstTy).isScalable ()) &&
401
+ " scalable vectors are not supported." );
402
+
403
+ if (!isLessThanOrEqualTargetBitWidth (insertOp.getSourceType (),
404
+ targetVectorBitWidth))
405
+ return rewriter.notifyMatchFailure (
406
+ insertOp, " Can't flatten since targetBitWidth < OpSize" );
407
+
408
+ // dynamic position is not supported
409
+ if (insertOp.hasDynamicPosition ())
410
+ return rewriter.notifyMatchFailure (insertOp,
411
+ " dynamic position is not supported." );
412
+ auto srcTy = insertOp.getSourceType ();
413
+ auto srcAsVec = dyn_cast<VectorType>(srcTy);
414
+ uint64_t srcSize = 0 ;
415
+ if (srcAsVec) {
416
+ srcSize = srcAsVec.getNumElements ();
417
+ } else {
418
+ return rewriter.notifyMatchFailure (insertOp,
419
+ " scalars are not supported." );
420
+ }
421
+
422
+ auto dstShape = insertOp.getDestVectorType ().getShape ();
423
+ const auto dstSize = insertOp.getDestVectorType ().getNumElements ();
424
+ auto dstSizeForOffsets = dstSize;
425
+
426
+ // compute linearized offset
427
+ int64_t linearizedOffset = 0 ;
428
+ auto offsetsNd = insertOp.getStaticPosition ();
429
+ for (auto [dim, offset] : llvm::enumerate (offsetsNd)) {
430
+ dstSizeForOffsets /= dstShape[dim];
431
+ linearizedOffset += offset * dstSizeForOffsets;
432
+ }
433
+
434
+ llvm::SmallVector<int64_t , 2 > indices (dstSize);
435
+ auto origValsUntil = indices.begin ();
436
+ std::advance (origValsUntil, linearizedOffset);
437
+ std::iota (indices.begin (), origValsUntil,
438
+ 0 ); // original values that remain [0, offset)
439
+ auto newValsUntil = origValsUntil;
440
+ std::advance (newValsUntil, srcSize);
441
+ std::iota (origValsUntil, newValsUntil,
442
+ dstSize); // new values [offset, offset+srcNumElements)
443
+ std::iota (newValsUntil, indices.end (),
444
+ linearizedOffset + srcSize); // the rest of original values
445
+ // [offset+srcNumElements, end)
446
+
447
+ rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
448
+ insertOp, dstTy, adaptor.getDest (), adaptor.getSource (),
449
+ rewriter.getI64ArrayAttr (indices));
450
+
451
+ return success ();
452
+ }
453
+
358
454
private:
359
455
unsigned targetVectorBitWidth;
360
456
};
@@ -410,6 +506,6 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
410
506
: true ;
411
507
});
412
508
patterns.add <LinearizeVectorShuffle, LinearizeVectorExtract,
413
- LinearizeVectorExtractStridedSlice>(
509
+ LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
414
510
typeConverter, patterns.getContext (), targetBitWidth);
415
511
}
0 commit comments