-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][Vector] Support mixed mode vector.contract lowering #117753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesThis patch adds mixed-mode contract support. The implementation follows the documentation of vector.contract: https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop > If operands and the result have types of different bitwidths, operands are promoted to have the same bitwidth as the result before performing the contraction. For integer types, only signless integer types are supported, and the promotion happens via sign extension. Full diff: https://github.com/llvm/llvm-project/pull/117753.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 21261478f0648f..c8ad2892384995 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -80,6 +80,22 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
}
+Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v,
+ Type dstElementType) {
+ Type elementType = v.getType();
+ auto vecType = dyn_cast<VectorType>(elementType);
+ if (vecType)
+ elementType = vecType.getElementType();
+ if (elementType == dstElementType)
+ return v;
+ Type promotedType = dstElementType;
+ if (vecType)
+ promotedType = vecType.clone(promotedType);
+ if (isa<FloatType>(dstElementType))
+ return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
+ return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
+}
+
// Helper method to possibly drop a dimension in a load.
// TODO
static Value reshapeLoad(Location loc, Value val, VectorType type,
@@ -136,6 +152,11 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
using vector::CombiningKind;
Value mul;
+ if (acc) {
+ x = promoteToElementType(loc, rewriter, x, getElementTypeOrSelf(acc));
+ y = promoteToElementType(loc, rewriter, y, getElementTypeOrSelf(acc));
+ }
+
if (isInt) {
if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
@@ -413,21 +434,6 @@ struct UnrolledOuterProductGenerator
return rewriter.create<vector::TransposeOp>(loc, v, perm);
}
- Value promote(Value v, Type dstElementType) {
- Type elementType = v.getType();
- auto vecType = dyn_cast<VectorType>(elementType);
- if (vecType)
- elementType = vecType.getElementType();
- if (elementType == dstElementType)
- return v;
- Type promotedType = dstElementType;
- if (vecType)
- promotedType = vecType.clone(promotedType);
- if (isa<FloatType>(dstElementType))
- return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
- return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
- }
-
FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
VectorType lhsType, int reductionSize,
std::optional<Value> maybeMask = std::nullopt) {
@@ -439,8 +445,8 @@ struct UnrolledOuterProductGenerator
for (int64_t k = 0; k < reductionSize; ++k) {
Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
- extractA = promote(extractA, resElementType);
- extractB = promote(extractB, resElementType);
+ extractA = promoteToElementType(loc, rewriter, extractA, resElementType);
+ extractB = promoteToElementType(loc, rewriter, extractB, resElementType);
Value extractMask;
if (maybeMask.has_value() && maybeMask.value())
extractMask =
@@ -764,6 +770,8 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
Value b = rank == 1
? rhs
: rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
+ a = promoteToElementType(loc, rewriter, a, getElementTypeOrSelf(dstType));
+ b = promoteToElementType(loc, rewriter, b, getElementTypeOrSelf(dstType));
Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
Value reduced = rewriter.create<vector::ReductionOp>(
op.getLoc(), vector::CombiningKind::ADD, m);
@@ -925,12 +933,6 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
if (failed(filter(op)))
return failure();
- // TODO: support mixed mode contract lowering.
- if (op.getLhsType().getElementType() !=
- getElementTypeOrSelf(op.getAccType()) ||
- op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
- return failure();
-
// TODO: the code below assumes the default contraction, make sure it supports
// other kinds before enabling this lowering.
if (op.getKind() != vector::CombiningKind::ADD) {
@@ -1149,10 +1151,15 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
if (rhsType.getRank() != 1)
return rewriter.notifyMatchFailure(
op, "When LHS has rank 1, expected also RHS to have rank 1");
- Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
- auto kind = vector::CombiningKind::ADD;
Value acc = op.getAcc();
+ Value lhs = promoteToElementType(loc, rewriter, op.getLhs(),
+ getElementTypeOrSelf(acc));
+ Value rhs = promoteToElementType(loc, rewriter, op.getRhs(),
+ getElementTypeOrSelf(acc));
+ Value m = createMul(loc, lhs, rhs, isInt, rewriter);
+ auto kind = vector::CombiningKind::ADD;
+
Operation *reductionOp =
acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
: rewriter.create<vector::ReductionOp>(loc, kind, m);
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
index 0ba185bb847609..3927058a4c6b45 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
@@ -295,6 +295,33 @@ func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1
return %res : vector<2xi32>
}
+// CHECK-LABEL: @matmul_mixed
+// CHECK: %[[EXT00:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[EXT01:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[MUL1:.+]] = arith.mulf %[[EXT00]], %[[EXT01]] : vector<2xf32>
+// CHECK: vector.reduction <add>, %[[MUL1]] : vector<2xf32> into f32
+
+// CHECK: %[[EXT11:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[MUL2:.+]] = arith.mulf %[[EXT00]], %[[EXT11]] : vector<2xf32>
+// CHECK: vector.reduction <add>, %[[MUL2]] : vector<2xf32> into f32
+
+// CHECK: %[[EXT20:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[EXT21:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[MUL3:.+]] = arith.mulf %[[EXT20]], %[[EXT21]] : vector<2xf32>
+// CHECK: vector.reduction <add>, %[[MUL3]] : vector<2xf32> into f32
+
+// CHECK: %[[EXT31:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[MUL4:.+]] = arith.mulf %[[EXT20]], %[[EXT31]] : vector<2xf32>
+// CHECK: vector.reduction <add>, %[[MUL4]] : vector<2xf32> into f32
+
+func.func @matmul_mixed(%arg0: vector<2x2xf16>,
+ %arg1: vector<2x2xf16>,
+ %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+ %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<2x2xf16>, vector<2x2xf16> into vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
index e93c5a08bdc7c9..5d9977e94b1598 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
@@ -51,6 +51,24 @@ func.func @parallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vect
return %0 : f32
}
+// CHECK-LABEL: func @parallel_contract_lowering_mixed_types
+// CHECK: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16>
+// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16>
+// CHECK: %[[EXT0:.+]] = arith.extf %[[E0]] : f16 to f32
+// CHECK: %[[EXT1:.+]] = arith.extf %[[E1]] : f16 to f32
+// CHECK: %[[M:.*]] = arith.mulf %[[EXT0]], %[[EXT1]] : f32
+// CHECK: %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32
+// CHECK: return %[[A]] : f32
+func.func @parallel_contract_lowering_mixed_types(%arg0: vector<1x1xf16>, %arg1: vector<1x1xf16>, %arg2: f32) -> f32 {
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> ()>],
+ iterator_types = ["reduction", "reduction"], kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2 : vector<1x1xf16>, vector<1x1xf16> into f32
+ return %0 : f32
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
|
cee8bec
to
466676c
Compare
|
||
if (isa<FloatType>(dstElementType)) | ||
return rewriter.create<arith::ExtFOp>(loc, promotedType, v); | ||
// For integer types, vector.contract only supports signless integer types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'd add a new line before this comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks a lot for the contribution. As you may know, this has been a controversial topic (actually, I'm not sure if we discussed it in the context of vector.contract
or linalg at the time). The main problem is this:
For integer types, only signless integer types are supported, and the promotion happens via sign extension.
If we want to walk this path, we should find a consistent way to represent signed and unsigned extensions per operand (and any potential flags associated with each operand extension, including FP ones). Otherwise, we would need different matches/handling depending on whether the specific extension each operand is going through.
Any ideas to address this issue?
Just to be clear, this patch is implementing mixed mode vector.contract lowering for lowering variants other than outer product. Outer product lowering already implemented mixed mode lowering, this PR simply adds it for other variants as well, in the same way outer product lowering did it. If something is controversial here, I'm only following what is already implemented. On the point of signed/unsigned integer extensions per operand, the line that you quoted regarding using sign extension is not something I decided, but is from vector.contract documentation (https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop):
The documentation already mentions that we use signed extension, which is what this pr implements. Regarding FP types, if an operand extension requires fast math flags, I would guess that this should be annotated on vector.contract. But I'm not aware of any other lowerings looking at fast math flags or any fp flags. Do you have an example or a documentation link I could follow for what is the intended behavior there? |
Sorry if my comment reads critical. It was not the intent. This is a recurring issue, and I appreciate you bringing it up! Yet another half-baked thing we have in the Vector dialect and it’s probably a good time to address it. The main challenge lies in embedding conversion semantics for each operand within
What do you think? |
Ok, thanks for making it clear with the full picture. This makes sense, it does seem half-baked. I think that option 2 might be too big of change to plumb through, as many transformations need it. I know there are some mixed precision fadd intrinsics that use vector.contract as a way to target it, and might cause a lot of churn. How about instead we do Option 1) in 2 parts:
What do you think? |
It sounds great to me! |
Cool! Let me send a patch tommorow to update vector.contract documentation as discussed. |
This patch adds mixed-mode contract support. The implementation follows the documentation of vector.contract:
https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop