Skip to content

Commit 620724c

Browse files
committed
[MLIR][NVVM] Update dot.accumulate NVVM Ops
This change: - Adds the dot.accumulate.2way Op to the NVVM dialect for 16-bit to 8-bit dot-product accumulate operation. - Refactors the recently added dot.accumulate.4way and adds a verifier.
1 parent 690a30f commit 620724c

File tree

5 files changed

+239
-17
lines changed

5 files changed

+239
-17
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 95 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3445,25 +3445,28 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
34453445
}
34463446

34473447
//===----------------------------------------------------------------------===//
3448-
// NVVM dot.accumulate.4way Op
3448+
// NVVM dot.accumulate Ops
34493449
//===----------------------------------------------------------------------===//
34503450

3451-
def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
3452-
def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;
3451+
def DotAccumulateS8 : I32EnumAttrCase<"S8", 1, "s8">;
3452+
def DotAccumulateU8 : I32EnumAttrCase<"U8", 0, "u8">;
3453+
def DotAccumulateS16 : I32EnumAttrCase<"S16", 2, "s16">;
3454+
def DotAccumulateU16 : I32EnumAttrCase<"U16", 3, "u16">;
34533455

3454-
def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
3455-
"NVVM DotAccumulate4WayType",
3456-
[DotAccumulate4WayS8, DotAccumulate4WayU8]> {
3456+
def DotAccumulateType : I32EnumAttr<"DotAccumulateType",
3457+
"NVVM DotAccumulateType",
3458+
[DotAccumulateS8, DotAccumulateU8,
3459+
DotAccumulateS16, DotAccumulateU16]> {
34573460
let cppNamespace = "::mlir::NVVM";
34583461
let genSpecializedAttr = 0;
34593462
}
34603463

3461-
def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType, "dot_accumulate_4way_type"> {
3464+
def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType, "dot_accumulate_type"> {
34623465
let assemblyFormat = "`<` $value `>`";
34633466
}
34643467

34653468
def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3466-
let summary = "Four-way byte dot product-accumulate instruction.";
3469+
let summary = "Four-way byte dot product-accumulate instruction";
34673470
let description = [{
34683471
Performs a four-way byte dot-product which is accumulated in a 32-bit
34693472
result.
@@ -3481,11 +3484,13 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
34813484
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a)
34823485
}];
34833486

3487+
let hasVerifier = 1;
3488+
34843489
let arguments = (ins
34853490
VectorOfLengthAndType<[4], [I8]>:$a,
3486-
DotAccumulate4WayTypeAttr:$a_type,
3491+
DotAccumulateTypeAttr:$a_type,
34873492
VectorOfLengthAndType<[4], [I8]>:$b,
3488-
DotAccumulate4WayTypeAttr:$b_type,
3493+
DotAccumulateTypeAttr:$b_type,
34893494
I32:$c
34903495
);
34913496

@@ -3495,8 +3500,8 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
34953500

34963501
let extraClassDeclaration = [{
34973502
static llvm::Intrinsic::ID
3498-
getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
3499-
NVVM::DotAccumulate4WayType b_type);
3503+
getIntrinsicID(NVVM::DotAccumulateType a_type,
3504+
NVVM::DotAccumulateType b_type);
35003505
llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
35013506
}];
35023507

@@ -3508,6 +3513,84 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
35083513
}];
35093514
}
35103515

3516+
def DotAccumulate2WayModeLo : I32EnumAttrCase<"LO", 0, "lo">;
3517+
def DotAccumulate2WayModeHi : I32EnumAttrCase<"HI", 1, "hi">;
3518+
3519+
def DotAccumulate2WayMode : I32EnumAttr<"DotAccumulate2WayMode",
3520+
"NVVM DotAccumulate2WayMode",
3521+
[DotAccumulate2WayModeLo, DotAccumulate2WayModeHi]> {
3522+
let cppNamespace = "::mlir::NVVM";
3523+
let genSpecializedAttr = 0;
3524+
}
3525+
3526+
def DotAccumulate2WayModeAttr : EnumAttr<NVVM_Dialect, DotAccumulate2WayMode, "dot_accumulate_2way_mode"> {
3527+
let assemblyFormat = "$value";
3528+
}
3529+
3530+
def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
3531+
let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
3532+
let description = [{
3533+
Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a
3534+
32-bit result.
3535+
Operand `a` is a vector of two 16-bit elements and operand `b` a vector
3536+
of four 8-bit elements between which the dot product is computed.
3537+
3538+
The `a_type` and `b_type` attributes specify the type of the elements in `a`
3539+
and `b` respectively.
3540+
If `a_type` is `s16`, then the elements in `a` are sign-extended to
3541+
32-bit before the dot product is computed.
3542+
If `a_type` is `u16`, then the elements in `a` are zero-extended to
3543+
32-bit instead.
3544+
If `b_type` is `s8`, then the elements in `b` are sign-extended to
3545+
32-bit before the dot product is computed.
3546+
If `b_type` is `u8`, then the elements in `b` are zero-extended to
3547+
32-bit instead.
3548+
3549+
The 'mode` attribute specifies which two bytes of `b` are used for the dot
3550+
product. If `mode` is `lo`, then the dot product is computed between `a`
3551+
and elements at indices 2 and 3 of `b`. If `mode` is `hi`, then the dot
3552+
product is computed between `a` and elements at indices 0 and 1 of `b`.
3553+
Operand `c` is a 32-bit integer to which the result is accumulated. It is
3554+
treated as holding a signed integer if any of `a_type` or `b_type` is
3555+
signed.
3556+
3557+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
3558+
}];
3559+
3560+
let hasVerifier = 1;
3561+
3562+
let arguments = (ins
3563+
DotAccumulate2WayModeAttr:$mode,
3564+
VectorOfLengthAndType<[2], [I16]>:$a,
3565+
DotAccumulateTypeAttr:$a_type,
3566+
VectorOfLengthAndType<[4], [I8]>:$b,
3567+
DotAccumulateTypeAttr:$b_type,
3568+
I32:$c
3569+
);
3570+
3571+
let results = (outs I32:$res);
3572+
3573+
let assemblyFormat = "$mode $a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3574+
3575+
let extraClassDeclaration = [{
3576+
static llvm::Intrinsic::ID
3577+
getIntrinsicID(NVVM::DotAccumulateType a_type,
3578+
NVVM::DotAccumulateType b_type);
3579+
llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3580+
llvm::Value* isHi(NVVM::DotAccumulate2WayMode mode,
3581+
llvm::IRBuilderBase& builder);
3582+
}];
3583+
3584+
string llvmBuilder = [{
3585+
llvm::Intrinsic::ID id = NVVM::DotAccumulate2WayOp::getIntrinsicID($a_type, $b_type);
3586+
llvm::Value* argA = op.getPackedArg($a, builder);
3587+
llvm::Value* argB = op.getPackedArg($b, builder);
3588+
$res = createIntrinsicCall(builder, id,
3589+
{argA, argB, op.isHi($mode, builder), $c}
3590+
);
3591+
}];
3592+
}
3593+
35113594
//===----------------------------------------------------------------------===//
35123595
// NVVM target attribute.
35133596
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,46 @@ NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
12111211
llvm::Type::getInt32Ty(builder.getContext()));
12121212
}
12131213

1214+
LogicalResult NVVM::DotAccumulate4WayOp::verify() {
1215+
NVVM::DotAccumulateType aType = getAType();
1216+
NVVM::DotAccumulateType bType = getBType();
1217+
1218+
if (aType != NVVM::DotAccumulateType::S8 &&
1219+
aType != NVVM::DotAccumulateType::U8)
1220+
return emitOpError("a_type must be S8 or U8");
1221+
if (bType != NVVM::DotAccumulateType::S8 &&
1222+
bType != NVVM::DotAccumulateType::U8)
1223+
return emitOpError("b_type must be S8 or U8");
1224+
1225+
return success();
1226+
}
1227+
1228+
llvm::Value *
1229+
NVVM::DotAccumulate2WayOp::getPackedArg(llvm::Value *arg,
1230+
llvm::IRBuilderBase &builder) {
1231+
return builder.CreateBitCast(arg,
1232+
llvm::Type::getInt32Ty(builder.getContext()));
1233+
}
1234+
1235+
llvm::Value *NVVM::DotAccumulate2WayOp::isHi(NVVM::DotAccumulate2WayMode mode,
1236+
llvm::IRBuilderBase &builder) {
1237+
return builder.getInt1(mode == NVVM::DotAccumulate2WayMode::HI);
1238+
}
1239+
1240+
LogicalResult NVVM::DotAccumulate2WayOp::verify() {
1241+
NVVM::DotAccumulateType aType = getAType();
1242+
NVVM::DotAccumulateType bType = getBType();
1243+
1244+
if (aType != NVVM::DotAccumulateType::S16 &&
1245+
aType != NVVM::DotAccumulateType::U16)
1246+
return emitOpError("a_type must be S16 or U16");
1247+
if (bType != NVVM::DotAccumulateType::S8 &&
1248+
bType != NVVM::DotAccumulateType::U8)
1249+
return emitOpError("b_type must be S8 or U8");
1250+
1251+
return success();
1252+
}
1253+
12141254
//===----------------------------------------------------------------------===//
12151255
// getIntrinsicID/getIntrinsicIDAndArgs methods
12161256
//===----------------------------------------------------------------------===//
@@ -1599,10 +1639,10 @@ static void nvvmInferResultRanges(Operation *op, Value result,
15991639
}
16001640

16011641
llvm::Intrinsic::ID
1602-
DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
1603-
NVVM::DotAccumulate4WayType b_type) {
1604-
bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
1605-
bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
1642+
DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulateType a_type,
1643+
NVVM::DotAccumulateType b_type) {
1644+
bool is_a_siext = a_type == NVVM::DotAccumulateType::S8;
1645+
bool is_b_siext = b_type == NVVM::DotAccumulateType::S8;
16061646
unsigned type = (is_a_siext << 1) | is_b_siext;
16071647
switch (type) {
16081648
case 0:
@@ -1618,6 +1658,26 @@ DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
16181658
}
16191659
}
16201660

1661+
llvm::Intrinsic::ID
1662+
DotAccumulate2WayOp::getIntrinsicID(NVVM::DotAccumulateType a_type,
1663+
NVVM::DotAccumulateType b_type) {
1664+
bool is_a_siext = a_type == NVVM::DotAccumulateType::S16;
1665+
bool is_b_siext = b_type == NVVM::DotAccumulateType::S8;
1666+
unsigned type = (is_a_siext << 1) | is_b_siext;
1667+
switch (type) {
1668+
case 0:
1669+
return llvm::Intrinsic::nvvm_idp2a_u_u;
1670+
case 1:
1671+
return llvm::Intrinsic::nvvm_idp2a_u_s;
1672+
case 2:
1673+
return llvm::Intrinsic::nvvm_idp2a_s_u;
1674+
case 3:
1675+
return llvm::Intrinsic::nvvm_idp2a_s_s;
1676+
default:
1677+
llvm_unreachable("Invalid DP2a type");
1678+
}
1679+
}
1680+
16211681
//===----------------------------------------------------------------------===//
16221682
// NVVMDialect initialization, type parsing, and registration.
16231683
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,14 +579,23 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
579579
}
580580

581581
// CHECK-LABEL: @dot_accumulate_4way
582-
func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
582+
func.func @dot_accumulate_4way(%a_vec: vector<4xi8>, %b_vec: vector<4xi8>, %c: i32) {
583583
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
584584
%1 = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
585585
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
586586
%3 = nvvm.dot.accumulate.4way %a_vec <s8>, %b_vec <s8>, %c: vector<4xi8>, vector<4xi8>
587587
return
588588
}
589589

590+
// CHECK-LABEL: @dot_accumulate_2way
591+
func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c: i32) {
592+
// CHECK: nvvm.dot.accumulate.2way lo %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi16>, vector<4xi8>
593+
%1 = nvvm.dot.accumulate.2way lo %a_vec <u16>, %b_vec <u8>, %c: vector<2xi16>, vector<4xi8>
594+
// CHECK: nvvm.dot.accumulate.2way hi %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi16>, vector<4xi8>
595+
%3 = nvvm.dot.accumulate.2way hi %a_vec <s16>, %b_vec <s8>, %c: vector<2xi16>, vector<4xi8>
596+
return
597+
}
598+
590599
// -----
591600

592601
// Just check these don't emit errors.

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,35 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
248248
%res = nvvm.cvt.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16
249249
llvm.return
250250
}
251+
252+
// -----
253+
254+
llvm.func @nvvm_dot_accumulate_4way_invalid_type_a(%a_vec : vector<4xi8>, %b_vec : vector<4xi8>, %c : i32) {
255+
// expected-error @below {{a_type must be S8 or U8}}
256+
%res = nvvm.dot.accumulate.4way %a_vec <u16>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
257+
llvm.return
258+
}
259+
260+
// -----
261+
262+
llvm.func @nvvm_dot_accumulate_4way_invalid_type_b(%a_vec : vector<4xi8>, %b_vec : vector<4xi8>, %c : i32) {
263+
// expected-error @below {{b_type must be S8 or U8}}
264+
%res = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u16>, %c: vector<4xi8>, vector<4xi8>
265+
llvm.return
266+
}
267+
268+
// ----
269+
270+
llvm.func @nvvm_dot_accumulate_2way_invalid_type_a(%a_vec : vector<2xi16>, %b_vec : vector<4xi8>, %c : i32) {
271+
// expected-error @below {{a_type must be S16 or U16}}
272+
%res = nvvm.dot.accumulate.2way lo %a_vec <u8>, %b_vec <u8>, %c: vector<2xi16>, vector<4xi8>
273+
llvm.return
274+
}
275+
276+
// -----
277+
278+
llvm.func @nvvm_dot_accumulate_2way_invalid_type_b(%a_vec : vector<2xi16>, %b_vec : vector<4xi8>, %c : i32) {
279+
// expected-error @below {{b_type must be S8 or U8}}
280+
%res = nvvm.dot.accumulate.2way lo %a_vec <u16>, %b_vec <u16>, %c: vector<2xi16>, vector<4xi8>
281+
llvm.return
282+
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,3 +866,41 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32)
866866
%3 = nvvm.dot.accumulate.4way %a <s8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
867867
llvm.return
868868
}
869+
870+
// -----
871+
// CHECK-LABEL: @nvvm_dot_accumulate_2way
872+
llvm.func @nvvm_dot_accumulate_2way(%a: vector<2xi16>, %b: vector<4xi8>, %c: i32) {
873+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
874+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
875+
// CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
876+
%0 = nvvm.dot.accumulate.2way lo %a <u16>, %b <u8>, %c: vector<2xi16>, vector<4xi8>
877+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
878+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
879+
// CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
880+
%1 = nvvm.dot.accumulate.2way hi %a <u16>, %b <u8>, %c: vector<2xi16>, vector<4xi8>
881+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
882+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
883+
// CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
884+
%2 = nvvm.dot.accumulate.2way lo %a <s16>, %b <u8>, %c: vector<2xi16>, vector<4xi8>
885+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
886+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
887+
// CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
888+
%3 = nvvm.dot.accumulate.2way hi %a <s16>, %b <u8>, %c: vector<2xi16>, vector<4xi8>
889+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
890+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
891+
// CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
892+
%4 = nvvm.dot.accumulate.2way lo %a <u16>, %b <s8>, %c: vector<2xi16>, vector<4xi8>
893+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
894+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
895+
// CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
896+
%5 = nvvm.dot.accumulate.2way hi %a <u16>, %b <s8>, %c: vector<2xi16>, vector<4xi8>
897+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
898+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
899+
// CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
900+
%6 = nvvm.dot.accumulate.2way lo %a <s16>, %b <s8>, %c: vector<2xi16>, vector<4xi8>
901+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
902+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
903+
// CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
904+
%7 = nvvm.dot.accumulate.2way hi %a <s16>, %b <s8>, %c: vector<2xi16>, vector<4xi8>
905+
llvm.return
906+
}

0 commit comments

Comments
 (0)