Skip to content

Commit c409da2

Browse files
[mlir][ROCDL] Add permlanex16 op to allow subgroup reductions on gfx10+ (#135983)
Adding Permlanex16Op to ROCDL dialect to enable subgroup reduce lowering to DPP ops for gfx 10+ devices. See [this PR](#133204). --------- Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 9dbe107 commit c409da2

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,22 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
668668
}];
669669
}
670670

671+
// PermLaneX16 intrinsic operation
672+
def ROCDL_PermlaneX16Op : ROCDL_IntrOp<"permlanex16", [], [0],
673+
[AllTypesMatch<["res", "old", "src0"]>, AllTypesMatch<["src1", "src2"]>], 1, 0, 0,
674+
[4, 5], ["fi", "boundControl"]>,
675+
Arguments<(ins LLVM_Type:$old, LLVM_Type:$src0, LLVM_Type:$src1, LLVM_Type:$src2,
676+
I1Attr:$fi, I1Attr:$boundControl)> {
677+
let results = (outs LLVM_Type:$res);
678+
let assemblyFormat = [{
679+
attr-dict $old `,` $src0 `,` $src1 `,` $src2 `,` $fi `,` $boundControl `:` type($src0) `,` type($src1)
680+
}];
681+
let description = [{
682+
Performs a `permlanex16` operation with the given operands, applying the
683+
permutation specified by $fi to the provided inputs.
684+
}];
685+
}
686+
671687
def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
672688
BuildableType<"::mlir::VectorType::get("
673689
"{2},$_builder.getI16Type())">;

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,16 @@ llvm.func @rocdl.readlane(%src : f32) -> f32 {
889889

890890
// -----
891891

892+
llvm.func @rocdl.permlanex16(%src : f32) -> f32 {
893+
%cst0 = llvm.mlir.constant(-1 : i32) : i32
894+
// CHECK-LABEL: rocdl.permlanex16
895+
// CHECK: rocdl.permlanex16 %{{.*}} %{{.*}}
896+
%ret = rocdl.permlanex16 %src, %src, %cst0, %cst0, 0, -1 : f32, i32
897+
llvm.return %ret : f32
898+
}
899+
900+
// -----
901+
892902
// expected-error@below {{attribute attached to unexpected op}}
893903
func.func private @expected_llvm_func() attributes { rocdl.kernel }
894904

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,20 @@ llvm.func @rocdl.make.buffer.rsrc.p7.p1(%ptr : !llvm.ptr<1>,
872872
llvm.return %rsrc : !llvm.ptr<7>
873873
}
874874

875+
llvm.func @rocdl.permlanex16(%src0 : f32, %src1 : i32, %src2 : vector<2 x f32>, %src3 : vector<2 x i32>) -> f32 {
876+
%cst0 = llvm.mlir.constant(-1 : i32) : i32
877+
// CHECK-LABEL: rocdl.permlanex16
878+
// CHECK: call float @llvm.amdgcn.permlanex16.f32(float %{{.*}}, float %{{.*}}, i32 -1, i32 -1, i1 false, i1 true)
879+
%ret0 = rocdl.permlanex16 %src0, %src0, %cst0, %cst0, 0, -1 : f32, i32
880+
// CHECK: call i32 @llvm.amdgcn.permlanex16.i32(i32 %{{.*}}, i32 %{{.*}}, i32 -1, i32 -1, i1 false, i1 true)
881+
%ret1 = rocdl.permlanex16 %src1, %src1, %cst0, %cst0, 0, -1 : i32, i32
882+
// CHECK: call <2 x float> @llvm.amdgcn.permlanex16.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}, i32 -1, i32 -1, i1 false, i1 true)
883+
%ret2 = rocdl.permlanex16 %src2, %src2, %cst0, %cst0, 0, -1 : vector<2 x f32>, i32
884+
// CHECK: call <2 x i32> @llvm.amdgcn.permlanex16.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, i32 -1, i32 -1, i1 false, i1 true)
885+
%ret3 = rocdl.permlanex16 %src3, %src3, %cst0, %cst0, 0, -1 : vector<2 x i32>, i32
886+
llvm.return %ret0 : f32
887+
}
888+
875889
llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vector<8xf32> {
876890
// CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
877891
%r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>

0 commit comments

Comments
 (0)