6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
9
+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
10
+ #include " mlir/Dialect/Arith/Utils/Utils.h"
9
11
#include " mlir/Dialect/Tensor/IR/Tensor.h"
10
12
#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
11
13
#include " mlir/IR/PatternMatch.h"
14
+ #include " mlir/Interfaces/ValueBoundsOpInterface.h"
12
15
#include " llvm/Support/Debug.h"
16
+ #include " llvm/Support/LogicalResult.h"
13
17
14
18
using namespace mlir ;
15
19
using namespace mlir ::tensor;
@@ -210,6 +214,200 @@ struct BubbleUpExpandThroughParallelCollapse
210
214
}
211
215
};
212
216
217
+ // / Converts `tensor.extract_slice(tensor.expand_shape)` to
218
+ // / `tensor.expand_shape(tensor.extract_slice)`.
219
+ // / For this transformation to be possible, the slice must be fully contiguous
220
+ // / within each reassociation group of the expand_shape. If the transformation
221
+ // / is not possible, or if the slice is rank reducing, the function returns
222
+ // / failure.
223
+ // /
224
+ // / Example:
225
+ // / ```
226
+ // / %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
227
+ // / tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
228
+ // / %slice = tensor.extract_slice %reshape ...
229
+ // / tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
230
+ // /
231
+ // / // The transformation is possible because each reassociation group has a
232
+ // / // contiguous slice. (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4])
233
+ // / // After the transformation:
234
+ // /
235
+ // / %slice = tensor.extract_slice %in ...
236
+ // / tensor<8x16x32xf32> to tensor<8x5x4xf32>
237
+ // / %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
238
+ // / tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
239
+ // / ```
240
+ // /
241
+ // / Note - this pattern could be reworked to be a swap pattern between
242
+ // / `tensor.expand_shape` and `tensor.extract_slice`, but is currently
243
+ // / implemented only as a bubble up pattern for `tensor.extract_slice`.
244
+ struct BubbleUpExpandShapeThroughExtractSlice
245
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
246
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
247
+
248
+ LogicalResult matchAndRewrite (tensor::ExtractSliceOp sliceOp,
249
+ PatternRewriter &rewriter) const override {
250
+ auto expandShapeOp =
251
+ sliceOp.getSource ().getDefiningOp <tensor::ExpandShapeOp>();
252
+
253
+ if (checkPreconditionForBubbleUpExtractSlice (sliceOp, expandShapeOp,
254
+ rewriter)
255
+ .failed ())
256
+ return failure ();
257
+
258
+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets ();
259
+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes ();
260
+ SmallVector<OpFoldResult> outputShape =
261
+ getMixedValues (expandShapeOp.getStaticOutputShape (),
262
+ expandShapeOp.getOutputShape (), rewriter);
263
+
264
+ // Helper variables and function for accumulating the new offset and length
265
+ // values.
266
+ Location loc = expandShapeOp->getLoc ();
267
+ AffineExpr d0, d1, d2;
268
+ bindDims (rewriter.getContext (), d0, d1, d2);
269
+ // Multiply two integers.
270
+ auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
271
+ auto mulMap = AffineMap::get (2 , 0 , {d0 * d1});
272
+ return affine::makeComposedFoldedAffineApply (rewriter, loc, mulMap,
273
+ {v1, v2});
274
+ };
275
+
276
+ // Compute new offsets, lengths, and strides.
277
+ SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
278
+ for (const ReassociationIndices &indices :
279
+ expandShapeOp.getReassociationIndices ()) {
280
+ OpFoldResult newSize = rewriter.getIndexAttr (1 );
281
+ SmallVector<OpFoldResult> basis, delinOffsets;
282
+
283
+ int64_t i = 0 ;
284
+ int64_t e = indices.size ();
285
+ // Offset = cumulative product of leading unit extracted dims.
286
+ for (; i < e; ++i) {
287
+ int64_t expandedDim = indices[i];
288
+ if (!isConstantIntValue (sizes[expandedDim], 1 ))
289
+ break ;
290
+
291
+ basis.push_back (outputShape[expandedDim]);
292
+ delinOffsets.push_back (offsets[expandedDim]);
293
+ }
294
+
295
+ if (i != e) {
296
+ int64_t expandedDim = indices[i];
297
+ basis.push_back (outputShape[expandedDim]);
298
+ delinOffsets.push_back (offsets[expandedDim]);
299
+ newSize = sizes[expandedDim];
300
+ i++;
301
+ }
302
+
303
+ for (; i < e; ++i) {
304
+ OpFoldResult fullSize = outputShape[indices[i]];
305
+ basis.push_back (fullSize);
306
+ delinOffsets.push_back (rewriter.getIndexAttr (0 ));
307
+ newSize = mul (newSize, fullSize);
308
+ }
309
+ SmallVector<Value> offsetVals =
310
+ llvm::map_to_vector (delinOffsets, [&](OpFoldResult ofr) {
311
+ return getValueOrCreateConstantIndexOp (rewriter, loc, ofr);
312
+ });
313
+ OpFoldResult newOffset =
314
+ rewriter
315
+ .create <affine::AffineLinearizeIndexOp>(loc, offsetVals, basis,
316
+ /* disjoint=*/ true )
317
+ .getResult ();
318
+ newOffsets.push_back (newOffset);
319
+ newLengths.push_back (newSize);
320
+
321
+ // Only unit stride supported.
322
+ newStrides.push_back (rewriter.getIndexAttr (1 ));
323
+ }
324
+
325
+ // The shape of the result can be obtained from the sizes passed in.
326
+ SmallVector<Value> dynDims;
327
+ SmallVector<int64_t > shape;
328
+ dispatchIndexOpFoldResults (sizes, dynDims, shape);
329
+ RankedTensorType resultType = RankedTensorType::get (
330
+ shape, expandShapeOp.getResultType ().getElementType ());
331
+
332
+ // Create a new ExtractSliceOp and ExpandShapeOp.
333
+ Value newSliceOp = rewriter.create <tensor::ExtractSliceOp>(
334
+ loc, expandShapeOp.getSrc (), newOffsets, newLengths, newStrides);
335
+ rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
336
+ sliceOp, resultType, newSliceOp,
337
+ expandShapeOp.getReassociationIndices (), sizes);
338
+ return success ();
339
+ }
340
+
341
+ LogicalResult
342
+ checkPreconditionForBubbleUpExtractSlice (tensor::ExtractSliceOp sliceOp,
343
+ tensor::ExpandShapeOp expandShapeOp,
344
+ PatternRewriter &rewriter) const {
345
+
346
+ if (!expandShapeOp) {
347
+ return rewriter.notifyMatchFailure (
348
+ sliceOp, " tensor.extract_slice source not produced by expand_shape" );
349
+ }
350
+
351
+ if (!sliceOp.hasUnitStride ()) {
352
+ return rewriter.notifyMatchFailure (
353
+ sliceOp, " unsupported: non-unit stride. Only contiguous slices can "
354
+ " be supported in this transformation." );
355
+ }
356
+
357
+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets ();
358
+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes ();
359
+
360
+ if (static_cast <size_t >(sliceOp.getResultType ().getRank ()) !=
361
+ sizes.size ()) {
362
+ return rewriter.notifyMatchFailure (sliceOp,
363
+ " unimplemented: rank reducing slice" );
364
+ }
365
+
366
+ SmallVector<OpFoldResult> outputShape =
367
+ getMixedValues (expandShapeOp.getStaticOutputShape (),
368
+ expandShapeOp.getOutputShape (), rewriter);
369
+
370
+ std::function<bool (OpFoldResult, OpFoldResult, OpFoldResult)>
371
+ isZeroOffsetAndFullSize =
372
+ [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
373
+ if (!isConstantIntValue (offset, 0 ))
374
+ return false ;
375
+ FailureOr<bool > maybeEqual =
376
+ ValueBoundsConstraintSet::areEqual (sliceSize, size);
377
+ return llvm::succeeded (maybeEqual) && maybeEqual.value ();
378
+ };
379
+
380
+ // First verify that this is a full slice of the expanded tensor.
381
+ for (const ReassociationIndices &indices :
382
+ expandShapeOp.getReassociationIndices ()) {
383
+ int64_t i = 0 ;
384
+ int64_t e = indices.size ();
385
+ // Find the first expanded dim after the first dim with non-unit extracted
386
+ // size.
387
+ for (; i < e; ++i) {
388
+ if (!isConstantIntValue (sizes[indices[i]], 1 )) {
389
+ // +1 to skip the first non-unit size dim.
390
+ i++;
391
+ break ;
392
+ }
393
+ }
394
+
395
+ // Verify that all subsequent dimensions extract the full size of the
396
+ // source tensor.
397
+ for (; i < e; ++i) {
398
+ int64_t expandedDim = indices[i];
399
+ if (!isZeroOffsetAndFullSize (offsets[expandedDim], sizes[expandedDim],
400
+ outputShape[expandedDim])) {
401
+ return rewriter.notifyMatchFailure (
402
+ sliceOp, " Not a contiguous slice of the expanded tensor." );
403
+ }
404
+ }
405
+ }
406
+
407
+ return success ();
408
+ }
409
+ };
410
+
213
411
} // namespace
214
412
215
413
void mlir::tensor::populateReassociativeReshapeFoldingPatterns (
@@ -227,3 +425,8 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
227
425
RewritePatternSet &patterns) {
228
426
patterns.add <BubbleUpExpandThroughParallelCollapse>(patterns.getContext ());
229
427
}
428
+
429
+ void mlir::tensor::populateBubbleUpExtractSliceOpPatterns (
430
+ RewritePatternSet &patterns) {
431
+ patterns.add <BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext ());
432
+ }
0 commit comments