@@ -151,13 +151,39 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
151
151
return false ;
152
152
}
153
153
154
- AffineMap mlir::vector::getTransferMinorIdentityMap (ShapedType shapedType,
155
- VectorType vectorType) {
156
- int64_t elementVectorRank = 0 ;
154
+ // / Returns the number of dimensions of the `shapedType` that participate in the
155
+ // / vector transfer, effectively the rank of the vector dimensions within the
156
+ // / `shapedType`. This is calculated by taking the rank of the `vectorType`
157
+ // / being transferred and subtracting the rank of the `shapedType`'s element
158
+ // / type if it's also a vector.
159
+ // /
160
+ // / This is used to determine the number of minor dimensions for identity maps
161
+ // / in vector transfers.
162
+ // /
163
+ // / For example, given a transfer operation involving `shapedType` and
164
+ // / `vectorType`:
165
+ // /
166
+ // / - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
167
+ // / - shapedType.getElementType() = f32 (rank 0)
168
+ // / - vectorType.getRank() = 2
169
+ // / - Result = 2 - 0 = 2
170
+ // /
171
+ // / - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
172
+ // / - shapedType.getElementType() = vector<20xf32> (rank 1)
173
+ // / - vectorType.getRank() = 1
174
+ // / - Result = 1 - 1 = 0
175
+ static unsigned getRealVectorRank (ShapedType shapedType,
176
+ VectorType vectorType) {
177
+ unsigned elementVectorRank = 0 ;
157
178
VectorType elementVectorType =
158
179
llvm::dyn_cast<VectorType>(shapedType.getElementType ());
159
180
if (elementVectorType)
160
181
elementVectorRank += elementVectorType.getRank ();
182
+ return vectorType.getRank () - elementVectorRank;
183
+ }
184
+
185
+ AffineMap mlir::vector::getTransferMinorIdentityMap (ShapedType shapedType,
186
+ VectorType vectorType) {
161
187
// 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
162
188
// TODO: replace once we have 0-d vectors.
163
189
if (shapedType.getRank () == 0 &&
@@ -166,7 +192,7 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
166
192
/* numDims=*/ 0 , /* numSymbols=*/ 0 ,
167
193
getAffineConstantExpr (0 , shapedType.getContext ()));
168
194
return AffineMap::getMinorIdentityMap (
169
- shapedType.getRank (), vectorType. getRank () - elementVectorRank ,
195
+ shapedType.getRank (), getRealVectorRank (shapedType, vectorType) ,
170
196
shapedType.getContext ());
171
197
}
172
198
@@ -4234,6 +4260,10 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4234
4260
Attribute permMapAttr = result.attributes .get (permMapAttrName);
4235
4261
AffineMap permMap;
4236
4262
if (!permMapAttr) {
4263
+ if (shapedType.getRank () < getRealVectorRank (shapedType, vectorType))
4264
+ return parser.emitError (typesLoc,
4265
+ " expected a custom permutation_map when "
4266
+ " rank(source) != rank(destination)" );
4237
4267
permMap = getTransferMinorIdentityMap (shapedType, vectorType);
4238
4268
result.attributes .set (permMapAttrName, AffineMapAttr::get (permMap));
4239
4269
} else {
@@ -4649,6 +4679,10 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4649
4679
auto permMapAttr = result.attributes .get (permMapAttrName);
4650
4680
AffineMap permMap;
4651
4681
if (!permMapAttr) {
4682
+ if (shapedType.getRank () < getRealVectorRank (shapedType, vectorType))
4683
+ return parser.emitError (typesLoc,
4684
+ " expected a custom permutation_map when "
4685
+ " rank(source) != rank(destination)" );
4652
4686
permMap = getTransferMinorIdentityMap (shapedType, vectorType);
4653
4687
result.attributes .set (permMapAttrName, AffineMapAttr::get (permMap));
4654
4688
} else {
0 commit comments