Skip to content

Commit d58ad1c

Browse files
d0klravenclaw
authored andcommitted
Revert "[mlir][Transforms][NFC] Dialect Conversion: Move argument materialization logic (llvm#96329)"
This reverts commit c01ce79. It depends on f1e0657 which breaks SCF lowering.
1 parent f378dfc commit d58ad1c

File tree

1 file changed

+81
-52
lines changed

1 file changed

+81
-52
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 81 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,6 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
5353
});
5454
}
5555

56-
/// Helper function that computes an insertion point where the given value is
57-
/// defined and can be used without a dominance violation.
58-
static OpBuilder::InsertPoint computeInsertPoint(Value value) {
59-
Block *insertBlock = value.getParentBlock();
60-
Block::iterator insertPt = insertBlock->begin();
61-
if (OpResult inputRes = dyn_cast<OpResult>(value))
62-
insertPt = ++inputRes.getOwner()->getIterator();
63-
return OpBuilder::InsertPoint(insertBlock, insertPt);
64-
}
65-
6656
//===----------------------------------------------------------------------===//
6757
// ConversionValueMapping
6858
//===----------------------------------------------------------------------===//
@@ -455,9 +445,11 @@ class BlockTypeConversionRewrite : public BlockRewrite {
455445
return rewrite->getKind() == Kind::BlockTypeConversion;
456446
}
457447

458-
Block *getOrigBlock() const { return origBlock; }
459-
460-
const TypeConverter *getConverter() const { return converter; }
448+
/// Materialize any necessary conversions for converted arguments that have
449+
/// live users, using the provided `findLiveUser` to search for a user that
450+
/// survives the conversion process.
451+
LogicalResult
452+
materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
461453

462454
void commit(RewriterBase &rewriter) override;
463455

@@ -849,10 +841,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
849841
/// Build an unresolved materialization operation given an output type and set
850842
/// of input operands.
851843
Value buildUnresolvedMaterialization(MaterializationKind kind,
852-
OpBuilder::InsertPoint ip, Location loc,
844+
Block *insertBlock,
845+
Block::iterator insertPt, Location loc,
853846
ValueRange inputs, Type outputType,
854847
Type origOutputType,
855848
const TypeConverter *converter);
849+
Value buildUnresolvedTargetMaterialization(Location loc, Value input,
850+
Type outputType,
851+
const TypeConverter *converter);
856852

857853
//===--------------------------------------------------------------------===//
858854
// Rewriter Notification Hooks
@@ -985,6 +981,49 @@ void BlockTypeConversionRewrite::rollback() {
985981
block->replaceAllUsesWith(origBlock);
986982
}
987983

984+
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
985+
function_ref<Operation *(Value)> findLiveUser) {
986+
// Process the remapping for each of the original arguments.
987+
for (auto it : llvm::enumerate(origBlock->getArguments())) {
988+
BlockArgument origArg = it.value();
989+
// Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
990+
OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
991+
builder.setInsertionPointToStart(block);
992+
993+
// If the type of this argument changed and the argument is still live, we
994+
// need to materialize a conversion.
995+
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
996+
continue;
997+
Operation *liveUser = findLiveUser(origArg);
998+
if (!liveUser)
999+
continue;
1000+
1001+
Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
1002+
assert(replacementValue && "replacement value not found");
1003+
Value newArg;
1004+
if (converter) {
1005+
builder.setInsertionPointAfterValue(replacementValue);
1006+
newArg = converter->materializeSourceConversion(
1007+
builder, origArg.getLoc(), origArg.getType(), replacementValue);
1008+
assert((!newArg || newArg.getType() == origArg.getType()) &&
1009+
"materialization hook did not provide a value of the expected "
1010+
"type");
1011+
}
1012+
if (!newArg) {
1013+
InFlightDiagnostic diag =
1014+
emitError(origArg.getLoc())
1015+
<< "failed to materialize conversion for block argument #"
1016+
<< it.index() << " that remained live after conversion, type was "
1017+
<< origArg.getType();
1018+
diag.attachNote(liveUser->getLoc())
1019+
<< "see existing live user here: " << *liveUser;
1020+
return failure();
1021+
}
1022+
rewriterImpl.mapping.map(origArg, newArg);
1023+
}
1024+
return success();
1025+
}
1026+
9881027
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
9891028
Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
9901029
if (!repl)
@@ -1157,10 +1196,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11571196
Type newOperandType = newOperand.getType();
11581197
if (currentTypeConverter && desiredType && newOperandType != desiredType) {
11591198
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1160-
Value castValue = buildUnresolvedMaterialization(
1161-
MaterializationKind::Target, computeInsertPoint(newOperand),
1162-
operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
1163-
/*origArgType=*/{}, currentTypeConverter);
1199+
Value castValue = buildUnresolvedTargetMaterialization(
1200+
operandLoc, newOperand, desiredType, currentTypeConverter);
11641201
mapping.map(mapping.lookupOrDefault(newOperand), castValue);
11651202
newOperand = castValue;
11661203
}
@@ -1288,9 +1325,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12881325
// This block argument was dropped and no replacement value was provided.
12891326
// Materialize a replacement value "out of thin air".
12901327
Value repl = buildUnresolvedMaterialization(
1291-
MaterializationKind::Source,
1292-
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1293-
/*inputs=*/ValueRange(),
1328+
MaterializationKind::Source, newBlock, newBlock->begin(),
1329+
origArg.getLoc(), /*inputs=*/ValueRange(),
12941330
/*outputType=*/origArgType, /*origArgType=*/{}, converter);
12951331
mapping.map(origArg, repl);
12961332
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1315,9 +1351,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13151351
auto replArgs =
13161352
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
13171353
Value repl = buildUnresolvedMaterialization(
1318-
MaterializationKind::Argument,
1319-
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1320-
/*inputs=*/replArgs,
1354+
MaterializationKind::Argument, newBlock, newBlock->begin(),
1355+
origArg.getLoc(), /*inputs=*/replArgs,
13211356
/*outputType=*/tryLegalizeType(origArgType), origArgType, converter);
13221357
mapping.map(origArg, repl);
13231358
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1339,22 +1374,34 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13391374
/// Build an unresolved materialization operation given an output type and set
13401375
/// of input operands.
13411376
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1342-
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1343-
ValueRange inputs, Type outputType, Type origArgType,
1377+
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1378+
Location loc, ValueRange inputs, Type outputType, Type origArgType,
13441379
const TypeConverter *converter) {
13451380
// Avoid materializing an unnecessary cast.
13461381
if (inputs.size() == 1 && inputs.front().getType() == outputType)
13471382
return inputs.front();
13481383

13491384
// Create an unresolved materialization. We use a new OpBuilder to avoid
13501385
// tracking the materialization like we do for other operations.
1351-
OpBuilder builder(ip.getBlock(), ip.getPoint());
1386+
OpBuilder builder(insertBlock, insertPt);
13521387
auto convertOp =
13531388
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
13541389
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
13551390
origArgType);
13561391
return convertOp.getResult(0);
13571392
}
1393+
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
1394+
Location loc, Value input, Type outputType,
1395+
const TypeConverter *converter) {
1396+
Block *insertBlock = input.getParentBlock();
1397+
Block::iterator insertPt = insertBlock->begin();
1398+
if (OpResult inputRes = dyn_cast<OpResult>(input))
1399+
insertPt = ++inputRes.getOwner()->getIterator();
1400+
1401+
return buildUnresolvedMaterialization(
1402+
MaterializationKind::Target, insertBlock, insertPt, loc, input,
1403+
outputType, /*origArgType=*/{}, converter);
1404+
}
13581405

13591406
//===----------------------------------------------------------------------===//
13601407
// Rewriter Notification Hooks
@@ -2468,9 +2515,9 @@ LogicalResult
24682515
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
24692516
std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
24702517
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2471-
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
2472-
failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2473-
inverseMapping)))
2518+
if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2519+
inverseMapping)) ||
2520+
failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
24742521
return failure();
24752522

24762523
// Process requested operation replacements.
@@ -2526,28 +2573,10 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
25262573
++i) {
25272574
auto &rewrite = rewriterImpl.rewrites[i];
25282575
if (auto *blockTypeConversionRewrite =
2529-
dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
2530-
// Process the remapping for each of the original arguments.
2531-
for (Value origArg :
2532-
blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
2533-
// If the type of this argument changed and the argument is still live,
2534-
// we need to materialize a conversion.
2535-
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
2536-
continue;
2537-
Operation *liveUser = findLiveUser(origArg);
2538-
if (!liveUser)
2539-
continue;
2540-
2541-
Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
2542-
assert(replacementValue && "replacement value not found");
2543-
Value repl = rewriterImpl.buildUnresolvedMaterialization(
2544-
MaterializationKind::Source, computeInsertPoint(replacementValue),
2545-
origArg.getLoc(), /*inputs=*/replacementValue,
2546-
/*outputType=*/origArg.getType(), /*origArgType=*/{},
2547-
blockTypeConversionRewrite->getConverter());
2548-
rewriterImpl.mapping.map(origArg, repl);
2549-
}
2550-
}
2576+
dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
2577+
if (failed(blockTypeConversionRewrite->materializeLiveConversions(
2578+
findLiveUser)))
2579+
return failure();
25512580
}
25522581
return success();
25532582
}

0 commit comments

Comments
 (0)