Skip to content

Commit c425124

Browse files
authored
[mlir][ArmSME] Rename slice move operations to insert/extract_tile_slice (#106755)
This renames: - `arm_sme.move_tile_slice_to_vector` to `arm_sme.extract_tile_slice` - `arm_sme.move_vector_to_tile_slice` to `arm_sme.insert_tile_slice` The new names are more consistent with the rest of MLIR and should be easier to understand. The current names (to me personally) are hard to parse and easy to mix up when skimming through code. Additionally, the syntax for `insert_tile_slice` has changed from: ```mlir %4 = arm_sme.insert_tile_slice %0, %1, %2 : vector<[16]xi8> into vector<[16]x[16]xi8> ``` To: ```mlir %4 = arm_sme.insert_tile_slice %0, %1[%2] : vector<[16]xi8> into vector<[16]x[16]xi8> ``` This is for consistency with `extract_tile_slice`, but also helps with readability as it makes it clear which operand is the index.
1 parent b9bba6c commit c425124

File tree

13 files changed

+282
-285
lines changed

13 files changed

+282
-285
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
592592
}];
593593
}
594594

595-
def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
595+
def InsertTileSliceOp : ArmSME_Op<"insert_tile_slice", [
596596
ArmSMETileOpInterface, Pure,
597597
AllTypesMatch<["tile", "result"]>,
598598
TypesMatchWith<
@@ -603,25 +603,25 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
603603
"::llvm::cast<mlir::VectorType>($_self).getElementType(),"
604604
"/*scalableDims=*/{true})">,
605605
]> {
606-
let summary = "Move 1-D scalable vector to slice of 2-D tile";
606+
let summary = "Insert 1-D scalable vector into slice of 2-D tile";
607607
let description = [{
608-
The vector to tile slice operation moves a 1-D scalable vector to a slice
609-
of a 2-D scalable vector tile at the given index. The type of the 1-D
610-
scalable vector to be moved must match the type of the tile slice. A tile
611-
slice is a 1-D vector of horizontally or vertically contiguous elements
612-
within a ZA tile. The updated tile is returned as the result.
608+
Inserts a 1-D scalable vector to a slice of a 2-D scalable vector tile at
609+
the given index. The type of the 1-D scalable vector to be inserted must
610+
match the type of the tile slice. A tile slice is a 1-D vector of
611+
horizontally or vertically contiguous elements within a ZA tile. The updated
612+
tile is returned as the result.
613613

614614
An optional tile slice layout attribute specifies whether the tile slice is
615615
horizontal (default) or vertical.
616616

617-
Example 1: Move a vector<[16]xi8> into tile horizontally (default) at given index.
617+
Example 1: Insert `vector<[16]xi8>` into tile horizontally at the given index.
618618
```mlir
619-
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8>
619+
%tile_update = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[16]xi8> into vector<[16]x[16]xi8>
620620
```
621621

622-
Example 2: Move a vector<[2]xf64> into tile vertically at given index.
622+
Example 2: Insert `vector<[2]xf64>` into tile vertically at the given index.
623623
```mlir
624-
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[2]xf64> into vector<[2]x[2]xf64>
624+
%tile_update = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> into vector<[2]x[2]xf64>
625625
```
626626
}];
627627
let arguments = (ins
@@ -636,35 +636,35 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
636636
}];
637637

638638
let assemblyFormat = [{
639-
$vector `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
639+
$vector `,` $tile `[` $tile_slice_index `]` (`layout` `` $layout^)?
640640
attr-dict `:` type($vector) `into` type($result)
641641
}];
642642
}
643643

644-
def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
644+
def ExtractTileSliceOp : ArmSME_Op<"extract_tile_slice", [
645645
ArmSMETileOpInterface, Pure,
646646
TypesMatchWith<
647647
"type of 'result' matches type of 'tile' slice",
648648
"tile", "result",
649649
"VectorType(VectorType::Builder(::llvm::cast<mlir::VectorType>($_self)).dropDim(0))">,
650650
]> {
651-
let summary = "Move slice of a 2-D tile to a 1-D scalable vector";
651+
let summary = "Extract 1-D scalable vector from slice of 2-D tile";
652652
let description = [{
653-
The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
654-
scalable tile at the given index. A tile slice is a 1-D vector of
655-
horizontally or vertically contiguous elements within a ZA tile.
653+
Extracts a 1-D scalable slice from a 2-D scalable tile at the given index.
654+
A tile slice is a 1-D vector of horizontally or vertically contiguous
655+
elements within a ZA tile.
656656

657657
An optional tile slice layout attribute specifies whether the tile slice is
658658
horizontal (default) or vertical.
659659

660660
Example 1: Extract `vector<[16]xi8>` from tile horizontally at the given index.
661661
```mlir
662-
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
662+
%slice = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
663663
```
664664

665665
Example 2: Extract `vector<[2]xf64>` from tile vertically at the given index.
666666
```mlir
667-
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
667+
%slice = arm_sme.extract_tile_slice %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
668668
```
669669
}];
670670

mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
6464
return success();
6565
}
6666

67-
// Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
67+
// Lower non-zero constants to a loop of 'arm_sme.insert_tile_slice'
6868
// ops that broadcast the constant to each tile slice.
6969
auto loc = constantOp.getLoc();
7070

@@ -79,9 +79,9 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
7979
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
8080
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
8181
Value currentTile) {
82-
// Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
82+
// Create 'arm_sme.insert_tile_slice' to write vector to tile
8383
// slice.
84-
auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
84+
auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
8585
loc, tileType, constantOp1D, currentTile, tileSliceIndex);
8686
return nextTile.getResult();
8787
};

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -575,23 +575,23 @@ struct StoreTileSliceConversion
575575
}
576576
};
577577

578-
/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
579-
struct MoveVectorToTileSliceConversion
580-
: public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
578+
/// Lower `arm_sme.insert_tile_slice` to SME intrinsics.
579+
struct InsertTileSliceConversion
580+
: public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> {
581581
using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
582582

583583
LogicalResult
584-
matchAndRewrite(arm_sme::MoveVectorToTileSliceOp moveVectorToTileSliceOp,
585-
arm_sme::MoveVectorToTileSliceOp::Adaptor adaptor,
584+
matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp,
585+
arm_sme::InsertTileSliceOp::Adaptor adaptor,
586586
ConversionPatternRewriter &rewriter) const override {
587-
auto loc = moveVectorToTileSliceOp.getLoc();
588-
auto tileType = moveVectorToTileSliceOp.getTileType();
587+
auto loc = insertTileSliceOp.getLoc();
588+
auto tileType = insertTileSliceOp.getTileType();
589589

590-
auto tileId = getTileIdOrError(moveVectorToTileSliceOp);
590+
auto tileId = getTileIdOrError(insertTileSliceOp);
591591
if (!tileId)
592592
return failure();
593593

594-
auto tileSlice = moveVectorToTileSliceOp.getTileSliceIndex();
594+
auto tileSlice = insertTileSliceOp.getTileSliceIndex();
595595

596596
// Cast tile slice from index to i32 for intrinsic.
597597
auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
@@ -606,42 +606,40 @@ struct MoveVectorToTileSliceConversion
606606
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
607607

608608
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
609-
switch (moveVectorToTileSliceOp.getLayout()) {
609+
switch (insertTileSliceOp.getLayout()) {
610610
case arm_sme::TileSliceLayout::Horizontal:
611611
rewriter.create<arm_sme::aarch64_sme_write_horiz>(
612612
loc, tileId, tileSliceI32, allActiveMask,
613-
moveVectorToTileSliceOp.getVector());
613+
insertTileSliceOp.getVector());
614614
break;
615615
case arm_sme::TileSliceLayout::Vertical:
616616
rewriter.create<arm_sme::aarch64_sme_write_vert>(
617617
loc, tileId, tileSliceI32, allActiveMask,
618-
moveVectorToTileSliceOp.getVector());
618+
insertTileSliceOp.getVector());
619619
break;
620620
}
621621

622-
// Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
622+
// Intrinsic has no result, replace 'arm_sme.insert_tile_slice' with
623623
// the input tile to preserve dataflow.
624-
rewriter.replaceOp(moveVectorToTileSliceOp,
625-
moveVectorToTileSliceOp.getTile());
624+
rewriter.replaceOp(insertTileSliceOp, insertTileSliceOp.getTile());
626625

627626
return success();
628627
}
629628
};
630629

631-
/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
632-
struct MoveTileSliceToVectorConversion
633-
: public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
630+
/// Lower `arm_sme.extract_tile_slice` to SME intrinsics.
631+
struct ExtractTileSliceConversion
632+
: public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> {
634633
using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
635634

636635
LogicalResult
637-
matchAndRewrite(arm_sme::MoveTileSliceToVectorOp moveTileSliceToVector,
638-
OpAdaptor,
636+
matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor,
639637
ConversionPatternRewriter &rewriter) const override {
640-
auto loc = moveTileSliceToVector.getLoc();
641-
auto sliceType = moveTileSliceToVector.getSliceType();
642-
auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
638+
auto loc = extractTileSlice.getLoc();
639+
auto sliceType = extractTileSlice.getSliceType();
640+
auto sliceIndex = extractTileSlice.getTileSliceIndex();
643641

644-
auto tileId = getTileIdOrError(moveTileSliceToVector);
642+
auto tileId = getTileIdOrError(extractTileSlice);
645643
if (!tileId)
646644
return failure();
647645

@@ -659,16 +657,16 @@ struct MoveTileSliceToVectorConversion
659657
loc, rewriter.getI32Type(), sliceIndex);
660658

661659
// Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
662-
switch (moveTileSliceToVector.getLayout()) {
660+
switch (extractTileSlice.getLayout()) {
663661
case arm_sme::TileSliceLayout::Horizontal:
664662
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
665-
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
666-
tileId, sliceIndexI32);
663+
extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
664+
sliceIndexI32);
667665
break;
668666
case arm_sme::TileSliceLayout::Vertical:
669667
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
670-
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
671-
tileId, sliceIndexI32);
668+
extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
669+
sliceIndexI32);
672670
break;
673671
}
674672

@@ -985,8 +983,8 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
985983
});
986984

987985
addArmSMEConversionPatterns<
988-
LoadTileSliceConversion, MoveTileSliceToVectorConversion,
989-
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
986+
LoadTileSliceConversion, ExtractTileSliceConversion,
987+
InsertTileSliceConversion, StoreTileSliceConversion,
990988
StreamingVLOpConversion, OuterProductOpConversion,
991989
OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
992990
arm_sme::aarch64_sme_mopa_wide>,

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
245245
/// : memref<?x?xi32>, vector<[4]xi1>,
246246
/// vector<[4]xi32> into vector<[4]xi32>
247247
/// // Insert slice into tile
248-
/// %tile_update = arm_sme.move_vector_to_tile_slice
249-
/// %slice, %iter_tile, %tile_slice_idx :
248+
/// %tile_update = arm_sme.insert_tile_slice
249+
/// %slice, %iter_tile[%tile_slice_idx] :
250250
/// vector<[4]xi32> into vector<[4]x[4]xi32>
251251
/// scf.yield %tile_update : vector<[4]x[4]xi32>
252252
/// }
@@ -332,11 +332,11 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
332332
loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
333333
/*passthru=*/pad1DOp);
334334

335-
// Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
336-
auto moveSlice = rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
335+
// Create 'arm_sme.insert_tile_slice' to insert slice into tile.
336+
auto insertSlice = rewriter.create<arm_sme::InsertTileSliceOp>(
337337
loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
338338
tileLoadOp.getLayout());
339-
rewriter.create<scf::YieldOp>(loc, moveSlice.getResult());
339+
rewriter.create<scf::YieldOp>(loc, insertSlice.getResult());
340340

341341
rewriter.setInsertionPointAfter(forOp);
342342

0 commit comments

Comments
 (0)