@@ -80,6 +80,22 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
80
80
return AffineMap::get (map.getNumDims () - 1 , 0 , results, ctx);
81
81
}
82
82
83
+ Value promoteToElementType (Location loc, RewriterBase &rewriter, Value v,
84
+ Type dstElementType) {
85
+ Type elementType = v.getType ();
86
+ auto vecType = dyn_cast<VectorType>(elementType);
87
+ if (vecType)
88
+ elementType = vecType.getElementType ();
89
+ if (elementType == dstElementType)
90
+ return v;
91
+ Type promotedType = dstElementType;
92
+ if (vecType)
93
+ promotedType = vecType.clone (promotedType);
94
+ if (isa<FloatType>(dstElementType))
95
+ return rewriter.create <arith::ExtFOp>(loc, promotedType, v);
96
+ return rewriter.create <arith::ExtSIOp>(loc, promotedType, v);
97
+ }
98
+
83
99
// Helper method to possibly drop a dimension in a load.
84
100
// TODO
85
101
static Value reshapeLoad (Location loc, Value val, VectorType type,
@@ -136,6 +152,11 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
136
152
using vector::CombiningKind;
137
153
Value mul;
138
154
155
+ if (acc) {
156
+ x = promoteToElementType (loc, rewriter, x, getElementTypeOrSelf (acc));
157
+ y = promoteToElementType (loc, rewriter, y, getElementTypeOrSelf (acc));
158
+ }
159
+
139
160
if (isInt) {
140
161
if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
141
162
kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
@@ -413,21 +434,6 @@ struct UnrolledOuterProductGenerator
413
434
return rewriter.create <vector::TransposeOp>(loc, v, perm);
414
435
}
415
436
416
- Value promote (Value v, Type dstElementType) {
417
- Type elementType = v.getType ();
418
- auto vecType = dyn_cast<VectorType>(elementType);
419
- if (vecType)
420
- elementType = vecType.getElementType ();
421
- if (elementType == dstElementType)
422
- return v;
423
- Type promotedType = dstElementType;
424
- if (vecType)
425
- promotedType = vecType.clone (promotedType);
426
- if (isa<FloatType>(dstElementType))
427
- return rewriter.create <arith::ExtFOp>(loc, promotedType, v);
428
- return rewriter.create <arith::ExtSIOp>(loc, promotedType, v);
429
- }
430
-
431
437
FailureOr<Value> outerProd (Value lhs, Value rhs, Value res,
432
438
VectorType lhsType, int reductionSize,
433
439
std::optional<Value> maybeMask = std::nullopt) {
@@ -439,8 +445,8 @@ struct UnrolledOuterProductGenerator
439
445
for (int64_t k = 0 ; k < reductionSize; ++k) {
440
446
Value extractA = rewriter.create <vector::ExtractOp>(loc, lhs, k);
441
447
Value extractB = rewriter.create <vector::ExtractOp>(loc, rhs, k);
442
- extractA = promote ( extractA, resElementType);
443
- extractB = promote ( extractB, resElementType);
448
+ extractA = promoteToElementType (loc, rewriter, extractA, resElementType);
449
+ extractB = promoteToElementType (loc, rewriter, extractB, resElementType);
444
450
Value extractMask;
445
451
if (maybeMask.has_value () && maybeMask.value ())
446
452
extractMask =
@@ -764,6 +770,8 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
764
770
Value b = rank == 1
765
771
? rhs
766
772
: rewriter.create <vector::ExtractOp>(op.getLoc (), rhs, c);
773
+ a = promoteToElementType (loc, rewriter, a, getElementTypeOrSelf (dstType));
774
+ b = promoteToElementType (loc, rewriter, b, getElementTypeOrSelf (dstType));
767
775
Value m = createMul (op.getLoc (), a, b, isInt, rewriter);
768
776
Value reduced = rewriter.create <vector::ReductionOp>(
769
777
op.getLoc (), vector::CombiningKind::ADD, m);
@@ -925,12 +933,6 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
925
933
if (failed (filter (op)))
926
934
return failure ();
927
935
928
- // TODO: support mixed mode contract lowering.
929
- if (op.getLhsType ().getElementType () !=
930
- getElementTypeOrSelf (op.getAccType ()) ||
931
- op.getRhsType ().getElementType () != getElementTypeOrSelf (op.getAccType ()))
932
- return failure ();
933
-
934
936
// TODO: the code below assumes the default contraction, make sure it supports
935
937
// other kinds before enabling this lowering.
936
938
if (op.getKind () != vector::CombiningKind::ADD) {
@@ -1149,10 +1151,15 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
1149
1151
if (rhsType.getRank () != 1 )
1150
1152
return rewriter.notifyMatchFailure (
1151
1153
op, " When LHS has rank 1, expected also RHS to have rank 1" );
1152
- Value m = createMul (loc, op.getLhs (), op.getRhs (), isInt, rewriter);
1153
- auto kind = vector::CombiningKind::ADD;
1154
1154
1155
1155
Value acc = op.getAcc ();
1156
+ Value lhs = promoteToElementType (loc, rewriter, op.getLhs (),
1157
+ getElementTypeOrSelf (acc));
1158
+ Value rhs = promoteToElementType (loc, rewriter, op.getRhs (),
1159
+ getElementTypeOrSelf (acc));
1160
+ Value m = createMul (loc, lhs, rhs, isInt, rewriter);
1161
+ auto kind = vector::CombiningKind::ADD;
1162
+
1156
1163
Operation *reductionOp =
1157
1164
acc ? rewriter.create <vector::ReductionOp>(loc, kind, m, acc)
1158
1165
: rewriter.create <vector::ReductionOp>(loc, kind, m);
0 commit comments