-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svdupq_lane
#135356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-sve @llvm/pr-subscribers-mlir Author: Momchil Velikov (momchil-velikov) ChangesFull diff: https://github.com/llvm/llvm-project/pull/135356.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index cdcf4d8752e87..5223575cfcabe 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -61,6 +61,13 @@ class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedCont
"a 1-D scalable vector with length " # length,
"::mlir::VectorType">;
+def SVEVector : AnyTypeOf<[
+ Scalable1DVectorOfLength<2, [I64, F64]>,
+ Scalable1DVectorOfLength<4, [I32, F32]>,
+ Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
+ Scalable1DVectorOfLength<16, [I8]>],
+ "an SVE vector with element size <= 64-bit">;
+
//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
@@ -72,14 +79,22 @@ class ArmSVE_IntrOp<string mnemonic,
list<Trait> traits = [],
list<int> overloadedOperands = [],
list<int> overloadedResults = [],
- int numResults = 1> :
+ int numResults = 1,
+ list<int> immArgPositions = [],
+ list<string> immArgAttrNames = []> :
LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
/*string opName=*/"intr." # mnemonic,
/*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
/*list<int> overloadedResults=*/overloadedResults,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
- /*int numResults=*/numResults>;
+ /*int numResults=*/numResults,
+ /*bit requiresAccessGroup=*/0,
+ /*bit requiresAliasAnalysis=*/0,
+ /*bit requiresFastmath=*/0,
+ /*bit requiresOpBundles=*/0,
+ /*list<int> immArgPositions=*/immArgPositions,
+ /*list<string> immArgAttrNames=*/immArgAttrNames>;
class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
list<Trait> traits = []>:
@@ -509,6 +524,41 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
+def DupQLaneOp : ArmSVE_Op<"dupq_lane", [Pure, AllTypesMatch<["src", "dst"]>]> {
+ let summary = "Broadcast indexed 128-bit segment to vector";
+
+ let description = [{
+ This operation fills each 128-bit segment of a vector with the elements
+ from the indexed 128-bit sgement of the source vector. If the VL is
+ 128 bits the operation is a NOP.
+
+ Example:
+ ```mlir
+ // VL == 256
+ // %X = [A B C D x x x x]
+ %Y = arm_sve.dupq_lane %X[0] : vector<[4]xi32>
+ // Y = [A B C D A B C D]
+
+ // %U = [x x x x x x x x A B C D E F G H]
+ %V = arm_sve.dupq_lane %U[1] : vector<[8]xf16>
+ // %V = [A B C D E F H A B C D E F H]
+ ```
+ }];
+
+ let arguments = (ins SVEVector:$src,
+ I64Attr:$lane);
+ let results = (outs SVEVector:$dst);
+
+ let builders = [
+ OpBuilder<(ins "Value":$src, "int64_t":$lane), [{
+ build($_builder, $_state, src.getType(), src, lane);
+ }]>];
+
+ let assemblyFormat = [{
+ $src `[` $lane `]` attr-dict `:` type($dst)
+ }];
+}
+
def UmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
@@ -610,4 +660,14 @@ def WhileLTIntrOp :
/*overloadedResults=*/[0]>,
Arguments<(ins I64:$base, I64:$n)>;
+def DupQLaneIntrOp : ArmSVE_IntrOp<"dupq_lane",
+ /*traits=*/[],
+ /*overloadedOperands=*/[0],
+ /*overloadedResults=*/[],
+ /*numResults=*/1,
+ /*immArgPositions*/[1],
+ /*immArgAttrNames*/["lane"]>,
+ Arguments<(ins Arg<ScalableVectorOfRank<[1]>, "v">:$v,
+ Arg<I64Attr, "lane">:$lane)>;
+
#endif // ARMSVE_OPS
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 2bdb640699d03..087ebdbbf6afb 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
+using DupQLaneLowering = OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
ScalableMaskedAddIIntrOp>;
@@ -192,6 +193,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
+ DupQLaneLowering,
ScalableMaskedAddIOpLowering,
ScalableMaskedAddFOpLowering,
ScalableMaskedSubIOpLowering,
@@ -219,6 +221,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
+ DupQLaneIntrOp,
ScalableMaskedAddIIntrOp,
ScalableMaskedAddFIntrOp,
ScalableMaskedSubIIntrOp,
@@ -238,6 +241,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaOp,
UdotOp,
UmmlaOp,
+ DupQLaneOp,
ScalableMaskedAddIOp,
ScalableMaskedAddFOp,
ScalableMaskedSubIOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index bdb69a95a52de..5d044517e0ea8 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -271,3 +271,46 @@ func.func @arm_sve_psel_mixed_predicate_types(%p0: vector<[8]xi1>, %p1: vector<[
%0 = arm_sve.psel %p0, %p1[%index] : vector<[8]xi1>, vector<[16]xi1>
return %0 : vector<[8]xi1>
}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_dupq_lane(
+// CHECK-SAME: %[[A0:[a-z0-9]+]]: vector<[16]xi8>
+// CHECK-SAME: %[[A1:[a-z0-9]+]]: vector<[8]xi16>
+// CHECK-SAME: %[[A2:[a-z0-9]+]]: vector<[8]xf16>
+// CHECK-SAME: %[[A3:[a-z0-9]+]]: vector<[8]xbf16>
+// CHECK-SAME: %[[A4:[a-z0-9]+]]: vector<[4]xi32>
+// CHECK-SAME: %[[A5:[a-z0-9]+]]: vector<[4]xf32>
+// CHECK-SAME: %[[A6:[a-z0-9]+]]: vector<[2]xi64>
+// CHECK-SAME: %[[A7:[a-z0-9]+]]: vector<[2]xf64>
+// CHECK-SAME: -> !llvm.struct<(vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>, vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>)> {
+
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A0]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A1]]) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A2]]) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A3]]) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A4]]) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A5]]) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A6]]) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A7]]) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
+func.func @arm_sve_dupq_lane(
+ %v16i8: vector<[16]xi8>, %v8i16: vector<[8]xi16>,
+ %v8f16: vector<[8]xf16>, %v8bf16: vector<[8]xbf16>,
+ %v4i32: vector<[4]xi32>, %v4f32: vector<[4]xf32>,
+ %v2i64: vector<[2]xi64>, %v2f64: vector<[2]xf64>)
+ -> (vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
+ vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>) {
+
+ %0 = arm_sve.dupq_lane %v16i8[0] : vector<[16]xi8>
+ %1 = arm_sve.dupq_lane %v8i16[1] : vector<[8]xi16>
+ %2 = arm_sve.dupq_lane %v8f16[2] : vector<[8]xf16>
+ %3 = arm_sve.dupq_lane %v8bf16[3] : vector<[8]xbf16>
+ %4 = arm_sve.dupq_lane %v4i32[4] : vector<[4]xi32>
+ %5 = arm_sve.dupq_lane %v4f32[5] : vector<[4]xf32>
+ %6 = arm_sve.dupq_lane %v2i64[6] : vector<[2]xi64>
+ %7 = arm_sve.dupq_lane %v2f64[7] : vector<[2]xf64>
+
+ return %0, %1, %2, %3, %4, %5, %6, %7
+ : vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
+ vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index ed5a1fc7ba2e4..edb85f7470902 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -390,3 +390,37 @@ llvm.func @arm_sve_psel(%pn: vector<[16]xi1>, %p1: vector<[2]xi1>, %p2: vector<[
"arm_sve.intr.psel"(%pn, %p4, %index) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
llvm.return
}
+
+// CHECK-LABEL: @arm_sve_dupq_lane
+// CHECK-SAME: <vscale x 16 x i8> %0
+// CHECK-SAME: <vscale x 8 x i16> %1
+// CHECK-SAME: <vscale x 8 x half> %2
+// CHECK-SAME: <vscale x 8 x bfloat> %3
+// CHECK-SAME: <vscale x 4 x i32> %4
+// CHECK-SAME: <vscale x 4 x float> %5
+// CHECK-SAME: <vscale x 2 x i64> %6
+// CHECK-SAME: <vscale x 2 x double> %7
+
+// CHECK: call <vscale x 16 x i8> @llvm.aarch64.sve.dupq.lane.nxv16i8(<vscale x 16 x i8> %0, i64 0)
+// CHECK: call <vscale x 8 x i16> @llvm.aarch64.sve.dupq.lane.nxv8i16(<vscale x 8 x i16> %1, i64 1)
+// CHECK: call <vscale x 8 x half> @llvm.aarch64.sve.dupq.lane.nxv8f16(<vscale x 8 x half> %2, i64 2)
+// CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sve.dupq.lane.nxv8bf16(<vscale x 8 x bfloat> %3, i64 3)
+// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.dupq.lane.nxv4i32(<vscale x 4 x i32> %4, i64 4)
+// CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.dupq.lane.nxv4f32(<vscale x 4 x float> %5, i64 5)
+// CHECK: call <vscale x 2 x i64> @llvm.aarch64.sve.dupq.lane.nxv2i64(<vscale x 2 x i64> %6, i64 6)
+// CHECK: call <vscale x 2 x double> @llvm.aarch64.sve.dupq.lane.nxv2f64(<vscale x 2 x double> %7, i64 7)
+
+llvm.func @arm_sve_dupq_lane(%arg0: vector<[16]xi8>, %arg1: vector<[8]xi16>,
+ %arg2: vector<[8]xf16>, %arg3: vector<[8]xbf16>,
+ %arg4: vector<[4]xi32>,%arg5: vector<[4]xf32>,
+ %arg6: vector<[2]xi64>, %arg7: vector<[2]xf64>) {
+ %0 = "arm_sve.intr.dupq_lane"(%arg0) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+ %1 = "arm_sve.intr.dupq_lane"(%arg1) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
+ %2 = "arm_sve.intr.dupq_lane"(%arg2) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
+ %3 = "arm_sve.intr.dupq_lane"(%arg3) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
+ %4 = "arm_sve.intr.dupq_lane"(%arg4) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
+ %5 = "arm_sve.intr.dupq_lane"(%arg5) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
+ %6 = "arm_sve.intr.dupq_lane"(%arg6) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
+ %7 = "arm_sve.intr.dupq_lane"(%arg7) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
+ llvm.return
+}
|
You can test this locally with the following command:git-clang-format --diff HEAD~1 HEAD --extensions cpp -- mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp View the diff from clang-format here.diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 087ebdbbf..fe13ed033 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,7 +24,8 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
-using DupQLaneLowering = OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
+using DupQLaneLowering =
+ OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
ScalableMaskedAddIIntrOp>;
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks Momchil!
Some minor comments/requests inline.
@@ -219,6 +221,7 @@ void mlir::configureArmSVELegalizeForExportTarget( | |||
SmmlaIntrOp, | |||
UdotIntrOp, | |||
UmmlaIntrOp, | |||
DupQLaneIntrOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] I know we’ve already diverged from alphabetical order, but let’s try to course-correct. Would you mind moving this to the top of the list? Same comment applies below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've put the new op in this position (instead of top or bottom) as it makes for a smaller diff.
Why not rearrange all in alphabetical order then?
No description provided.