Skip to content

[MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svdupq_lane #135356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedCont
"a 1-D scalable vector with length " # length,
"::mlir::VectorType">;

def SVEVector : AnyTypeOf<[
Scalable1DVectorOfLength<2, [I64, F64]>,
Scalable1DVectorOfLength<4, [I32, F32]>,
Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
Scalable1DVectorOfLength<16, [I8]>],
"an SVE vector with element size <= 64-bit">;

//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
Expand All @@ -72,14 +79,22 @@ class ArmSVE_IntrOp<string mnemonic,
list<Trait> traits = [],
list<int> overloadedOperands = [],
list<int> overloadedResults = [],
int numResults = 1> :
int numResults = 1,
list<int> immArgPositions = [],
list<string> immArgAttrNames = []> :
LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
/*string opName=*/"intr." # mnemonic,
/*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
/*list<int> overloadedResults=*/overloadedResults,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
/*int numResults=*/numResults>;
/*int numResults=*/numResults,
/*bit requiresAccessGroup=*/0,
/*bit requiresAliasAnalysis=*/0,
/*bit requiresFastmath=*/0,
/*bit requiresOpBundles=*/0,
/*list<int> immArgPositions=*/immArgPositions,
/*list<string> immArgAttrNames=*/immArgAttrNames>;

class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
list<Trait> traits = []>:
Expand Down Expand Up @@ -509,6 +524,41 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",

def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;

def DupQLaneOp : ArmSVE_Op<"dupq_lane", [Pure, AllTypesMatch<["src", "dst"]>]> {
let summary = "Broadcast indexed 128-bit segment to vector";

let description = [{
This operation fills each 128-bit segment of a vector with the elements
from the indexed 128-bit sgement of the source vector. If the VL is
128 bits the operation is a NOP.

Example:
```mlir
// VL == 256
// %X = [A B C D x x x x]
%Y = arm_sve.dupq_lane %X[0] : vector<[4]xi32>
// Y = [A B C D A B C D]

// %U = [x x x x x x x x A B C D E F G H]
%V = arm_sve.dupq_lane %U[1] : vector<[8]xf16>
// %V = [A B C D E F H A B C D E F H]
```
}];

let arguments = (ins SVEVector:$src,
I64Attr:$lane);
let results = (outs SVEVector:$dst);

let builders = [
OpBuilder<(ins "Value":$src, "int64_t":$lane), [{
build($_builder, $_state, src.getType(), src, lane);
}]>];

let assemblyFormat = [{
$src `[` $lane `]` attr-dict `:` type($dst)
}];
}

def UmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
Expand Down Expand Up @@ -610,4 +660,14 @@ def WhileLTIntrOp :
/*overloadedResults=*/[0]>,
Arguments<(ins I64:$base, I64:$n)>;

def DupQLaneIntrOp : ArmSVE_IntrOp<"dupq_lane",
/*traits=*/[],
/*overloadedOperands=*/[0],
/*overloadedResults=*/[],
/*numResults=*/1,
/*immArgPositions*/[1],
/*immArgAttrNames*/["lane"]>,
Arguments<(ins Arg<ScalableVectorOfRank<[1]>, "v">:$v,
Arg<I64Attr, "lane">:$lane)>;

#endif // ARMSVE_OPS
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using DupQLaneLowering = OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
ScalableMaskedAddIIntrOp>;
Expand Down Expand Up @@ -192,6 +193,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
DupQLaneLowering,
ScalableMaskedAddIOpLowering,
ScalableMaskedAddFOpLowering,
ScalableMaskedSubIOpLowering,
Expand Down Expand Up @@ -219,6 +221,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
DupQLaneIntrOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] I know we’ve already diverged from alphabetical order, but let’s try to course-correct. Would you mind moving this to the top of the list? Same comment applies below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've put the new op in this position (instead of top or bottom) as it makes for a smaller diff.
Why not rearrange all in alphabetical order then?

ScalableMaskedAddIIntrOp,
ScalableMaskedAddFIntrOp,
ScalableMaskedSubIIntrOp,
Expand All @@ -238,6 +241,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaOp,
UdotOp,
UmmlaOp,
DupQLaneOp,
ScalableMaskedAddIOp,
ScalableMaskedAddFOp,
ScalableMaskedSubIOp,
Expand Down
43 changes: 43 additions & 0 deletions mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,46 @@ func.func @arm_sve_psel_mixed_predicate_types(%p0: vector<[8]xi1>, %p1: vector<[
%0 = arm_sve.psel %p0, %p1[%index] : vector<[8]xi1>, vector<[16]xi1>
return %0 : vector<[8]xi1>
}

// -----

// CHECK-LABEL: @arm_sve_dupq_lane(
// CHECK-SAME: %[[A0:[a-z0-9]+]]: vector<[16]xi8>
// CHECK-SAME: %[[A1:[a-z0-9]+]]: vector<[8]xi16>
// CHECK-SAME: %[[A2:[a-z0-9]+]]: vector<[8]xf16>
// CHECK-SAME: %[[A3:[a-z0-9]+]]: vector<[8]xbf16>
// CHECK-SAME: %[[A4:[a-z0-9]+]]: vector<[4]xi32>
// CHECK-SAME: %[[A5:[a-z0-9]+]]: vector<[4]xf32>
// CHECK-SAME: %[[A6:[a-z0-9]+]]: vector<[2]xi64>
// CHECK-SAME: %[[A7:[a-z0-9]+]]: vector<[2]xf64>
// CHECK-SAME: -> !llvm.struct<(vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>, vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>)> {

// CHECK: "arm_sve.intr.dupq_lane"(%[[A0]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
// CHECK: "arm_sve.intr.dupq_lane"(%[[A1]]) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
// CHECK: "arm_sve.intr.dupq_lane"(%[[A2]]) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
// CHECK: "arm_sve.intr.dupq_lane"(%[[A3]]) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
// CHECK: "arm_sve.intr.dupq_lane"(%[[A4]]) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
// CHECK: "arm_sve.intr.dupq_lane"(%[[A5]]) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
// CHECK: "arm_sve.intr.dupq_lane"(%[[A6]]) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
// CHECK: "arm_sve.intr.dupq_lane"(%[[A7]]) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
func.func @arm_sve_dupq_lane(
%v16i8: vector<[16]xi8>, %v8i16: vector<[8]xi16>,
%v8f16: vector<[8]xf16>, %v8bf16: vector<[8]xbf16>,
%v4i32: vector<[4]xi32>, %v4f32: vector<[4]xf32>,
%v2i64: vector<[2]xi64>, %v2f64: vector<[2]xf64>)
-> (vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>) {

%0 = arm_sve.dupq_lane %v16i8[0] : vector<[16]xi8>
%1 = arm_sve.dupq_lane %v8i16[1] : vector<[8]xi16>
%2 = arm_sve.dupq_lane %v8f16[2] : vector<[8]xf16>
%3 = arm_sve.dupq_lane %v8bf16[3] : vector<[8]xbf16>
%4 = arm_sve.dupq_lane %v4i32[4] : vector<[4]xi32>
%5 = arm_sve.dupq_lane %v4f32[5] : vector<[4]xf32>
%6 = arm_sve.dupq_lane %v2i64[6] : vector<[2]xi64>
%7 = arm_sve.dupq_lane %v2f64[7] : vector<[2]xf64>

return %0, %1, %2, %3, %4, %5, %6, %7
: vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>
}
34 changes: 34 additions & 0 deletions mlir/test/Target/LLVMIR/arm-sve.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,37 @@ llvm.func @arm_sve_psel(%pn: vector<[16]xi1>, %p1: vector<[2]xi1>, %p2: vector<[
"arm_sve.intr.psel"(%pn, %p4, %index) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
llvm.return
}

// CHECK-LABEL: @arm_sve_dupq_lane
// CHECK-SAME: <vscale x 16 x i8> %0
// CHECK-SAME: <vscale x 8 x i16> %1
// CHECK-SAME: <vscale x 8 x half> %2
// CHECK-SAME: <vscale x 8 x bfloat> %3
// CHECK-SAME: <vscale x 4 x i32> %4
// CHECK-SAME: <vscale x 4 x float> %5
// CHECK-SAME: <vscale x 2 x i64> %6
// CHECK-SAME: <vscale x 2 x double> %7

// CHECK: call <vscale x 16 x i8> @llvm.aarch64.sve.dupq.lane.nxv16i8(<vscale x 16 x i8> %0, i64 0)
// CHECK: call <vscale x 8 x i16> @llvm.aarch64.sve.dupq.lane.nxv8i16(<vscale x 8 x i16> %1, i64 1)
// CHECK: call <vscale x 8 x half> @llvm.aarch64.sve.dupq.lane.nxv8f16(<vscale x 8 x half> %2, i64 2)
// CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sve.dupq.lane.nxv8bf16(<vscale x 8 x bfloat> %3, i64 3)
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.dupq.lane.nxv4i32(<vscale x 4 x i32> %4, i64 4)
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.dupq.lane.nxv4f32(<vscale x 4 x float> %5, i64 5)
// CHECK: call <vscale x 2 x i64> @llvm.aarch64.sve.dupq.lane.nxv2i64(<vscale x 2 x i64> %6, i64 6)
// CHECK: call <vscale x 2 x double> @llvm.aarch64.sve.dupq.lane.nxv2f64(<vscale x 2 x double> %7, i64 7)

llvm.func @arm_sve_dupq_lane(%arg0: vector<[16]xi8>, %arg1: vector<[8]xi16>,
%arg2: vector<[8]xf16>, %arg3: vector<[8]xbf16>,
%arg4: vector<[4]xi32>,%arg5: vector<[4]xf32>,
%arg6: vector<[2]xi64>, %arg7: vector<[2]xf64>) {
%0 = "arm_sve.intr.dupq_lane"(%arg0) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
%1 = "arm_sve.intr.dupq_lane"(%arg1) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
%2 = "arm_sve.intr.dupq_lane"(%arg2) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
%3 = "arm_sve.intr.dupq_lane"(%arg3) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
%4 = "arm_sve.intr.dupq_lane"(%arg4) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
%5 = "arm_sve.intr.dupq_lane"(%arg5) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
%6 = "arm_sve.intr.dupq_lane"(%arg6) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
%7 = "arm_sve.intr.dupq_lane"(%arg7) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
llvm.return
}
Loading