Skip to content

Commit 2f055dd

Browse files
authored
[mlir][ArmSME] Add tile slice layout attr to vector <-> tile ops (#69186)
This is used in #69148 when lowering masked tile_store with non-zero pad, see #69148 This updates: * `arm_sme.move_vector_to_tile_slice` * `arm_sme.move_tile_slice_to_vector`
1 parent d9cfb82 commit 2f055dd

File tree

4 files changed

+100
-27
lines changed

4 files changed

+100
-27
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -441,21 +441,24 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
441441
of a 2-D scalable vector tile at the given index. The type of the 1-D
442442
scalable vector to be moved must match the type of the tile slice. A tile
443443
slice is a 1-D vector of horizontally or vertically contiguous elements
444-
within a ZA tile. Horizontal tile slices are currently assumed when
445-
lowering to intrinsics. The updated tile is returned as the result.
444+
within a ZA tile. The updated tile is returned as the result.
446445

447-
Example 1: Move a vector<[16]xi8> into tile at given index.
446+
An optional tile slice layout attribute specifies whether the tile slice is
447+
horizontal (default) or vertical.
448+
449+
Example 1: Move a vector<[16]xi8> into tile horizontally (default) at given index.
448450
```mlir
449451
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8>
450452
```
451453

452-
Example 2: Move a vector<[2]xf64> into tile at given index.
454+
Example 2: Move a vector<[2]xf64> into tile vertically at given index.
453455
```mlir
454-
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
456+
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[2]xf64> into vector<[2]x[2]xf64>
455457
```
456458
}];
457459
let arguments = (ins
458-
SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index);
460+
SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index,
461+
ArmSME_TileSliceLayoutAttr:$layout);
459462
let results = (outs SMETile:$result);
460463

461464
let extraClassDeclaration = [{
@@ -465,7 +468,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
465468
}];
466469

467470
let assemblyFormat = [{
468-
$vector `,` $tile `,` $tile_slice_index
471+
$vector `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
469472
attr-dict `:` type($vector) `into` type($result)
470473
}];
471474
}
@@ -480,29 +483,34 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
480483
let description = [{
481484
The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
482485
scalable tile at the given index. A tile slice is a 1-D vector of
483-
horizontally or vertically contiguous elements within a ZA tile. Horizontal
484-
tile slices are currently assumed when lowering to intrinsics.
486+
horizontally or vertically contiguous elements within a ZA tile.
487+
488+
An optional tile slice layout attribute specifies whether the tile slice is
489+
horizontal (default) or vertical.
485490

486-
Example 1: Extract `vector<[16]xi8>` from tile at the given index.
491+
Example 1: Extract `vector<[16]xi8>` from tile horizontally at the given index.
487492
```mlir
488493
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
489494
```
490495

491-
Example 2: Extract `vector<[2]xf64>` from tile at the given index.
496+
Example 2: Extract `vector<[2]xf64>` from tile vertically at the given index.
492497
```mlir
493-
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
498+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
494499
```
495500
}];
496501

497-
let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
502+
let arguments = (ins
503+
SMETile:$tile, Index:$tile_slice_index,
504+
ArmSME_TileSliceLayoutAttr:$layout
505+
);
498506
let results = (outs SVEVector:$result);
499507

500508
let extraClassDeclaration = [{
501509
VectorType getSliceType() { return getResult().getType(); }
502510
}];
503511

504512
let assemblyFormat = [{
505-
$tile `[` $tile_slice_index `]` attr-dict
513+
$tile `[` $tile_slice_index `]` (`layout` `` $layout^)? attr-dict
506514
`:` type($result) `from` type($tile)
507515
}];
508516
}

mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,7 @@ struct StoreTileSliceToArmSMELowering
350350
}
351351
};
352352

353-
/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. Only horizontal
354-
/// tile slices are currently supported.
353+
/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
355354
struct MoveVectorToTileSliceToArmSMELowering
356355
: public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
357356
using ConvertOpToLLVMPattern<
@@ -388,10 +387,19 @@ struct MoveVectorToTileSliceToArmSMELowering
388387

389388
auto tileI32 = castTileIDToI32(tile, loc, rewriter);
390389

391-
// Create 'arm_sme.intr.write.horiz' to write vector to tile slice.
392-
rewriter.create<arm_sme::aarch64_sme_write_horiz>(
393-
loc, tileI32, tileSliceI32, allActiveMask,
394-
moveVectorToTileSliceOp.getVector());
390+
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
391+
switch (moveVectorToTileSliceOp.getLayout()) {
392+
case arm_sme::TileSliceLayout::Horizontal:
393+
rewriter.create<arm_sme::aarch64_sme_write_horiz>(
394+
loc, tileI32, tileSliceI32, allActiveMask,
395+
moveVectorToTileSliceOp.getVector());
396+
break;
397+
case arm_sme::TileSliceLayout::Vertical:
398+
rewriter.create<arm_sme::aarch64_sme_write_vert>(
399+
loc, tileI32, tileSliceI32, allActiveMask,
400+
moveVectorToTileSliceOp.getVector());
401+
break;
402+
}
395403

396404
// Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
397405
// 'arm_sme.cast_tile_to_vector' to preserve dataflow.
@@ -402,8 +410,7 @@ struct MoveVectorToTileSliceToArmSMELowering
402410
}
403411
};
404412

405-
/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. Only horizontal
406-
/// tile slices are currently supported.
413+
/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
407414
struct MoveTileSliceToVectorArmSMELowering
408415
: public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
409416
using ConvertOpToLLVMPattern<
@@ -435,10 +442,19 @@ struct MoveTileSliceToVectorArmSMELowering
435442
auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
436443
loc, rewriter.getI32Type(), sliceIndex);
437444

438-
// Create 'arm_sme.intr.read.horiz' to extract the tile slice.
439-
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
440-
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
441-
tileIdI32, sliceIndexI32);
445+
// Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
446+
switch (moveTileSliceToVector.getLayout()) {
447+
case arm_sme::TileSliceLayout::Horizontal:
448+
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
449+
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
450+
tileIdI32, sliceIndexI32);
451+
break;
452+
case arm_sme::TileSliceLayout::Vertical:
453+
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
454+
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
455+
tileIdI32, sliceIndexI32);
456+
break;
457+
}
442458

443459
return success();
444460
}
@@ -680,7 +696,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
680696
arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
681697
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
682698
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
683-
arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_mopa,
699+
arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
700+
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
684701
arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
685702
target.addLegalOp<GetTileID>();
686703
target.addIllegalOp<vector::OuterProductOp>();

mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,29 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
400400
return
401401
}
402402

403+
//===----------------------------------------------------------------------===//
404+
// arm_sme.move_vector_to_tile_slice
405+
//===----------------------------------------------------------------------===//
406+
407+
// -----
408+
409+
// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_hor_i32
410+
// CHECK: "arm_sme.intr.write.horiz"({{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
411+
func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () {
412+
%c0 = arith.constant 0 : index
413+
arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
414+
return
415+
}
416+
417+
// -----
418+
419+
// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_ver_bf16
420+
// CHECK: "arm_sme.intr.write.vert"({{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
421+
func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () {
422+
%c0 = arith.constant 0 : index
423+
arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
424+
return
425+
}
403426

404427
//===----------------------------------------------------------------------===//
405428
// arm_sme.move_tile_slice_to_vector
@@ -485,3 +508,12 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
485508
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
486509
return %slice : vector<[2]xf64>
487510
}
511+
512+
// -----
513+
514+
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_ver_i128
515+
// CHECK: "arm_sme.intr.read.vert"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
516+
func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
517+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
518+
return %slice : vector<[1]xi128>
519+
}

mlir/test/Dialect/ArmSME/roundtrip.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,14 @@ func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %til
10591059
return
10601060
}
10611061

1062+
// -----
1063+
1064+
func.func @arm_sme_move_vector_to_tile_slice_ver_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () {
1065+
// CHECK: arm_sme.move_vector_to_tile_slice {{.*}} layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
1066+
%c0 = arith.constant 0 : index
1067+
arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
1068+
return
1069+
}
10621070

10631071
//===----------------------------------------------------------------------===//
10641072
// arm_sme.move_tile_slice_to_vector
@@ -1135,3 +1143,11 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
11351143
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
11361144
return %slice : vector<[2]xf64>
11371145
}
1146+
1147+
// -----
1148+
1149+
func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
1150+
// CHECK: arm_sme.move_tile_slice_to_vector {{.*}} layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
1151+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
1152+
return %slice : vector<[2]xf64>
1153+
}

0 commit comments

Comments
 (0)