Skip to content

Commit 0d72f0b

Browse files
authored
[mlir][Vector] Fix "scalability" in CastAwayExtractStridedSliceLeadingOneDim (#81187)
Makes sure that "scalability" flags in the `CastAwayExtractStridedSliceLeadingOneDim` pattern are correctly updated.
1 parent 9dd8ba4 commit 0d72f0b

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim
7373
VectorType oldDstType = extractOp.getType();
7474
VectorType newDstType =
7575
VectorType::get(oldDstType.getShape().drop_front(dropCount),
76-
oldDstType.getElementType());
76+
oldDstType.getElementType(),
77+
oldDstType.getScalableDims().drop_front(dropCount));
7778

7879
Location loc = extractOp.getLoc();
7980

mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,16 @@ func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8x
206206
return %0: vector<1x1x8xf16>
207207
}
208208

209+
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims_scalable
210+
func.func @cast_away_extract_strided_slice_leading_one_dims_scalable(%arg0: vector<1x8x[8]xf16>) -> vector<1x1x[8]xf16> {
211+
// CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16>
212+
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x[8]xf16> to vector<1x[8]xf16>
213+
%0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x[8]xf16> to vector<1x1x[8]xf16>
214+
// CHECK: %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x[8]xf16> to vector<1x1x[8]xf16>
215+
// CHECK: return %[[RET]]
216+
return %0: vector<1x1x[8]xf16>
217+
}
218+
209219
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
210220
func.func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
211221
// CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8xf16> from vector<1x8xf16>
@@ -217,6 +227,17 @@ func.func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16
217227
return %0: vector<1x8x8xf16>
218228
}
219229

230+
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_scalable
231+
func.func @cast_away_insert_strided_slice_leading_one_dims_scalable(%arg0: vector<1x[8]xf16>, %arg1: vector<1x8x[8]xf16>) -> vector<1x8x[8]xf16> {
232+
// CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<[8]xf16> from vector<1x[8]xf16>
233+
// CHECK: %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16>
234+
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<[8]xf16> into vector<8x[8]xf16>
235+
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[8]xf16> into vector<1x8x[8]xf16>
236+
// CHECK: %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x[8]xf16> to vector<1x8x[8]xf16>
237+
// CHECK: return %[[RET]]
238+
return %0: vector<1x8x[8]xf16>
239+
}
240+
220241
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
221242
// CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16>
222243
func.func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {

0 commit comments

Comments
 (0)