|
18 | 18 | #include "mlir/IR/PatternMatch.h"
|
19 | 19 | #include "mlir/IR/TypeUtilities.h"
|
20 | 20 | #include "mlir/Interfaces/InferTypeOpInterface.h"
|
| 21 | +#include "mlir/Interfaces/SideEffectInterfaces.h" |
21 | 22 | #include "mlir/Interfaces/ViewLikeInterface.h"
|
22 | 23 | #include "llvm/ADT/STLExtras.h"
|
23 | 24 | #include "llvm/ADT/SmallBitVector.h"
|
@@ -258,6 +259,159 @@ void AllocaScopeOp::getSuccessorRegions(
|
258 | 259 | regions.push_back(RegionSuccessor(&bodyRegion()));
|
259 | 260 | }
|
260 | 261 |
|
| 262 | +/// Given an operation, return whether this op is guaranteed to |
| 263 | +/// allocate an AutomaticAllocationScopeResource |
| 264 | +static bool isGuaranteedAutomaticAllocationScope(Operation *op) { |
| 265 | + MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); |
| 266 | + if (!interface) |
| 267 | + return false; |
| 268 | + for (auto res : op->getResults()) { |
| 269 | + if (auto effect = |
| 270 | + interface.getEffectOnValue<MemoryEffects::Allocate>(res)) { |
| 271 | + if (isa<SideEffects::AutomaticAllocationScopeResource>( |
| 272 | + effect->getResource())) |
| 273 | + return true; |
| 274 | + } |
| 275 | + } |
| 276 | + return false; |
| 277 | +} |
| 278 | + |
| 279 | +/// Given an operation, return whether this op could to |
| 280 | +/// allocate an AutomaticAllocationScopeResource |
| 281 | +static bool isPotentialAutomaticAllocationScope(Operation *op) { |
| 282 | + MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); |
| 283 | + if (!interface) |
| 284 | + return true; |
| 285 | + for (auto res : op->getResults()) { |
| 286 | + if (auto effect = |
| 287 | + interface.getEffectOnValue<MemoryEffects::Allocate>(res)) { |
| 288 | + if (isa<SideEffects::AutomaticAllocationScopeResource>( |
| 289 | + effect->getResource())) |
| 290 | + return true; |
| 291 | + } |
| 292 | + } |
| 293 | + return false; |
| 294 | +} |
| 295 | + |
| 296 | +/// Return whether this op is the last non terminating op |
| 297 | +/// in a region. That is to say, it is in a one-block region |
| 298 | +/// and is only followed by a terminator. This prevents |
| 299 | +/// extending the lifetime of allocations. |
| 300 | +static bool lastNonTerminatorInRegion(Operation *op) { |
| 301 | + return op->getNextNode() == op->getBlock()->getTerminator() && |
| 302 | + op->getParentRegion()->getBlocks().size() == 1; |
| 303 | +} |
| 304 | + |
| 305 | +/// Inline an AllocaScopeOp if either the direct parent is an allocation scope |
| 306 | +/// or it contains no allocation. |
| 307 | +struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> { |
| 308 | + using OpRewritePattern<AllocaScopeOp>::OpRewritePattern; |
| 309 | + |
| 310 | + LogicalResult matchAndRewrite(AllocaScopeOp op, |
| 311 | + PatternRewriter &rewriter) const override { |
| 312 | + if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>()) { |
| 313 | + bool hasPotentialAlloca = |
| 314 | + op->walk([&](Operation *alloc) { |
| 315 | + if (isPotentialAutomaticAllocationScope(alloc)) |
| 316 | + return WalkResult::interrupt(); |
| 317 | + return WalkResult::skip(); |
| 318 | + }).wasInterrupted(); |
| 319 | + if (hasPotentialAlloca) |
| 320 | + return failure(); |
| 321 | + } |
| 322 | + |
| 323 | + // Only apply to if this is this last non-terminator |
| 324 | + // op in the block (lest lifetime be extended) of a one |
| 325 | + // block region |
| 326 | + if (!lastNonTerminatorInRegion(op)) |
| 327 | + return failure(); |
| 328 | + |
| 329 | + Block *block = &op.getRegion().front(); |
| 330 | + Operation *terminator = block->getTerminator(); |
| 331 | + ValueRange results = terminator->getOperands(); |
| 332 | + rewriter.mergeBlockBefore(block, op); |
| 333 | + rewriter.replaceOp(op, results); |
| 334 | + rewriter.eraseOp(terminator); |
| 335 | + return success(); |
| 336 | + } |
| 337 | +}; |
| 338 | + |
| 339 | +/// Move allocations into an allocation scope, if it is legal to |
| 340 | +/// move them (e.g. their operands are available at the location |
| 341 | +/// the op would be moved to). |
| 342 | +struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> { |
| 343 | + using OpRewritePattern<AllocaScopeOp>::OpRewritePattern; |
| 344 | + |
| 345 | + LogicalResult matchAndRewrite(AllocaScopeOp op, |
| 346 | + PatternRewriter &rewriter) const override { |
| 347 | + |
| 348 | + if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) |
| 349 | + return failure(); |
| 350 | + |
| 351 | + Operation *lastParentWithoutScope = op->getParentOp(); |
| 352 | + |
| 353 | + if (!lastParentWithoutScope || |
| 354 | + lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>()) |
| 355 | + return failure(); |
| 356 | + |
| 357 | + // Only apply to if this is this last non-terminator |
| 358 | + // op in the block (lest lifetime be extended) of a one |
| 359 | + // block region |
| 360 | + if (!lastNonTerminatorInRegion(op) || |
| 361 | + !lastNonTerminatorInRegion(lastParentWithoutScope)) |
| 362 | + return failure(); |
| 363 | + |
| 364 | + while (!lastParentWithoutScope->getParentOp() |
| 365 | + ->hasTrait<OpTrait::AutomaticAllocationScope>()) { |
| 366 | + lastParentWithoutScope = lastParentWithoutScope->getParentOp(); |
| 367 | + if (!lastParentWithoutScope || |
| 368 | + !lastNonTerminatorInRegion(lastParentWithoutScope)) |
| 369 | + return failure(); |
| 370 | + } |
| 371 | + Operation *scope = lastParentWithoutScope->getParentOp(); |
| 372 | + assert(scope->hasTrait<OpTrait::AutomaticAllocationScope>()); |
| 373 | + |
| 374 | + Region *containingRegion = nullptr; |
| 375 | + for (auto &r : lastParentWithoutScope->getRegions()) { |
| 376 | + if (r.isAncestor(op->getParentRegion())) { |
| 377 | + assert(containingRegion == nullptr && |
| 378 | + "only one region can contain the op"); |
| 379 | + containingRegion = &r; |
| 380 | + } |
| 381 | + } |
| 382 | + assert(containingRegion && "op must be contained in a region"); |
| 383 | + |
| 384 | + SmallVector<Operation *> toHoist; |
| 385 | + op->walk([&](Operation *alloc) { |
| 386 | + if (!isGuaranteedAutomaticAllocationScope(alloc)) |
| 387 | + return WalkResult::skip(); |
| 388 | + |
| 389 | + // If any operand is not defined before the location of |
| 390 | + // lastParentWithoutScope (i.e. where we would hoist to), skip. |
| 391 | + if (llvm::any_of(alloc->getOperands(), [&](Value v) { |
| 392 | + return containingRegion->isAncestor(v.getParentRegion()); |
| 393 | + })) |
| 394 | + return WalkResult::skip(); |
| 395 | + toHoist.push_back(alloc); |
| 396 | + return WalkResult::advance(); |
| 397 | + }); |
| 398 | + |
| 399 | + if (!toHoist.size()) |
| 400 | + return failure(); |
| 401 | + rewriter.setInsertionPoint(lastParentWithoutScope); |
| 402 | + for (auto op : toHoist) { |
| 403 | + auto cloned = rewriter.clone(*op); |
| 404 | + rewriter.replaceOp(op, cloned->getResults()); |
| 405 | + } |
| 406 | + return success(); |
| 407 | + } |
| 408 | +}; |
| 409 | + |
| 410 | +void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 411 | + MLIRContext *context) { |
| 412 | + results.add<AllocaScopeInliner, AllocaScopeHoister>(context); |
| 413 | +} |
| 414 | + |
261 | 415 | //===----------------------------------------------------------------------===//
|
262 | 416 | // AssumeAlignmentOp
|
263 | 417 | //===----------------------------------------------------------------------===//
|
|
0 commit comments