Skip to content

[MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svdupq_lane #135356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

momchil-velikov
Copy link
Collaborator

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Apr 11, 2025

@llvm/pr-subscribers-mlir-sve
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Momchil Velikov (momchil-velikov)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/135356.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+62-2)
  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp (+4)
  • (modified) mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir (+43)
  • (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+34)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index cdcf4d8752e87..5223575cfcabe 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -61,6 +61,13 @@ class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedCont
   "a 1-D scalable vector with length " # length,
   "::mlir::VectorType">;
 
+def SVEVector : AnyTypeOf<[
+  Scalable1DVectorOfLength<2, [I64, F64]>,
+  Scalable1DVectorOfLength<4, [I32, F32]>,
+  Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
+  Scalable1DVectorOfLength<16, [I8]>],
+  "an SVE vector with element size <= 64-bit">;
+
 //===----------------------------------------------------------------------===//
 // ArmSVE op definitions
 //===----------------------------------------------------------------------===//
@@ -72,14 +79,22 @@ class ArmSVE_IntrOp<string mnemonic,
                     list<Trait> traits = [],
                     list<int> overloadedOperands = [],
                     list<int> overloadedResults = [],
-                    int numResults = 1> :
+                    int numResults = 1,
+                    list<int> immArgPositions = [],
+                    list<string> immArgAttrNames = []> :
   LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
                   /*string opName=*/"intr." # mnemonic,
                   /*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
                   /*list<int> overloadedResults=*/overloadedResults,
                   /*list<int> overloadedOperands=*/overloadedOperands,
                   /*list<Trait> traits=*/traits,
-                  /*int numResults=*/numResults>;
+                  /*int numResults=*/numResults,
+                  /*bit requiresAccessGroup=*/0,
+                  /*bit requiresAliasAnalysis=*/0,
+                  /*bit requiresFastmath=*/0,
+                  /*bit requiresOpBundles=*/0,
+                  /*list<int> immArgPositions=*/immArgPositions,
+                  /*list<string> immArgAttrNames=*/immArgAttrNames>;
 
 class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
                                     list<Trait> traits = []>:
@@ -509,6 +524,41 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
 
 def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
 
+def DupQLaneOp : ArmSVE_Op<"dupq_lane", [Pure, AllTypesMatch<["src", "dst"]>]> {
+  let summary = "Broadcast indexed 128-bit segment to vector";
+
+  let description = [{
+    This operation fills each 128-bit segment of a vector with the elements
+    from the indexed 128-bit sgement of the source vector. If the VL is
+    128 bits the operation is a NOP.
+
+    Example:
+    ```mlir
+    // VL == 256
+    // %X = [A B C D x x x x]
+    %Y = arm_sve.dupq_lane %X[0] : vector<[4]xi32>
+    // Y = [A B C D A B C D]
+
+    // %U = [x x x x x x x x A B C D E F G H]
+    %V = arm_sve.dupq_lane %U[1] : vector<[8]xf16>
+    // %V = [A B C D E F H A B C D E F H]
+    ```
+  }];
+
+  let arguments = (ins SVEVector:$src,
+                       I64Attr:$lane);
+  let results = (outs SVEVector:$dst);
+
+  let builders = [
+    OpBuilder<(ins "Value":$src, "int64_t":$lane), [{
+      build($_builder, $_state, src.getType(), src, lane);
+    }]>];
+
+  let assemblyFormat = [{
+    $src `[` $lane `]` attr-dict `:` type($dst)
+  }];
+}
+
 def UmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"ummla">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
@@ -610,4 +660,14 @@ def WhileLTIntrOp :
     /*overloadedResults=*/[0]>,
   Arguments<(ins I64:$base, I64:$n)>;
 
+def DupQLaneIntrOp : ArmSVE_IntrOp<"dupq_lane",
+    /*traits=*/[],
+    /*overloadedOperands=*/[0],
+    /*overloadedResults=*/[],
+    /*numResults=*/1,
+    /*immArgPositions*/[1],
+    /*immArgAttrNames*/["lane"]>,
+    Arguments<(ins Arg<ScalableVectorOfRank<[1]>, "v">:$v,
+                   Arg<I64Attr, "lane">:$lane)>;
+
 #endif // ARMSVE_OPS
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 2bdb640699d03..087ebdbbf6afb 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
+using DupQLaneLowering = OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
 using ScalableMaskedAddIOpLowering =
     OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
                                  ScalableMaskedAddIIntrOp>;
@@ -192,6 +193,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                SmmlaOpLowering,
                UdotOpLowering,
                UmmlaOpLowering,
+               DupQLaneLowering,
                ScalableMaskedAddIOpLowering,
                ScalableMaskedAddFOpLowering,
                ScalableMaskedSubIOpLowering,
@@ -219,6 +221,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     SmmlaIntrOp,
                     UdotIntrOp,
                     UmmlaIntrOp,
+                    DupQLaneIntrOp,
                     ScalableMaskedAddIIntrOp,
                     ScalableMaskedAddFIntrOp,
                     ScalableMaskedSubIIntrOp,
@@ -238,6 +241,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
                       SmmlaOp,
                       UdotOp,
                       UmmlaOp,
+                      DupQLaneOp,
                       ScalableMaskedAddIOp,
                       ScalableMaskedAddFOp,
                       ScalableMaskedSubIOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index bdb69a95a52de..5d044517e0ea8 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -271,3 +271,46 @@ func.func @arm_sve_psel_mixed_predicate_types(%p0: vector<[8]xi1>, %p1: vector<[
   %0 = arm_sve.psel %p0, %p1[%index] : vector<[8]xi1>, vector<[16]xi1>
   return %0 : vector<[8]xi1>
 }
+
+// -----
+
+// CHECK-LABEL: @arm_sve_dupq_lane(
+// CHECK-SAME: %[[A0:[a-z0-9]+]]: vector<[16]xi8>
+// CHECK-SAME: %[[A1:[a-z0-9]+]]: vector<[8]xi16>
+// CHECK-SAME: %[[A2:[a-z0-9]+]]: vector<[8]xf16>
+// CHECK-SAME: %[[A3:[a-z0-9]+]]: vector<[8]xbf16>
+// CHECK-SAME: %[[A4:[a-z0-9]+]]: vector<[4]xi32>
+// CHECK-SAME: %[[A5:[a-z0-9]+]]: vector<[4]xf32>
+// CHECK-SAME: %[[A6:[a-z0-9]+]]: vector<[2]xi64>
+// CHECK-SAME: %[[A7:[a-z0-9]+]]: vector<[2]xf64>
+// CHECK-SAME: -> !llvm.struct<(vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>, vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>)> {
+
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A0]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A1]]) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A2]]) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A3]]) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A4]]) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A5]]) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A6]]) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A7]]) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
+func.func @arm_sve_dupq_lane(
+  %v16i8: vector<[16]xi8>, %v8i16: vector<[8]xi16>,
+  %v8f16: vector<[8]xf16>, %v8bf16: vector<[8]xbf16>,
+  %v4i32: vector<[4]xi32>, %v4f32: vector<[4]xf32>,
+  %v2i64: vector<[2]xi64>, %v2f64: vector<[2]xf64>)
+  -> (vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
+      vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>) {
+
+  %0 = arm_sve.dupq_lane %v16i8[0]  : vector<[16]xi8>
+  %1 = arm_sve.dupq_lane %v8i16[1]  : vector<[8]xi16>
+  %2 = arm_sve.dupq_lane %v8f16[2]  : vector<[8]xf16>
+  %3 = arm_sve.dupq_lane %v8bf16[3] : vector<[8]xbf16>
+  %4 = arm_sve.dupq_lane %v4i32[4]  : vector<[4]xi32>
+  %5 = arm_sve.dupq_lane %v4f32[5]  : vector<[4]xf32>
+  %6 = arm_sve.dupq_lane %v2i64[6]  : vector<[2]xi64>
+  %7 = arm_sve.dupq_lane %v2f64[7]  : vector<[2]xf64>
+
+  return %0, %1, %2, %3, %4, %5, %6, %7
+    : vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
+      vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index ed5a1fc7ba2e4..edb85f7470902 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -390,3 +390,37 @@ llvm.func @arm_sve_psel(%pn: vector<[16]xi1>, %p1: vector<[2]xi1>, %p2: vector<[
   "arm_sve.intr.psel"(%pn, %p4, %index) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
   llvm.return
 }
+
+// CHECK-LABEL: @arm_sve_dupq_lane
+// CHECK-SAME: <vscale x 16 x i8> %0
+// CHECK-SAME: <vscale x 8 x i16> %1
+// CHECK-SAME: <vscale x 8 x half> %2
+// CHECK-SAME: <vscale x 8 x bfloat> %3
+// CHECK-SAME: <vscale x 4 x i32> %4
+// CHECK-SAME: <vscale x 4 x float> %5
+// CHECK-SAME: <vscale x 2 x i64> %6
+// CHECK-SAME: <vscale x 2 x double> %7
+
+// CHECK: call <vscale x 16 x i8> @llvm.aarch64.sve.dupq.lane.nxv16i8(<vscale x 16 x i8> %0, i64 0)
+// CHECK: call <vscale x 8 x i16> @llvm.aarch64.sve.dupq.lane.nxv8i16(<vscale x 8 x i16> %1, i64 1)
+// CHECK: call <vscale x 8 x half> @llvm.aarch64.sve.dupq.lane.nxv8f16(<vscale x 8 x half> %2, i64 2)
+// CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sve.dupq.lane.nxv8bf16(<vscale x 8 x bfloat> %3, i64 3)
+// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.dupq.lane.nxv4i32(<vscale x 4 x i32> %4, i64 4)
+// CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.dupq.lane.nxv4f32(<vscale x 4 x float> %5, i64 5)
+// CHECK: call <vscale x 2 x i64> @llvm.aarch64.sve.dupq.lane.nxv2i64(<vscale x 2 x i64> %6, i64 6)
+// CHECK: call <vscale x 2 x double> @llvm.aarch64.sve.dupq.lane.nxv2f64(<vscale x 2 x double> %7, i64 7)
+
+llvm.func @arm_sve_dupq_lane(%arg0: vector<[16]xi8>, %arg1: vector<[8]xi16>,
+                             %arg2: vector<[8]xf16>, %arg3: vector<[8]xbf16>,
+                             %arg4: vector<[4]xi32>,%arg5: vector<[4]xf32>,
+                             %arg6: vector<[2]xi64>, %arg7: vector<[2]xf64>) {
+  %0 = "arm_sve.intr.dupq_lane"(%arg0) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+  %1 = "arm_sve.intr.dupq_lane"(%arg1) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
+  %2 = "arm_sve.intr.dupq_lane"(%arg2) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
+  %3 = "arm_sve.intr.dupq_lane"(%arg3) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
+  %4 = "arm_sve.intr.dupq_lane"(%arg4) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
+  %5 = "arm_sve.intr.dupq_lane"(%arg5) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
+  %6 = "arm_sve.intr.dupq_lane"(%arg6) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
+  %7 = "arm_sve.intr.dupq_lane"(%arg7) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
+  llvm.return
+}

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff HEAD~1 HEAD --extensions cpp -- mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
View the diff from clang-format here.
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 087ebdbbf..fe13ed033 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,7 +24,8 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
-using DupQLaneLowering = OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
+using DupQLaneLowering =
+    OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
 using ScalableMaskedAddIOpLowering =
     OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
                                  ScalableMaskedAddIIntrOp>;

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks Momchil!

Some minor comments/requests inline.

@@ -219,6 +221,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
DupQLaneIntrOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] I know we’ve already diverged from alphabetical order, but let’s try to course-correct. Would you mind moving this to the top of the list? Same comment applies below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've put the new op in this position (instead of top or bottom) as it makes for a smaller diff.
Why not rearrange all in alphabetical order then?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants