@@ -151,10 +151,12 @@ struct LinearizeVectorExtractStridedSlice final
151
151
LogicalResult
152
152
matchAndRewrite (vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
153
153
ConversionPatternRewriter &rewriter) const override {
154
- Type dstType = getTypeConverter ()->convertType (extractOp.getType ());
155
- assert (!(extractOp.getVector ().getType ().isScalable () ||
156
- cast<VectorType>(dstType).isScalable ()) &&
157
- " scalable vectors are not supported." );
154
+ VectorType dstType =
155
+ getTypeConverter ()->convertType <VectorType>(extractOp.getType ());
156
+ assert (dstType && " vector type destination expected." );
157
+ if (extractOp.getVector ().getType ().isScalable () || dstType.isScalable ())
158
+ return rewriter.notifyMatchFailure (extractOp,
159
+ " scalable vectors are not supported." );
158
160
if (!isLessThanTargetBitWidth (extractOp, targetVectorBitWidth))
159
161
return rewriter.notifyMatchFailure (
160
162
extractOp, " Can't flatten since targetBitWidth <= OpSize" );
@@ -264,10 +266,14 @@ struct LinearizeVectorShuffle final
264
266
LogicalResult
265
267
matchAndRewrite (vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
266
268
ConversionPatternRewriter &rewriter) const override {
267
- Type dstType = getTypeConverter ()->convertType (shuffleOp.getType ());
269
+ VectorType dstType =
270
+ getTypeConverter ()->convertType <VectorType>(shuffleOp.getType ());
271
+ assert (dstType && " vector type destination expected." );
272
+ // The assert is used because vector.shuffle does not support scalable
273
+ // vectors.
268
274
assert (!(shuffleOp.getV1VectorType ().isScalable () ||
269
275
shuffleOp.getV2VectorType ().isScalable () ||
270
- cast<VectorType>( dstType) .isScalable ()) &&
276
+ dstType.isScalable ()) &&
271
277
" scalable vectors are not supported." );
272
278
if (!isLessThanTargetBitWidth (shuffleOp, targetVectorBitWidth))
273
279
return rewriter.notifyMatchFailure (
@@ -336,9 +342,10 @@ struct LinearizeVectorExtract final
336
342
matchAndRewrite (vector::ExtractOp extractOp, OpAdaptor adaptor,
337
343
ConversionPatternRewriter &rewriter) const override {
338
344
Type dstTy = getTypeConverter ()->convertType (extractOp.getType ());
339
- assert (!(extractOp.getVector ().getType ().isScalable () ||
340
- cast<VectorType>(dstTy).isScalable ()) &&
341
- " scalable vectors are not supported." );
345
+ if (extractOp.getVector ().getType ().isScalable () ||
346
+ cast<VectorType>(dstTy).isScalable ())
347
+ return rewriter.notifyMatchFailure (extractOp,
348
+ " scalable vectors are not supported." );
342
349
if (!isLessThanTargetBitWidth (extractOp, targetVectorBitWidth))
343
350
return rewriter.notifyMatchFailure (
344
351
extractOp, " Can't flatten since targetBitWidth <= OpSize" );
@@ -394,10 +401,12 @@ struct LinearizeVectorInsert final
394
401
LogicalResult
395
402
matchAndRewrite (vector::InsertOp insertOp, OpAdaptor adaptor,
396
403
ConversionPatternRewriter &rewriter) const override {
397
- Type dstTy = getTypeConverter ()->convertType (insertOp.getDestVectorType ());
398
- assert (!(insertOp.getDestVectorType ().isScalable () ||
399
- cast<VectorType>(dstTy).isScalable ()) &&
400
- " scalable vectors are not supported." );
404
+ VectorType dstTy = getTypeConverter ()->convertType <VectorType>(
405
+ insertOp.getDestVectorType ());
406
+ assert (dstTy && " vector type destination expected." );
407
+ if (insertOp.getDestVectorType ().isScalable () || dstTy.isScalable ())
408
+ return rewriter.notifyMatchFailure (insertOp,
409
+ " scalable vectors are not supported." );
401
410
402
411
if (!isLessThanOrEqualTargetBitWidth (insertOp.getSourceType (),
403
412
targetVectorBitWidth))
0 commit comments