Skip to content

Commit 3d70ba6

Browse files
authored
[mlir][ArmSVE] Add convert.from/to.svbool intrinsics (#68418)
These will be used in future pass to ensure that loads/stores of masks are legal (as the LLVM backend does not support this for any type smaller than an svbool, which is vector<[16]xi1>). Depends on #68399
1 parent 962a049 commit 3d70ba6

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td

+24
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ def ArmSVE_Dialect : Dialect {
3030
}];
3131
}
3232

33+
//===----------------------------------------------------------------------===//
34+
// ArmSVE type definitions
35+
//===----------------------------------------------------------------------===//
36+
37+
def SVBool : ScalableVectorOfRankAndLengthAndType<
38+
[1], [16], [I1]>;
39+
40+
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
41+
[1], [16, 8, 4, 2, 1], [I1]>;
42+
3343
//===----------------------------------------------------------------------===//
3444
// ArmSVE op definitions
3545
//===----------------------------------------------------------------------===//
@@ -302,4 +312,18 @@ def ScalableMaskedDivFIntrOp :
302312
ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
303313
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
304314

315+
def ConvertFromSvboolIntrOp :
316+
ArmSVE_IntrOp<"convert.from.svbool",
317+
[TypeIs<"res", SVEPredicate>],
318+
/*overloadedOperands=*/[],
319+
/*overloadedResults=*/[0]>,
320+
Arguments<(ins SVBool:$svbool)>;
321+
322+
def ConvertToSvboolIntrOp :
323+
ArmSVE_IntrOp<"convert.to.svbool",
324+
[TypeIs<"res", SVBool>],
325+
/*overloadedOperands=*/[0],
326+
/*overloadedResults=*/[]>,
327+
Arguments<(ins SVEPredicate:$mask)>;
328+
305329
#endif // ARMSVE_OPS

mlir/test/Target/LLVMIR/arm-sve.mlir

+44
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,47 @@ llvm.func @get_vector_scale() -> i64 {
272272
%0 = "llvm.intr.vscale"() : () -> i64
273273
llvm.return %0 : i64
274274
}
275+
276+
// CHECK-LABEL: @arm_sve_convert_from_svbool(
277+
// CHECK-SAME: <vscale x 16 x i1> %[[SVBOOL:[0-9]+]])
278+
llvm.func @arm_sve_convert_from_svbool(%nxv16i1 : vector<[16]xi1>) {
279+
// CHECK: %[[RES0:.*]] = call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %[[SVBOOL]])
280+
%res0 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
281+
: (vector<[16]xi1>) -> vector<[8]xi1>
282+
// CHECK: %[[RES1:.*]] = call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> %[[SVBOOL]])
283+
%res1 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
284+
: (vector<[16]xi1>) -> vector<[4]xi1>
285+
// CHECK: %[[RES2:.*]] = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> %[[SVBOOL]])
286+
%res2 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
287+
: (vector<[16]xi1>) -> vector<[2]xi1>
288+
// CHECK: %[[RES3:.*]] = call <vscale x 1 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv1i1(<vscale x 16 x i1> %[[SVBOOL]])
289+
%res3 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
290+
: (vector<[16]xi1>) -> vector<[1]xi1>
291+
llvm.return
292+
}
293+
294+
// CHECK-LABEL: arm_sve_convert_to_svbool(
295+
// CHECK-SAME: <vscale x 8 x i1> %[[P8:[0-9]+]],
296+
// CHECK-SAME: <vscale x 4 x i1> %[[P4:[0-9]+]],
297+
// CHECK-SAME: <vscale x 2 x i1> %[[P2:[0-9]+]],
298+
// CHECK-SAME: <vscale x 1 x i1> %[[P1:[0-9]+]])
299+
llvm.func @arm_sve_convert_to_svbool(
300+
%nxv8i1 : vector<[8]xi1>,
301+
%nxv4i1 : vector<[4]xi1>,
302+
%nxv2i1 : vector<[2]xi1>,
303+
%nxv1i1 : vector<[1]xi1>
304+
) {
305+
// CHECK-NEXT: %[[RES0:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %[[P8]])
306+
%res0 = "arm_sve.intr.convert.to.svbool"(%nxv8i1)
307+
: (vector<[8]xi1>) -> vector<[16]xi1>
308+
// CHECK-NEXT: %[[RES1:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %[[P4]])
309+
%res1 = "arm_sve.intr.convert.to.svbool"(%nxv4i1)
310+
: (vector<[4]xi1>) -> vector<[16]xi1>
311+
// CHECK-NEXT: %[[RES2:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> %[[P2]])
312+
%res2 = "arm_sve.intr.convert.to.svbool"(%nxv2i1)
313+
: (vector<[2]xi1>) -> vector<[16]xi1>
314+
// CHECK-NEXT: %[[RES3:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv1i1(<vscale x 1 x i1> %[[P1]])
315+
%res3 = "arm_sve.intr.convert.to.svbool"(%nxv1i1)
316+
: (vector<[1]xi1>) -> vector<[16]xi1>
317+
llvm.return
318+
}

0 commit comments

Comments
 (0)