Skip to content

Commit 466676c

Browse files
committed
Address comments
1 parent ec96f24 commit 466676c

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,25 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
8080
return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
8181
}
8282

83-
Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v,
84-
Type dstElementType) {
85-
Type elementType = v.getType();
86-
auto vecType = dyn_cast<VectorType>(elementType);
87-
if (vecType)
88-
elementType = vecType.getElementType();
83+
static Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v,
84+
Type dstElementType) {
85+
Type elementType = getElementTypeOrSelf(v.getType());
8986
if (elementType == dstElementType)
9087
return v;
88+
89+
// vector.contract only allows extension on operands.
90+
assert(elementType.getIntOrFloatBitWidth() <=
91+
dstElementType.getIntOrFloatBitWidth() &&
92+
"vector.contract does not allow truncation of operands");
93+
9194
Type promotedType = dstElementType;
92-
if (vecType)
95+
if (auto vecType = dyn_cast<VectorType>(v.getType()))
9396
promotedType = vecType.clone(promotedType);
97+
9498
if (isa<FloatType>(dstElementType))
9599
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
100+
// For integer types, vector.contract only supports signless integer types
101+
// and promotion happens via sign extension.
96102
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
97103
}
98104

0 commit comments

Comments
 (0)