Skip to content

Commit a79501e

Browse files
[mlir][Transforms][NFC] Simplify ArgConverter state
* When converting a block signature, `ArgConverter` creates a new block with the new signature and moves all operation from the old block to the new block. The new block is temporarily inserted into a region that is stored in `regionMapping`. The old block is not yet deleted, so that the conversion can be rolled back. `regionMapping` is not needed. Instead of moving the old block to a temporary region, it can just be unlinked. Block erasures are handles in the same way in the dialect conversion. * `regionToConverter` is a mapping from regions to type converter. That field is never accessed within `ArgConverter`. It should be stored in `ConversionPatternRewriterImpl` instead.
1 parent 820bcdd commit a79501e

File tree

1 file changed

+22
-57
lines changed

1 file changed

+22
-57
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 22 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -343,23 +343,6 @@ struct ArgConverter {
343343
const TypeConverter *converter;
344344
};
345345

346-
/// Return if the signature of the given block has already been converted.
347-
bool hasBeenConverted(Block *block) const {
348-
return conversionInfo.count(block) || convertedBlocks.count(block);
349-
}
350-
351-
/// Set the type converter to use for the given region.
352-
void setConverter(Region *region, const TypeConverter *typeConverter) {
353-
assert(typeConverter && "expected valid type converter");
354-
regionToConverter[region] = typeConverter;
355-
}
356-
357-
/// Return the type converter to use for the given region, or null if there
358-
/// isn't one.
359-
const TypeConverter *getConverter(Region *region) {
360-
return regionToConverter.lookup(region);
361-
}
362-
363346
//===--------------------------------------------------------------------===//
364347
// Rewrite Application
365348
//===--------------------------------------------------------------------===//
@@ -409,24 +392,10 @@ struct ArgConverter {
409392
ConversionValueMapping &mapping,
410393
SmallVectorImpl<BlockArgument> &argReplacements);
411394

412-
/// Insert a new conversion into the cache.
413-
void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
414-
415395
/// A collection of blocks that have had their arguments converted. This is a
416396
/// map from the new replacement block, back to the original block.
417397
llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
418398

419-
/// The set of original blocks that were converted.
420-
DenseSet<Block *> convertedBlocks;
421-
422-
/// A mapping from valid regions, to those containing the original blocks of a
423-
/// conversion.
424-
DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
425-
426-
/// A mapping of regions to type converters that should be used when
427-
/// converting the arguments of blocks within that region.
428-
DenseMap<Region *, const TypeConverter *> regionToConverter;
429-
430399
/// The pattern rewriter to use when materializing conversions.
431400
PatternRewriter &rewriter;
432401

@@ -474,12 +443,12 @@ void ArgConverter::discardRewrites(Block *block) {
474443
block->getArgument(i).dropAllUses();
475444
block->replaceAllUsesWith(origBlock);
476445

477-
// Move the operations back the original block and the delete the new block.
446+
// Move the operations back the original block, move the original block back
447+
// into its original location and the delete the new block.
478448
origBlock->getOperations().splice(origBlock->end(), block->getOperations());
479-
origBlock->moveBefore(block);
449+
block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
480450
block->erase();
481451

482-
convertedBlocks.erase(origBlock);
483452
conversionInfo.erase(it);
484453
}
485454

@@ -510,6 +479,9 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
510479
mapping.lookupOrDefault(castValue, origArg.getType()));
511480
}
512481
}
482+
483+
delete origBlock;
484+
blockInfo.origBlock = nullptr;
513485
}
514486
}
515487

@@ -572,9 +544,11 @@ FailureOr<Block *> ArgConverter::convertSignature(
572544
Block *block, const TypeConverter *converter,
573545
ConversionValueMapping &mapping,
574546
SmallVectorImpl<BlockArgument> &argReplacements) {
575-
// Check if the block was already converted. If the block is detached,
576-
// conservatively assume it is going to be deleted.
577-
if (hasBeenConverted(block) || !block->getParent())
547+
// Check if the block was already converted.
548+
// * If the block is mapped in `conversionInfo`, it is a converted block.
549+
// * If the block is detached, conservatively assume that it is going to be
550+
// deleted; it is likely the old block (before it was converted).
551+
if (conversionInfo.count(block) || !block->getParent())
578552
return block;
579553
// If a converter wasn't provided, and the block wasn't already converted,
580554
// there is nothing we can do.
@@ -603,6 +577,9 @@ Block *ArgConverter::applySignatureConversion(
603577
// signature.
604578
Block *newBlock = block->splitBlock(block->begin());
605579
block->replaceAllUsesWith(newBlock);
580+
// Unlink the block, but do not erase it yet, so that the change can be rolled
581+
// back.
582+
block->getParent()->getBlocks().remove(block);
606583

607584
// Map all new arguments to the location of the argument they originate from.
608585
SmallVector<Location> newLocs(convertedTypes.size(),
@@ -679,24 +656,8 @@ Block *ArgConverter::applySignatureConversion(
679656
ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
680657
}
681658

682-
// Remove the original block from the region and return the new one.
683-
insertConversion(newBlock, std::move(info));
684-
return newBlock;
685-
}
686-
687-
void ArgConverter::insertConversion(Block *newBlock,
688-
ConvertedBlockInfo &&info) {
689-
// Get a region to insert the old block.
690-
Region *region = newBlock->getParent();
691-
std::unique_ptr<Region> &mappedRegion = regionMapping[region];
692-
if (!mappedRegion)
693-
mappedRegion = std::make_unique<Region>(region->getParentOp());
694-
695-
// Move the original block to the mapped region and emplace the conversion.
696-
mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
697-
info.origBlock->getIterator());
698-
convertedBlocks.insert(info.origBlock);
699659
conversionInfo.insert({newBlock, std::move(info)});
660+
return newBlock;
700661
}
701662

702663
//===----------------------------------------------------------------------===//
@@ -1196,6 +1157,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
11961157
/// active.
11971158
const TypeConverter *currentTypeConverter = nullptr;
11981159

1160+
/// A mapping of regions to type converters that should be used when
1161+
/// converting the arguments of blocks within that region.
1162+
DenseMap<Region *, const TypeConverter *> regionToConverter;
1163+
11991164
/// This allows the user to collect the match failure message.
12001165
function_ref<void(Diagnostic &)> notifyCallback;
12011166

@@ -1473,7 +1438,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14731438
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
14741439
Region *region, const TypeConverter &converter,
14751440
TypeConverter::SignatureConversion *entryConversion) {
1476-
argConverter.setConverter(region, &converter);
1441+
regionToConverter[region] = &converter;
14771442
if (region->empty())
14781443
return nullptr;
14791444

@@ -1488,7 +1453,7 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
14881453
LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
14891454
Region *region, const TypeConverter &converter,
14901455
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
1491-
argConverter.setConverter(region, &converter);
1456+
regionToConverter[region] = &converter;
14921457
if (region->empty())
14931458
return success();
14941459

@@ -2162,7 +2127,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
21622127

21632128
// If the region of the block has a type converter, try to convert the block
21642129
// directly.
2165-
if (auto *converter = impl.argConverter.getConverter(block->getParent())) {
2130+
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
21662131
if (failed(impl.convertBlockSignature(block, converter))) {
21672132
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
21682133
"block"));

0 commit comments

Comments
 (0)