@@ -3445,25 +3445,28 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
3445
3445
}
3446
3446
3447
3447
//===----------------------------------------------------------------------===//
3448
- // NVVM dot.accumulate.4way Op
3448
+ // NVVM dot.accumulate Ops
3449
3449
//===----------------------------------------------------------------------===//
3450
3450
3451
- def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
3452
- def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;
3451
+ def DotAccumulateS8 : I32EnumAttrCase<"S8", 1, "s8">;
3452
+ def DotAccumulateU8 : I32EnumAttrCase<"U8", 0, "u8">;
3453
+ def DotAccumulateS16 : I32EnumAttrCase<"S16", 2, "s16">;
3454
+ def DotAccumulateU16 : I32EnumAttrCase<"U16", 3, "u16">;
3453
3455
3454
- def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
3455
- "NVVM DotAccumulate4WayType",
3456
- [DotAccumulate4WayS8, DotAccumulate4WayU8]> {
3456
+ def DotAccumulateType : I32EnumAttr<"DotAccumulateType",
3457
+ "NVVM DotAccumulateType",
3458
+ [DotAccumulateS8, DotAccumulateU8,
3459
+ DotAccumulateS16, DotAccumulateU16]> {
3457
3460
let cppNamespace = "::mlir::NVVM";
3458
3461
let genSpecializedAttr = 0;
3459
3462
}
3460
3463
3461
- def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType , "dot_accumulate_4way_type "> {
3464
+ def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType , "dot_accumulate_type "> {
3462
3465
let assemblyFormat = "`<` $value `>`";
3463
3466
}
3464
3467
3465
3468
def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3466
- let summary = "Four-way byte dot product-accumulate instruction. ";
3469
+ let summary = "Four-way byte dot product-accumulate instruction";
3467
3470
let description = [{
3468
3471
Performs a four-way byte dot-product which is accumulated in a 32-bit
3469
3472
result.
@@ -3481,11 +3484,13 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3481
3484
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a)
3482
3485
}];
3483
3486
3487
+ let hasVerifier = 1;
3488
+
3484
3489
let arguments = (ins
3485
3490
VectorOfLengthAndType<[4], [I8]>:$a,
3486
- DotAccumulate4WayTypeAttr :$a_type,
3491
+ DotAccumulateTypeAttr :$a_type,
3487
3492
VectorOfLengthAndType<[4], [I8]>:$b,
3488
- DotAccumulate4WayTypeAttr :$b_type,
3493
+ DotAccumulateTypeAttr :$b_type,
3489
3494
I32:$c
3490
3495
);
3491
3496
@@ -3495,8 +3500,8 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3495
3500
3496
3501
let extraClassDeclaration = [{
3497
3502
static llvm::Intrinsic::ID
3498
- getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
3499
- NVVM::DotAccumulate4WayType b_type);
3503
+ getIntrinsicID(NVVM::DotAccumulateType a_type,
3504
+ NVVM::DotAccumulateType b_type);
3500
3505
llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3501
3506
}];
3502
3507
@@ -3508,6 +3513,84 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3508
3513
}];
3509
3514
}
3510
3515
3516
+ def DotAccumulate2WayModeLo : I32EnumAttrCase<"LO", 0, "lo">;
3517
+ def DotAccumulate2WayModeHi : I32EnumAttrCase<"HI", 1, "hi">;
3518
+
3519
+ def DotAccumulate2WayMode : I32EnumAttr<"DotAccumulate2WayMode",
3520
+ "NVVM DotAccumulate2WayMode",
3521
+ [DotAccumulate2WayModeLo, DotAccumulate2WayModeHi]> {
3522
+ let cppNamespace = "::mlir::NVVM";
3523
+ let genSpecializedAttr = 0;
3524
+ }
3525
+
3526
+ def DotAccumulate2WayModeAttr : EnumAttr<NVVM_Dialect, DotAccumulate2WayMode, "dot_accumulate_2way_mode"> {
3527
+ let assemblyFormat = "$value";
3528
+ }
3529
+
3530
+ def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
3531
+ let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
3532
+ let description = [{
3533
+ Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a
3534
+ 32-bit result.
3535
+ Operand `a` is a vector of two 16-bit elements and operand `b` a vector
3536
+ of four 8-bit elements between which the dot product is computed.
3537
+
3538
+ The `a_type` and `b_type` attributes specify the type of the elements in `a`
3539
+ and `b` respectively.
3540
+ If `a_type` is `s16`, then the elements in `a` are sign-extended to
3541
+ 32-bit before the dot product is computed.
3542
+ If `a_type` is `u16`, then the elements in `a` are zero-extended to
3543
+ 32-bit instead.
3544
+ If `b_type` is `s8`, then the elements in `b` are sign-extended to
3545
+ 32-bit before the dot product is computed.
3546
+ If `b_type` is `u8`, then the elements in `b` are zero-extended to
3547
+ 32-bit instead.
3548
+
3549
+ The 'mode` attribute specifies which two bytes of `b` are used for the dot
3550
+ product. If `mode` is `lo`, then the dot product is computed between `a`
3551
+ and elements at indices 2 and 3 of `b`. If `mode` is `hi`, then the dot
3552
+ product is computed between `a` and elements at indices 0 and 1 of `b`.
3553
+ Operand `c` is a 32-bit integer to which the result is accumulated. It is
3554
+ treated as holding a signed integer if any of `a_type` or `b_type` is
3555
+ signed.
3556
+
3557
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
3558
+ }];
3559
+
3560
+ let hasVerifier = 1;
3561
+
3562
+ let arguments = (ins
3563
+ DotAccumulate2WayModeAttr:$mode,
3564
+ VectorOfLengthAndType<[2], [I16]>:$a,
3565
+ DotAccumulateTypeAttr:$a_type,
3566
+ VectorOfLengthAndType<[4], [I8]>:$b,
3567
+ DotAccumulateTypeAttr:$b_type,
3568
+ I32:$c
3569
+ );
3570
+
3571
+ let results = (outs I32:$res);
3572
+
3573
+ let assemblyFormat = "$mode $a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3574
+
3575
+ let extraClassDeclaration = [{
3576
+ static llvm::Intrinsic::ID
3577
+ getIntrinsicID(NVVM::DotAccumulateType a_type,
3578
+ NVVM::DotAccumulateType b_type);
3579
+ llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3580
+ llvm::Value* isHi(NVVM::DotAccumulate2WayMode mode,
3581
+ llvm::IRBuilderBase& builder);
3582
+ }];
3583
+
3584
+ string llvmBuilder = [{
3585
+ llvm::Intrinsic::ID id = NVVM::DotAccumulate2WayOp::getIntrinsicID($a_type, $b_type);
3586
+ llvm::Value* argA = op.getPackedArg($a, builder);
3587
+ llvm::Value* argB = op.getPackedArg($b, builder);
3588
+ $res = createIntrinsicCall(builder, id,
3589
+ {argA, argB, op.isHi($mode, builder), $c}
3590
+ );
3591
+ }];
3592
+ }
3593
+
3511
3594
//===----------------------------------------------------------------------===//
3512
3595
// NVVM target attribute.
3513
3596
//===----------------------------------------------------------------------===//
0 commit comments