Skip to content

[NFC] Simplify the tiling implementation using cloning. #72178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 78 additions & 39 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
Operation *op,
ValueRange newDestArgs) {
Operation *clonedOp = rewriter.clone(*op);
if (newDestArgs.empty()) {
return clonedOp;
}
if (auto destinationStyleOp =
dyn_cast<DestinationStyleOpInterface>(clonedOp)) {
destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
Expand All @@ -139,15 +142,17 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
/// the
/// tile processed within the inner most loop.
/// the tile processed within the inner most loop.
/// Note that this methods adds `scf.yield` operation for all but the innermost
/// loop. These yield the value returned by the immediately inner loop. The
/// caller is expected to add the scf.yield operation for the innermost loop.
static SmallVector<scf::ForOp> generateTileLoopNest(
OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes) {
assert(!loopRanges.empty() && "expected at least one loop range");
assert(loopRanges.size() == tileSizes.size() &&
"expected as many tile sizes as loop ranges");
SmallVector<OpFoldResult> &sizes, ValueRange destinationTensors = {}) {
if (loopRanges.empty()) {
return {};
}
OpBuilder::InsertionGuard guard(builder);
SmallVector<scf::ForOp> loops;
offsets.resize(loopRanges.size());
Expand All @@ -169,17 +174,25 @@ static SmallVector<scf::ForOp> generateTileLoopNest(
}

auto loop = builder.create<scf::ForOp>(
loc, offset, size, tileSize, ValueRange{},
loc, offset, size, tileSize, destinationTensors,
[&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
ValueRange /*iterArgs*/) {
sizes[loopRange.index()] =
getBoundedTileSize(bodyBuilder, bodyLoc, loopRange.value(), iv,
getAsOpFoldResult(tileSize));
builder.create<scf::YieldOp>(loc);
});
offsets[loopRange.index()] = loop.getInductionVar();
loops.push_back(loop);
builder.setInsertionPoint(loop.getBody()->getTerminator());
builder.setInsertionPointToEnd(loop.getBody());
destinationTensors = loop.getRegionIterArgs();
}

// Add the scf.yield operations for all the outer loops.
for (auto [outerLoop, innerLoop] :
llvm::zip(MutableArrayRef(loops).drop_back(),
MutableArrayRef(loops).drop_front())) {
builder.setInsertionPointToEnd(outerLoop.getBody());
builder.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop.getResults());
}
return loops;
}
Expand Down Expand Up @@ -317,10 +330,6 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
// 1. Get the range of the loops that are represented by the operation.
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
size_t numLoops = iterationDomain.size();
if (numLoops == 0) {
return rewriter.notifyMatchFailure(
op, "unable to tile op with no iteration domain");
}

// 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
// skips tiling a particular dimension. This convention is significantly
Expand All @@ -333,6 +342,14 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
}

// 3. Find the destination tensors to use for the operation.
SmallVector<Value> destinationTensors;
if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
destinationTensors))) {
return rewriter.notifyMatchFailure(op,
"unable to create destination tensors");
}

SmallVector<OpFoldResult> offsets, sizes;
SmallVector<scf::ForOp> forLoops;
{
Expand All @@ -354,11 +371,12 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
applyPermutationToVector(tileSizeVector, interchangeVector);
}

// 3. Materialize an empty loop nest that iterates over the tiles. These
// 4. Materialize an empty loop nest that iterates over the tiles. These
// loops for now do not return any values even if the original operation has
// results.
forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
tileSizeVector, offsets, sizes);
tileSizeVector, offsets, sizes,
destinationTensors);

if (!interchangeVector.empty()) {
auto inversePermutation = invertPermutationVector(interchangeVector);
Expand All @@ -375,17 +393,29 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
}
});

// 4. Generate the tiled implementation within the inner most loop.
if (!forLoops.empty())
rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator());
FailureOr<TilingResult> tiledImplementation =
op.getTiledImplementation(rewriter, offsets, sizes);
// 5. Generate the tiled implementation within the inner most loop.
SmallVector<Value> clonedOpDestination = destinationTensors;
if (!forLoops.empty()) {
rewriter.setInsertionPointToEnd(forLoops.back().getBody());
clonedOpDestination =
llvm::map_to_vector(forLoops.back().getRegionIterArgs(),
[](BlockArgument b) -> Value { return b; });
}

if (op->getNumResults() == 0) {
return scf::SCFTilingResult{
tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
// 5a. Clone the operation within the loop body.
auto clonedOp = cast<TilingInterface>(
cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination));

// 5b. Tile the cloned operation.
FailureOr<TilingResult> tiledImplementation =
clonedOp.getTiledImplementation(rewriter, offsets, sizes);
if (failed(tiledImplementation)) {
return rewriter.notifyMatchFailure(op, "failed to tile operation");
}

// 5c. Delete the cloned operation.
rewriter.eraseOp(clonedOp);

// If loops are empty, the tiled op is used as the replacement for the untiled
// op.
if (forLoops.empty()) {
Expand All @@ -394,30 +424,39 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
tiledImplementation->tiledValues};
}

// 5. Yield all the results of the tiled operation. The surrounding loop
// nest is modified to insert a destructive update pattern to yield
// from the loop nest values to replace the untiled op with.
if (op->getNumResults() == 0) {
// The innermost loop does not have a `scf.yield` yet. There is nothing to
// return, so generate an empty `scf.yield` operation.
rewriter.setInsertionPointToEnd(forLoops.back().getBody());
rewriter.create<scf::YieldOp>(op->getLoc());
return scf::SCFTilingResult{
tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
}

// 6. Yield all the results of the tiled operation.
int64_t numResults = op->getNumResults();
SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
resultSizesList(numResults);
for (const auto &result : llvm::enumerate(op->getResults())) {
if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
sizes,
resultOffsetsList[result.index()],
resultSizesList[result.index()]))) {
SmallVector<Value> yieldedValues;
for (auto [index, tiledValue] :
llvm::enumerate(tiledImplementation->tiledValues)) {
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
resultOffsets, resultSizes))) {
return rewriter.notifyMatchFailure(
op, "failed to get slice of result produced");
}
SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
rewriter.getIndexAttr(1));
auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
op->getLoc(), tiledValue, clonedOpDestination[index], resultOffsets,
resultSizes, resultStrides);
yieldedValues.push_back(insertSlice);
}
rewriter.create<scf::YieldOp>(op->getLoc(), yieldedValues);

SmallVector<Value> destinationTensors;
if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
destinationTensors)))
return rewriter.notifyMatchFailure(op, "failed to get destinations");

SmallVector<Value> replacements = yieldTiledValues(
rewriter, destinationTensors, tiledImplementation.value(),
resultOffsetsList, resultSizesList, forLoops);
SmallVector<Value> replacements = llvm::map_to_vector(
forLoops.front().getResults(), [](OpResult r) -> Value { return r; });
LLVM_DEBUG({
if (!forLoops.empty()) {
llvm::dbgs() << "After tiled implementation :\n";
Expand Down
126 changes: 82 additions & 44 deletions mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -100,37 +100,37 @@ func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200x
} -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>)
return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>
}
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)>
// CHECK-LABEL: func.func @multi_result(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[C300:.+]] = arith.constant 300 : index
// CHECK-DAG: %[[INIT0:.+]] = tensor.empty()
// CHECK-DAG: %[[INIT1:.+]] = tensor.empty()
// CHECK: %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
// CHECK-SAME: iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
// CHECK: %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
// CHECK: %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
// CHECK-SAME: iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
// CHECK-DAG: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK-SAME: [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, 20] [1, 1, 1]
// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]]
// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG4]]
// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
// CHECK: %[[RESULT_TILE:.+]]:2 = linalg.generic
// CHECK-SAME: ins(%[[ARG_TILE]] :
// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
// CHECK: %[[UPDATE0:.+]] = tensor.insert_slice %[[RESULT_TILE]]#0 into %[[ARG3]]
// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
// CHECK: %[[UPDATE1:.+]] = tensor.insert_slice %[[RESULT_TILE]]#1 into %[[ARG4]]
// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
// CHECK: scf.yield %[[UPDATE0]], %[[UPDATE1]]
// CHECK: scf.yield %[[INNER]]#0, %[[INNER]]#1
// CHECK: return %[[OUTER]]#0, %[[OUTER]]#1
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)>
// CHECK-LABEL: func.func @multi_result(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[C300:.+]] = arith.constant 300 : index
// CHECK-DAG: %[[INIT0:.+]] = tensor.empty()
// CHECK-DAG: %[[INIT1:.+]] = tensor.empty()
// CHECK: %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
// CHECK-SAME: iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
// CHECK: %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
// CHECK: %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
// CHECK-SAME: iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
// CHECK-DAG: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK-SAME: [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, 20] [1, 1, 1]
// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]]
// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG4]]
// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
// CHECK: %[[RESULT_TILE:.+]]:2 = linalg.generic
// CHECK-SAME: ins(%[[ARG_TILE]] :
// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
// CHECK: %[[UPDATE0:.+]] = tensor.insert_slice %[[RESULT_TILE]]#0 into %[[ARG3]]
// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
// CHECK: %[[UPDATE1:.+]] = tensor.insert_slice %[[RESULT_TILE]]#1 into %[[ARG4]]
// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
// CHECK: scf.yield %[[UPDATE0]], %[[UPDATE1]]
// CHECK: scf.yield %[[INNER]]#0, %[[INNER]]#1
// CHECK: return %[[OUTER]]#0, %[[OUTER]]#1

// -----

Expand Down Expand Up @@ -193,14 +193,9 @@ func.func @conv2D(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,

// -----

// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>

// CHECK-LABEL: @indexed_semantics
func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
// Check that we correctly amend "linalg.index" results.

// CHECK: scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
// CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
%0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
Expand All @@ -209,13 +204,8 @@ func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) ->
ins(%arg0: tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) {
^bb0(%arg2: f32, %arg3: f32):
// CHECK: %[[INDEX0:.+]] = linalg.index 0
// CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
%1 = linalg.index 0 : index
// CHECK: %[[INDEX1:.+]] = linalg.index 1
// CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
%2 = linalg.index 1 : index
// CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
%3 = arith.addi %1, %2 : index
%4 = arith.index_cast %3 : index to i64
%5 = arith.uitofp %4 : i64 to f32
Expand All @@ -224,6 +214,15 @@ func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) ->
} -> (tensor<?x?xf32>)
return %0 : tensor<?x?xf32>
}
// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
// CHECK-LABEL: @indexed_semantics
// CHECK: scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
// CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
// CHECK: %[[INDEX0:.+]] = linalg.index 0
// CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
// CHECK: %[[INDEX1:.+]] = linalg.index 1
// CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
// CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]

// -----

Expand Down Expand Up @@ -276,14 +275,53 @@ func.func @interchange_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,

// -----

func.func @linalg_copy_matmul(%a: memref<?x?xf32>, %b: memref<?x?xf32>) {
linalg.copy {__internal_transform__ = "simple_copy_memref"}
ins(%a : memref<?x?xf32>) outs(%b : memref<?x?xf32>)
return
}
// CHECK-LABEL: func @linalg_copy_matmul(
// CHECK: scf.for
// CHECK: scf.for
// CHECK: memref.subview
// CHECK: memref.subview
// CHECK: linalg.copy
func.func @linalg_copy_matmul(%a: memref<?x?xf32>, %b: memref<?x?xf32>) {
linalg.copy {__internal_transform__ = "simple_copy_memref"}
ins(%a : memref<?x?xf32>) outs(%b : memref<?x?xf32>)

// -----

func.func @check_scalar_operation(%arg0 : tensor<f32>) -> tensor<f32> {
%init = tensor.empty() : tensor<f32>
%0 = linalg.generic {
indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
iterator_types = []}
{__internal_transform__ = "scalar_op"}
ins(%arg0 : tensor<f32>) outs(%init : tensor<f32>){
^bb0(%b0 : f32, %b1 : f32):
%1 = arith.mulf %b0, %b0 : f32
linalg.yield %1 : f32
} -> tensor<f32>
return %0 : tensor<f32>
}
// CHECK-LABEL: func @check_scalar_operation
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK-SAME: __internal_transform__ = "tiled"

// -----

func.func @check_scalar_memref_operation(%arg0 : memref<f32>, %arg1 : memref<f32>){
linalg.generic {
indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
iterator_types = []}
{__internal_transform__ = "scalar_op"}
ins(%arg0 : memref<f32>) outs(%arg1 : memref<f32>){
^bb0(%b0 : f32, %b1 : f32):
%1 = arith.mulf %b0, %b0 : f32
linalg.yield %1 : f32
}
return
}
// CHECK-LABEL: func @check_scalar_memref_operation
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK-SAME: __internal_transform__ = "tiled"
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,8 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
addPatternForTiling(context, patterns, "pad_outer_tiling", {2, 3});
// 10. Tiling M and N dims of `linalg.copy` on memrefs.
addPatternForTiling(context, patterns, "simple_copy_memref", {10, 20});
// 11. Tiling scalar operations.
addPatternForTiling(context, patterns, "scalar_op", {});
return;
}
if (testTilingForAll) {
Expand Down