@@ -3445,35 +3445,35 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
3445
3445
}
3446
3446
3447
3447
//===----------------------------------------------------------------------===//
3448
- // NVVM dot.accumulate.4way Op
3448
+ // NVVM dot.accumulate Ops
3449
3449
//===----------------------------------------------------------------------===//
3450
3450
3451
- def DotAccumulate4WayS8 : I32EnumAttrCase<"S8 ", 1, "s8 ">;
3452
- def DotAccumulate4WayU8 : I32EnumAttrCase<"U8 ", 0, "u8 ">;
3451
+ def DotAccumulateSigned : I32EnumAttrCase<"S ", 1, "s ">;
3452
+ def DotAccumulateUnsigned : I32EnumAttrCase<"U ", 0, "u ">;
3453
3453
3454
- def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType ",
3455
- "NVVM DotAccumulate4WayType ",
3456
- [DotAccumulate4WayS8, DotAccumulate4WayU8 ]> {
3454
+ def DotAccumulateType : I32EnumAttr<"DotAccumulateType ",
3455
+ "NVVM DotAccumulateType ",
3456
+ [DotAccumulateSigned, DotAccumulateUnsigned ]> {
3457
3457
let cppNamespace = "::mlir::NVVM";
3458
3458
let genSpecializedAttr = 0;
3459
3459
}
3460
3460
3461
- def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType , "dot_accumulate_4way_type "> {
3461
+ def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType , "dot_accumulate_type "> {
3462
3462
let assemblyFormat = "`<` $value `>`";
3463
3463
}
3464
3464
3465
3465
def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3466
- let summary = "Four-way byte dot product-accumulate instruction. ";
3466
+ let summary = "Four-way byte dot product-accumulate instruction";
3467
3467
let description = [{
3468
3468
Performs a four-way byte dot-product which is accumulated in a 32-bit
3469
3469
result.
3470
3470
Operand `a` and `b` are vectors of 4 bytes between which the dot product is
3471
3471
computed.
3472
3472
The `a_type` and `b_type` attributes specify the type of the elements in `a`
3473
3473
and `b` respectively.
3474
- If `a_type` or `b_type` is `s8 `, then the elements in the corresponding
3474
+ If `a_type` or `b_type` is `s `, then the elements in the corresponding
3475
3475
vector are sign-extended to 32-bit before the dot product is computed.
3476
- If `a_type` or `b_type` is `u8 `, then the elements in the corresponding
3476
+ If `a_type` or `b_type` is `u `, then the elements in the corresponding
3477
3477
vector are zero-extended to 32-bit instead.
3478
3478
Operand `c` is a 32-bit integer to which the result is accumulated. It is
3479
3479
treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
@@ -3483,9 +3483,9 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3483
3483
3484
3484
let arguments = (ins
3485
3485
VectorOfLengthAndType<[4], [I8]>:$a,
3486
- DotAccumulate4WayTypeAttr :$a_type,
3486
+ DotAccumulateTypeAttr :$a_type,
3487
3487
VectorOfLengthAndType<[4], [I8]>:$b,
3488
- DotAccumulate4WayTypeAttr :$b_type,
3488
+ DotAccumulateTypeAttr :$b_type,
3489
3489
I32:$c
3490
3490
);
3491
3491
@@ -3495,8 +3495,8 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3495
3495
3496
3496
let extraClassDeclaration = [{
3497
3497
static llvm::Intrinsic::ID
3498
- getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
3499
- NVVM::DotAccumulate4WayType b_type);
3498
+ getIntrinsicID(NVVM::DotAccumulateType a_type,
3499
+ NVVM::DotAccumulateType b_type);
3500
3500
llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3501
3501
}];
3502
3502
@@ -3508,6 +3508,66 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3508
3508
}];
3509
3509
}
3510
3510
3511
+ def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
3512
+ let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
3513
+ let description = [{
3514
+ Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a
3515
+ 32-bit result.
3516
+ Operand `a` is a vector of two 16-bit elements and operand `b` a vector
3517
+ of four 8-bit elements between which the dot product is computed.
3518
+
3519
+ The `a_type` and `b_type` attributes specify the type of the elements in `a`
3520
+ and `b` respectively.
3521
+ If `a_type` or `b_type` is `s`, then the elements in the corresponding
3522
+ vector are sign-extended to 32-bit before the dot product is computed.
3523
+ If `a_type` or `b_type` is `u`, then the elements in the corresponding
3524
+ vector are zero-extended to 32-bit instead.
3525
+
3526
+ The `hi` boolean attribute specifies which two bytes of `b` are used for
3527
+ the dot product. If `hi` is true, then the dot product is computed between
3528
+ `a` and elements at indices 2 and 3 of `b`. If `hi` is false, then the dot
3529
+ product is computed between `a` and elements at indices 0 and 1 of `b`.
3530
+ By default, `hi` is false.
3531
+
3532
+ Operand `c` is a 32-bit integer to which the result is accumulated. It is
3533
+ treated as holding a signed integer if any of `a_type` or `b_type` is
3534
+ signed.
3535
+
3536
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
3537
+ }];
3538
+
3539
+ let arguments = (ins
3540
+ VectorOfLengthAndType<[2], [I16]>:$a,
3541
+ DotAccumulateTypeAttr:$a_type,
3542
+ VectorOfLengthAndType<[4], [I8]>:$b,
3543
+ DotAccumulateTypeAttr:$b_type,
3544
+ I32:$c,
3545
+ DefaultValuedAttr<BoolAttr, "false">:$hi
3546
+ );
3547
+
3548
+ let results = (outs I32:$res);
3549
+
3550
+ let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3551
+
3552
+ let extraClassDeclaration = [{
3553
+ static llvm::Intrinsic::ID
3554
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3555
+ llvm::IRBuilderBase &builder,
3556
+ llvm::SmallVector<llvm::Value *> &args);
3557
+ llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3558
+ }];
3559
+
3560
+ string llvmBuilder = [{
3561
+ llvm::SmallVector<llvm::Value *> args;
3562
+
3563
+ llvm::Intrinsic::ID
3564
+ id = NVVM::DotAccumulate2WayOp::getIntrinsicIDAndArgs(
3565
+ *op, moduleTranslation, builder, args);
3566
+
3567
+ $res = createIntrinsicCall(builder, id, args);
3568
+ }];
3569
+ }
3570
+
3511
3571
//===----------------------------------------------------------------------===//
3512
3572
// NVVM target attribute.
3513
3573
//===----------------------------------------------------------------------===//
0 commit comments