Skip to content

Commit a68c8e8

Browse files
authored
[mlir][vector] Fix parser of vector.transfer_read (#133721)
This PR adds a check in the parser to prevent a crash when vector.transfer_read fails to create minor identity permutation. map. Fixes #132851 a.mlir ``` module { func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> { %c3_i32 = arith.constant 3 : i32 %0 = vector.transfer_read %arg1[%c3_i32, %c3_i32], %c3_i32 : memref<?xindex>, vector<3x4xi32> return %0 : vector<3x4xi32> } } ```
1 parent 5981be7 commit a68c8e8

File tree

2 files changed

+55
-4
lines changed

2 files changed

+55
-4
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,39 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
151151
return false;
152152
}
153153

154-
AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
155-
VectorType vectorType) {
156-
int64_t elementVectorRank = 0;
154+
/// Returns the number of dimensions of the `shapedType` that participate in the
155+
/// vector transfer, effectively the rank of the vector dimensions within the
156+
/// `shapedType`. This is calculated by taking the rank of the `vectorType`
157+
/// being transferred and subtracting the rank of the `shapedType`'s element
158+
/// type if it's also a vector.
159+
///
160+
/// This is used to determine the number of minor dimensions for identity maps
161+
/// in vector transfers.
162+
///
163+
/// For example, given a transfer operation involving `shapedType` and
164+
/// `vectorType`:
165+
///
166+
/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
167+
/// - shapedType.getElementType() = f32 (rank 0)
168+
/// - vectorType.getRank() = 2
169+
/// - Result = 2 - 0 = 2
170+
///
171+
/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
172+
/// - shapedType.getElementType() = vector<20xf32> (rank 1)
173+
/// - vectorType.getRank() = 1
174+
/// - Result = 1 - 1 = 0
175+
static unsigned getRealVectorRank(ShapedType shapedType,
176+
VectorType vectorType) {
177+
unsigned elementVectorRank = 0;
157178
VectorType elementVectorType =
158179
llvm::dyn_cast<VectorType>(shapedType.getElementType());
159180
if (elementVectorType)
160181
elementVectorRank += elementVectorType.getRank();
182+
return vectorType.getRank() - elementVectorRank;
183+
}
184+
185+
AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
186+
VectorType vectorType) {
161187
// 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
162188
// TODO: replace once we have 0-d vectors.
163189
if (shapedType.getRank() == 0 &&
@@ -166,7 +192,7 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
166192
/*numDims=*/0, /*numSymbols=*/0,
167193
getAffineConstantExpr(0, shapedType.getContext()));
168194
return AffineMap::getMinorIdentityMap(
169-
shapedType.getRank(), vectorType.getRank() - elementVectorRank,
195+
shapedType.getRank(), getRealVectorRank(shapedType, vectorType),
170196
shapedType.getContext());
171197
}
172198

@@ -4234,6 +4260,10 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
42344260
Attribute permMapAttr = result.attributes.get(permMapAttrName);
42354261
AffineMap permMap;
42364262
if (!permMapAttr) {
4263+
if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType))
4264+
return parser.emitError(typesLoc,
4265+
"expected a custom permutation_map when "
4266+
"rank(source) != rank(destination)");
42374267
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
42384268
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
42394269
} else {
@@ -4649,6 +4679,10 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
46494679
auto permMapAttr = result.attributes.get(permMapAttrName);
46504680
AffineMap permMap;
46514681
if (!permMapAttr) {
4682+
if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType))
4683+
return parser.emitError(typesLoc,
4684+
"expected a custom permutation_map when "
4685+
"rank(source) != rank(destination)");
46524686
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
46534687
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
46544688
} else {

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,15 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
525525

526526
// -----
527527

528+
func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xindex> {
529+
%c3 = arith.constant 3 : index
530+
// expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
531+
%0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xindex>
532+
return %0 : vector<3x4xindex>
533+
}
534+
535+
// -----
536+
528537
func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
529538
%c3 = arith.constant 3 : index
530539
%cst = arith.constant 3.0 : f32
@@ -646,6 +655,14 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
646655

647656
// -----
648657

658+
func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xindex>, %output_memref: memref<?xindex>) {
659+
%c3 = arith.constant 3 : index
660+
// expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
661+
vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xindex>, memref<?xindex>
662+
}
663+
664+
// -----
665+
649666
func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
650667
// expected-error@+1 {{expected offsets of same size as destination vector rank}}
651668
%1 = vector.insert_strided_slice %a, %b {offsets = [100], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>

0 commit comments

Comments
 (0)