Skip to content

Commit cdcd673

Browse files
committed
Revert "[mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments (llvm#97213)"
This reverts commit bbd4af5.
1 parent 6245561 commit cdcd673

File tree

2 files changed

+122
-57
lines changed

2 files changed

+122
-57
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 118 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -431,14 +431,34 @@ class MoveBlockRewrite : public BlockRewrite {
431431
Block *insertBeforeBlock;
432432
};
433433

434+
/// This structure contains the information pertaining to an argument that has
435+
/// been converted.
436+
struct ConvertedArgInfo {
437+
ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
438+
Value castValue = nullptr)
439+
: newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
440+
441+
/// The start index of in the new argument list that contains arguments that
442+
/// replace the original.
443+
unsigned newArgIdx;
444+
445+
/// The number of arguments that replaced the original argument.
446+
unsigned newArgSize;
447+
448+
/// The cast value that was created to cast from the new arguments to the
449+
/// old. This only used if 'newArgSize' > 1.
450+
Value castValue;
451+
};
452+
434453
/// Block type conversion. This rewrite is partially reflected in the IR.
435454
class BlockTypeConversionRewrite : public BlockRewrite {
436455
public:
437-
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
438-
Block *block, Block *origBlock,
439-
const TypeConverter *converter)
456+
BlockTypeConversionRewrite(
457+
ConversionPatternRewriterImpl &rewriterImpl, Block *block,
458+
Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
459+
const TypeConverter *converter)
440460
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
441-
origBlock(origBlock), converter(converter) {}
461+
origBlock(origBlock), argInfo(argInfo), converter(converter) {}
442462

443463
static bool classof(const IRRewrite *rewrite) {
444464
return rewrite->getKind() == Kind::BlockTypeConversion;
@@ -458,6 +478,10 @@ class BlockTypeConversionRewrite : public BlockRewrite {
458478
/// The original block that was requested to have its signature converted.
459479
Block *origBlock;
460480

481+
/// The conversion information for each of the arguments. The information is
482+
/// std::nullopt if the argument was dropped during conversion.
483+
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
484+
461485
/// The type converter used to convert the arguments.
462486
const TypeConverter *converter;
463487
};
@@ -666,16 +690,12 @@ class CreateOperationRewrite : public OperationRewrite {
666690
/// The type of materialization.
667691
enum MaterializationKind {
668692
/// This materialization materializes a conversion for an illegal block
669-
/// argument type, to the original one.
693+
/// argument type, to a legal one.
670694
Argument,
671695

672696
/// This materialization materializes a conversion from an illegal type to a
673697
/// legal one.
674-
Target,
675-
676-
/// This materialization materializes a conversion from a legal type back to
677-
/// an illegal one.
678-
Source
698+
Target
679699
};
680700

681701
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@@ -715,7 +735,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
715735
private:
716736
/// The corresponding type converter to use when resolving this
717737
/// materialization, and the kind of this materialization.
718-
llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
738+
llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
719739
converterAndKind;
720740
};
721741
} // namespace
@@ -834,6 +854,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
834854
ValueRange inputs, Type outputType,
835855
const TypeConverter *converter);
836856

857+
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
858+
ValueRange inputs,
859+
Type outputType,
860+
const TypeConverter *converter);
861+
837862
Value buildUnresolvedTargetMaterialization(Location loc, Value input,
838863
Type outputType,
839864
const TypeConverter *converter);
@@ -963,6 +988,28 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
963988
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
964989
for (Operation *op : block->getUsers())
965990
listener->notifyOperationModified(op);
991+
992+
// Process the remapping for each of the original arguments.
993+
for (auto [origArg, info] :
994+
llvm::zip_equal(origBlock->getArguments(), argInfo)) {
995+
// Handle the case of a 1->0 value mapping.
996+
if (!info) {
997+
if (Value newArg =
998+
rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
999+
rewriter.replaceAllUsesWith(origArg, newArg);
1000+
continue;
1001+
}
1002+
1003+
// Otherwise this is a 1->1+ value mapping.
1004+
Value castValue = info->castValue;
1005+
assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
1006+
1007+
// If the argument is still used, replace it with the generated cast.
1008+
if (!origArg.use_empty()) {
1009+
rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
1010+
castValue, origArg.getType()));
1011+
}
1012+
}
9661013
}
9671014

9681015
void BlockTypeConversionRewrite::rollback() {
@@ -987,12 +1034,14 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
9871034
continue;
9881035

9891036
Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
990-
assert(replacementValue && "replacement value not found");
1037+
bool isDroppedArg = replacementValue == origArg;
1038+
if (!isDroppedArg)
1039+
builder.setInsertionPointAfterValue(replacementValue);
9911040
Value newArg;
9921041
if (converter) {
993-
builder.setInsertionPointAfterValue(replacementValue);
9941042
newArg = converter->materializeSourceConversion(
995-
builder, origArg.getLoc(), origArg.getType(), replacementValue);
1043+
builder, origArg.getLoc(), origArg.getType(),
1044+
isDroppedArg ? ValueRange() : ValueRange(replacementValue));
9961045
assert((!newArg || newArg.getType() == origArg.getType()) &&
9971046
"materialization hook did not provide a value of the expected "
9981047
"type");
@@ -1003,6 +1052,8 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
10031052
<< "failed to materialize conversion for block argument #"
10041053
<< it.index() << " that remained live after conversion, type was "
10051054
<< origArg.getType();
1055+
if (!isDroppedArg)
1056+
diag << ", with target type " << replacementValue.getType();
10061057
diag.attachNote(liveUser->getLoc())
10071058
<< "see existing live user here: " << *liveUser;
10081059
return failure();
@@ -1288,64 +1339,73 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12881339
// Replace all uses of the old block with the new block.
12891340
block->replaceAllUsesWith(newBlock);
12901341

1342+
// Remap each of the original arguments as determined by the signature
1343+
// conversion.
1344+
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
1345+
argInfo.resize(origArgCount);
1346+
12911347
for (unsigned i = 0; i != origArgCount; ++i) {
1292-
BlockArgument origArg = block->getArgument(i);
1293-
Type origArgType = origArg.getType();
1294-
1295-
std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1296-
signatureConversion.getInputMapping(i);
1297-
if (!inputMap) {
1298-
// This block argument was dropped and no replacement value was provided.
1299-
// Materialize a replacement value "out of thin air".
1300-
Value repl = buildUnresolvedMaterialization(
1301-
MaterializationKind::Source, newBlock, newBlock->begin(),
1302-
origArg.getLoc(), /*inputs=*/ValueRange(),
1303-
/*outputType=*/origArgType, converter);
1304-
mapping.map(origArg, repl);
1305-
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1348+
auto inputMap = signatureConversion.getInputMapping(i);
1349+
if (!inputMap)
13061350
continue;
1307-
}
1351+
BlockArgument origArg = block->getArgument(i);
13081352

1309-
if (Value repl = inputMap->replacementValue) {
1310-
// This block argument was dropped and a replacement value was provided.
1353+
// If inputMap->replacementValue is not nullptr, then the argument is
1354+
// dropped and a replacement value is provided to be the remappedValue.
1355+
if (inputMap->replacementValue) {
13111356
assert(inputMap->size == 0 &&
13121357
"invalid to provide a replacement value when the argument isn't "
13131358
"dropped");
1314-
mapping.map(origArg, repl);
1359+
mapping.map(origArg, inputMap->replacementValue);
13151360
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
13161361
continue;
13171362
}
13181363

1319-
// This is a 1->1+ mapping. 1->N mappings are not fully supported in the
1320-
// dialect conversion. Therefore, we need an argument materialization to
1321-
// turn the replacement block arguments into a single SSA value that can be
1322-
// used as a replacement.
1364+
// Otherwise, this is a 1->1+ mapping.
13231365
auto replArgs =
13241366
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1325-
Value argMat = buildUnresolvedMaterialization(
1326-
MaterializationKind::Argument, newBlock, newBlock->begin(),
1327-
origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
1328-
mapping.map(origArg, argMat);
1329-
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1367+
Value newArg;
13301368

1369+
// If this is a 1->1 mapping and the types of new and replacement arguments
1370+
// match (i.e. it's an identity map), then the argument is mapped to its
1371+
// original type.
13311372
// FIXME: We simply pass through the replacement argument if there wasn't a
13321373
// converter, which isn't great as it allows implicit type conversions to
13331374
// appear. We should properly restructure this code to handle cases where a
13341375
// converter isn't provided and also to properly handle the case where an
13351376
// argument materialization is actually a temporary source materialization
13361377
// (e.g. in the case of 1->N).
1337-
Type legalOutputType;
1338-
if (converter)
1339-
legalOutputType = converter->convertType(origArgType);
1340-
if (legalOutputType && legalOutputType != origArgType) {
1341-
Value targetMat = buildUnresolvedTargetMaterialization(
1342-
origArg.getLoc(), argMat, legalOutputType, converter);
1343-
mapping.map(argMat, targetMat);
1378+
if (replArgs.size() == 1 &&
1379+
(!converter || replArgs[0].getType() == origArg.getType())) {
1380+
newArg = replArgs.front();
1381+
mapping.map(origArg, newArg);
1382+
} else {
1383+
// Build argument materialization: new block arguments -> old block
1384+
// argument type.
1385+
Value argMat = buildUnresolvedArgumentMaterialization(
1386+
newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
1387+
mapping.map(origArg, argMat);
1388+
1389+
// Build target materialization: old block argument type -> legal type.
1390+
// Note: This function returns an "empty" type if no valid conversion to
1391+
// a legal type exists. In that case, we continue the conversion with the
1392+
// original block argument type.
1393+
Type legalOutputType = converter->convertType(origArg.getType());
1394+
if (legalOutputType && legalOutputType != origArg.getType()) {
1395+
newArg = buildUnresolvedTargetMaterialization(
1396+
origArg.getLoc(), argMat, legalOutputType, converter);
1397+
mapping.map(argMat, newArg);
1398+
} else {
1399+
newArg = argMat;
1400+
}
13441401
}
1402+
13451403
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1404+
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
13461405
}
13471406

1348-
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
1407+
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
1408+
converter);
13491409

13501410
// Erase the old block. (It is just unlinked for now and will be erased during
13511411
// cleanup.)
@@ -1376,6 +1436,13 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
13761436
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
13771437
return convertOp.getResult(0);
13781438
}
1439+
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
1440+
Block *block, Location loc, ValueRange inputs, Type outputType,
1441+
const TypeConverter *converter) {
1442+
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
1443+
block->begin(), loc, inputs, outputType,
1444+
converter);
1445+
}
13791446
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
13801447
Location loc, Value input, Type outputType,
13811448
const TypeConverter *converter) {
@@ -2792,10 +2859,6 @@ static LogicalResult legalizeUnresolvedMaterialization(
27922859
newMaterialization = converter->materializeTargetConversion(
27932860
rewriter, op->getLoc(), outputType, inputOperands);
27942861
break;
2795-
case MaterializationKind::Source:
2796-
newMaterialization = converter->materializeSourceConversion(
2797-
rewriter, op->getLoc(), outputType, inputOperands);
2798-
break;
27992862
}
28002863
if (newMaterialization) {
28012864
assert(newMaterialization.getType() == outputType &&
@@ -2808,8 +2871,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
28082871

28092872
InFlightDiagnostic diag = op->emitError()
28102873
<< "failed to legalize unresolved materialization "
2811-
"from ("
2812-
<< inputOperands.getTypes() << ") to " << outputType
2874+
"from "
2875+
<< inputOperands.getTypes() << " to " << outputType
28132876
<< " that remained live after conversion";
28142877
if (Operation *liveUser = findLiveUser(op->getUsers())) {
28152878
diag.attachNote(liveUser->getLoc())

mlir/test/Transforms/test-legalize-type-conversion.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33

44
func.func @test_invalid_arg_materialization(
5-
// expected-error@below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
5+
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
66
%arg0: i16) {
7+
// expected-note@below {{see existing live user here}}
78
"foo.return"(%arg0) : (i16) -> ()
89
}
910

@@ -103,8 +104,9 @@ func.func @test_block_argument_not_converted() {
103104
// Make sure argument type changes aren't implicitly forwarded.
104105
func.func @test_signature_conversion_no_converter() {
105106
"test.signature_conversion_no_converter"() ({
106-
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
107+
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
107108
^bb0(%arg0: f32):
109+
// expected-note@below {{see existing live user here}}
108110
"test.type_consumer"(%arg0) : (f32) -> ()
109111
"test.return"(%arg0) : (f32) -> ()
110112
}) : () -> ()

0 commit comments

Comments
 (0)