Skip to content

Commit 7aeccec

Browse files
committed
[MLIR][NVVM] Update dot.accumulate NVVM Ops
This change: - Adds the dot.accumulate.2way Op to the NVVM dialect for 16-bit to 8-bit dot-product accumulate operation. - Refactors the recently added dot.accumulate.4way and adds a verifier.
1 parent 690a30f commit 7aeccec

File tree

4 files changed

+166
-25
lines changed

4 files changed

+166
-25
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 74 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3445,35 +3445,35 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
34453445
}
34463446

34473447
//===----------------------------------------------------------------------===//
3448-
// NVVM dot.accumulate.4way Op
3448+
// NVVM dot.accumulate Ops
34493449
//===----------------------------------------------------------------------===//
34503450

3451-
def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
3452-
def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;
3451+
def DotAccumulateSigned : I32EnumAttrCase<"S", 1, "s">;
3452+
def DotAccumulateUnsigned : I32EnumAttrCase<"U", 0, "u">;
34533453

3454-
def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
3455-
"NVVM DotAccumulate4WayType",
3456-
[DotAccumulate4WayS8, DotAccumulate4WayU8]> {
3454+
def DotAccumulateType : I32EnumAttr<"DotAccumulateType",
3455+
"NVVM DotAccumulateType",
3456+
[DotAccumulateSigned, DotAccumulateUnsigned]> {
34573457
let cppNamespace = "::mlir::NVVM";
34583458
let genSpecializedAttr = 0;
34593459
}
34603460

3461-
def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType, "dot_accumulate_4way_type"> {
3461+
def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType, "dot_accumulate_type"> {
34623462
let assemblyFormat = "`<` $value `>`";
34633463
}
34643464

34653465
def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3466-
let summary = "Four-way byte dot product-accumulate instruction.";
3466+
let summary = "Four-way byte dot product-accumulate instruction";
34673467
let description = [{
34683468
Performs a four-way byte dot-product which is accumulated in a 32-bit
34693469
result.
34703470
Operand `a` and `b` are vectors of 4 bytes between which the dot product is
34713471
computed.
34723472
The `a_type` and `b_type` attributes specify the type of the elements in `a`
34733473
and `b` respectively.
3474-
If `a_type` or `b_type` is `s8`, then the elements in the corresponding
3474+
If `a_type` or `b_type` is `s`, then the elements in the corresponding
34753475
vector are sign-extended to 32-bit before the dot product is computed.
3476-
If `a_type` or `b_type` is `u8`, then the elements in the corresponding
3476+
If `a_type` or `b_type` is `u`, then the elements in the corresponding
34773477
vector are zero-extended to 32-bit instead.
34783478
Operand `c` is a 32-bit integer to which the result is accumulated. It is
34793479
treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
@@ -3483,9 +3483,9 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
34833483

34843484
let arguments = (ins
34853485
VectorOfLengthAndType<[4], [I8]>:$a,
3486-
DotAccumulate4WayTypeAttr:$a_type,
3486+
DotAccumulateTypeAttr:$a_type,
34873487
VectorOfLengthAndType<[4], [I8]>:$b,
3488-
DotAccumulate4WayTypeAttr:$b_type,
3488+
DotAccumulateTypeAttr:$b_type,
34893489
I32:$c
34903490
);
34913491

@@ -3495,8 +3495,8 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
34953495

34963496
let extraClassDeclaration = [{
34973497
static llvm::Intrinsic::ID
3498-
getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
3499-
NVVM::DotAccumulate4WayType b_type);
3498+
getIntrinsicID(NVVM::DotAccumulateType a_type,
3499+
NVVM::DotAccumulateType b_type);
35003500
llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
35013501
}];
35023502

@@ -3508,6 +3508,66 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
35083508
}];
35093509
}
35103510

3511+
def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
3512+
let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
3513+
let description = [{
3514+
Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a
3515+
32-bit result.
3516+
Operand `a` is a vector of two 16-bit elements and operand `b` a vector
3517+
of four 8-bit elements between which the dot product is computed.
3518+
3519+
The `a_type` and `b_type` attributes specify the type of the elements in `a`
3520+
and `b` respectively.
3521+
If `a_type` or `b_type` is `s`, then the elements in the corresponding
3522+
vector are sign-extended to 32-bit before the dot product is computed.
3523+
If `a_type` or `b_type` is `u`, then the elements in the corresponding
3524+
vector are zero-extended to 32-bit instead.
3525+
3526+
The `hi` boolean attribute specifies which two bytes of `b` are used for
3527+
the dot product. If `hi` is true, then the dot product is computed between
3528+
`a` and elements at indices 2 and 3 of `b`. If `hi` is false, then the dot
3529+
product is computed between `a` and elements at indices 0 and 1 of `b`.
3530+
By default, `hi` is false.
3531+
3532+
Operand `c` is a 32-bit integer to which the result is accumulated. It is
3533+
treated as holding a signed integer if any of `a_type` or `b_type` is
3534+
signed.
3535+
3536+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
3537+
}];
3538+
3539+
let arguments = (ins
3540+
VectorOfLengthAndType<[2], [I16]>:$a,
3541+
DotAccumulateTypeAttr:$a_type,
3542+
VectorOfLengthAndType<[4], [I8]>:$b,
3543+
DotAccumulateTypeAttr:$b_type,
3544+
I32:$c,
3545+
DefaultValuedAttr<BoolAttr, "false">:$hi
3546+
);
3547+
3548+
let results = (outs I32:$res);
3549+
3550+
let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3551+
3552+
let extraClassDeclaration = [{
3553+
static llvm::Intrinsic::ID
3554+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3555+
llvm::IRBuilderBase &builder,
3556+
llvm::SmallVector<llvm::Value *> &args);
3557+
llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3558+
}];
3559+
3560+
string llvmBuilder = [{
3561+
llvm::SmallVector<llvm::Value *> args;
3562+
3563+
llvm::Intrinsic::ID
3564+
id = NVVM::DotAccumulate2WayOp::getIntrinsicIDAndArgs(
3565+
*op, moduleTranslation, builder, args);
3566+
3567+
$res = createIntrinsicCall(builder, id, args);
3568+
}];
3569+
}
3570+
35113571
//===----------------------------------------------------------------------===//
35123572
// NVVM target attribute.
35133573
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,13 @@ NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
12111211
llvm::Type::getInt32Ty(builder.getContext()));
12121212
}
12131213

1214+
llvm::Value *
1215+
NVVM::DotAccumulate2WayOp::getPackedArg(llvm::Value *arg,
1216+
llvm::IRBuilderBase &builder) {
1217+
return builder.CreateBitCast(arg,
1218+
llvm::Type::getInt32Ty(builder.getContext()));
1219+
}
1220+
12141221
//===----------------------------------------------------------------------===//
12151222
// getIntrinsicID/getIntrinsicIDAndArgs methods
12161223
//===----------------------------------------------------------------------===//
@@ -1599,10 +1606,10 @@ static void nvvmInferResultRanges(Operation *op, Value result,
15991606
}
16001607

16011608
llvm::Intrinsic::ID
1602-
DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
1603-
NVVM::DotAccumulate4WayType b_type) {
1604-
bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
1605-
bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
1609+
DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulateType a_type,
1610+
NVVM::DotAccumulateType b_type) {
1611+
bool is_a_siext = a_type == NVVM::DotAccumulateType::S;
1612+
bool is_b_siext = b_type == NVVM::DotAccumulateType::S;
16061613
unsigned type = (is_a_siext << 1) | is_b_siext;
16071614
switch (type) {
16081615
case 0:
@@ -1618,6 +1625,33 @@ DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
16181625
}
16191626
}
16201627

1628+
llvm::Intrinsic::ID DotAccumulate2WayOp::getIntrinsicIDAndArgs(
1629+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder,
1630+
llvm::SmallVector<llvm::Value *> &args) {
1631+
auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
1632+
1633+
args.push_back(curOp.getPackedArg(mt.lookupValue(curOp.getA()), builder));
1634+
args.push_back(curOp.getPackedArg(mt.lookupValue(curOp.getB()), builder));
1635+
args.push_back(builder.getInt1(curOp.getHi()));
1636+
args.push_back(mt.lookupValue(curOp.getC()));
1637+
1638+
bool is_a_siext = curOp.getAType() == NVVM::DotAccumulateType::S;
1639+
bool is_b_siext = curOp.getBType() == NVVM::DotAccumulateType::S;
1640+
unsigned type = (is_a_siext << 1) | is_b_siext;
1641+
switch (type) {
1642+
case 0:
1643+
return llvm::Intrinsic::nvvm_idp2a_u_u;
1644+
case 1:
1645+
return llvm::Intrinsic::nvvm_idp2a_u_s;
1646+
case 2:
1647+
return llvm::Intrinsic::nvvm_idp2a_s_u;
1648+
case 3:
1649+
return llvm::Intrinsic::nvvm_idp2a_s_s;
1650+
default:
1651+
llvm_unreachable("Invalid DP2a type");
1652+
}
1653+
}
1654+
16211655
//===----------------------------------------------------------------------===//
16221656
// NVVMDialect initialization, type parsing, and registration.
16231657
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,20 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
579579
}
580580

581581
// CHECK-LABEL: @dot_accumulate_4way
582-
func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
582+
func.func @dot_accumulate_4way(%a_vec: vector<4xi8>, %b_vec: vector<4xi8>, %c: i32) {
583583
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
584-
%1 = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
584+
%1 = nvvm.dot.accumulate.4way %a_vec <u>, %b_vec <u>, %c: vector<4xi8>, vector<4xi8>
585585
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
586-
%3 = nvvm.dot.accumulate.4way %a_vec <s8>, %b_vec <s8>, %c: vector<4xi8>, vector<4xi8>
586+
%3 = nvvm.dot.accumulate.4way %a_vec <s>, %b_vec <s>, %c: vector<4xi8>, vector<4xi8>
587+
return
588+
}
589+
590+
// CHECK-LABEL: @dot_accumulate_2way
591+
func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c: i32) {
592+
// CHECK: nvvm.dot.accumulate.2way %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi16>, vector<4xi8>
593+
%1 = nvvm.dot.accumulate.2way %a_vec <u>, %b_vec <u>, %c: vector<2xi16>, vector<4xi8>
594+
// CHECK: nvvm.dot.accumulate.2way %{{.*}}, %{{.*}}, %{{.*}} {hi = true} : vector<2xi16>, vector<4xi8>
595+
%3 = nvvm.dot.accumulate.2way %a_vec <s>, %b_vec <s>, %c {hi = true}: vector<2xi16>, vector<4xi8>
587596
return
588597
}
589598

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -851,18 +851,56 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32)
851851
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
852852
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
853853
// CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
854-
%0 = nvvm.dot.accumulate.4way %a <u8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
854+
%0 = nvvm.dot.accumulate.4way %a <u>, %b <u>, %c: vector<4xi8>, vector<4xi8>
855855
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
856856
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
857857
// CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
858-
%1 = nvvm.dot.accumulate.4way %a <s8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
858+
%1 = nvvm.dot.accumulate.4way %a <s>, %b <u>, %c: vector<4xi8>, vector<4xi8>
859859
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
860860
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
861861
// CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
862-
%2 = nvvm.dot.accumulate.4way %a <u8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
862+
%2 = nvvm.dot.accumulate.4way %a <u>, %b <s>, %c: vector<4xi8>, vector<4xi8>
863863
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
864864
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
865865
// CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
866-
%3 = nvvm.dot.accumulate.4way %a <s8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
866+
%3 = nvvm.dot.accumulate.4way %a <s>, %b <s>, %c: vector<4xi8>, vector<4xi8>
867+
llvm.return
868+
}
869+
870+
// -----
871+
// CHECK-LABEL: @nvvm_dot_accumulate_2way
872+
llvm.func @nvvm_dot_accumulate_2way(%a: vector<2xi16>, %b: vector<4xi8>, %c: i32) {
873+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
874+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
875+
// CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
876+
%0 = nvvm.dot.accumulate.2way %a <u>, %b <u>, %c: vector<2xi16>, vector<4xi8>
877+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
878+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
879+
// CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
880+
%1 = nvvm.dot.accumulate.2way %a <u>, %b <u>, %c {hi = true}: vector<2xi16>, vector<4xi8>
881+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
882+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
883+
// CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
884+
%2 = nvvm.dot.accumulate.2way %a <s>, %b <u>, %c: vector<2xi16>, vector<4xi8>
885+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
886+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
887+
// CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
888+
%3 = nvvm.dot.accumulate.2way %a <s>, %b <u>, %c {hi = true}: vector<2xi16>, vector<4xi8>
889+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
890+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
891+
// CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
892+
%4 = nvvm.dot.accumulate.2way %a <u>, %b <s>, %c: vector<2xi16>, vector<4xi8>
893+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
894+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
895+
// CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
896+
%5 = nvvm.dot.accumulate.2way %a <u>, %b <s>, %c {hi = true}: vector<2xi16>, vector<4xi8>
897+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
898+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
899+
// CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
900+
%6 = nvvm.dot.accumulate.2way %a <s>, %b <s>, %c: vector<2xi16>, vector<4xi8>
901+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
902+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
903+
// CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
904+
%7 = nvvm.dot.accumulate.2way %a <s>, %b <s>, %c {hi = true}: vector<2xi16>, vector<4xi8>
867905
llvm.return
868906
}

0 commit comments

Comments
 (0)