Skip to content

Commit bbd4af5

Browse files
[mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments (#97213)
This commit simplifies the handling of dropped arguments and updates some dialect conversion documentation that is outdated. When converting a block signature, a `BlockTypeConversionRewrite` object and potentially multiple `ReplaceBlockArgRewrite` are created. During the "commit" phase, uses of the old block arguments are replaced with the new block arguments, but the old implementation was written in an inconsistent way: some block arguments were replaced in `BlockTypeConversionRewrite::commit` and some were replaced in `ReplaceBlockArgRewrite::commit`. The new `BlockTypeConversionRewrite::commit` implementation is much simpler and no longer modifies any IR; that is done only in `ReplaceBlockArgRewrite` now. The `ConvertedArgInfo` data structure is no longer needed. To that end, materializations of dropped arguments are now built in `applySignatureConversion` instead of `materializeLiveConversions`; the latter function no longer has to deal with dropped arguments. Other minor improvements: - Add more comments to `applySignatureConversion`. Note: Error messages around failed materializations for dropped basic block arguments changed slightly. That is because those materializations are now built in `legalizeUnresolvedMaterialization` instead of `legalizeConvertedArgumentTypes`. This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion. This is a re-upload of #96207.
1 parent 0d26f65 commit bbd4af5

File tree

2 files changed

+57
-122
lines changed

2 files changed

+57
-122
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 55 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -432,34 +432,14 @@ class MoveBlockRewrite : public BlockRewrite {
432432
Block *insertBeforeBlock;
433433
};
434434

435-
/// This structure contains the information pertaining to an argument that has
436-
/// been converted.
437-
struct ConvertedArgInfo {
438-
ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
439-
Value castValue = nullptr)
440-
: newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
441-
442-
/// The start index of in the new argument list that contains arguments that
443-
/// replace the original.
444-
unsigned newArgIdx;
445-
446-
/// The number of arguments that replaced the original argument.
447-
unsigned newArgSize;
448-
449-
/// The cast value that was created to cast from the new arguments to the
450-
/// old. This only used if 'newArgSize' > 1.
451-
Value castValue;
452-
};
453-
454435
/// Block type conversion. This rewrite is partially reflected in the IR.
455436
class BlockTypeConversionRewrite : public BlockRewrite {
456437
public:
457-
BlockTypeConversionRewrite(
458-
ConversionPatternRewriterImpl &rewriterImpl, Block *block,
459-
Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
460-
const TypeConverter *converter)
438+
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
439+
Block *block, Block *origBlock,
440+
const TypeConverter *converter)
461441
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
462-
origBlock(origBlock), argInfo(argInfo), converter(converter) {}
442+
origBlock(origBlock), converter(converter) {}
463443

464444
static bool classof(const IRRewrite *rewrite) {
465445
return rewrite->getKind() == Kind::BlockTypeConversion;
@@ -479,10 +459,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
479459
/// The original block that was requested to have its signature converted.
480460
Block *origBlock;
481461

482-
/// The conversion information for each of the arguments. The information is
483-
/// std::nullopt if the argument was dropped during conversion.
484-
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
485-
486462
/// The type converter used to convert the arguments.
487463
const TypeConverter *converter;
488464
};
@@ -691,12 +667,16 @@ class CreateOperationRewrite : public OperationRewrite {
691667
/// The type of materialization.
692668
enum MaterializationKind {
693669
/// This materialization materializes a conversion for an illegal block
694-
/// argument type, to a legal one.
670+
/// argument type, to the original one.
695671
Argument,
696672

697673
/// This materialization materializes a conversion from an illegal type to a
698674
/// legal one.
699-
Target
675+
Target,
676+
677+
/// This materialization materializes a conversion from a legal type back to
678+
/// an illegal one.
679+
Source
700680
};
701681

702682
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@@ -736,7 +716,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
736716
private:
737717
/// The corresponding type converter to use when resolving this
738718
/// materialization, and the kind of this materialization.
739-
llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
719+
llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
740720
converterAndKind;
741721
};
742722
} // namespace
@@ -855,11 +835,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
855835
ValueRange inputs, Type outputType,
856836
const TypeConverter *converter);
857837

858-
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
859-
ValueRange inputs,
860-
Type outputType,
861-
const TypeConverter *converter);
862-
863838
Value buildUnresolvedTargetMaterialization(Location loc, Value input,
864839
Type outputType,
865840
const TypeConverter *converter);
@@ -989,28 +964,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
989964
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
990965
for (Operation *op : block->getUsers())
991966
listener->notifyOperationModified(op);
992-
993-
// Process the remapping for each of the original arguments.
994-
for (auto [origArg, info] :
995-
llvm::zip_equal(origBlock->getArguments(), argInfo)) {
996-
// Handle the case of a 1->0 value mapping.
997-
if (!info) {
998-
if (Value newArg =
999-
rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
1000-
rewriter.replaceAllUsesWith(origArg, newArg);
1001-
continue;
1002-
}
1003-
1004-
// Otherwise this is a 1->1+ value mapping.
1005-
Value castValue = info->castValue;
1006-
assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
1007-
1008-
// If the argument is still used, replace it with the generated cast.
1009-
if (!origArg.use_empty()) {
1010-
rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
1011-
castValue, origArg.getType()));
1012-
}
1013-
}
1014967
}
1015968

1016969
void BlockTypeConversionRewrite::rollback() {
@@ -1035,14 +988,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
1035988
continue;
1036989

1037990
Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
1038-
bool isDroppedArg = replacementValue == origArg;
1039-
if (!isDroppedArg)
1040-
builder.setInsertionPointAfterValue(replacementValue);
991+
assert(replacementValue && "replacement value not found");
1041992
Value newArg;
1042993
if (converter) {
994+
builder.setInsertionPointAfterValue(replacementValue);
1043995
newArg = converter->materializeSourceConversion(
1044-
builder, origArg.getLoc(), origArg.getType(),
1045-
isDroppedArg ? ValueRange() : ValueRange(replacementValue));
996+
builder, origArg.getLoc(), origArg.getType(), replacementValue);
1046997
assert((!newArg || newArg.getType() == origArg.getType()) &&
1047998
"materialization hook did not provide a value of the expected "
1048999
"type");
@@ -1053,8 +1004,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
10531004
<< "failed to materialize conversion for block argument #"
10541005
<< it.index() << " that remained live after conversion, type was "
10551006
<< origArg.getType();
1056-
if (!isDroppedArg)
1057-
diag << ", with target type " << replacementValue.getType();
10581007
diag.attachNote(liveUser->getLoc())
10591008
<< "see existing live user here: " << *liveUser;
10601009
return failure();
@@ -1340,73 +1289,64 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13401289
// Replace all uses of the old block with the new block.
13411290
block->replaceAllUsesWith(newBlock);
13421291

1343-
// Remap each of the original arguments as determined by the signature
1344-
// conversion.
1345-
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
1346-
argInfo.resize(origArgCount);
1347-
13481292
for (unsigned i = 0; i != origArgCount; ++i) {
1349-
auto inputMap = signatureConversion.getInputMapping(i);
1350-
if (!inputMap)
1351-
continue;
13521293
BlockArgument origArg = block->getArgument(i);
1294+
Type origArgType = origArg.getType();
1295+
1296+
std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1297+
signatureConversion.getInputMapping(i);
1298+
if (!inputMap) {
1299+
// This block argument was dropped and no replacement value was provided.
1300+
// Materialize a replacement value "out of thin air".
1301+
Value repl = buildUnresolvedMaterialization(
1302+
MaterializationKind::Source, newBlock, newBlock->begin(),
1303+
origArg.getLoc(), /*inputs=*/ValueRange(),
1304+
/*outputType=*/origArgType, converter);
1305+
mapping.map(origArg, repl);
1306+
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1307+
continue;
1308+
}
13531309

1354-
// If inputMap->replacementValue is not nullptr, then the argument is
1355-
// dropped and a replacement value is provided to be the remappedValue.
1356-
if (inputMap->replacementValue) {
1310+
if (Value repl = inputMap->replacementValue) {
1311+
// This block argument was dropped and a replacement value was provided.
13571312
assert(inputMap->size == 0 &&
13581313
"invalid to provide a replacement value when the argument isn't "
13591314
"dropped");
1360-
mapping.map(origArg, inputMap->replacementValue);
1315+
mapping.map(origArg, repl);
13611316
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
13621317
continue;
13631318
}
13641319

1365-
// Otherwise, this is a 1->1+ mapping.
1320+
// This is a 1->1+ mapping. 1->N mappings are not fully supported in the
1321+
// dialect conversion. Therefore, we need an argument materialization to
1322+
// turn the replacement block arguments into a single SSA value that can be
1323+
// used as a replacement.
13661324
auto replArgs =
13671325
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1368-
Value newArg;
1326+
Value argMat = buildUnresolvedMaterialization(
1327+
MaterializationKind::Argument, newBlock, newBlock->begin(),
1328+
origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
1329+
mapping.map(origArg, argMat);
1330+
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
13691331

1370-
// If this is a 1->1 mapping and the types of new and replacement arguments
1371-
// match (i.e. it's an identity map), then the argument is mapped to its
1372-
// original type.
13731332
// FIXME: We simply pass through the replacement argument if there wasn't a
13741333
// converter, which isn't great as it allows implicit type conversions to
13751334
// appear. We should properly restructure this code to handle cases where a
13761335
// converter isn't provided and also to properly handle the case where an
13771336
// argument materialization is actually a temporary source materialization
13781337
// (e.g. in the case of 1->N).
1379-
if (replArgs.size() == 1 &&
1380-
(!converter || replArgs[0].getType() == origArg.getType())) {
1381-
newArg = replArgs.front();
1382-
mapping.map(origArg, newArg);
1383-
} else {
1384-
// Build argument materialization: new block arguments -> old block
1385-
// argument type.
1386-
Value argMat = buildUnresolvedArgumentMaterialization(
1387-
newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
1388-
mapping.map(origArg, argMat);
1389-
1390-
// Build target materialization: old block argument type -> legal type.
1391-
// Note: This function returns an "empty" type if no valid conversion to
1392-
// a legal type exists. In that case, we continue the conversion with the
1393-
// original block argument type.
1394-
Type legalOutputType = converter->convertType(origArg.getType());
1395-
if (legalOutputType && legalOutputType != origArg.getType()) {
1396-
newArg = buildUnresolvedTargetMaterialization(
1397-
origArg.getLoc(), argMat, legalOutputType, converter);
1398-
mapping.map(argMat, newArg);
1399-
} else {
1400-
newArg = argMat;
1401-
}
1338+
Type legalOutputType;
1339+
if (converter)
1340+
legalOutputType = converter->convertType(origArgType);
1341+
if (legalOutputType && legalOutputType != origArgType) {
1342+
Value targetMat = buildUnresolvedTargetMaterialization(
1343+
origArg.getLoc(), argMat, legalOutputType, converter);
1344+
mapping.map(argMat, targetMat);
14021345
}
1403-
14041346
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1405-
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
14061347
}
14071348

1408-
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
1409-
converter);
1349+
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
14101350

14111351
// Erase the old block. (It is just unlinked for now and will be erased during
14121352
// cleanup.)
@@ -1437,13 +1377,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
14371377
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
14381378
return convertOp.getResult(0);
14391379
}
1440-
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
1441-
Block *block, Location loc, ValueRange inputs, Type outputType,
1442-
const TypeConverter *converter) {
1443-
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
1444-
block->begin(), loc, inputs, outputType,
1445-
converter);
1446-
}
14471380
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
14481381
Location loc, Value input, Type outputType,
14491382
const TypeConverter *converter) {
@@ -2862,6 +2795,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
28622795
newMaterialization = converter->materializeTargetConversion(
28632796
rewriter, op->getLoc(), outputType, inputOperands);
28642797
break;
2798+
case MaterializationKind::Source:
2799+
newMaterialization = converter->materializeSourceConversion(
2800+
rewriter, op->getLoc(), outputType, inputOperands);
2801+
break;
28652802
}
28662803
if (newMaterialization) {
28672804
assert(newMaterialization.getType() == outputType &&
@@ -2874,8 +2811,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
28742811

28752812
InFlightDiagnostic diag = op->emitError()
28762813
<< "failed to legalize unresolved materialization "
2877-
"from "
2878-
<< inputOperands.getTypes() << " to " << outputType
2814+
"from ("
2815+
<< inputOperands.getTypes() << ") to " << outputType
28792816
<< " that remained live after conversion";
28802817
if (Operation *liveUser = findLiveUser(op->getUsers())) {
28812818
diag.attachNote(liveUser->getLoc())

mlir/test/Transforms/test-legalize-type-conversion.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33

44
func.func @test_invalid_arg_materialization(
5-
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
5+
// expected-error@below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
66
%arg0: i16) {
7-
// expected-note@below {{see existing live user here}}
87
"foo.return"(%arg0) : (i16) -> ()
98
}
109

@@ -104,9 +103,8 @@ func.func @test_block_argument_not_converted() {
104103
// Make sure argument type changes aren't implicitly forwarded.
105104
func.func @test_signature_conversion_no_converter() {
106105
"test.signature_conversion_no_converter"() ({
107-
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
106+
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
108107
^bb0(%arg0: f32):
109-
// expected-note@below {{see existing live user here}}
110108
"test.type_consumer"(%arg0) : (f32) -> ()
111109
"test.return"(%arg0) : (f32) -> ()
112110
}) : () -> ()

0 commit comments

Comments
 (0)