Skip to content

Commit cee8bec

Browse files
committed
[mlir][Vector] Support mixed mode vector.contract lowering
1 parent 4028bb1 commit cee8bec

File tree

3 files changed

+77
-25
lines changed

3 files changed

+77
-25
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,22 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
8080
return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
8181
}
8282

83+
Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v,
84+
Type dstElementType) {
85+
Type elementType = v.getType();
86+
auto vecType = dyn_cast<VectorType>(elementType);
87+
if (vecType)
88+
elementType = vecType.getElementType();
89+
if (elementType == dstElementType)
90+
return v;
91+
Type promotedType = dstElementType;
92+
if (vecType)
93+
promotedType = vecType.clone(promotedType);
94+
if (isa<FloatType>(dstElementType))
95+
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
96+
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
97+
}
98+
8399
// Helper method to possibly drop a dimension in a load.
84100
// TODO
85101
static Value reshapeLoad(Location loc, Value val, VectorType type,
@@ -136,6 +152,11 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
136152
using vector::CombiningKind;
137153
Value mul;
138154

155+
if (acc) {
156+
x = promoteToElementType(loc, rewriter, x, getElementTypeOrSelf(acc));
157+
y = promoteToElementType(loc, rewriter, y, getElementTypeOrSelf(acc));
158+
}
159+
139160
if (isInt) {
140161
if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
141162
kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
@@ -413,21 +434,6 @@ struct UnrolledOuterProductGenerator
413434
return rewriter.create<vector::TransposeOp>(loc, v, perm);
414435
}
415436

416-
Value promote(Value v, Type dstElementType) {
417-
Type elementType = v.getType();
418-
auto vecType = dyn_cast<VectorType>(elementType);
419-
if (vecType)
420-
elementType = vecType.getElementType();
421-
if (elementType == dstElementType)
422-
return v;
423-
Type promotedType = dstElementType;
424-
if (vecType)
425-
promotedType = vecType.clone(promotedType);
426-
if (isa<FloatType>(dstElementType))
427-
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
428-
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
429-
}
430-
431437
FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
432438
VectorType lhsType, int reductionSize,
433439
std::optional<Value> maybeMask = std::nullopt) {
@@ -439,8 +445,8 @@ struct UnrolledOuterProductGenerator
439445
for (int64_t k = 0; k < reductionSize; ++k) {
440446
Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
441447
Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
442-
extractA = promote(extractA, resElementType);
443-
extractB = promote(extractB, resElementType);
448+
extractA = promoteToElementType(loc, rewriter, extractA, resElementType);
449+
extractB = promoteToElementType(loc, rewriter, extractB, resElementType);
444450
Value extractMask;
445451
if (maybeMask.has_value() && maybeMask.value())
446452
extractMask =
@@ -764,6 +770,8 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
764770
Value b = rank == 1
765771
? rhs
766772
: rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
773+
a = promoteToElementType(loc, rewriter, a, getElementTypeOrSelf(dstType));
774+
b = promoteToElementType(loc, rewriter, b, getElementTypeOrSelf(dstType));
767775
Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
768776
Value reduced = rewriter.create<vector::ReductionOp>(
769777
op.getLoc(), vector::CombiningKind::ADD, m);
@@ -925,12 +933,6 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
925933
if (failed(filter(op)))
926934
return failure();
927935

928-
// TODO: support mixed mode contract lowering.
929-
if (op.getLhsType().getElementType() !=
930-
getElementTypeOrSelf(op.getAccType()) ||
931-
op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
932-
return failure();
933-
934936
// TODO: the code below assumes the default contraction, make sure it supports
935937
// other kinds before enabling this lowering.
936938
if (op.getKind() != vector::CombiningKind::ADD) {
@@ -1149,10 +1151,15 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
11491151
if (rhsType.getRank() != 1)
11501152
return rewriter.notifyMatchFailure(
11511153
op, "When LHS has rank 1, expected also RHS to have rank 1");
1152-
Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1153-
auto kind = vector::CombiningKind::ADD;
11541154

11551155
Value acc = op.getAcc();
1156+
Value lhs = promoteToElementType(loc, rewriter, op.getLhs(),
1157+
getElementTypeOrSelf(acc));
1158+
Value rhs = promoteToElementType(loc, rewriter, op.getRhs(),
1159+
getElementTypeOrSelf(acc));
1160+
Value m = createMul(loc, lhs, rhs, isInt, rewriter);
1161+
auto kind = vector::CombiningKind::ADD;
1162+
11561163
Operation *reductionOp =
11571164
acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
11581165
: rewriter.create<vector::ReductionOp>(loc, kind, m);

mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,33 @@ func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1
295295
return %res : vector<2xi32>
296296
}
297297

298+
// CHECK-LABEL: @matmul_mixed
299+
// CHECK: %[[EXT00:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
300+
// CHECK: %[[EXT01:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
301+
// CHECK: %[[MUL1:.+]] = arith.mulf %[[EXT00]], %[[EXT01]] : vector<2xf32>
302+
// CHECK: vector.reduction <add>, %[[MUL1]] : vector<2xf32> into f32
303+
304+
// CHECK: %[[EXT11:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
305+
// CHECK: %[[MUL2:.+]] = arith.mulf %[[EXT00]], %[[EXT11]] : vector<2xf32>
306+
// CHECK: vector.reduction <add>, %[[MUL2]] : vector<2xf32> into f32
307+
308+
// CHECK: %[[EXT20:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
309+
// CHECK: %[[EXT21:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
310+
// CHECK: %[[MUL3:.+]] = arith.mulf %[[EXT20]], %[[EXT21]] : vector<2xf32>
311+
// CHECK: vector.reduction <add>, %[[MUL3]] : vector<2xf32> into f32
312+
313+
// CHECK: %[[EXT31:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
314+
// CHECK: %[[MUL4:.+]] = arith.mulf %[[EXT20]], %[[EXT31]] : vector<2xf32>
315+
// CHECK: vector.reduction <add>, %[[MUL4]] : vector<2xf32> into f32
316+
317+
func.func @matmul_mixed(%arg0: vector<2x2xf16>,
318+
%arg1: vector<2x2xf16>,
319+
%arg2: vector<2x2xf32>) -> vector<2x2xf32> {
320+
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
321+
: vector<2x2xf16>, vector<2x2xf16> into vector<2x2xf32>
322+
return %0 : vector<2x2xf32>
323+
}
324+
298325
module attributes {transform.with_named_sequence} {
299326
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
300327
%f = transform.structured.match ops{["func.func"]} in %module_op

mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,24 @@ func.func @parallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vect
5151
return %0 : f32
5252
}
5353

54+
// CHECK-LABEL: func @parallel_contract_lowering_mixed_types
55+
// CHECK: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16>
56+
// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16>
57+
// CHECK: %[[EXT0:.+]] = arith.extf %[[E0]] : f16 to f32
58+
// CHECK: %[[EXT1:.+]] = arith.extf %[[E1]] : f16 to f32
59+
// CHECK: %[[M:.*]] = arith.mulf %[[EXT0]], %[[EXT1]] : f32
60+
// CHECK: %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32
61+
// CHECK: return %[[A]] : f32
62+
func.func @parallel_contract_lowering_mixed_types(%arg0: vector<1x1xf16>, %arg1: vector<1x1xf16>, %arg2: f32) -> f32 {
63+
%0 = vector.contract {
64+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
65+
affine_map<(d0, d1) -> (d0, d1)>,
66+
affine_map<(d0, d1) -> ()>],
67+
iterator_types = ["reduction", "reduction"], kind = #vector.kind<add>}
68+
%arg0, %arg1, %arg2 : vector<1x1xf16>, vector<1x1xf16> into f32
69+
return %0 : f32
70+
}
71+
5472
module attributes {transform.with_named_sequence} {
5573
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
5674
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)