@@ -432,34 +432,14 @@ class MoveBlockRewrite : public BlockRewrite {
432
432
Block *insertBeforeBlock;
433
433
};
434
434
435
- // / This structure contains the information pertaining to an argument that has
436
- // / been converted.
437
- struct ConvertedArgInfo {
438
- ConvertedArgInfo (unsigned newArgIdx, unsigned newArgSize,
439
- Value castValue = nullptr )
440
- : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
441
-
442
- // / The start index of in the new argument list that contains arguments that
443
- // / replace the original.
444
- unsigned newArgIdx;
445
-
446
- // / The number of arguments that replaced the original argument.
447
- unsigned newArgSize;
448
-
449
- // / The cast value that was created to cast from the new arguments to the
450
- // / old. This only used if 'newArgSize' > 1.
451
- Value castValue;
452
- };
453
-
454
435
// / Block type conversion. This rewrite is partially reflected in the IR.
455
436
class BlockTypeConversionRewrite : public BlockRewrite {
456
437
public:
457
- BlockTypeConversionRewrite (
458
- ConversionPatternRewriterImpl &rewriterImpl, Block *block,
459
- Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1 > argInfo,
460
- const TypeConverter *converter)
438
+ BlockTypeConversionRewrite (ConversionPatternRewriterImpl &rewriterImpl,
439
+ Block *block, Block *origBlock,
440
+ const TypeConverter *converter)
461
441
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
462
- origBlock (origBlock), argInfo(argInfo), converter(converter) {}
442
+ origBlock (origBlock), converter(converter) {}
463
443
464
444
static bool classof (const IRRewrite *rewrite) {
465
445
return rewrite->getKind () == Kind::BlockTypeConversion;
@@ -479,10 +459,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
479
459
// / The original block that was requested to have its signature converted.
480
460
Block *origBlock;
481
461
482
- // / The conversion information for each of the arguments. The information is
483
- // / std::nullopt if the argument was dropped during conversion.
484
- SmallVector<std::optional<ConvertedArgInfo>, 1 > argInfo;
485
-
486
462
// / The type converter used to convert the arguments.
487
463
const TypeConverter *converter;
488
464
};
@@ -691,12 +667,16 @@ class CreateOperationRewrite : public OperationRewrite {
691
667
// / The type of materialization.
692
668
enum MaterializationKind {
693
669
// / This materialization materializes a conversion for an illegal block
694
- // / argument type, to a legal one.
670
+ // / argument type, to the original one.
695
671
Argument,
696
672
697
673
// / This materialization materializes a conversion from an illegal type to a
698
674
// / legal one.
699
- Target
675
+ Target,
676
+
677
+ // / This materialization materializes a conversion from a legal type back to
678
+ // / an illegal one.
679
+ Source
700
680
};
701
681
702
682
// / An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@@ -736,7 +716,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
736
716
private:
737
717
// / The corresponding type converter to use when resolving this
738
718
// / materialization, and the kind of this materialization.
739
- llvm::PointerIntPair<const TypeConverter *, 1 , MaterializationKind>
719
+ llvm::PointerIntPair<const TypeConverter *, 2 , MaterializationKind>
740
720
converterAndKind;
741
721
};
742
722
} // namespace
@@ -855,11 +835,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
855
835
ValueRange inputs, Type outputType,
856
836
const TypeConverter *converter);
857
837
858
- Value buildUnresolvedArgumentMaterialization (Block *block, Location loc,
859
- ValueRange inputs,
860
- Type outputType,
861
- const TypeConverter *converter);
862
-
863
838
Value buildUnresolvedTargetMaterialization (Location loc, Value input,
864
839
Type outputType,
865
840
const TypeConverter *converter);
@@ -989,28 +964,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
989
964
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener ()))
990
965
for (Operation *op : block->getUsers ())
991
966
listener->notifyOperationModified (op);
992
-
993
- // Process the remapping for each of the original arguments.
994
- for (auto [origArg, info] :
995
- llvm::zip_equal (origBlock->getArguments (), argInfo)) {
996
- // Handle the case of a 1->0 value mapping.
997
- if (!info) {
998
- if (Value newArg =
999
- rewriterImpl.mapping .lookupOrNull (origArg, origArg.getType ()))
1000
- rewriter.replaceAllUsesWith (origArg, newArg);
1001
- continue ;
1002
- }
1003
-
1004
- // Otherwise this is a 1->1+ value mapping.
1005
- Value castValue = info->castValue ;
1006
- assert (info->newArgSize >= 1 && castValue && " expected 1->1+ mapping" );
1007
-
1008
- // If the argument is still used, replace it with the generated cast.
1009
- if (!origArg.use_empty ()) {
1010
- rewriter.replaceAllUsesWith (origArg, rewriterImpl.mapping .lookupOrDefault (
1011
- castValue, origArg.getType ()));
1012
- }
1013
- }
1014
967
}
1015
968
1016
969
void BlockTypeConversionRewrite::rollback () {
@@ -1035,14 +988,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
1035
988
continue ;
1036
989
1037
990
Value replacementValue = rewriterImpl.mapping .lookupOrDefault (origArg);
1038
- bool isDroppedArg = replacementValue == origArg;
1039
- if (!isDroppedArg)
1040
- builder.setInsertionPointAfterValue (replacementValue);
991
+ assert (replacementValue && " replacement value not found" );
1041
992
Value newArg;
1042
993
if (converter) {
994
+ builder.setInsertionPointAfterValue (replacementValue);
1043
995
newArg = converter->materializeSourceConversion (
1044
- builder, origArg.getLoc (), origArg.getType (),
1045
- isDroppedArg ? ValueRange () : ValueRange (replacementValue));
996
+ builder, origArg.getLoc (), origArg.getType (), replacementValue);
1046
997
assert ((!newArg || newArg.getType () == origArg.getType ()) &&
1047
998
" materialization hook did not provide a value of the expected "
1048
999
" type" );
@@ -1053,8 +1004,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
1053
1004
<< " failed to materialize conversion for block argument #"
1054
1005
<< it.index () << " that remained live after conversion, type was "
1055
1006
<< origArg.getType ();
1056
- if (!isDroppedArg)
1057
- diag << " , with target type " << replacementValue.getType ();
1058
1007
diag.attachNote (liveUser->getLoc ())
1059
1008
<< " see existing live user here: " << *liveUser;
1060
1009
return failure ();
@@ -1340,73 +1289,64 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1340
1289
// Replace all uses of the old block with the new block.
1341
1290
block->replaceAllUsesWith (newBlock);
1342
1291
1343
- // Remap each of the original arguments as determined by the signature
1344
- // conversion.
1345
- SmallVector<std::optional<ConvertedArgInfo>, 1 > argInfo;
1346
- argInfo.resize (origArgCount);
1347
-
1348
1292
for (unsigned i = 0 ; i != origArgCount; ++i) {
1349
- auto inputMap = signatureConversion.getInputMapping (i);
1350
- if (!inputMap)
1351
- continue ;
1352
1293
BlockArgument origArg = block->getArgument (i);
1294
+ Type origArgType = origArg.getType ();
1295
+
1296
+ std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1297
+ signatureConversion.getInputMapping (i);
1298
+ if (!inputMap) {
1299
+ // This block argument was dropped and no replacement value was provided.
1300
+ // Materialize a replacement value "out of thin air".
1301
+ Value repl = buildUnresolvedMaterialization (
1302
+ MaterializationKind::Source, newBlock, newBlock->begin (),
1303
+ origArg.getLoc (), /* inputs=*/ ValueRange (),
1304
+ /* outputType=*/ origArgType, converter);
1305
+ mapping.map (origArg, repl);
1306
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1307
+ continue ;
1308
+ }
1353
1309
1354
- // If inputMap->replacementValue is not nullptr, then the argument is
1355
- // dropped and a replacement value is provided to be the remappedValue.
1356
- if (inputMap->replacementValue ) {
1310
+ if (Value repl = inputMap->replacementValue ) {
1311
+ // This block argument was dropped and a replacement value was provided.
1357
1312
assert (inputMap->size == 0 &&
1358
1313
" invalid to provide a replacement value when the argument isn't "
1359
1314
" dropped" );
1360
- mapping.map (origArg, inputMap-> replacementValue );
1315
+ mapping.map (origArg, repl );
1361
1316
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1362
1317
continue ;
1363
1318
}
1364
1319
1365
- // Otherwise, this is a 1->1+ mapping.
1320
+ // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
1321
+ // dialect conversion. Therefore, we need an argument materialization to
1322
+ // turn the replacement block arguments into a single SSA value that can be
1323
+ // used as a replacement.
1366
1324
auto replArgs =
1367
1325
newBlock->getArguments ().slice (inputMap->inputNo , inputMap->size );
1368
- Value newArg;
1326
+ Value argMat = buildUnresolvedMaterialization (
1327
+ MaterializationKind::Argument, newBlock, newBlock->begin (),
1328
+ origArg.getLoc (), /* inputs=*/ replArgs, origArgType, converter);
1329
+ mapping.map (origArg, argMat);
1330
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1369
1331
1370
- // If this is a 1->1 mapping and the types of new and replacement arguments
1371
- // match (i.e. it's an identity map), then the argument is mapped to its
1372
- // original type.
1373
1332
// FIXME: We simply pass through the replacement argument if there wasn't a
1374
1333
// converter, which isn't great as it allows implicit type conversions to
1375
1334
// appear. We should properly restructure this code to handle cases where a
1376
1335
// converter isn't provided and also to properly handle the case where an
1377
1336
// argument materialization is actually a temporary source materialization
1378
1337
// (e.g. in the case of 1->N).
1379
- if (replArgs.size () == 1 &&
1380
- (!converter || replArgs[0 ].getType () == origArg.getType ())) {
1381
- newArg = replArgs.front ();
1382
- mapping.map (origArg, newArg);
1383
- } else {
1384
- // Build argument materialization: new block arguments -> old block
1385
- // argument type.
1386
- Value argMat = buildUnresolvedArgumentMaterialization (
1387
- newBlock, origArg.getLoc (), replArgs, origArg.getType (), converter);
1388
- mapping.map (origArg, argMat);
1389
-
1390
- // Build target materialization: old block argument type -> legal type.
1391
- // Note: This function returns an "empty" type if no valid conversion to
1392
- // a legal type exists. In that case, we continue the conversion with the
1393
- // original block argument type.
1394
- Type legalOutputType = converter->convertType (origArg.getType ());
1395
- if (legalOutputType && legalOutputType != origArg.getType ()) {
1396
- newArg = buildUnresolvedTargetMaterialization (
1397
- origArg.getLoc (), argMat, legalOutputType, converter);
1398
- mapping.map (argMat, newArg);
1399
- } else {
1400
- newArg = argMat;
1401
- }
1338
+ Type legalOutputType;
1339
+ if (converter)
1340
+ legalOutputType = converter->convertType (origArgType);
1341
+ if (legalOutputType && legalOutputType != origArgType) {
1342
+ Value targetMat = buildUnresolvedTargetMaterialization (
1343
+ origArg.getLoc (), argMat, legalOutputType, converter);
1344
+ mapping.map (argMat, targetMat);
1402
1345
}
1403
-
1404
1346
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1405
- argInfo[i] = ConvertedArgInfo (inputMap->inputNo , inputMap->size , newArg);
1406
1347
}
1407
1348
1408
- appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
1409
- converter);
1349
+ appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
1410
1350
1411
1351
// Erase the old block. (It is just unlinked for now and will be erased during
1412
1352
// cleanup.)
@@ -1437,13 +1377,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1437
1377
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1438
1378
return convertOp.getResult (0 );
1439
1379
}
1440
- Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization (
1441
- Block *block, Location loc, ValueRange inputs, Type outputType,
1442
- const TypeConverter *converter) {
1443
- return buildUnresolvedMaterialization (MaterializationKind::Argument, block,
1444
- block->begin (), loc, inputs, outputType,
1445
- converter);
1446
- }
1447
1380
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization (
1448
1381
Location loc, Value input, Type outputType,
1449
1382
const TypeConverter *converter) {
@@ -2862,6 +2795,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
2862
2795
newMaterialization = converter->materializeTargetConversion (
2863
2796
rewriter, op->getLoc (), outputType, inputOperands);
2864
2797
break ;
2798
+ case MaterializationKind::Source:
2799
+ newMaterialization = converter->materializeSourceConversion (
2800
+ rewriter, op->getLoc (), outputType, inputOperands);
2801
+ break ;
2865
2802
}
2866
2803
if (newMaterialization) {
2867
2804
assert (newMaterialization.getType () == outputType &&
@@ -2874,8 +2811,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
2874
2811
2875
2812
InFlightDiagnostic diag = op->emitError ()
2876
2813
<< " failed to legalize unresolved materialization "
2877
- " from "
2878
- << inputOperands.getTypes () << " to " << outputType
2814
+ " from ( "
2815
+ << inputOperands.getTypes () << " ) to " << outputType
2879
2816
<< " that remained live after conversion" ;
2880
2817
if (Operation *liveUser = findLiveUser (op->getUsers ())) {
2881
2818
diag.attachNote (liveUser->getLoc ())
0 commit comments