Skip to content

[mlir][Vector] Support mixed mode vector.contract lowering #117753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Groverkss
Copy link
Member

This patch adds mixed-mode contract support. The implementation follows the documentation of vector.contract:

https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop

If operands and the result have types of different bitwidths, operands are promoted to have the same bitwidth as the result before performing the contraction. For integer types, only signless integer types are supported, and the promotion happens via sign extension.

@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

Changes

This patch adds mixed-mode contract support. The implementation follows the documentation of vector.contract:

https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop

> If operands and the result have types of different bitwidths, operands are promoted to have the same bitwidth as the result before performing the contraction. For integer types, only signless integer types are supported, and the promotion happens via sign extension.


Full diff: https://github.com/llvm/llvm-project/pull/117753.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+32-25)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir (+27)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir (+18)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 21261478f0648f..c8ad2892384995 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -80,6 +80,22 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
 }
 
+Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v,
+                           Type dstElementType) {
+  Type elementType = v.getType();
+  auto vecType = dyn_cast<VectorType>(elementType);
+  if (vecType)
+    elementType = vecType.getElementType();
+  if (elementType == dstElementType)
+    return v;
+  Type promotedType = dstElementType;
+  if (vecType)
+    promotedType = vecType.clone(promotedType);
+  if (isa<FloatType>(dstElementType))
+    return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
+  return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
+}
+
 // Helper method to possibly drop a dimension in a load.
 // TODO
 static Value reshapeLoad(Location loc, Value val, VectorType type,
@@ -136,6 +152,11 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
   using vector::CombiningKind;
   Value mul;
 
+  if (acc) {
+    x = promoteToElementType(loc, rewriter, x, getElementTypeOrSelf(acc));
+    y = promoteToElementType(loc, rewriter, y, getElementTypeOrSelf(acc));
+  }
+
   if (isInt) {
     if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
         kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
@@ -413,21 +434,6 @@ struct UnrolledOuterProductGenerator
     return rewriter.create<vector::TransposeOp>(loc, v, perm);
   }
 
-  Value promote(Value v, Type dstElementType) {
-    Type elementType = v.getType();
-    auto vecType = dyn_cast<VectorType>(elementType);
-    if (vecType)
-      elementType = vecType.getElementType();
-    if (elementType == dstElementType)
-      return v;
-    Type promotedType = dstElementType;
-    if (vecType)
-      promotedType = vecType.clone(promotedType);
-    if (isa<FloatType>(dstElementType))
-      return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
-    return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
-  }
-
   FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
                              VectorType lhsType, int reductionSize,
                              std::optional<Value> maybeMask = std::nullopt) {
@@ -439,8 +445,8 @@ struct UnrolledOuterProductGenerator
     for (int64_t k = 0; k < reductionSize; ++k) {
       Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
       Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
-      extractA = promote(extractA, resElementType);
-      extractB = promote(extractB, resElementType);
+      extractA = promoteToElementType(loc, rewriter, extractA, resElementType);
+      extractB = promoteToElementType(loc, rewriter, extractB, resElementType);
       Value extractMask;
       if (maybeMask.has_value() && maybeMask.value())
         extractMask =
@@ -764,6 +770,8 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
       Value b = rank == 1
                     ? rhs
                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
+      a = promoteToElementType(loc, rewriter, a, getElementTypeOrSelf(dstType));
+      b = promoteToElementType(loc, rewriter, b, getElementTypeOrSelf(dstType));
       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
       Value reduced = rewriter.create<vector::ReductionOp>(
           op.getLoc(), vector::CombiningKind::ADD, m);
@@ -925,12 +933,6 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
   if (failed(filter(op)))
     return failure();
 
-  // TODO: support mixed mode contract lowering.
-  if (op.getLhsType().getElementType() !=
-          getElementTypeOrSelf(op.getAccType()) ||
-      op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
-    return failure();
-
   // TODO: the code below assumes the default contraction, make sure it supports
   // other kinds before enabling this lowering.
   if (op.getKind() != vector::CombiningKind::ADD) {
@@ -1149,10 +1151,15 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
     if (rhsType.getRank() != 1)
       return rewriter.notifyMatchFailure(
           op, "When LHS has rank 1, expected also RHS to have rank 1");
-    Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
-    auto kind = vector::CombiningKind::ADD;
 
     Value acc = op.getAcc();
+    Value lhs = promoteToElementType(loc, rewriter, op.getLhs(),
+                                     getElementTypeOrSelf(acc));
+    Value rhs = promoteToElementType(loc, rewriter, op.getRhs(),
+                                     getElementTypeOrSelf(acc));
+    Value m = createMul(loc, lhs, rhs, isInt, rewriter);
+    auto kind = vector::CombiningKind::ADD;
+
     Operation *reductionOp =
         acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
             : rewriter.create<vector::ReductionOp>(loc, kind, m);
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
index 0ba185bb847609..3927058a4c6b45 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
@@ -295,6 +295,33 @@ func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1
   return %res : vector<2xi32>
 }
 
+// CHECK-LABEL: @matmul_mixed
+// CHECK:  %[[EXT00:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK:  %[[EXT01:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK:  %[[MUL1:.+]] = arith.mulf %[[EXT00]], %[[EXT01]] : vector<2xf32>
+// CHECK:  vector.reduction <add>, %[[MUL1]] : vector<2xf32> into f32
+
+// CHECK:  %[[EXT11:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK:  %[[MUL2:.+]] = arith.mulf %[[EXT00]], %[[EXT11]] : vector<2xf32>
+// CHECK:  vector.reduction <add>, %[[MUL2]] : vector<2xf32> into f32
+
+// CHECK:  %[[EXT20:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK:  %[[EXT21:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK:  %[[MUL3:.+]] = arith.mulf %[[EXT20]], %[[EXT21]] : vector<2xf32>
+// CHECK:  vector.reduction <add>, %[[MUL3]] : vector<2xf32> into f32
+
+// CHECK:  %[[EXT31:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK:  %[[MUL4:.+]] = arith.mulf %[[EXT20]], %[[EXT31]] : vector<2xf32>
+// CHECK:  vector.reduction <add>, %[[MUL4]] : vector<2xf32> into f32
+
+func.func @matmul_mixed(%arg0: vector<2x2xf16>,
+                        %arg1: vector<2x2xf16>,
+                        %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+    : vector<2x2xf16>, vector<2x2xf16> into vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
index e93c5a08bdc7c9..5d9977e94b1598 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
@@ -51,6 +51,24 @@ func.func @parallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vect
   return %0 : f32
 }
 
+// CHECK-LABEL: func @parallel_contract_lowering_mixed_types
+//       CHECK:   %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16>
+//       CHECK:   %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16>
+//       CHECK:   %[[EXT0:.+]] = arith.extf %[[E0]] : f16 to f32
+//       CHECK:   %[[EXT1:.+]] = arith.extf %[[E1]] : f16 to f32
+//       CHECK:   %[[M:.*]] = arith.mulf %[[EXT0]], %[[EXT1]] : f32
+//       CHECK:   %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32
+//       CHECK:   return %[[A]] : f32
+func.func @parallel_contract_lowering_mixed_types(%arg0: vector<1x1xf16>, %arg1: vector<1x1xf16>, %arg2: f32) -> f32 {
+  %0 = vector.contract {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> ()>],
+    iterator_types = ["reduction", "reduction"], kind = #vector.kind<add>}
+  %arg0, %arg1, %arg2 : vector<1x1xf16>, vector<1x1xf16> into f32
+  return %0 : f32
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

@Groverkss Groverkss force-pushed the mixed-mode-contract branch from cee8bec to 466676c Compare December 3, 2024 15:28
@Groverkss Groverkss requested a review from kuhar December 3, 2024 15:55

if (isa<FloatType>(dstElementType))
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
// For integer types, vector.contract only supports signless integer types
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd add a new line before this comment

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks a lot for the contribution. As you may know, this has been a controversial topic (actually, I'm not sure if we discussed it in the context of vector.contract or linalg at the time). The main problem is this:

For integer types, only signless integer types are supported, and the promotion happens via sign extension.

If we want to walk this path, we should find a consistent way to represent signed and unsigned extensions per operand (and any potential flags associated with each operand extension, including FP ones). Otherwise, we would need different matches/handling depending on whether the specific extension each operand is going through.

Any ideas to address this issue?

@Groverkss
Copy link
Member Author

Hey, thanks a lot for the contribution. As you may know, this has been a controversial topic (actually, I'm not sure if we discussed it in the context of vector.contract or linalg at the time). The main problem is this:

For integer types, only signless integer types are supported, and the promotion happens via sign extension.

If we want to walk this path, we should find a consistent way to represent signed and unsigned extensions per operand (and any potential flags associated with each operand extension, including FP ones). Otherwise, we would need different matches/handling depending on whether the specific extension each operand is going through.

Any ideas to address this issue?

Just to be clear, this patch is implementing mixed mode vector.contract lowering for lowering variants other than outer product. Outer product lowering already implemented mixed mode lowering, this PR simply adds it for other variants as well, in the same way outer product lowering did it. If something is controversial here, I'm only following what is already implemented.

On the point of signed/unsigned integer extensions per operand, the line that you quoted regarding using sign extension is not something I decided, but is from vector.contract documentation (https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop):

If operands and the result have types of different bitwidths, operands are promoted to have the same bitwidth as the result before performing the contraction. For integer types, only signless integer types are supported, and the promotion happens via sign extension.

The documentation already mentions that we use signed extension, which is what this pr implements.

Regarding FP types, if an operand extension requires fast math flags, I would guess that this should be annotated on vector.contract. But I'm not aware of any other lowerings looking at fast math flags or any fp flags. Do you have an example or a documentation link I could follow for what is the intended behavior there?

@dcaballe
Copy link
Contributor

dcaballe commented Dec 3, 2024

Sorry if my comment reads critical. It was not the intent. This is a recurring issue, and I appreciate you bringing it up! Yet another half-baked thing we have in the Vector dialect and it’s probably a good time to address it.

The main challenge lies in embedding conversion semantics for each operand within vector.contract, as conversions can be complex (signedness, FMF, rounding/denormal modes and what not). Regardless of the current implementation, supporting only a few conversion cases would lead to an inconsistent representation and I think we shouldn’t move towards another inconsistent intermediate state unless there is a clear path to a final state where full conversion semantics are supported. With this in mind, I see a couple of ways to move forward:

  1. Propose a way to encode conversion information per operand (e.g., using an attribute). This would involve translating explicit Arith/Vector conversion semantics into that embedded representation, which may or may not be worthwhile depending on what you are trying to achieve.

  2. Take a step back, disallow conversions in vector.contract until we have a comprehensive solution and update the documentation accordingly.

What do you think?

@Groverkss
Copy link
Member Author

Sorry if my comment reads critical. It was not the intent. This is a recurring issue, and I appreciate you bringing it up! Yet another half-baked thing we have in the Vector dialect and it’s probably a good time to address it.

The main challenge lies in embedding conversion semantics for each operand within vector.contract, as conversions can be complex (signedness, FMF, rounding/denormal modes and what not). Regardless of the current implementation, supporting only a few conversion cases would lead to an inconsistent representation and I think we shouldn’t move towards another inconsistent intermediate state unless there is a clear path to a final state where full conversion semantics are supported. With this in mind, I see a couple of ways to move forward:

  1. Propose a way to encode conversion information per operand (e.g., using an attribute). This would involve translating explicit Arith/Vector conversion semantics into that embedded representation, which may or may not be worthwhile depending on what you are trying to achieve.
  2. Take a step back, disallow conversions in vector.contract until we have a comprehensive solution and update the documentation accordingly.

What do you think?

Ok, thanks for making it clear with the full picture. This makes sense, it does seem half-baked.

I think that option 2 might be too big of change to plumb through, as many transformations need it. I know there are some mixed precision fadd intrinsics that use vector.contract as a way to target it, and might cause a lot of churn.

How about instead we do Option 1) in 2 parts:

  • We update vector.contract documentation to mention explicitly how it handles extension for now. For integer types, it is already mentioned. We can also mention it for floating types (no extra flags for rounding / denormal modes / etc.). We can also mention that we would like to support this in future by allowing per operand extension information. This will make the lowering match what the documentation says and we can build on a restricted solution that matches the documentation for now. We can also mention that for now if the user has some other extension semantics, they should move the extension out of vector.contract themselves.
  • I work on a patch on the side to implement FastMath/Denorm interfaces for vector.contract operand extension.

What do you think?

@dcaballe
Copy link
Contributor

dcaballe commented Dec 4, 2024

It sounds great to me!

@Groverkss
Copy link
Member Author

It sounds great to me!

Cool! Let me send a patch tommorow to update vector.contract documentation as discussed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants