@@ -3533,35 +3533,35 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
3533
3533
}
3534
3534
3535
3535
//===----------------------------------------------------------------------===//
3536
- // NVVM dot.accumulate.4way Op
3536
+ // NVVM dot.accumulate Ops
3537
3537
//===----------------------------------------------------------------------===//
3538
3538
3539
- def DotAccumulate4WayS8 : I32EnumAttrCase<"S8 ", 1, "s8 ">;
3540
- def DotAccumulate4WayU8 : I32EnumAttrCase<"U8 ", 0, "u8 ">;
3539
+ def DotAccumulateSigned : I32EnumAttrCase<"SIGNED ", 1, "signed ">;
3540
+ def DotAccumulateUnsigned : I32EnumAttrCase<"UNSIGNED ", 0, "unsigned ">;
3541
3541
3542
- def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType ",
3543
- "NVVM DotAccumulate4WayType ",
3544
- [DotAccumulate4WayS8, DotAccumulate4WayU8 ]> {
3542
+ def DotAccumulateType : I32EnumAttr<"DotAccumulateType ",
3543
+ "NVVM DotAccumulateType ",
3544
+ [DotAccumulateSigned, DotAccumulateUnsigned ]> {
3545
3545
let cppNamespace = "::mlir::NVVM";
3546
3546
let genSpecializedAttr = 0;
3547
3547
}
3548
3548
3549
- def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType , "dot_accumulate_4way_type "> {
3549
+ def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType , "dot_accumulate_type "> {
3550
3550
let assemblyFormat = "`<` $value `>`";
3551
3551
}
3552
3552
3553
3553
def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3554
- let summary = "Four-way byte dot product-accumulate instruction. ";
3554
+ let summary = "Four-way byte dot product-accumulate instruction";
3555
3555
let description = [{
3556
3556
Performs a four-way byte dot-product which is accumulated in a 32-bit
3557
3557
result.
3558
3558
Operand `a` and `b` are vectors of 4 bytes between which the dot product is
3559
3559
computed.
3560
3560
The `a_type` and `b_type` attributes specify the type of the elements in `a`
3561
3561
and `b` respectively.
3562
- If `a_type` or `b_type` is `s8 `, then the elements in the corresponding
3562
+ If `a_type` or `b_type` is `s `, then the elements in the corresponding
3563
3563
vector are sign-extended to 32-bit before the dot product is computed.
3564
- If `a_type` or `b_type` is `u8 `, then the elements in the corresponding
3564
+ If `a_type` or `b_type` is `u `, then the elements in the corresponding
3565
3565
vector are zero-extended to 32-bit instead.
3566
3566
Operand `c` is a 32-bit integer to which the result is accumulated. It is
3567
3567
treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
@@ -3571,9 +3571,9 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3571
3571
3572
3572
let arguments = (ins
3573
3573
VectorOfLengthAndType<[4], [I8]>:$a,
3574
- DotAccumulate4WayTypeAttr :$a_type,
3574
+ DotAccumulateTypeAttr :$a_type,
3575
3575
VectorOfLengthAndType<[4], [I8]>:$b,
3576
- DotAccumulate4WayTypeAttr :$b_type,
3576
+ DotAccumulateTypeAttr :$b_type,
3577
3577
I32:$c
3578
3578
);
3579
3579
@@ -3582,17 +3582,69 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3582
3582
let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3583
3583
3584
3584
let extraClassDeclaration = [{
3585
- static llvm::Intrinsic::ID
3586
- getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
3587
- NVVM::DotAccumulate4WayType b_type);
3588
- llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3585
+ static mlir::NVVM::IDArgPair
3586
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3587
+ llvm::IRBuilderBase &builder);
3588
+ }];
3589
+
3590
+ string llvmBuilder = [{
3591
+ auto [id, args] = NVVM::DotAccumulate4WayOp::getIntrinsicIDAndArgs(
3592
+ *op, moduleTranslation, builder);
3593
+ $res = createIntrinsicCall(builder, id, args);
3594
+ }];
3595
+ }
3596
+
3597
+ def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
3598
+ let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
3599
+ let description = [{
3600
+ Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a
3601
+ 32-bit result.
3602
+ Operand `a` is a vector of two 16-bit elements and operand `b` a vector
3603
+ of four 8-bit elements between which the dot product is computed.
3604
+
3605
+ The `a_type` and `b_type` attributes specify the type of the elements in `a`
3606
+ and `b` respectively.
3607
+ If `a_type` or `b_type` is `s`, then the elements in the corresponding
3608
+ vector are sign-extended to 32-bit before the dot product is computed.
3609
+ If `a_type` or `b_type` is `u`, then the elements in the corresponding
3610
+ vector are zero-extended to 32-bit instead.
3611
+
3612
+ The `b_hi` boolean attribute specifies which two bytes of `b` are used for
3613
+ the dot product. If `b_hi` is true, then the dot product is computed
3614
+ between `a` and elements at indices 2 and 3 of `b`. If `b_hi` is false,
3615
+ then the dot product is computed between `a` and elements at indices 0 and
3616
+ 1 of `b`.
3617
+
3618
+ Operand `c` is a 32-bit integer to which the result is accumulated. It is
3619
+ treated as holding a signed integer if any of `a_type` or `b_type` is
3620
+ signed.
3621
+
3622
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
3589
3623
}];
3590
3624
3625
+ let arguments = (ins
3626
+ VectorOfLengthAndType<[2], [I16]>:$a,
3627
+ DotAccumulateTypeAttr:$a_type,
3628
+ VectorOfLengthAndType<[4], [I8]>:$b,
3629
+ DotAccumulateTypeAttr:$b_type,
3630
+ I32:$c,
3631
+ BoolAttr:$b_hi
3632
+ );
3633
+
3634
+ let results = (outs I32:$res);
3635
+
3636
+ let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3637
+
3638
+ let extraClassDeclaration = [{
3639
+ static mlir::NVVM::IDArgPair
3640
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3641
+ llvm::IRBuilderBase &builder);
3642
+ }];
3643
+
3591
3644
string llvmBuilder = [{
3592
- llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
3593
- llvm::Value* argA = op.getPackedArg($a, builder);
3594
- llvm::Value* argB = op.getPackedArg($b, builder);
3595
- $res = createIntrinsicCall(builder, id, {argA, argB, $c});
3645
+ auto [id, args] = NVVM::DotAccumulate2WayOp::getIntrinsicIDAndArgs(
3646
+ *op, moduleTranslation, builder);
3647
+ $res = createIntrinsicCall(builder, id, args);
3596
3648
}];
3597
3649
}
3598
3650
0 commit comments