Skip to content

Commit 78fb4f9

Browse files
committed
[SCF][MemRef] Enable SCF.Parallel Lowering to use Scope Op
As discussed in https://reviews.llvm.org/D119743 scf.parallel would continuously stack allocate since the alloca op was placd in the wsloop rather than the omp.parallel. This PR is the second stage of the fix for that problem. Specifically, we now introduce an alloca scope around the inlined body of the scf.parallel and enable a canonicalization to hoist the allocations to the surrounding allocation scope (e.g. omp.parallel). Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D120423
1 parent b9d6e8c commit 78fb4f9

File tree

7 files changed

+279
-28
lines changed

7 files changed

+279
-28
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,8 @@ def ConvertSCFToOpenMP : Pass<"convert-scf-to-openmp", "ModuleOp"> {
501501
let summary = "Convert SCF parallel loop to OpenMP parallel + workshare "
502502
"constructs.";
503503
let constructor = "mlir::createConvertSCFToOpenMPPass()";
504-
let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect"];
504+
let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect",
505+
"memref::MemRefDialect"];
505506
}
506507

507508
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope",
274274
let regions = (region SizedRegion<1>:$bodyRegion);
275275
let hasCustomAssemblyFormat = 1;
276276
let hasVerifier = 1;
277+
let hasCanonicalizer = 1;
277278
}
278279

279280
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1818
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1919
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2021
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
2122
#include "mlir/Dialect/SCF/SCF.h"
2223
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -364,8 +365,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
364365
loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1));
365366
SmallVector<Value> reductionVariables;
366367
reductionVariables.reserve(parallelOp.getNumReductions());
367-
Value token = rewriter.create<LLVM::StackSaveOp>(
368-
loc, LLVM::LLVMPointerType::get(rewriter.getIntegerType(8)));
369368
for (Value init : parallelOp.getInitVals()) {
370369
assert((LLVM::isCompatibleType(init.getType()) ||
371370
init.getType().isa<LLVM::PointerElementTypeInterface>()) &&
@@ -392,31 +391,31 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
392391
// Create the parallel wrapper.
393392
auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
394393
{
394+
395395
OpBuilder::InsertionGuard guard(rewriter);
396396
rewriter.createBlock(&ompParallel.region());
397397

398-
// Replace SCF yield with OpenMP yield.
399398
{
400-
OpBuilder::InsertionGuard innerGuard(rewriter);
401-
rewriter.setInsertionPointToEnd(parallelOp.getBody());
402-
assert(llvm::hasSingleElement(parallelOp.getRegion()) &&
403-
"expected scf.parallel to have one block");
404-
rewriter.replaceOpWithNewOp<omp::YieldOp>(
405-
parallelOp.getBody()->getTerminator(), ValueRange());
406-
}
407-
408-
// Replace the loop.
409-
auto loop = rewriter.create<omp::WsLoopOp>(
410-
parallelOp.getLoc(), parallelOp.getLowerBound(),
411-
parallelOp.getUpperBound(), parallelOp.getStep());
412-
rewriter.create<omp::TerminatorOp>(loc);
413-
414-
rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.region(),
415-
loop.region().begin());
416-
if (!reductionVariables.empty()) {
417-
loop.reductionsAttr(
418-
ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
419-
loop.reduction_varsMutable().append(reductionVariables);
399+
auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
400+
TypeRange());
401+
rewriter.create<omp::TerminatorOp>(loc);
402+
OpBuilder::InsertionGuard allocaGuard(rewriter);
403+
rewriter.createBlock(&scope.getBodyRegion());
404+
rewriter.setInsertionPointToStart(&scope.getBodyRegion().front());
405+
406+
// Replace the loop.
407+
auto loop = rewriter.create<omp::WsLoopOp>(
408+
parallelOp.getLoc(), parallelOp.getLowerBound(),
409+
parallelOp.getUpperBound(), parallelOp.getStep());
410+
rewriter.create<memref::AllocaScopeReturnOp>(loc);
411+
412+
rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.region(),
413+
loop.region().begin());
414+
if (!reductionVariables.empty()) {
415+
loop.reductionsAttr(
416+
ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
417+
loop.reduction_varsMutable().append(reductionVariables);
418+
}
420419
}
421420
}
422421

@@ -429,7 +428,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
429428
}
430429
rewriter.replaceOp(parallelOp, results);
431430

432-
rewriter.create<LLVM::StackRestoreOp>(loc, token);
433431
return success();
434432
}
435433
};
@@ -438,7 +436,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
438436
static LogicalResult applyPatterns(ModuleOp module) {
439437
ConversionTarget target(*module.getContext());
440438
target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
441-
target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect>();
439+
target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
440+
memref::MemRefDialect>();
442441

443442
RewritePatternSet patterns(module.getContext());
444443
patterns.add<ParallelOpLowering>(module.getContext());

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/IR/PatternMatch.h"
1919
#include "mlir/IR/TypeUtilities.h"
2020
#include "mlir/Interfaces/InferTypeOpInterface.h"
21+
#include "mlir/Interfaces/SideEffectInterfaces.h"
2122
#include "mlir/Interfaces/ViewLikeInterface.h"
2223
#include "llvm/ADT/STLExtras.h"
2324
#include "llvm/ADT/SmallBitVector.h"
@@ -258,6 +259,159 @@ void AllocaScopeOp::getSuccessorRegions(
258259
regions.push_back(RegionSuccessor(&bodyRegion()));
259260
}
260261

262+
/// Given an operation, return whether this op is guaranteed to
263+
/// allocate an AutomaticAllocationScopeResource
264+
static bool isGuaranteedAutomaticAllocationScope(Operation *op) {
265+
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
266+
if (!interface)
267+
return false;
268+
for (auto res : op->getResults()) {
269+
if (auto effect =
270+
interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
271+
if (isa<SideEffects::AutomaticAllocationScopeResource>(
272+
effect->getResource()))
273+
return true;
274+
}
275+
}
276+
return false;
277+
}
278+
279+
/// Given an operation, return whether this op could to
280+
/// allocate an AutomaticAllocationScopeResource
281+
static bool isPotentialAutomaticAllocationScope(Operation *op) {
282+
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
283+
if (!interface)
284+
return true;
285+
for (auto res : op->getResults()) {
286+
if (auto effect =
287+
interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
288+
if (isa<SideEffects::AutomaticAllocationScopeResource>(
289+
effect->getResource()))
290+
return true;
291+
}
292+
}
293+
return false;
294+
}
295+
296+
/// Return whether this op is the last non terminating op
297+
/// in a region. That is to say, it is in a one-block region
298+
/// and is only followed by a terminator. This prevents
299+
/// extending the lifetime of allocations.
300+
static bool lastNonTerminatorInRegion(Operation *op) {
301+
return op->getNextNode() == op->getBlock()->getTerminator() &&
302+
op->getParentRegion()->getBlocks().size() == 1;
303+
}
304+
305+
/// Inline an AllocaScopeOp if either the direct parent is an allocation scope
306+
/// or it contains no allocation.
307+
struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
308+
using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
309+
310+
LogicalResult matchAndRewrite(AllocaScopeOp op,
311+
PatternRewriter &rewriter) const override {
312+
if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>()) {
313+
bool hasPotentialAlloca =
314+
op->walk([&](Operation *alloc) {
315+
if (isPotentialAutomaticAllocationScope(alloc))
316+
return WalkResult::interrupt();
317+
return WalkResult::skip();
318+
}).wasInterrupted();
319+
if (hasPotentialAlloca)
320+
return failure();
321+
}
322+
323+
// Only apply to if this is this last non-terminator
324+
// op in the block (lest lifetime be extended) of a one
325+
// block region
326+
if (!lastNonTerminatorInRegion(op))
327+
return failure();
328+
329+
Block *block = &op.getRegion().front();
330+
Operation *terminator = block->getTerminator();
331+
ValueRange results = terminator->getOperands();
332+
rewriter.mergeBlockBefore(block, op);
333+
rewriter.replaceOp(op, results);
334+
rewriter.eraseOp(terminator);
335+
return success();
336+
}
337+
};
338+
339+
/// Move allocations into an allocation scope, if it is legal to
340+
/// move them (e.g. their operands are available at the location
341+
/// the op would be moved to).
342+
struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
343+
using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
344+
345+
LogicalResult matchAndRewrite(AllocaScopeOp op,
346+
PatternRewriter &rewriter) const override {
347+
348+
if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
349+
return failure();
350+
351+
Operation *lastParentWithoutScope = op->getParentOp();
352+
353+
if (!lastParentWithoutScope ||
354+
lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
355+
return failure();
356+
357+
// Only apply to if this is this last non-terminator
358+
// op in the block (lest lifetime be extended) of a one
359+
// block region
360+
if (!lastNonTerminatorInRegion(op) ||
361+
!lastNonTerminatorInRegion(lastParentWithoutScope))
362+
return failure();
363+
364+
while (!lastParentWithoutScope->getParentOp()
365+
->hasTrait<OpTrait::AutomaticAllocationScope>()) {
366+
lastParentWithoutScope = lastParentWithoutScope->getParentOp();
367+
if (!lastParentWithoutScope ||
368+
!lastNonTerminatorInRegion(lastParentWithoutScope))
369+
return failure();
370+
}
371+
Operation *scope = lastParentWithoutScope->getParentOp();
372+
assert(scope->hasTrait<OpTrait::AutomaticAllocationScope>());
373+
374+
Region *containingRegion = nullptr;
375+
for (auto &r : lastParentWithoutScope->getRegions()) {
376+
if (r.isAncestor(op->getParentRegion())) {
377+
assert(containingRegion == nullptr &&
378+
"only one region can contain the op");
379+
containingRegion = &r;
380+
}
381+
}
382+
assert(containingRegion && "op must be contained in a region");
383+
384+
SmallVector<Operation *> toHoist;
385+
op->walk([&](Operation *alloc) {
386+
if (!isGuaranteedAutomaticAllocationScope(alloc))
387+
return WalkResult::skip();
388+
389+
// If any operand is not defined before the location of
390+
// lastParentWithoutScope (i.e. where we would hoist to), skip.
391+
if (llvm::any_of(alloc->getOperands(), [&](Value v) {
392+
return containingRegion->isAncestor(v.getParentRegion());
393+
}))
394+
return WalkResult::skip();
395+
toHoist.push_back(alloc);
396+
return WalkResult::advance();
397+
});
398+
399+
if (!toHoist.size())
400+
return failure();
401+
rewriter.setInsertionPoint(lastParentWithoutScope);
402+
for (auto op : toHoist) {
403+
auto cloned = rewriter.clone(*op);
404+
rewriter.replaceOp(op, cloned->getResults());
405+
}
406+
return success();
407+
}
408+
};
409+
410+
void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
411+
MLIRContext *context) {
412+
results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
413+
}
414+
261415
//===----------------------------------------------------------------------===//
262416
// AssumeAlignmentOp
263417
//===----------------------------------------------------------------------===//

mlir/test/Conversion/SCFToOpenMP/reductions.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
2121
%arg3 : index, %arg4 : index) {
2222
// CHECK: %[[CST:.*]] = arith.constant 0.0
2323
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1
24-
// CHECK: llvm.intr.stacksave
2524
// CHECK: %[[BUF:.*]] = llvm.alloca %[[ONE]] x f32
2625
// CHECK: llvm.store %[[CST]], %[[BUF]]
2726
%step = arith.constant 1 : index
2827
%zero = arith.constant 0.0 : f32
2928
// CHECK: omp.parallel
29+
// CHECK: memref.alloca_scope
3030
// CHECK: omp.wsloop
3131
// CHECK-SAME: reduction(@[[$REDF]] -> %[[BUF]]
3232
scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
@@ -43,7 +43,6 @@ func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
4343
}
4444
// CHECK: omp.terminator
4545
// CHECK: llvm.load %[[BUF]]
46-
// CHECK: llvm.intr.stackrestore
4746
return
4847
}
4948

@@ -162,6 +161,7 @@ func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
162161
// CHECK: llvm.store %[[IONE]], %[[BUF2]]
163162

164163
// CHECK: omp.parallel
164+
// CHECK: memref.alloca_scope
165165
// CHECK: omp.wsloop
166166
// CHECK-SAME: reduction(@[[$REDF1]] -> %[[BUF1]]
167167
// CHECK-SAME: @[[$REDF2]] -> %[[BUF2]]

mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
func @parallel(%arg0: index, %arg1: index, %arg2: index,
55
%arg3: index, %arg4: index, %arg5: index) {
66
// CHECK: omp.parallel {
7+
// CHECK: memref.alloca_scope
78
// CHECK: omp.wsloop (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
89
scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
910
// CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> ()
@@ -20,9 +21,11 @@ func @parallel(%arg0: index, %arg1: index, %arg2: index,
2021
func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
2122
%arg3: index, %arg4: index, %arg5: index) {
2223
// CHECK: omp.parallel {
24+
// CHECK: memref.alloca_scope
2325
// CHECK: omp.wsloop (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
2426
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
2527
// CHECK: omp.parallel
28+
// CHECK: memref.alloca_scope
2629
// CHECK: omp.wsloop (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
2730
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
2831
// CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> ()
@@ -41,6 +44,7 @@ func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
4144
func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
4245
%arg3: index, %arg4: index, %arg5: index) {
4346
// CHECK: omp.parallel {
47+
// CHECK: memref.alloca_scope
4448
// CHECK: omp.wsloop (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
4549
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
4650
// CHECK: "test.payload1"(%[[LVAR_AL1]]) : (index) -> ()
@@ -52,6 +56,7 @@ func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
5256
// CHECK: }
5357

5458
// CHECK: omp.parallel {
59+
// CHECK: memref.alloca_scope
5560
// CHECK: omp.wsloop (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
5661
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
5762
// CHECK: "test.payload2"(%[[LVAR_AL2]]) : (index) -> ()

0 commit comments

Comments
 (0)