Skip to content

Commit 37fecfa

Browse files
authored
[mlir] Support rank-reduced extract_slice in ExtractSliceOfPadTensorSwapPattern (#138921)
This PR fixes `ExtractSliceOfPadTensorSwapPattern` to support rank-reducing `tensor.extract_slice` ops, which were previously unhandled and could cause crashes. To support this, an additional `tensor.extract_slice` is inserted after `tensor.pad` to reduce the result rank.
1 parent f2bc7b7 commit 37fecfa

File tree

2 files changed

+57
-3
lines changed

2 files changed

+57
-3
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,9 +1017,22 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
10171017
sliceOp.getMixedSizes(), zeroSliceGuard);
10181018
if (failed(tilingResult))
10191019
return failure();
1020-
// All shapes are static and the data source is actually used. Rewrite into
1021-
// pad(extract_slice(x)).
1022-
rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
1020+
1021+
RankedTensorType sourceType = sliceOp.getSourceType();
1022+
RankedTensorType resultType = sliceOp.getResultType();
1023+
1024+
// If the extract_slice is not rank-reduced, all shapes are static and the
1025+
// data source is actually used. Rewrite into pad(extract_slice(x)).
1026+
if (sourceType.getRank() == resultType.getRank()) {
1027+
rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
1028+
return success();
1029+
}
1030+
1031+
// Handle rank-reduced slice by creating another extract_slice op.
1032+
Value rankReduced = tensor::createCanonicalRankReducingExtractSliceOp(
1033+
rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1034+
1035+
rewriter.replaceOp(sliceOp, rankReduced);
10231036
return success();
10241037
}
10251038

mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,26 @@ func.func @static_mixed_data_low_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
129129

130130
// -----
131131

132+
// CHECK-LABEL: @static_rank_reduce
133+
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x16x4xf32>, %[[PADVAL:.*]]: f32
134+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 14, 4] [1, 1, 1] : tensor<8x16x4xf32> to tensor<1x14x4xf32>
135+
// CHECK: %[[PADDED:.*]] = tensor.pad %[[SLICE]] low[0, 2, 0] high[0, 0, 0] {
136+
// CHECK: } : tensor<1x14x4xf32> to tensor<1x16x4xf32>
137+
// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[PADDED]][0, 0, 0] [1, 16, 4] [1, 1, 1] : tensor<1x16x4xf32> to tensor<16x4xf32>
138+
// CHECK: return %[[RESULT]]
139+
func.func @static_rank_reduce(%arg0: tensor<8x16x4xf32>, %pad: f32)
140+
-> tensor<16x4xf32> {
141+
%0 = tensor.pad %arg0 low[0, 2, 0] high[0, 0, 0] {
142+
^bb0(%i: index, %j: index, %k: index):
143+
tensor.yield %pad : f32
144+
} : tensor<8x16x4xf32> to tensor<8x18x4xf32>
145+
%1 = tensor.extract_slice %0[0, 0, 0] [1, 16, 4] [1, 1, 1]
146+
: tensor<8x18x4xf32> to tensor<16x4xf32>
147+
return %1 : tensor<16x4xf32>
148+
}
149+
150+
// -----
151+
132152
// CHECK-LABEL: @dynamic_high_pad
133153
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x5xf32>
134154
// CHECK-NOT: tensor.pad
@@ -217,6 +237,27 @@ func.func @dynamic_zero_high_padding(%arg0 : tensor<?x?xf32>, %pad : f32,
217237
return %1 : tensor<?x?xf32>
218238
}
219239

240+
// -----
241+
242+
// CHECK-LABEL: @dynamic_rank_reduce
243+
// CHECK: %[[TEMP:.*]] = scf.if %{{.*}} -> (tensor<1x4xf32>) {
244+
// CHECK: tensor.generate
245+
// CHECK: } else {
246+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %{{.*}} : tensor<?x5xf32> to tensor<?x1xf32>
247+
// CHECK: tensor.pad %[[SLICE]] low[0, 0] high[%{{.*}}, 3] {
248+
// CHECK: } : tensor<?x1xf32> to tensor<1x4xf32>
249+
// CHECK: }
250+
// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[TEMP]]{{.*}} : tensor<1x4xf32> to tensor<4xf32>
251+
// CHECK: return %[[RESULT]]
252+
func.func @dynamic_rank_reduce(%arg0 : tensor<?x5xf32>, %s1: index, %pad : f32) -> tensor<4xf32> {
253+
%0 = tensor.pad %arg0 low[0, 0] high[7, 8] {
254+
^bb0(%arg1: index, %arg2: index):
255+
tensor.yield %pad : f32
256+
} : tensor<?x5xf32> to tensor<?x13xf32>
257+
%1 = tensor.extract_slice %0[2, 4] [1, 4] [1, 1] : tensor<?x13xf32> to tensor<4xf32>
258+
return %1 : tensor<4xf32>
259+
}
260+
220261
// -----
221262
// CHECK-LABEL: @nopaddim_with_dynamic_extract(
222263
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4x5xf32>

0 commit comments

Comments
 (0)