Skip to content

Commit 7830c76

Browse files
[MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svdupq_lane
1 parent 4e63e04 commit 7830c76

File tree

4 files changed

+145
-2
lines changed

4 files changed

+145
-2
lines changed

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedCont
6161
"a 1-D scalable vector with length " # length,
6262
"::mlir::VectorType">;
6363

64+
def SVEVector : AnyTypeOf<[
65+
Scalable1DVectorOfLength<2, [I64, F64]>,
66+
Scalable1DVectorOfLength<4, [I32, F32]>,
67+
Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
68+
Scalable1DVectorOfLength<16, [I8]>],
69+
"an SVE vector with element size <= 64-bit">;
70+
6471
//===----------------------------------------------------------------------===//
6572
// ArmSVE op definitions
6673
//===----------------------------------------------------------------------===//
@@ -72,14 +79,22 @@ class ArmSVE_IntrOp<string mnemonic,
7279
list<Trait> traits = [],
7380
list<int> overloadedOperands = [],
7481
list<int> overloadedResults = [],
75-
int numResults = 1> :
82+
int numResults = 1,
83+
list<int> immArgPositions = [],
84+
list<string> immArgAttrNames = []> :
7685
LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
7786
/*string opName=*/"intr." # mnemonic,
7887
/*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
7988
/*list<int> overloadedResults=*/overloadedResults,
8089
/*list<int> overloadedOperands=*/overloadedOperands,
8190
/*list<Trait> traits=*/traits,
82-
/*int numResults=*/numResults>;
91+
/*int numResults=*/numResults,
92+
/*bit requiresAccessGroup=*/0,
93+
/*bit requiresAliasAnalysis=*/0,
94+
/*bit requiresFastmath=*/0,
95+
/*bit requiresOpBundles=*/0,
96+
/*list<int> immArgPositions=*/immArgPositions,
97+
/*list<string> immArgAttrNames=*/immArgAttrNames>;
8398

8499
class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
85100
list<Trait> traits = []>:
@@ -509,6 +524,42 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
509524

510525
def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
511526

527+
def DupQLaneOp : ArmSVE_Op<"dupq_lane", [Pure, AllTypesMatch<["src", "dst"]>]> {
528+
let summary = "Broadcast indexed 128-bit segment to vector";
529+
530+
let description = [{
531+
This operation fills each 128-bit segment of a vector with the elements
532+
from the indexed 128-bit segment of the source vector. If the VL is
533+
128 bits the operation is a NOP. If the index exceeds the number of
534+
128-bit segments in a vector the result is an all-zeroes vector.
535+
536+
Example:
537+
```mlir
538+
// VL == 256
539+
// %X = [A B C D x x x x]
540+
%Y = arm_sve.dupq_lane %X[0] : vector<[4]xi32>
541+
// Y = [A B C D A B C D]
542+
543+
// %U = [x x x x x x x x A B C D E F G H]
544+
%V = arm_sve.dupq_lane %U[1] : vector<[8]xf16>
545+
// %V = [A B C D E F H A B C D E F H]
546+
```
547+
}];
548+
549+
let arguments = (ins SVEVector:$src,
550+
I64Attr:$lane);
551+
let results = (outs SVEVector:$dst);
552+
553+
let builders = [
554+
OpBuilder<(ins "Value":$src, "int64_t":$lane), [{
555+
build($_builder, $_state, src.getType(), src, lane);
556+
}]>];
557+
558+
let assemblyFormat = [{
559+
$src `[` $lane `]` attr-dict `:` type($dst)
560+
}];
561+
}
562+
512563
def UmmlaIntrOp :
513564
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
514565
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
@@ -610,4 +661,14 @@ def WhileLTIntrOp :
610661
/*overloadedResults=*/[0]>,
611662
Arguments<(ins I64:$base, I64:$n)>;
612663

664+
def DupQLaneIntrOp : ArmSVE_IntrOp<"dupq_lane",
665+
/*traits=*/[],
666+
/*overloadedOperands=*/[0],
667+
/*overloadedResults=*/[],
668+
/*numResults=*/1,
669+
/*immArgPositions*/[1],
670+
/*immArgAttrNames*/["lane"]>,
671+
Arguments<(ins Arg<ScalableVectorOfRank<[1]>, "v">:$v,
672+
Arg<I64Attr, "lane">:$lane)>;
673+
613674
#endif // ARMSVE_OPS

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
2424
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
2525
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
2626
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
27+
using DupQLaneLowering =
28+
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
2729
using ScalableMaskedAddIOpLowering =
2830
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
2931
ScalableMaskedAddIIntrOp>;
@@ -192,6 +194,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
192194
SmmlaOpLowering,
193195
UdotOpLowering,
194196
UmmlaOpLowering,
197+
DupQLaneLowering,
195198
ScalableMaskedAddIOpLowering,
196199
ScalableMaskedAddFOpLowering,
197200
ScalableMaskedSubIOpLowering,
@@ -219,6 +222,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
219222
SmmlaIntrOp,
220223
UdotIntrOp,
221224
UmmlaIntrOp,
225+
DupQLaneIntrOp,
222226
ScalableMaskedAddIIntrOp,
223227
ScalableMaskedAddFIntrOp,
224228
ScalableMaskedSubIIntrOp,
@@ -238,6 +242,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
238242
SmmlaOp,
239243
UdotOp,
240244
UmmlaOp,
245+
DupQLaneOp,
241246
ScalableMaskedAddIOp,
242247
ScalableMaskedAddFOp,
243248
ScalableMaskedSubIOp,

mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,3 +271,46 @@ func.func @arm_sve_psel_mixed_predicate_types(%p0: vector<[8]xi1>, %p1: vector<[
271271
%0 = arm_sve.psel %p0, %p1[%index] : vector<[8]xi1>, vector<[16]xi1>
272272
return %0 : vector<[8]xi1>
273273
}
274+
275+
// -----
276+
277+
// CHECK-LABEL: @arm_sve_dupq_lane(
278+
// CHECK-SAME: %[[A0:[a-z0-9]+]]: vector<[16]xi8>
279+
// CHECK-SAME: %[[A1:[a-z0-9]+]]: vector<[8]xi16>
280+
// CHECK-SAME: %[[A2:[a-z0-9]+]]: vector<[8]xf16>
281+
// CHECK-SAME: %[[A3:[a-z0-9]+]]: vector<[8]xbf16>
282+
// CHECK-SAME: %[[A4:[a-z0-9]+]]: vector<[4]xi32>
283+
// CHECK-SAME: %[[A5:[a-z0-9]+]]: vector<[4]xf32>
284+
// CHECK-SAME: %[[A6:[a-z0-9]+]]: vector<[2]xi64>
285+
// CHECK-SAME: %[[A7:[a-z0-9]+]]: vector<[2]xf64>
286+
// CHECK-SAME: -> !llvm.struct<(vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>, vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>)> {
287+
288+
// CHECK: "arm_sve.intr.dupq_lane"(%[[A0]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
289+
// CHECK: "arm_sve.intr.dupq_lane"(%[[A1]]) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
290+
// CHECK: "arm_sve.intr.dupq_lane"(%[[A2]]) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
291+
// CHECK: "arm_sve.intr.dupq_lane"(%[[A3]]) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
292+
// CHECK: "arm_sve.intr.dupq_lane"(%[[A4]]) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
293+
// CHECK: "arm_sve.intr.dupq_lane"(%[[A5]]) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
294+
// CHECK: "arm_sve.intr.dupq_lane"(%[[A6]]) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
295+
// CHECK: "arm_sve.intr.dupq_lane"(%[[A7]]) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
296+
func.func @arm_sve_dupq_lane(
297+
%v16i8: vector<[16]xi8>, %v8i16: vector<[8]xi16>,
298+
%v8f16: vector<[8]xf16>, %v8bf16: vector<[8]xbf16>,
299+
%v4i32: vector<[4]xi32>, %v4f32: vector<[4]xf32>,
300+
%v2i64: vector<[2]xi64>, %v2f64: vector<[2]xf64>)
301+
-> (vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
302+
vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>) {
303+
304+
%0 = arm_sve.dupq_lane %v16i8[0] : vector<[16]xi8>
305+
%1 = arm_sve.dupq_lane %v8i16[1] : vector<[8]xi16>
306+
%2 = arm_sve.dupq_lane %v8f16[2] : vector<[8]xf16>
307+
%3 = arm_sve.dupq_lane %v8bf16[3] : vector<[8]xbf16>
308+
%4 = arm_sve.dupq_lane %v4i32[4] : vector<[4]xi32>
309+
%5 = arm_sve.dupq_lane %v4f32[5] : vector<[4]xf32>
310+
%6 = arm_sve.dupq_lane %v2i64[6] : vector<[2]xi64>
311+
%7 = arm_sve.dupq_lane %v2f64[7] : vector<[2]xf64>
312+
313+
return %0, %1, %2, %3, %4, %5, %6, %7
314+
: vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
315+
vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>
316+
}

mlir/test/Target/LLVMIR/arm-sve.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,37 @@ llvm.func @arm_sve_psel(%pn: vector<[16]xi1>, %p1: vector<[2]xi1>, %p2: vector<[
390390
"arm_sve.intr.psel"(%pn, %p4, %index) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
391391
llvm.return
392392
}
393+
394+
// CHECK-LABEL: @arm_sve_dupq_lane
395+
// CHECK-SAME: <vscale x 16 x i8> %0
396+
// CHECK-SAME: <vscale x 8 x i16> %1
397+
// CHECK-SAME: <vscale x 8 x half> %2
398+
// CHECK-SAME: <vscale x 8 x bfloat> %3
399+
// CHECK-SAME: <vscale x 4 x i32> %4
400+
// CHECK-SAME: <vscale x 4 x float> %5
401+
// CHECK-SAME: <vscale x 2 x i64> %6
402+
// CHECK-SAME: <vscale x 2 x double> %7
403+
404+
405+
llvm.func @arm_sve_dupq_lane(%arg0: vector<[16]xi8>, %arg1: vector<[8]xi16>,
406+
%arg2: vector<[8]xf16>, %arg3: vector<[8]xbf16>,
407+
%arg4: vector<[4]xi32>,%arg5: vector<[4]xf32>,
408+
%arg6: vector<[2]xi64>, %arg7: vector<[2]xf64>) {
409+
// CHECK: call <vscale x 16 x i8> @llvm.aarch64.sve.dupq.lane.nxv16i8(<vscale x 16 x i8> %0, i64 0)
410+
%0 = "arm_sve.intr.dupq_lane"(%arg0) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
411+
// CHECK: call <vscale x 8 x i16> @llvm.aarch64.sve.dupq.lane.nxv8i16(<vscale x 8 x i16> %1, i64 1)
412+
%1 = "arm_sve.intr.dupq_lane"(%arg1) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
413+
// CHECK: call <vscale x 8 x half> @llvm.aarch64.sve.dupq.lane.nxv8f16(<vscale x 8 x half> %2, i64 2)
414+
%2 = "arm_sve.intr.dupq_lane"(%arg2) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
415+
// CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sve.dupq.lane.nxv8bf16(<vscale x 8 x bfloat> %3, i64 3)
416+
%3 = "arm_sve.intr.dupq_lane"(%arg3) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
417+
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.dupq.lane.nxv4i32(<vscale x 4 x i32> %4, i64 4)
418+
%4 = "arm_sve.intr.dupq_lane"(%arg4) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
419+
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.dupq.lane.nxv4f32(<vscale x 4 x float> %5, i64 5)
420+
%5 = "arm_sve.intr.dupq_lane"(%arg5) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
421+
// CHECK: call <vscale x 2 x i64> @llvm.aarch64.sve.dupq.lane.nxv2i64(<vscale x 2 x i64> %6, i64 6)
422+
%6 = "arm_sve.intr.dupq_lane"(%arg6) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
423+
// CHECK: call <vscale x 2 x double> @llvm.aarch64.sve.dupq.lane.nxv2f64(<vscale x 2 x double> %7, i64 7)
424+
%7 = "arm_sve.intr.dupq_lane"(%arg7) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
425+
llvm.return
426+
}

0 commit comments

Comments
 (0)