6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
9
+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
10
+ #include " mlir/Dialect/Arith/Utils/Utils.h"
9
11
#include " mlir/Dialect/Tensor/IR/Tensor.h"
10
12
#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
11
13
#include " mlir/IR/PatternMatch.h"
14
+ #include " mlir/Interfaces/ValueBoundsOpInterface.h"
12
15
#include " llvm/Support/Debug.h"
13
16
14
17
using namespace mlir ;
@@ -210,6 +213,178 @@ struct BubbleUpExpandThroughParallelCollapse
210
213
}
211
214
};
212
215
216
+ // / Converts `tensor.extract_slice(tensor.expand_shape)` to
217
+ // / `tensor.expand_shape(tensor.extract_slice)`.
218
+ // / For this transformation to be possible, the slice must be fully contiguous
219
+ // / within each reassociation group of the expand_shape. If the transformation
220
+ // / is not possible, or if the slice is rank reducting, the function returns
221
+ // / failure.
222
+ // /
223
+ // / Example:
224
+ // / ```
225
+ // / %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
226
+ // / tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
227
+ // / %slice = tensor.extract_slice %reshape ...
228
+ // / tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
229
+ // /
230
+ // / // The transformation is possible because each reassociation group has a
231
+ // / // contiguous slice. (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4])
232
+ // / // After the transformation:
233
+ // /
234
+ // / %slice = tensor.extract_slice %in ...
235
+ // / tensor<8x16x32xf32> to tensor<8x5x4xf32>
236
+ // / %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
237
+ // / tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
238
+ // / ```
239
+ // /
240
+ // / Note - this pattern could be reworked to be a swap pattern between
241
+ // / `tensor.expand_shape` and `tensor.extract_slice`, but is currently
242
+ // / implemented only as a bubble up pattern for `tensor.extract_slice`.
243
+ struct BubbleUpExpandShapeThroughExtractSlice
244
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
245
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
246
+
247
+ LogicalResult matchAndRewrite (tensor::ExtractSliceOp sliceOp,
248
+ PatternRewriter &rewriter) const override {
249
+ auto expandShapeOp =
250
+ sliceOp.getSource ().getDefiningOp <tensor::ExpandShapeOp>();
251
+ if (!expandShapeOp) {
252
+ return rewriter.notifyMatchFailure (
253
+ sliceOp, " slice source not produced by expand_shape" );
254
+ }
255
+
256
+ if (!sliceOp.hasUnitStride ()) {
257
+ return rewriter.notifyMatchFailure (sliceOp,
258
+ " unsupported: non-unit stride" );
259
+ }
260
+
261
+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets ();
262
+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes ();
263
+
264
+ if (static_cast <size_t >(sliceOp.getResultType ().getRank ()) !=
265
+ sizes.size ()) {
266
+ return rewriter.notifyMatchFailure (sliceOp,
267
+ " unimplemented: rank reducing slice" );
268
+ }
269
+
270
+ // Helper variables and function for accumulating the new offset and length
271
+ // values.
272
+ Location loc = expandShapeOp->getLoc ();
273
+ AffineExpr d0, d1, d2;
274
+ bindDims (rewriter.getContext (), d0, d1, d2);
275
+ // Multiply two integers.
276
+ auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
277
+ auto mulMap = AffineMap::get (2 , 0 , {d0 * d1});
278
+ return affine::makeComposedFoldedAffineApply (rewriter, loc, mulMap,
279
+ {v1, v2});
280
+ };
281
+
282
+ SmallVector<OpFoldResult> outputShape =
283
+ getMixedValues (expandShapeOp.getStaticOutputShape (),
284
+ expandShapeOp.getOutputShape (), rewriter);
285
+
286
+ auto isZeroOffsetAndFullSize =
287
+ [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
288
+ if (!isConstantIntValue (offset, 0 ))
289
+ return false ;
290
+ FailureOr<bool > maybeEqual =
291
+ ValueBoundsConstraintSet::areEqual (sliceSize, size);
292
+ return llvm::succeeded (maybeEqual) && maybeEqual.value ();
293
+ };
294
+
295
+ // First verify that this is a full slice of the expanded tensor.
296
+ for (const ReassociationIndices &indices :
297
+ expandShapeOp.getReassociationIndices ()) {
298
+ int64_t i = 0 ;
299
+ int64_t e = indices.size ();
300
+ // Find the first expanded dim after the first dim with non-unit extracted
301
+ // size.
302
+ for (; i < e; ++i) {
303
+ if (!isConstantIntValue (sizes[indices[i]], 1 )) {
304
+ // +1 to skip the first non-unit size dim.
305
+ i++;
306
+ break ;
307
+ }
308
+ }
309
+
310
+ // Verify that all subsequent dimensions extract the full size of the
311
+ // source tensor.
312
+ for (; i < e; ++i) {
313
+ int64_t expandedDim = indices[i];
314
+ if (!isZeroOffsetAndFullSize (offsets[expandedDim], sizes[expandedDim],
315
+ outputShape[expandedDim])) {
316
+ return rewriter.notifyMatchFailure (
317
+ sliceOp, " Not a contiguous slice of the expanded tensor." );
318
+ }
319
+ }
320
+ }
321
+
322
+ // Compute new offsets, lengths, and strides.
323
+ SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
324
+ for (const ReassociationIndices &indices :
325
+ expandShapeOp.getReassociationIndices ()) {
326
+ OpFoldResult newSize = rewriter.getIndexAttr (1 );
327
+ SmallVector<OpFoldResult> basis, delinOffsets;
328
+
329
+ int64_t i = 0 ;
330
+ int64_t e = indices.size ();
331
+ // Offset = cumulative product of leading unit extracted dims.
332
+ for (; i < e; ++i) {
333
+ int64_t expandedDim = indices[i];
334
+ if (!isConstantIntValue (sizes[expandedDim], 1 ))
335
+ break ;
336
+
337
+ basis.push_back (outputShape[expandedDim]);
338
+ delinOffsets.push_back (offsets[expandedDim]);
339
+ }
340
+
341
+ if (i != e) {
342
+ int64_t expandedDim = indices[i];
343
+ basis.push_back (outputShape[expandedDim]);
344
+ delinOffsets.push_back (offsets[expandedDim]);
345
+ newSize = sizes[expandedDim];
346
+ i++;
347
+ }
348
+
349
+ for (; i < e; ++i) {
350
+ OpFoldResult fullSize = outputShape[indices[i]];
351
+ basis.push_back (fullSize);
352
+ delinOffsets.push_back (rewriter.getIndexAttr (0 ));
353
+ newSize = mul (newSize, fullSize);
354
+ }
355
+ SmallVector<Value> offsetVals =
356
+ llvm::map_to_vector (delinOffsets, [&](OpFoldResult ofr) {
357
+ return getValueOrCreateConstantIndexOp (rewriter, loc, ofr);
358
+ });
359
+ OpFoldResult newOffset =
360
+ rewriter
361
+ .create <affine::AffineLinearizeIndexOp>(loc, offsetVals, basis,
362
+ /* disjoint=*/ true )
363
+ .getResult ();
364
+ newOffsets.push_back (newOffset);
365
+ newLengths.push_back (newSize);
366
+
367
+ // Only unit stride supported.
368
+ newStrides.push_back (rewriter.getIndexAttr (1 ));
369
+ }
370
+
371
+ // The shape of the result can be obtained from the sizes passed in.
372
+ SmallVector<Value> dynDims;
373
+ SmallVector<int64_t > shape;
374
+ dispatchIndexOpFoldResults (sizes, dynDims, shape);
375
+ RankedTensorType resultType = RankedTensorType::get (
376
+ shape, expandShapeOp.getResultType ().getElementType ());
377
+
378
+ // Create a new ExtractSliceOp and ExpandShapeOp.
379
+ Value newSliceOp = rewriter.create <tensor::ExtractSliceOp>(
380
+ loc, expandShapeOp.getSrc (), newOffsets, newLengths, newStrides);
381
+ rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
382
+ sliceOp, resultType, newSliceOp,
383
+ expandShapeOp.getReassociationIndices (), sizes);
384
+ return success ();
385
+ }
386
+ };
387
+
213
388
} // namespace
214
389
215
390
void mlir::tensor::populateReassociativeReshapeFoldingPatterns (
@@ -227,3 +402,8 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
227
402
RewritePatternSet &patterns) {
228
403
patterns.add <BubbleUpExpandThroughParallelCollapse>(patterns.getContext ());
229
404
}
405
+
406
+ void mlir::tensor::populateBubbleUpExtractSliceOpPatterns (
407
+ RewritePatternSet &patterns) {
408
+ patterns.add <BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext ());
409
+ }
0 commit comments