@@ -3533,36 +3533,37 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
3533
3533
}
3534
3534
3535
3535
//===----------------------------------------------------------------------===//
3536
- // NVVM dot.accumulate.4way Op
3536
+ // NVVM dot.accumulate Ops
3537
3537
//===----------------------------------------------------------------------===//
3538
3538
3539
- def DotAccumulate4WayS8 : I32EnumAttrCase<"S8 ", 1 , "s8 ">;
3540
- def DotAccumulate4WayU8 : I32EnumAttrCase<"U8 ", 0 , "u8 ">;
3539
+ def DotAccumulateUnsigned : I32EnumAttrCase<"UNSIGNED ", 0 , "unsigned ">;
3540
+ def DotAccumulateSigned : I32EnumAttrCase<"SIGNED ", 1 , "signed ">;
3541
3541
3542
- def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType ",
3543
- "NVVM DotAccumulate4WayType ",
3544
- [DotAccumulate4WayS8, DotAccumulate4WayU8 ]> {
3542
+ def DotAccumulateType : I32EnumAttr<"DotAccumulateType ",
3543
+ "NVVM DotAccumulateType ",
3544
+ [DotAccumulateSigned, DotAccumulateUnsigned ]> {
3545
3545
let cppNamespace = "::mlir::NVVM";
3546
3546
let genSpecializedAttr = 0;
3547
3547
}
3548
3548
3549
- def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType , "dot_accumulate_4way_type "> {
3549
+ def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType , "dot_accumulate_type "> {
3550
3550
let assemblyFormat = "`<` $value `>`";
3551
3551
}
3552
3552
3553
3553
def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3554
- let summary = "Four-way byte dot product-accumulate instruction. ";
3554
+ let summary = "Four-way byte dot product-accumulate instruction";
3555
3555
let description = [{
3556
3556
Performs a four-way byte dot-product which is accumulated in a 32-bit
3557
3557
result.
3558
3558
Operand `a` and `b` are vectors of 4 bytes between which the dot product is
3559
3559
computed.
3560
+
3560
3561
The `a_type` and `b_type` attributes specify the type of the elements in `a`
3561
3562
and `b` respectively.
3562
- If `a_type` or `b_type` is `s8 `, then the elements in the corresponding
3563
+ If `a_type` or `b_type` is `signed `, then the elements in the corresponding
3563
3564
vector are sign-extended to 32-bit before the dot product is computed.
3564
- If `a_type` or `b_type` is `u8 `, then the elements in the corresponding
3565
- vector are zero-extended to 32-bit instead.
3565
+ If `a_type` or `b_type` is `unsigned `, then the elements in the corresponding vector are zero-extended to 32-bit instead.
3566
+
3566
3567
Operand `c` is a 32-bit integer to which the result is accumulated. It is
3567
3568
treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
3568
3569
@@ -3571,9 +3572,9 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3571
3572
3572
3573
let arguments = (ins
3573
3574
VectorOfLengthAndType<[4], [I8]>:$a,
3574
- DotAccumulate4WayTypeAttr :$a_type,
3575
+ DotAccumulateTypeAttr :$a_type,
3575
3576
VectorOfLengthAndType<[4], [I8]>:$b,
3576
- DotAccumulate4WayTypeAttr :$b_type,
3577
+ DotAccumulateTypeAttr :$b_type,
3577
3578
I32:$c
3578
3579
);
3579
3580
@@ -3582,17 +3583,69 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3582
3583
let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3583
3584
3584
3585
let extraClassDeclaration = [{
3585
- static llvm::Intrinsic::ID
3586
- getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
3587
- NVVM::DotAccumulate4WayType b_type);
3588
- llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3586
+ static mlir::NVVM::IDArgPair
3587
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3588
+ llvm::IRBuilderBase &builder);
3589
+ }];
3590
+
3591
+ string llvmBuilder = [{
3592
+ auto [id, args] = NVVM::DotAccumulate4WayOp::getIntrinsicIDAndArgs(
3593
+ *op, moduleTranslation, builder);
3594
+ $res = createIntrinsicCall(builder, id, args);
3589
3595
}];
3596
+ }
3597
+
3598
+ def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
3599
+ let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
3600
+ let description = [{
3601
+ Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a
3602
+ 32-bit result.
3603
+ Operand `a` is a vector of two 16-bit elements and operand `b` a vector
3604
+ of four 8-bit elements between which the dot product is computed.
3605
+
3606
+ The `a_type` and `b_type` attributes specify the type of the elements in `a`
3607
+ and `b` respectively.
3608
+ If `a_type` or `b_type` is `s`, then the elements in the corresponding
3609
+ vector are sign-extended to 32-bit before the dot product is computed.
3610
+ If `a_type` or `b_type` is `u`, then the elements in the corresponding
3611
+ vector are zero-extended to 32-bit instead.
3612
+
3613
+ The `b_hi` boolean attribute specifies which two bytes of `b` are used for
3614
+ the dot product. If `b_hi` is true, then the dot product is computed
3615
+ between `a` and elements at indices 2 and 3 of `b`. If `b_hi` is false,
3616
+ then the dot product is computed between `a` and elements at indices 0 and
3617
+ 1 of `b`.
3590
3618
3619
+ Operand `c` is a 32-bit integer to which the result is accumulated. It is
3620
+ treated as holding a signed integer if any of `a_type` or `b_type` is
3621
+ signed.
3622
+
3623
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
3624
+ }];
3625
+
3626
+ let arguments = (ins
3627
+ VectorOfLengthAndType<[2], [I16]>:$a,
3628
+ DotAccumulateTypeAttr:$a_type,
3629
+ VectorOfLengthAndType<[4], [I8]>:$b,
3630
+ DotAccumulateTypeAttr:$b_type,
3631
+ I32:$c,
3632
+ BoolAttr:$b_hi
3633
+ );
3634
+
3635
+ let results = (outs I32:$res);
3636
+
3637
+ let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3638
+
3639
+ let extraClassDeclaration = [{
3640
+ static mlir::NVVM::IDArgPair
3641
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3642
+ llvm::IRBuilderBase &builder);
3643
+ }];
3644
+
3591
3645
string llvmBuilder = [{
3592
- llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
3593
- llvm::Value* argA = op.getPackedArg($a, builder);
3594
- llvm::Value* argB = op.getPackedArg($b, builder);
3595
- $res = createIntrinsicCall(builder, id, {argA, argB, $c});
3646
+ auto [id, args] = NVVM::DotAccumulate2WayOp::getIntrinsicIDAndArgs(
3647
+ *op, moduleTranslation, builder);
3648
+ $res = createIntrinsicCall(builder, id, args);
3596
3649
}];
3597
3650
}
3598
3651
0 commit comments