Skip to content

Commit da37c76

Browse files
[mlir][vector] Add a check to ensure input vector rank equals target shape rank (#127706)
Fixes issue #126197 The crash is caused because, during IR transformation, the vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an input vector of higher rank using a target vector of lower rank, which is not supported. Specific example : ``` module { func.func @func1() { %cst_25 = arith.constant dense<3.718400e+04> : vector<4x2x2xf16> %cst_26 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32> %47 = vector.fma %cst_26, %cst_26, %cst_26 : vector<24x2x2xf32> %818 = scf.execute_region -> vector<24x2x2xf32> { scf.yield %47 : vector<24x2x2xf32> } %823 = vector.extract_strided_slice %cst_25 {offsets = [2], sizes = [1], strides = [1]} : vector<4x2x2xf16> to vector<1x2x2xf16> return } } ``` --------- Co-authored-by: Kai Sasaki <[email protected]>
1 parent b1a735b commit da37c76

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,12 @@ struct UnrollElementwisePattern : public RewritePattern {
437437
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
438438
SmallVector<int64_t> originalSize =
439439
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
440+
// Bail-out if rank(source) != rank(target). The main limitation here is the
441+
// fact that `ExtractStridedSlice` requires the rank for the input and
442+
// output to match. If needed, we can relax this later.
443+
if (originalSize.size() != targetShape->size())
444+
return rewriter.notifyMatchFailure(
445+
op, "expected input vector rank to match target shape rank");
440446
Location loc = op->getLoc();
441447
// Prepare the result vector.
442448
Value result = rewriter.create<arith::ConstantOp>(

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,16 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
188188
// CHECK-LABEL: func @vector_fma
189189
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
190190

191+
// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
192+
func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
193+
%0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
194+
return %0 : vector<3x2x2xf32>
195+
}
196+
// CHECK-LABEL: func @negative_vector_fma_3d
197+
// CHECK-NOT: vector.extract_strided_slice
198+
// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
199+
// CHECK: return
200+
191201
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
192202
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
193203
return %0 : vector<4xf32>

0 commit comments

Comments
 (0)