Skip to content

Commit 8cbbdad

Browse files
MacDuejustinfargnoli
authored andcommitted
[mlir][ArmSME] Add arm_sme.streaming_vl operation (llvm#77321)
This operation provides a convenient way to query the streaming vector length regardless of the streaming mode. This most useful for functions that call/pass data to streaming functions, but are not streaming themselves. Example: ```mlir %svl_w = arm_sme.streaming_vl <word> ``` Created based on discussion here: llvm#76086 (comment)
1 parent 8c0adee commit 8cbbdad

File tree

4 files changed

+166
-3
lines changed

4 files changed

+166
-3
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,21 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
223223
let defaultValue = "CombiningKind::Add";
224224
}
225225

226+
def TypeSize : I32EnumAttr<"TypeSize", "Size of a vector element type", [
227+
I32EnumAttrCase<"Byte" , 0, "byte">,
228+
I32EnumAttrCase<"Half" , 1, "half">,
229+
I32EnumAttrCase<"Word" , 2, "word">,
230+
I32EnumAttrCase<"Double", 3, "double">,
231+
]> {
232+
let cppNamespace = "::mlir::arm_sme";
233+
let genSpecializedAttr = 0;
234+
}
235+
236+
def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
237+
"type_size"> {
238+
let assemblyFormat = "`<` $value `>`";
239+
}
240+
226241
//===----------------------------------------------------------------------===//
227242
// ArmSME op definitions
228243
//===----------------------------------------------------------------------===//
@@ -768,4 +783,33 @@ let arguments = (ins
768783
}];
769784
}
770785

786+
def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
787+
{
788+
let summary = "Query the streaming vector length";
789+
790+
let description = [{
791+
This operation returns the streaming vector length (SVL) for a given type
792+
size. Unlike `vector.vscale` the value returned is invariant to the
793+
streaming mode.
794+
795+
Example:
796+
```mlir
797+
// Streaming vector length in:
798+
// - bytes (8-bit, SVL.B)
799+
%svl_b = arm_sme.streaming_vl <byte>
800+
// - half words (16-bit, SVL.H)
801+
%svl_h = arm_sme.streaming_vl <half>
802+
// - words (32-bit, SVL.W)
803+
%svl_w = arm_sme.streaming_vl <word>
804+
// - double words (64-bit, SVL.D)
805+
%svl_d = arm_sme.streaming_vl <double>
806+
```
807+
}];
808+
809+
let arguments = (ins ArmSME_TypeSizeAttr: $type_size);
810+
let results = (outs Index);
811+
812+
let assemblyFormat = "$type_size attr-dict";
813+
}
814+
771815
#endif // ARMSME_OPS

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,45 @@ struct OuterProductOpConversion
518518
}
519519
};
520520

521+
/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
522+
///
523+
/// Example:
524+
///
525+
/// %0 = arm_sme.streaming_vl <half>
526+
///
527+
/// is converted to:
528+
///
529+
/// %cnt = "arm_sme.intr.cntsh"() : () -> i64
530+
/// %0 = arith.index_cast %cnt : i64 to index
531+
///
532+
struct StreamingVLOpConversion
533+
: public ConvertOpToLLVMPattern<arm_sme::StreamingVLOp> {
534+
using ConvertOpToLLVMPattern<arm_sme::StreamingVLOp>::ConvertOpToLLVMPattern;
535+
536+
LogicalResult
537+
matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
538+
arm_sme::StreamingVLOp::Adaptor adaptor,
539+
ConversionPatternRewriter &rewriter) const override {
540+
auto loc = streamingVlOp.getLoc();
541+
auto i64Type = rewriter.getI64Type();
542+
auto *intrOp = [&]() -> Operation * {
543+
switch (streamingVlOp.getTypeSize()) {
544+
case arm_sme::TypeSize::Byte:
545+
return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
546+
case arm_sme::TypeSize::Half:
547+
return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
548+
case arm_sme::TypeSize::Word:
549+
return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
550+
case arm_sme::TypeSize::Double:
551+
return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
552+
}
553+
}();
554+
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
555+
streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
556+
return success();
557+
}
558+
};
559+
521560
} // namespace
522561

523562
namespace {
@@ -555,7 +594,9 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
555594
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
556595
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
557596
arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
558-
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
597+
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
598+
arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
599+
arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
559600
target.addLegalDialect<arith::ArithDialect>();
560601
target.addLegalOp<UnrealizedConversionCastOp>();
561602
}
@@ -572,8 +613,8 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
572613

573614
patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
574615
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
575-
OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
576-
converter);
616+
OuterProductOpConversion, ZeroOpConversion, GetTileConversion,
617+
StreamingVLOpConversion>(converter);
577618
}
578619

579620
std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {

mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,3 +559,45 @@ func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile_slice_index : index)
559559
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
560560
return %slice : vector<[1]xi128>
561561
}
562+
563+
//===----------------------------------------------------------------------===//
564+
// arm_sme.streaming_vl
565+
//===----------------------------------------------------------------------===//
566+
567+
// -----
568+
569+
// CHECK-LABEL: @arm_sme_streaming_vl_bytes
570+
// CHECK: %[[COUNT:.*]] = "arm_sme.intr.cntsb"() : () -> i64
571+
// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[COUNT]] : i64 to index
572+
// CHECK: return %[[INDEX_COUNT]] : index
573+
func.func @arm_sme_streaming_vl_bytes() -> index {
574+
%svl_b = arm_sme.streaming_vl <byte>
575+
return %svl_b : index
576+
}
577+
578+
// -----
579+
580+
// CHECK-LABEL: @arm_sme_streaming_vl_half_words
581+
// CHECK: "arm_sme.intr.cntsh"() : () -> i64
582+
func.func @arm_sme_streaming_vl_half_words() -> index {
583+
%svl_h = arm_sme.streaming_vl <half>
584+
return %svl_h : index
585+
}
586+
587+
// -----
588+
589+
// CHECK-LABEL: @arm_sme_streaming_vl_words
590+
// CHECK: "arm_sme.intr.cntsw"() : () -> i64
591+
func.func @arm_sme_streaming_vl_words() -> index {
592+
%svl_w = arm_sme.streaming_vl <word>
593+
return %svl_w : index
594+
}
595+
596+
// -----
597+
598+
// CHECK-LABEL: @arm_sme_streaming_vl_double_words
599+
// CHECK: "arm_sme.intr.cntsd"() : () -> i64
600+
func.func @arm_sme_streaming_vl_double_words() -> index {
601+
%svl_d = arm_sme.streaming_vl <double>
602+
return %svl_d : index
603+
}

mlir/test/Dialect/ArmSME/roundtrip.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,3 +1095,39 @@ func.func @arm_sme_outerproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: v
10951095
%result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB) : vector<[16]xi8>, vector<[16]xi8>
10961096
return %result : vector<[16]x[16]xi8>
10971097
}
1098+
1099+
//===----------------------------------------------------------------------===//
1100+
// arm_sme.streaming_vl
1101+
//===----------------------------------------------------------------------===//
1102+
1103+
// -----
1104+
1105+
func.func @arm_sme_streaming_vl_bytes() -> index {
1106+
// CHECK: arm_sme.streaming_vl <byte>
1107+
%svl_b = arm_sme.streaming_vl <byte>
1108+
return %svl_b : index
1109+
}
1110+
1111+
// -----
1112+
1113+
func.func @arm_sme_streaming_vl_half_words() -> index {
1114+
// CHECK: arm_sme.streaming_vl <half>
1115+
%svl_h = arm_sme.streaming_vl <half>
1116+
return %svl_h : index
1117+
}
1118+
1119+
// -----
1120+
1121+
func.func @arm_sme_streaming_vl_words() -> index {
1122+
// CHECK: arm_sme.streaming_vl <word>
1123+
%svl_w = arm_sme.streaming_vl <word>
1124+
return %svl_w : index
1125+
}
1126+
1127+
// -----
1128+
1129+
func.func @arm_sme_streaming_vl_double_words() -> index {
1130+
// CHECK: arm_sme.streaming_vl <double>
1131+
%svl_d = arm_sme.streaming_vl <double>
1132+
return %svl_d : index
1133+
}

0 commit comments

Comments
 (0)