@@ -304,188 +304,6 @@ static void calculateTileOffsetsAndSizes(
304
304
}
305
305
}
306
306
307
- // / Returns a vector of bools representing if, for each axis, `op` can be tiled
308
- // / without incurring in a race condition and thus it is thread-safe to do the
309
- // / tiling. This is checked by iterating over numThreads and ensuring that the
310
- // / corresponding iterator type is "parallel". If it is not, then we know that
311
- // / such dimension is unsafe to tile.
312
- SmallVector<bool > safeToTileToForall (mlir::MLIRContext *ctx, LinalgOp linalgOp,
313
- ArrayRef<OpFoldResult> numThreads) {
314
- auto iterators = linalgOp.getIteratorTypesArray ();
315
- SmallVector<bool > safeToTile (numThreads.size (), true );
316
-
317
- for (unsigned i = 0 , e = numThreads.size (); i != e; i++) {
318
- if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
319
- if (cast<IntegerAttr>(attr).getValue ().getSExtValue () > 1 ) {
320
- safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
321
- }
322
- } else {
323
- safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
324
- }
325
- }
326
- return safeToTile;
327
- }
328
-
329
- // / Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
330
- // / tiling is specified by the number of tiles/threads `numThreads` and the
331
- // / optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
332
- // / not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i],
333
- // / numThreads[i])`. If non-empty, the `mapping` is added as an
334
- // / attribute to the resulting `scf.forall`. A zero tile sizes indicate
335
- // / that the dimension is not tiled, and can be thought of as tiling by the full
336
- // / size of data.
337
- // / It is the user's responsibility to ensure that `numThreads` is a valid
338
- // / tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
339
- // / Linalg case). If the dimension is not parallelizable, a warning is issued to
340
- // / notify the user that the generated code is not safe to parallelize. If
341
- // / `omitTileOffsetBoundsCheck` is true, then the function will assume that
342
- // / `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
343
- static FailureOr<ForallTilingResult> tileToForallOpImpl (
344
- RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
345
- std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
346
- std::optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
347
- Location loc = op->getLoc ();
348
- OpBuilder::InsertionGuard g (b);
349
-
350
- SmallVector<Range> loopRanges = op.getIterationDomain (b);
351
- if (loopRanges.empty ())
352
- return op->emitOpError (" expected non-empty loop ranges" );
353
- auto hasStrideOne = [](Range r) { return !isConstantIntValue (r.stride , 1 ); };
354
- if (llvm::any_of (loopRanges, hasStrideOne))
355
- return op->emitOpError (" only stride-1 supported atm" );
356
-
357
- // Gather destination tensors.
358
- SmallVector<Value> dest;
359
- if (failed (tensor::getOrCreateDestinations (b, loc, op, dest)))
360
- return op->emitOpError (" failed to get destination tensors" );
361
-
362
- SmallVector<OpFoldResult> nonZeroNumThreads =
363
- llvm::to_vector (llvm::make_filter_range (numThreads, [](OpFoldResult ofr) {
364
- return !isConstantIntValue (ofr, 0 );
365
- }));
366
- SmallVector<Value> materializedNonZeroNumThreads =
367
- llvm::to_vector (llvm::map_range (nonZeroNumThreads, [&](OpFoldResult ofr) {
368
- return getValueOrCreateConstantIndexOp (b, loc, ofr);
369
- }));
370
-
371
- LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation ());
372
- if (linalgOp) {
373
- // Check if tiling is thread safe and print a warning if not.
374
- SmallVector<bool > tilingSafety =
375
- safeToTileToForall (b.getContext (), linalgOp, numThreads);
376
- for (size_t i = 0 ; i < tilingSafety.size (); i++)
377
- if (!tilingSafety[i])
378
- op.emitWarning () << " tiling is not thread safe at axis #" << i;
379
- }
380
-
381
- // 1. Create the ForallOp. We don't use the lambda body-builder
382
- // version because we require the use of RewriterBase in the body, so we
383
- // manually move the insertion point to the body below.
384
- scf::ForallOp forallOp = b.create <scf::ForallOp>(
385
- loc, getAsOpFoldResult ((materializedNonZeroNumThreads)), dest, mapping);
386
-
387
- // 2. Fill out the ForallOp body.
388
- SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
389
- calculateTileOffsetsAndSizes (b, loc, forallOp, numThreads, loopRanges,
390
- omitTileOffsetBoundsCheck, nominalTileSizes,
391
- tiledOffsets, tiledSizes);
392
-
393
- // 3. Clone the tileable op and update its destination operands to use the
394
- // output bbArgs of the ForallOp.
395
- ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs ();
396
- Operation *tiledOp = nullptr ;
397
- SmallVector<Value> tiledValues;
398
- {
399
- // 3.a. RAII guard, inserting within forallOp, before terminator.
400
- OpBuilder::InsertionGuard g (b);
401
- b.setInsertionPoint (forallOp.getTerminator ());
402
- Operation *clonedOp = b.clone (*op.getOperation ());
403
- auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
404
- if (destinationStyleOp) {
405
- for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable ()) {
406
- // Swap tensor inits with the corresponding block argument of the
407
- // scf.forall op. Memref inits remain as is.
408
- if (isa<TensorType>(outOperand.get ().getType ())) {
409
- auto *it = llvm::find (dest, outOperand.get ());
410
- assert (it != dest.end () && " could not find destination tensor" );
411
- unsigned destNum = std::distance (dest.begin (), it);
412
- outOperand.set (destBbArgs[destNum]);
413
- }
414
- }
415
- }
416
-
417
- // 4. Tile the cloned op and delete the clone.
418
- FailureOr<TilingResult> tilingResult =
419
- cast<TilingInterface>(clonedOp).getTiledImplementation (b, tiledOffsets,
420
- tiledSizes);
421
- if (failed (tilingResult))
422
- return clonedOp->emitError (" Failed to tile op: " );
423
- if (tilingResult->tiledOps .size () != 1 ) {
424
- return clonedOp->emitError (" expected a single produced tiled op, got " )
425
- << tilingResult->tiledOps .size ();
426
- }
427
-
428
- b.eraseOp (clonedOp);
429
- tiledOp = tilingResult->tiledOps .front ();
430
- tiledValues = tilingResult->tiledValues ;
431
- }
432
-
433
- // 5. Parallel insert back into the result tensor.
434
- for (auto it : llvm::zip (llvm::seq (unsigned (0 ), unsigned (dest.size ())),
435
- tiledValues, destBbArgs)) {
436
- // 5.a. Partial subset information is inserted just before the terminator.
437
- OpBuilder::InsertionGuard g (b);
438
- b.setInsertionPoint (forallOp.getTerminator ());
439
-
440
- SmallVector<OpFoldResult> resultOffsets, resultSizes;
441
- if (failed (op.getResultTilePosition (b, std::get<0 >(it), tiledOffsets,
442
- tiledSizes, resultOffsets,
443
- resultSizes)))
444
- return op->emitOpError (" output offsets couldn't be calculated" );
445
- SmallVector<OpFoldResult> strides (resultSizes.size (), b.getIndexAttr (1 ));
446
-
447
- // 5.b. Parallel insertions are inserted at the end of the combining
448
- // terminator.
449
- b.setInsertionPointToEnd (forallOp.getTerminator ().getBody ());
450
- b.create <tensor::ParallelInsertSliceOp>(loc, std::get<1 >(it),
451
- std::get<2 >(it), resultOffsets,
452
- resultSizes, strides);
453
- }
454
- return ForallTilingResult{forallOp, tiledOp};
455
- }
456
-
457
- FailureOr<ForallTilingResult>
458
- linalg::tileToForallOp (RewriterBase &b, TilingInterface op,
459
- ArrayRef<OpFoldResult> numThreads,
460
- std::optional<ArrayAttr> mapping) {
461
- return tileToForallOpImpl (b, op, numThreads,
462
- /* nominalTileSizes=*/ std::nullopt, mapping,
463
- /* omitTileOffsetBoundsCheck=*/ false );
464
- }
465
-
466
- FailureOr<ForallTilingResult>
467
- linalg::tileToForallOpUsingTileSizes (RewriterBase &b, TilingInterface op,
468
- ArrayRef<OpFoldResult> tileSizes,
469
- std::optional<ArrayAttr> mapping) {
470
- SmallVector<Range> loopRanges = op.getIterationDomain (b);
471
- unsigned nLoops = loopRanges.size ();
472
- SmallVector<OpFoldResult> numThreads;
473
- numThreads.reserve (nLoops);
474
- AffineExpr s0, s1;
475
- bindSymbols (b.getContext (), s0, s1);
476
- AffineExpr divExpr = s0.ceilDiv (s1);
477
- for (const auto &it : llvm::zip (tileSizes, loopRanges)) {
478
- OpFoldResult numTiles = std::get<0 >(it);
479
- if (!isConstantIntValue (numTiles, 0 ))
480
- numTiles = makeComposedFoldedAffineApply (
481
- b, op.getLoc (), divExpr, {std::get<1 >(it).size , std::get<0 >(it)});
482
- numThreads.push_back (numTiles);
483
- }
484
- return tileToForallOpImpl (b, op, numThreads,
485
- /* nominalTileSizes=*/ tileSizes, mapping,
486
- /* omitTileOffsetBoundsCheck=*/ true );
487
- }
488
-
489
307
template <typename LoopTy>
490
308
static FailureOr<TiledLinalgOp>
491
309
tileLinalgOpImpl (RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
0 commit comments