6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
9
+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
10
+ #include " mlir/Dialect/Arith/Utils/Utils.h"
9
11
#include " mlir/Dialect/Tensor/IR/Tensor.h"
10
12
#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
11
13
#include " mlir/IR/PatternMatch.h"
14
+ #include " mlir/Interfaces/ValueBoundsOpInterface.h"
12
15
#include " llvm/Support/Debug.h"
16
+ #include " llvm/Support/LogicalResult.h"
13
17
14
18
using namespace mlir ;
15
19
using namespace mlir ::tensor;
@@ -210,6 +214,217 @@ struct BubbleUpExpandThroughParallelCollapse
210
214
}
211
215
};
212
216
217
+ // / Converts `tensor.extract_slice(tensor.expand_shape)` to
218
+ // / `tensor.expand_shape(tensor.extract_slice)`.
219
+ // / For this transformation to be possible, the slice must be fully contiguous
220
+ // / within each reassociation group of the expand_shape.
221
+ // / A slice is defined as fully contiguous within a reassociation group if after
222
+ // / flattening the reassociation group to a single 1D range, then the slice
223
+ // / taken out of the group could be defined as a single contiguous subrange
224
+ // / within that range.
225
+ // / If the transformation is not possible, or if the slice is rank reducing, the
226
+ // / function returns failure.
227
+ // /
228
+ // / Example:
229
+ // / ```
230
+ // / %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
231
+ // / tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
232
+ // / %slice = tensor.extract_slice %reshape ...
233
+ // / tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
234
+ // /
235
+ // / // The transformation is possible because each reassociation group has a
236
+ // / // contiguous slice. (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4])
237
+ // / // After the transformation:
238
+ // /
239
+ // / %slice = tensor.extract_slice %in ...
240
+ // / tensor<8x16x32xf32> to tensor<8x5x4xf32>
241
+ // / %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
242
+ // / tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
243
+ // / ```
244
+ // /
245
+ // / Note - this pattern could be reworked to be a swap pattern between
246
+ // / `tensor.expand_shape` and `tensor.extract_slice`, but is currently
247
+ // / implemented only as a bubble up pattern for `tensor.extract_slice`.
248
+ struct BubbleUpExpandShapeThroughExtractSlice
249
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
250
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
251
+
252
+ LogicalResult matchAndRewrite (tensor::ExtractSliceOp sliceOp,
253
+ PatternRewriter &rewriter) const override {
254
+ auto expandShapeOp =
255
+ sliceOp.getSource ().getDefiningOp <tensor::ExpandShapeOp>();
256
+
257
+ if (checkPreconditionForBubbleUpExtractSlice (sliceOp, expandShapeOp,
258
+ rewriter)
259
+ .failed ())
260
+ return failure ();
261
+
262
+ // The tensor.extract_slice before applying the pattern works on the result
263
+ // of the tensor.expand_shape, so variables referring to the state before
264
+ // applying the pattern are named with the prefix "expanded", and ones
265
+ // referring to the state after applying the pattern are named with the
266
+ // prefix "collapsed".
267
+ SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets ();
268
+ SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes ();
269
+ SmallVector<OpFoldResult> expandedShape =
270
+ getMixedValues (expandShapeOp.getStaticOutputShape (),
271
+ expandShapeOp.getOutputShape (), rewriter);
272
+
273
+ // Helper variables and function for accumulating the size values.
274
+ Location loc = expandShapeOp->getLoc ();
275
+ AffineExpr d0, d1, d2;
276
+ bindDims (rewriter.getContext (), d0, d1, d2);
277
+ // Multiply two integers.
278
+ auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
279
+ auto mulMap = AffineMap::get (2 , 0 , {d0 * d1});
280
+ return affine::makeComposedFoldedAffineApply (rewriter, loc, mulMap,
281
+ {v1, v2});
282
+ };
283
+
284
+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
285
+ // The new tensor.extract_slice will work on a tensor that has has a rank of
286
+ // ReassociationIndices.size(). In the loop a single offset, size, and
287
+ // stride value is computed per reassociation group.
288
+ SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
289
+ collapsedStrides;
290
+ for (const ReassociationIndices &indices :
291
+ expandShapeOp.getReassociationIndices ()) {
292
+ // collapsedSize will hold the size of the single dim that represents the
293
+ // reassociation group in the non expanded tensor.
294
+ OpFoldResult collapsedSize = rewriter.getIndexAttr (1 );
295
+ // The basis and delinOffsets are used to create an affine.linearize_index
296
+ // op to linearize the single offset value required for this reassociation
297
+ // group.
298
+ // basis holds the full sizes of the reassociation group dimensions
299
+ // of the expanded tensor.
300
+ // delinOffsets as in "delinearized offsets", holds the offsets within the
301
+ // reassociation group dimensions of the expanded tensor.
302
+ SmallVector<OpFoldResult> basis, delinOffsets;
303
+
304
+ for (long expandedDim : indices) {
305
+ // basis and delinOffsets can be obtained directly from the expanded
306
+ // state, but the collapsed size requires calculation as it did not
307
+ // previously exist.
308
+ basis.push_back (expandedShape[expandedDim]);
309
+ delinOffsets.push_back (expandedOffsets[expandedDim]);
310
+ collapsedSize = mul (collapsedSize, expandedSizes[expandedDim]);
311
+ }
312
+
313
+ SmallVector<Value> offsetVals =
314
+ llvm::map_to_vector (delinOffsets, [&](OpFoldResult ofr) {
315
+ return getValueOrCreateConstantIndexOp (rewriter, loc, ofr);
316
+ });
317
+ OpFoldResult collapsedOffset =
318
+ rewriter
319
+ .create <affine::AffineLinearizeIndexOp>(loc, offsetVals, basis,
320
+ /* disjoint=*/ true )
321
+ .getResult ();
322
+ collapsedOffsets.push_back (collapsedOffset);
323
+ collapsedSizes.push_back (collapsedSize);
324
+
325
+ // Only unit stride supported.
326
+ collapsedStrides.push_back (rewriter.getIndexAttr (1 ));
327
+ }
328
+
329
+ // The shape of the result can be obtained from the sizes passed in.
330
+ SmallVector<Value> dynDims;
331
+ SmallVector<int64_t > shape;
332
+ dispatchIndexOpFoldResults (expandedSizes, dynDims, shape);
333
+ RankedTensorType resultType = RankedTensorType::get (
334
+ shape, expandShapeOp.getResultType ().getElementType ());
335
+
336
+ // Create a new ExtractSliceOp and ExpandShapeOp.
337
+ Value newSliceOp = rewriter.create <tensor::ExtractSliceOp>(
338
+ loc, expandShapeOp.getSrc (), collapsedOffsets, collapsedSizes,
339
+ collapsedStrides);
340
+ rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
341
+ sliceOp, resultType, newSliceOp,
342
+ expandShapeOp.getReassociationIndices (), expandedSizes);
343
+ return success ();
344
+ }
345
+
346
+ // Helper function to check if all the required conditions for the
347
+ // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
348
+ // met.
349
+ LogicalResult
350
+ checkPreconditionForBubbleUpExtractSlice (tensor::ExtractSliceOp sliceOp,
351
+ tensor::ExpandShapeOp expandShapeOp,
352
+ PatternRewriter &rewriter) const {
353
+
354
+ if (!expandShapeOp) {
355
+ return rewriter.notifyMatchFailure (
356
+ sliceOp, " tensor.extract_slice source not produced by expand_shape" );
357
+ }
358
+
359
+ if (!sliceOp.hasUnitStride ()) {
360
+ return rewriter.notifyMatchFailure (
361
+ sliceOp, " unsupported: non-unit stride. Only contiguous slices can "
362
+ " be supported in this transformation." );
363
+ }
364
+
365
+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets ();
366
+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes ();
367
+
368
+ if (static_cast <size_t >(sliceOp.getResultType ().getRank ()) !=
369
+ sizes.size ()) {
370
+ return rewriter.notifyMatchFailure (sliceOp,
371
+ " unimplemented: rank reducing slice" );
372
+ }
373
+
374
+ SmallVector<OpFoldResult> outputShape =
375
+ getMixedValues (expandShapeOp.getStaticOutputShape (),
376
+ expandShapeOp.getOutputShape (), rewriter);
377
+
378
+ std::function<bool (OpFoldResult, OpFoldResult, OpFoldResult)>
379
+ isZeroOffsetAndFullSize =
380
+ [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
381
+ if (!isConstantIntValue (offset, 0 ))
382
+ return false ;
383
+ FailureOr<bool > maybeEqual =
384
+ ValueBoundsConstraintSet::areEqual (sliceSize, size);
385
+ return llvm::succeeded (maybeEqual) && maybeEqual.value ();
386
+ };
387
+
388
+ // Check that the slice is contiguous within each reassociation group.
389
+ // The slice is contiguous only if after the first dimension where a non
390
+ // unit slice is taken, the slice size on all subsequent dimensions of the
391
+ // group is equal to the entire size of the dimension.
392
+ // Examples of contiguous slices:
393
+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
394
+ // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
395
+ // Examples of non contiguous slices:
396
+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
397
+ // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
398
+ for (const ReassociationIndices &indices :
399
+ expandShapeOp.getReassociationIndices ()) {
400
+ int64_t i = 0 ;
401
+ int64_t e = indices.size ();
402
+ // Find the first expanded dim after the first dim with non-unit extracted
403
+ // size.
404
+ for (; i < e; ++i) {
405
+ if (!isConstantIntValue (sizes[indices[i]], 1 )) {
406
+ // +1 to skip the first non-unit size dim.
407
+ i++;
408
+ break ;
409
+ }
410
+ }
411
+
412
+ // Verify that all subsequent dimensions extract the full size of the
413
+ // source tensor.
414
+ for (; i < e; ++i) {
415
+ int64_t expandedDim = indices[i];
416
+ if (!isZeroOffsetAndFullSize (offsets[expandedDim], sizes[expandedDim],
417
+ outputShape[expandedDim])) {
418
+ return rewriter.notifyMatchFailure (
419
+ sliceOp, " Not a contiguous slice of the expanded tensor." );
420
+ }
421
+ }
422
+ }
423
+
424
+ return success ();
425
+ }
426
+ };
427
+
213
428
} // namespace
214
429
215
430
void mlir::tensor::populateReassociativeReshapeFoldingPatterns (
@@ -227,3 +442,8 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
227
442
RewritePatternSet &patterns) {
228
443
patterns.add <BubbleUpExpandThroughParallelCollapse>(patterns.getContext ());
229
444
}
445
+
446
+ void mlir::tensor::populateBubbleUpExtractSliceOpPatterns (
447
+ RewritePatternSet &patterns) {
448
+ patterns.add <BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext ());
449
+ }
0 commit comments