@@ -3533,36 +3533,38 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
3533
3533
}
3534
3534
3535
3535
//===----------------------------------------------------------------------===//
3536
- // NVVM dot.accumulate.4way Op
3536
+ // NVVM dot.accumulate Ops
3537
3537
//===----------------------------------------------------------------------===//
3538
3538
3539
- def DotAccumulate4WayS8 : I32EnumAttrCase<"S8 ", 1 , "s8 ">;
3540
- def DotAccumulate4WayU8 : I32EnumAttrCase<"U8 ", 0 , "u8 ">;
3539
+ def DotAccumulateUnsigned : I32EnumAttrCase<"UNSIGNED ", 0 , "unsigned ">;
3540
+ def DotAccumulateSigned : I32EnumAttrCase<"SIGNED ", 1 , "signed ">;
3541
3541
3542
- def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType ",
3543
- "NVVM DotAccumulate4WayType ",
3544
- [DotAccumulate4WayS8, DotAccumulate4WayU8 ]> {
3542
+ def DotAccumulateType : I32EnumAttr<"DotAccumulateType ",
3543
+ "NVVM DotAccumulateType ",
3544
+ [DotAccumulateSigned, DotAccumulateUnsigned ]> {
3545
3545
let cppNamespace = "::mlir::NVVM";
3546
3546
let genSpecializedAttr = 0;
3547
3547
}
3548
3548
3549
- def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType , "dot_accumulate_4way_type "> {
3549
+ def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType , "dot_accumulate_type "> {
3550
3550
let assemblyFormat = "`<` $value `>`";
3551
3551
}
3552
3552
3553
3553
def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3554
- let summary = "Four-way byte dot product-accumulate instruction. ";
3554
+ let summary = "Four-way byte dot product-accumulate instruction";
3555
3555
let description = [{
3556
3556
Performs a four-way byte dot-product which is accumulated in a 32-bit
3557
3557
result.
3558
3558
Operand `a` and `b` are vectors of 4 bytes between which the dot product is
3559
3559
computed.
3560
+
3560
3561
The `a_type` and `b_type` attributes specify the type of the elements in `a`
3561
3562
and `b` respectively.
3562
- If `a_type` or `b_type` is `s8 `, then the elements in the corresponding
3563
+ If `a_type` or `b_type` is `signed `, then the elements in the corresponding
3563
3564
vector are sign-extended to 32-bit before the dot product is computed.
3564
- If `a_type` or `b_type` is `u8`, then the elements in the corresponding
3565
- vector are zero-extended to 32-bit instead.
3565
+ If `a_type` or `b_type` is `unsigned`, then the elements in the
3566
+ corresponding vector are zero-extended to 32-bit instead.
3567
+
3566
3568
Operand `c` is a 32-bit integer to which the result is accumulated. It is
3567
3569
treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
3568
3570
@@ -3571,9 +3573,9 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3571
3573
3572
3574
let arguments = (ins
3573
3575
VectorOfLengthAndType<[4], [I8]>:$a,
3574
- DotAccumulate4WayTypeAttr :$a_type,
3576
+ DotAccumulateTypeAttr :$a_type,
3575
3577
VectorOfLengthAndType<[4], [I8]>:$b,
3576
- DotAccumulate4WayTypeAttr :$b_type,
3578
+ DotAccumulateTypeAttr :$b_type,
3577
3579
I32:$c
3578
3580
);
3579
3581
@@ -3582,17 +3584,69 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3582
3584
let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3583
3585
3584
3586
let extraClassDeclaration = [{
3585
- static llvm::Intrinsic::ID
3586
- getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
3587
- NVVM::DotAccumulate4WayType b_type);
3588
- llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3587
+ static mlir::NVVM::IDArgPair
3588
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3589
+ llvm::IRBuilderBase &builder);
3590
+ }];
3591
+
3592
+ string llvmBuilder = [{
3593
+ auto [id, args] = NVVM::DotAccumulate4WayOp::getIntrinsicIDAndArgs(
3594
+ *op, moduleTranslation, builder);
3595
+ $res = createIntrinsicCall(builder, id, args);
3589
3596
}];
3597
+ }
3598
+
3599
+ def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
3600
+ let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
3601
+ let description = [{
3602
+ Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a
3603
+ 32-bit result.
3604
+ Operand `a` is a vector of two 16-bit elements and operand `b` a vector
3605
+ of four 8-bit elements between which the dot product is computed.
3606
+
3607
+ The `a_type` and `b_type` attributes specify the type of the elements in `a`
3608
+ and `b` respectively.
3609
+ If `a_type` or `b_type` is `s`, then the elements in the corresponding
3610
+ vector are sign-extended to 32-bit before the dot product is computed.
3611
+ If `a_type` or `b_type` is `u`, then the elements in the corresponding
3612
+ vector are zero-extended to 32-bit instead.
3613
+
3614
+ The `b_hi` boolean attribute specifies which two bytes of `b` are used for
3615
+ the dot product. If `b_hi` is true, then the dot product is computed
3616
+ between `a` and elements at indices 2 and 3 of `b`. If `b_hi` is false,
3617
+ then the dot product is computed between `a` and elements at indices 0 and
3618
+ 1 of `b`.
3590
3619
3620
+ Operand `c` is a 32-bit integer to which the result is accumulated. It is
3621
+ treated as holding a signed integer if any of `a_type` or `b_type` is
3622
+ signed.
3623
+
3624
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
3625
+ }];
3626
+
3627
+ let arguments = (ins
3628
+ VectorOfLengthAndType<[2], [I16]>:$a,
3629
+ DotAccumulateTypeAttr:$a_type,
3630
+ VectorOfLengthAndType<[4], [I8]>:$b,
3631
+ DotAccumulateTypeAttr:$b_type,
3632
+ I32:$c,
3633
+ BoolAttr:$b_hi
3634
+ );
3635
+
3636
+ let results = (outs I32:$res);
3637
+
3638
+ let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3639
+
3640
+ let extraClassDeclaration = [{
3641
+ static mlir::NVVM::IDArgPair
3642
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3643
+ llvm::IRBuilderBase &builder);
3644
+ }];
3645
+
3591
3646
string llvmBuilder = [{
3592
- llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
3593
- llvm::Value* argA = op.getPackedArg($a, builder);
3594
- llvm::Value* argB = op.getPackedArg($b, builder);
3595
- $res = createIntrinsicCall(builder, id, {argA, argB, $c});
3647
+ auto [id, args] = NVVM::DotAccumulate2WayOp::getIntrinsicIDAndArgs(
3648
+ *op, moduleTranslation, builder);
3649
+ $res = createIntrinsicCall(builder, id, args);
3596
3650
}];
3597
3651
}
3598
3652
0 commit comments