@@ -80,19 +80,25 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
80
80
return AffineMap::get (map.getNumDims () - 1 , 0 , results, ctx);
81
81
}
82
82
83
- Value promoteToElementType (Location loc, RewriterBase &rewriter, Value v,
84
- Type dstElementType) {
85
- Type elementType = v.getType ();
86
- auto vecType = dyn_cast<VectorType>(elementType);
87
- if (vecType)
88
- elementType = vecType.getElementType ();
83
+ static Value promoteToElementType (Location loc, RewriterBase &rewriter, Value v,
84
+ Type dstElementType) {
85
+ Type elementType = getElementTypeOrSelf (v.getType ());
89
86
if (elementType == dstElementType)
90
87
return v;
88
+
89
+ // vector.contract only allows extension on operands.
90
+ assert (elementType.getIntOrFloatBitWidth () <=
91
+ dstElementType.getIntOrFloatBitWidth () &&
92
+ " vector.contract does not allow truncation of operands" );
93
+
91
94
Type promotedType = dstElementType;
92
- if (vecType)
95
+ if (auto vecType = dyn_cast<VectorType>(v. getType ()) )
93
96
promotedType = vecType.clone (promotedType);
97
+
94
98
if (isa<FloatType>(dstElementType))
95
99
return rewriter.create <arith::ExtFOp>(loc, promotedType, v);
100
+ // For integer types, vector.contract only supports signless integer types
101
+ // and promotion happens via sign extension.
96
102
return rewriter.create <arith::ExtSIOp>(loc, promotedType, v);
97
103
}
98
104
0 commit comments