Skip to content

Commit 38689e2

Browse files
committed
[mlir][vector] Use notifyMatchFailure instead of assert in VectorLinearize
1 parent 7bea41e commit 38689e2

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,10 @@ struct LinearizeVectorExtractStridedSlice final
152152
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
153153
ConversionPatternRewriter &rewriter) const override {
154154
Type dstType = getTypeConverter()->convertType(extractOp.getType());
155-
assert(!(extractOp.getVector().getType().isScalable() ||
156-
cast<VectorType>(dstType).isScalable()) &&
157-
"scalable vectors are not supported.");
155+
if (extractOp.getVector().getType().isScalable() ||
156+
cast<VectorType>(dstType).isScalable())
157+
return rewriter.notifyMatchFailure(extractOp,
158+
"scalable vectors are not supported.");
158159
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
159160
return rewriter.notifyMatchFailure(
160161
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -265,6 +266,8 @@ struct LinearizeVectorShuffle final
265266
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
266267
ConversionPatternRewriter &rewriter) const override {
267268
Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
269+
// The assert is used because vector.shuffle does not support scalable
270+
// vectors.
268271
assert(!(shuffleOp.getV1VectorType().isScalable() ||
269272
shuffleOp.getV2VectorType().isScalable() ||
270273
cast<VectorType>(dstType).isScalable()) &&
@@ -336,9 +339,10 @@ struct LinearizeVectorExtract final
336339
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
337340
ConversionPatternRewriter &rewriter) const override {
338341
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
339-
assert(!(extractOp.getVector().getType().isScalable() ||
340-
cast<VectorType>(dstTy).isScalable()) &&
341-
"scalable vectors are not supported.");
342+
if (extractOp.getVector().getType().isScalable() ||
343+
cast<VectorType>(dstTy).isScalable())
344+
return rewriter.notifyMatchFailure(extractOp,
345+
"scalable vectors are not supported.");
342346
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
343347
return rewriter.notifyMatchFailure(
344348
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -395,9 +399,10 @@ struct LinearizeVectorInsert final
395399
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
396400
ConversionPatternRewriter &rewriter) const override {
397401
Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
398-
assert(!(insertOp.getDestVectorType().isScalable() ||
399-
cast<VectorType>(dstTy).isScalable()) &&
400-
"scalable vectors are not supported.");
402+
if (insertOp.getDestVectorType().isScalable() ||
403+
cast<VectorType>(dstTy).isScalable())
404+
return rewriter.notifyMatchFailure(insertOp,
405+
"scalable vectors are not supported.");
401406

402407
if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
403408
targetVectorBitWidth))

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32
129129
// -----
130130

131131
// ALL-LABEL: func.func @test_scalable_no_linearize(
132-
// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
133-
func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
132+
// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>, %[[VAL_1:.*]]: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> {
133+
func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>, %arg1: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> {
134134
// ALL: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32>
135135
%0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32>
136136

@@ -140,6 +140,15 @@ func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x
140140
// ALL: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<[2]x[2]xf32>
141141
%2 = arith.addf %0, %1 : vector<[2]x[2]xf32>
142142

143+
// ALL: %[[EXTRACTSLICE:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [1, 0], sizes = [1, 2], strides = [1, 1]} : vector<2x[2]xf32> to vector<1x[2]xf32>
144+
%3 = vector.extract_strided_slice %arg1 { sizes = [1, 2], strides = [1, 1], offsets = [1, 0] } : vector<2x[2]xf32> to vector<1x[2]xf32>
145+
146+
// ALL: %[[EXTRACT:.*]] = vector.extract %[[VAL_1]][0, 0] : f32 from vector<2x[2]xf32>
147+
%4 = vector.extract %arg1[0, 0]: f32 from vector<2x[2]xf32>
148+
149+
// ALL: %[[INSERT:.*]] = vector.insert %[[EXTRACT]], %[[VAL_1]] [0, 0] : f32 into vector<2x[2]xf32>
150+
%5 = vector.insert %4, %arg1[0, 0]: f32 into vector<2x[2]xf32>
151+
143152
// ALL: return %[[RES]] : vector<[2]x[2]xf32>
144153
return %2 : vector<[2]x[2]xf32>
145154
}
@@ -274,3 +283,4 @@ func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>)
274283
%0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
275284
return %0 : vector<2x8x4xf32>
276285
}
286+

0 commit comments

Comments
 (0)