Skip to content

[MLIR][NVVM] Add dot.accumulate.2way Op #140518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Wolfram70
Copy link
Contributor

@Wolfram70 Wolfram70 commented May 19, 2025

This change:

  • Adds the dot.accumulate.2way Op to the NVVM dialect for 16-bit to 8-bit dot-product accumulate operation.
  • Refactors the recently added dot.accumulate.4way Op.

@Wolfram70 Wolfram70 requested a review from durga4github May 19, 2025 09:30
@Wolfram70 Wolfram70 self-assigned this May 19, 2025
@Wolfram70 Wolfram70 requested a review from grypp as a code owner May 19, 2025 09:30
@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Srinivasa Ravi (Wolfram70)

Changes

This change:

  • Adds the dot.accumulate.2way Op to the NVVM dialect for 16-bit to 8-bit dot-product accumulate operation.
  • Refactors the recently added dot.accumulate.4way and adds a verifier.

Full diff: https://github.com/llvm/llvm-project/pull/140518.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+95-12)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+64-4)
  • (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (+10-1)
  • (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (+32)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+38)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 654aff71f25be..634251d6a9de1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3445,25 +3445,28 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
 }
 
 //===----------------------------------------------------------------------===//
-// NVVM dot.accumulate.4way Op
+// NVVM dot.accumulate Ops
 //===----------------------------------------------------------------------===//
 
-def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
-def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;
+def DotAccumulateS8 : I32EnumAttrCase<"S8", 1, "s8">;
+def DotAccumulateU8 : I32EnumAttrCase<"U8", 0, "u8">;
+def DotAccumulateS16 : I32EnumAttrCase<"S16", 2, "s16">;
+def DotAccumulateU16 : I32EnumAttrCase<"U16", 3, "u16">;
 
-def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
-                              "NVVM DotAccumulate4WayType",
-                              [DotAccumulate4WayS8, DotAccumulate4WayU8]> {
+def DotAccumulateType : I32EnumAttr<"DotAccumulateType",
+                              "NVVM DotAccumulateType",
+                              [DotAccumulateS8, DotAccumulateU8, 
+                                DotAccumulateS16, DotAccumulateU16]> {
   let cppNamespace = "::mlir::NVVM";
   let genSpecializedAttr = 0;
 }
 
-def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType, "dot_accumulate_4way_type"> {
+def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType, "dot_accumulate_type"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
 def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
-  let summary = "Four-way byte dot product-accumulate instruction.";
+  let summary = "Four-way byte dot product-accumulate instruction";
   let description = [{
     Performs a four-way byte dot-product which is accumulated in a 32-bit
     result.
@@ -3481,11 +3484,13 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a)
   }];
   
+  let hasVerifier = 1;
+  
   let arguments = (ins
     VectorOfLengthAndType<[4], [I8]>:$a,
-    DotAccumulate4WayTypeAttr:$a_type,
+    DotAccumulateTypeAttr:$a_type,
     VectorOfLengthAndType<[4], [I8]>:$b,
-    DotAccumulate4WayTypeAttr:$b_type,
+    DotAccumulateTypeAttr:$b_type,
     I32:$c
   );
 
@@ -3495,8 +3500,8 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
   
   let extraClassDeclaration = [{
     static llvm::Intrinsic::ID
-    getIntrinsicID(NVVM::DotAccumulate4WayType a_type, 
-                   NVVM::DotAccumulate4WayType b_type);
+    getIntrinsicID(NVVM::DotAccumulateType a_type, 
+                   NVVM::DotAccumulateType b_type);
     llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
   }];
 
@@ -3508,6 +3513,84 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
   }];
 }
 
+def DotAccumulate2WayModeLo : I32EnumAttrCase<"LO", 0, "lo">;
+def DotAccumulate2WayModeHi : I32EnumAttrCase<"HI", 1, "hi">;
+
+def DotAccumulate2WayMode : I32EnumAttr<"DotAccumulate2WayMode",
+                              "NVVM DotAccumulate2WayMode",
+                              [DotAccumulate2WayModeLo, DotAccumulate2WayModeHi]> {
+  let cppNamespace = "::mlir::NVVM";
+  let genSpecializedAttr = 0;
+}
+
+def DotAccumulate2WayModeAttr : EnumAttr<NVVM_Dialect, DotAccumulate2WayMode, "dot_accumulate_2way_mode"> {
+  let assemblyFormat = "$value";
+}
+
+def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
+  let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
+  let description = [{
+    Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a 
+    32-bit result.
+    Operand `a` is a vector of two 16-bit elements and operand `b` a vector 
+    of four 8-bit elements between which the dot product is computed.
+
+    The `a_type` and `b_type` attributes specify the type of the elements in `a`
+    and `b` respectively.
+    If `a_type` is `s16`, then the elements in `a` are sign-extended to 
+    32-bit before the dot product is computed.
+    If `a_type` is `u16`, then the elements in `a` are zero-extended to 
+    32-bit instead.
+    If `b_type` is `s8`, then the elements in `b` are sign-extended to 
+    32-bit before the dot product is computed.
+    If `b_type` is `u8`, then the elements in `b` are zero-extended to 
+    32-bit instead.
+
+    The 'mode` attribute specifies which two bytes of `b` are used for the dot
+    product. If `mode` is `lo`, then the dot product is computed between `a` 
+    and elements at indices 2 and 3 of `b`. If `mode` is `hi`, then the dot 
+    product is computed between `a` and elements at indices 0 and 1 of `b`.
+    Operand `c` is a 32-bit integer to which the result is accumulated. It is
+    treated as holding a signed integer if any of `a_type` or `b_type` is 
+    signed.
+    
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
+  }];
+
+  let hasVerifier = 1;
+
+  let arguments = (ins
+    DotAccumulate2WayModeAttr:$mode,
+    VectorOfLengthAndType<[2], [I16]>:$a,
+    DotAccumulateTypeAttr:$a_type,
+    VectorOfLengthAndType<[4], [I8]>:$b,
+    DotAccumulateTypeAttr:$b_type,
+    I32:$c
+  );
+
+  let results = (outs I32:$res);
+
+  let assemblyFormat = "$mode $a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
+  
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID
+    getIntrinsicID(NVVM::DotAccumulateType a_type, 
+                   NVVM::DotAccumulateType b_type);
+    llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
+    llvm::Value* isHi(NVVM::DotAccumulate2WayMode mode, 
+                            llvm::IRBuilderBase& builder);
+  }];
+  
+  string llvmBuilder = [{
+    llvm::Intrinsic::ID id = NVVM::DotAccumulate2WayOp::getIntrinsicID($a_type, $b_type);
+    llvm::Value* argA = op.getPackedArg($a, builder);
+    llvm::Value* argB = op.getPackedArg($b, builder);
+    $res = createIntrinsicCall(builder, id,
+              {argA, argB, op.isHi($mode, builder), $c}
+            );
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM target attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 1ea3f96fa75f5..2b60a34edf313 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1211,6 +1211,46 @@ NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
                                llvm::Type::getInt32Ty(builder.getContext()));
 }
 
+LogicalResult NVVM::DotAccumulate4WayOp::verify() {
+  NVVM::DotAccumulateType aType = getAType();
+  NVVM::DotAccumulateType bType = getBType();
+
+  if (aType != NVVM::DotAccumulateType::S8 &&
+      aType != NVVM::DotAccumulateType::U8)
+    return emitOpError("a_type must be S8 or U8");
+  if (bType != NVVM::DotAccumulateType::S8 &&
+      bType != NVVM::DotAccumulateType::U8)
+    return emitOpError("b_type must be S8 or U8");
+
+  return success();
+}
+
+llvm::Value *
+NVVM::DotAccumulate2WayOp::getPackedArg(llvm::Value *arg,
+                                        llvm::IRBuilderBase &builder) {
+  return builder.CreateBitCast(arg,
+                               llvm::Type::getInt32Ty(builder.getContext()));
+}
+
+llvm::Value *NVVM::DotAccumulate2WayOp::isHi(NVVM::DotAccumulate2WayMode mode,
+                                             llvm::IRBuilderBase &builder) {
+  return builder.getInt1(mode == NVVM::DotAccumulate2WayMode::HI);
+}
+
+LogicalResult NVVM::DotAccumulate2WayOp::verify() {
+  NVVM::DotAccumulateType aType = getAType();
+  NVVM::DotAccumulateType bType = getBType();
+
+  if (aType != NVVM::DotAccumulateType::S16 &&
+      aType != NVVM::DotAccumulateType::U16)
+    return emitOpError("a_type must be S16 or U16");
+  if (bType != NVVM::DotAccumulateType::S8 &&
+      bType != NVVM::DotAccumulateType::U8)
+    return emitOpError("b_type must be S8 or U8");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // getIntrinsicID/getIntrinsicIDAndArgs methods
 //===----------------------------------------------------------------------===//
@@ -1599,10 +1639,10 @@ static void nvvmInferResultRanges(Operation *op, Value result,
 }
 
 llvm::Intrinsic::ID
-DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
-                                    NVVM::DotAccumulate4WayType b_type) {
-  bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
-  bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
+DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulateType a_type,
+                                    NVVM::DotAccumulateType b_type) {
+  bool is_a_siext = a_type == NVVM::DotAccumulateType::S8;
+  bool is_b_siext = b_type == NVVM::DotAccumulateType::S8;
   unsigned type = (is_a_siext << 1) | is_b_siext;
   switch (type) {
   case 0:
@@ -1618,6 +1658,26 @@ DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
   }
 }
 
+llvm::Intrinsic::ID
+DotAccumulate2WayOp::getIntrinsicID(NVVM::DotAccumulateType a_type,
+                                    NVVM::DotAccumulateType b_type) {
+  bool is_a_siext = a_type == NVVM::DotAccumulateType::S16;
+  bool is_b_siext = b_type == NVVM::DotAccumulateType::S8;
+  unsigned type = (is_a_siext << 1) | is_b_siext;
+  switch (type) {
+  case 0:
+    return llvm::Intrinsic::nvvm_idp2a_u_u;
+  case 1:
+    return llvm::Intrinsic::nvvm_idp2a_u_s;
+  case 2:
+    return llvm::Intrinsic::nvvm_idp2a_s_u;
+  case 3:
+    return llvm::Intrinsic::nvvm_idp2a_s_s;
+  default:
+    llvm_unreachable("Invalid DP2a type");
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index e8425638cc9be..5568e104afcab 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -579,7 +579,7 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
 }
 
 // CHECK-LABEL: @dot_accumulate_4way
-func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
+func.func @dot_accumulate_4way(%a_vec: vector<4xi8>, %b_vec: vector<4xi8>, %c: i32) {
   // CHECK:   nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
   %1 = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
   // CHECK:   nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
@@ -587,6 +587,15 @@ func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: v
   return
 }
 
+// CHECK-LABEL: @dot_accumulate_2way
+func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c: i32) {
+  // CHECK:   nvvm.dot.accumulate.2way lo %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi16>, vector<4xi8>
+  %1 = nvvm.dot.accumulate.2way lo %a_vec <u16>, %b_vec <u8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK:   nvvm.dot.accumulate.2way hi %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi16>, vector<4xi8>
+  %3 = nvvm.dot.accumulate.2way hi %a_vec <s16>, %b_vec <s8>, %c: vector<2xi16>, vector<4xi8>
+  return
+}
+
 // -----
 
 // Just check these don't emit errors.
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index accec9c7af4f2..e350d5256b5a6 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -248,3 +248,35 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
   %res = nvvm.cvt.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16
   llvm.return
 }
+
+// -----
+
+llvm.func @nvvm_dot_accumulate_4way_invalid_type_a(%a_vec : vector<4xi8>, %b_vec : vector<4xi8>, %c : i32) {
+  // expected-error @below {{a_type must be S8 or U8}}
+  %res = nvvm.dot.accumulate.4way %a_vec <u16>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_dot_accumulate_4way_invalid_type_b(%a_vec : vector<4xi8>, %b_vec : vector<4xi8>, %c : i32) {
+  // expected-error @below {{b_type must be S8 or U8}}
+  %res = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u16>, %c: vector<4xi8>, vector<4xi8>
+  llvm.return
+}
+
+// ----
+
+llvm.func @nvvm_dot_accumulate_2way_invalid_type_a(%a_vec : vector<2xi16>, %b_vec : vector<4xi8>, %c : i32) {
+  // expected-error @below {{a_type must be S16 or U16}}
+  %res = nvvm.dot.accumulate.2way lo %a_vec <u8>, %b_vec <u8>, %c: vector<2xi16>, vector<4xi8>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_dot_accumulate_2way_invalid_type_b(%a_vec : vector<2xi16>, %b_vec : vector<4xi8>, %c : i32) {
+  // expected-error @below {{b_type must be S8 or U8}}
+  %res = nvvm.dot.accumulate.2way lo %a_vec <u16>, %b_vec <u16>, %c: vector<2xi16>, vector<4xi8>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 894b72733a46a..4bd9326da2233 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -866,3 +866,41 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32)
   %3 = nvvm.dot.accumulate.4way %a <s8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
   llvm.return
 }
+
+// -----
+// CHECK-LABEL: @nvvm_dot_accumulate_2way
+llvm.func @nvvm_dot_accumulate_2way(%a: vector<2xi16>, %b: vector<4xi8>, %c: i32) {
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+  %0 = nvvm.dot.accumulate.2way lo %a <u16>, %b <u8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+  %1 = nvvm.dot.accumulate.2way hi %a <u16>, %b <u8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+  %2 = nvvm.dot.accumulate.2way lo %a <s16>, %b <u8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+  %3 = nvvm.dot.accumulate.2way hi %a <s16>, %b <u8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+  %4 = nvvm.dot.accumulate.2way lo %a <u16>, %b <s8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+  %5 = nvvm.dot.accumulate.2way hi %a <u16>, %b <s8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+  %6 = nvvm.dot.accumulate.2way lo %a <s16>, %b <s8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+  %7 = nvvm.dot.accumulate.2way hi %a <s16>, %b <s8>, %c: vector<2xi16>, vector<4xi8>  
+  llvm.return
+}

Comment on lines 3451 to 3454
def DotAccumulateS8 : I32EnumAttrCase<"S8", 1, "s8">;
def DotAccumulateU8 : I32EnumAttrCase<"U8", 0, "u8">;
def DotAccumulateS16 : I32EnumAttrCase<"S16", 2, "s16">;
def DotAccumulateU16 : I32EnumAttrCase<"U16", 3, "u16">;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just use existing mlir types? why are we adding new types

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have added types in the past, at that time, MLIR didn't have all the types that tensor core supports. But now it has. We should cleanup them.
In the meantime, we can start using MLIR types

Copy link
Contributor Author

@Wolfram70 Wolfram70 May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I think the names here were misleading. I have reduced it to only s and u in the latest revision. I did it this way here (based on our discussion in #139043 (comment)) since this only refers to whether the individual elements will be sign or zero extended when computing the dot product (and whether the result by extension is signed or unsigned). Another way to do this would be to use AnyTypeOf for the result and the inputs between S8/16/32 and I8/16/32 but I'm not sure if that would look clean.

@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-dp2a branch from 620724c to dfc84ee Compare May 21, 2025 06:19
Copy link

github-actions bot commented May 21, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-dp2a branch from dfc84ee to 7aeccec Compare May 21, 2025 06:45
@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-dp2a branch from 7aeccec to 0a7100d Compare May 23, 2025 04:46
case 3:
return llvm::Intrinsic::nvvm_idp4a_s_s;
return {llvm::Intrinsic::nvvm_idp4a_s_s, args};
default:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can try an array with all 4 id's initialized. Seems that will be more shorter

@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-dp2a branch 2 times, most recently from 7401846 to 83f858d Compare May 23, 2025 08:49
Copy link
Contributor

@durga4github durga4github left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest revision looks good to me.

This change:
- Adds the dot.accumulate.2way Op to the NVVM dialect for 16-bit to 8-bit
  dot-product accumulate operation.
- Refactors the recently added dot.accumulate.4way and adds a verifier.
@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-dp2a branch from 83f858d to 38c55a5 Compare May 23, 2025 08:55
Comment on lines +3599 to +3649
def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
let description = [{
Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a
32-bit result.
Operand `a` is a vector of two 16-bit elements and operand `b` a vector
of four 8-bit elements between which the dot product is computed.

The `a_type` and `b_type` attributes specify the type of the elements in `a`
and `b` respectively.
If `a_type` or `b_type` is `s`, then the elements in the corresponding
vector are sign-extended to 32-bit before the dot product is computed.
If `a_type` or `b_type` is `u`, then the elements in the corresponding
vector are zero-extended to 32-bit instead.

The `b_hi` boolean attribute specifies which two bytes of `b` are used for
the dot product. If `b_hi` is true, then the dot product is computed
between `a` and elements at indices 2 and 3 of `b`. If `b_hi` is false,
then the dot product is computed between `a` and elements at indices 0 and
1 of `b`.

Operand `c` is a 32-bit integer to which the result is accumulated. It is
treated as holding a signed integer if any of `a_type` or `b_type` is
signed.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
}];

let arguments = (ins
VectorOfLengthAndType<[2], [I16]>:$a,
DotAccumulateTypeAttr:$a_type,
VectorOfLengthAndType<[4], [I8]>:$b,
DotAccumulateTypeAttr:$b_type,
I32:$c,
BoolAttr:$b_hi
);

let results = (outs I32:$res);

let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";

let extraClassDeclaration = [{
static mlir::NVVM::IDArgPair
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder);
}];

string llvmBuilder = [{
llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
llvm::Value* argA = op.getPackedArg($a, builder);
llvm::Value* argB = op.getPackedArg($b, builder);
$res = createIntrinsicCall(builder, id, {argA, argB, $c});
auto [id, args] = NVVM::DotAccumulate2WayOp::getIntrinsicIDAndArgs(
*op, moduleTranslation, builder);
$res = createIntrinsicCall(builder, id, args);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we split this op dot.accumulate.2way in a seperate PR?

Copy link
Contributor Author

@Wolfram70 Wolfram70 May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have created another PR for updating the dot.accumulate.4way Op here: #141223
Since the dot.accumulate.2way Op needs some common changes, I plan to rebase this once that is merged.

@Wolfram70 Wolfram70 changed the title [MLIR][NVVM] Update dot.accumulate NVVM Ops [MLIR][NVVM] Add dot.accumulate.2way Op May 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants