Skip to content

Commit 04bf1a4

Browse files
authored
Update LowerContractionToSMMLAPattern to ingnore matvec (#88288)
Patterns in `LowerContractionToSMMLAPattern` are designed to handle vector-to-matrix multiplication but not matrix-to-vector. This leads to the following error when processing `rhs` with rank < 2: ``` iree-compile: /usr/local/google/home/kooljblack/code/iree-build/llvm-project/tools/mlir/include/mlir/IR/BuiltinTypeInterfaces.h.inc:268: int64_t mlir::detail::ShapedTypeTrait<mlir::VectorType>::getDimSize(unsigned int) const [ConcreteType = mlir::VectorType]: Assertion `idx < getRank() && "invalid index for shaped type"' failed. ``` Updates to explicitly check the rhs rank and fail cases that cannot process.
1 parent 4dcf33b commit 04bf1a4

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ class LowerContractionToSMMLAPattern
5454
// Note: RHS is not transposed.
5555
mlir::VectorType lhsType = op.getLhsType();
5656
mlir::VectorType rhsType = op.getRhsType();
57+
// Avoid 0-D vectors and 1-D rhs:
58+
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
59+
return failure();
5760
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
5861
auto dimN = rhsType.getDimSize(0);
5962
auto dimK = rhsType.getDimSize(1);

mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,14 @@ func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8
258258
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x8xi32>, vector<8x8xi32> into vector<1x8xi32>
259259
return %res : vector<1x8xi32>
260260
}
261+
262+
// -----
263+
264+
// CHECK-LABEL: func.func @test_lower_vector_arm_neon_matvec
265+
// CHECK-NOT: arm_neon.intr.smmla
266+
func.func @test_lower_vector_arm_neon_matvec(%lhs: vector<8x8xi8>, %rhs: vector<8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
267+
%rhs_extsi= arith.extsi %rhs : vector<8xi8> to vector<8xi32>
268+
%lhs_extsi = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
269+
%res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8x8xi32>, vector<8xi32> into vector<8xi32>
270+
return %res : vector<8xi32>
271+
}

0 commit comments

Comments
 (0)