@@ -431,14 +431,34 @@ class MoveBlockRewrite : public BlockRewrite {
431
431
Block *insertBeforeBlock;
432
432
};
433
433
434
+ // / This structure contains the information pertaining to an argument that has
435
+ // / been converted.
436
+ struct ConvertedArgInfo {
437
+ ConvertedArgInfo (unsigned newArgIdx, unsigned newArgSize,
438
+ Value castValue = nullptr )
439
+ : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
440
+
441
+ // / The start index of in the new argument list that contains arguments that
442
+ // / replace the original.
443
+ unsigned newArgIdx;
444
+
445
+ // / The number of arguments that replaced the original argument.
446
+ unsigned newArgSize;
447
+
448
+ // / The cast value that was created to cast from the new arguments to the
449
+ // / old. This only used if 'newArgSize' > 1.
450
+ Value castValue;
451
+ };
452
+
434
453
// / Block type conversion. This rewrite is partially reflected in the IR.
435
454
class BlockTypeConversionRewrite : public BlockRewrite {
436
455
public:
437
- BlockTypeConversionRewrite (ConversionPatternRewriterImpl &rewriterImpl,
438
- Block *block, Block *origBlock,
439
- const TypeConverter *converter)
456
+ BlockTypeConversionRewrite (
457
+ ConversionPatternRewriterImpl &rewriterImpl, Block *block,
458
+ Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1 > argInfo,
459
+ const TypeConverter *converter)
440
460
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
441
- origBlock (origBlock), converter(converter) {}
461
+ origBlock (origBlock), argInfo(argInfo), converter(converter) {}
442
462
443
463
static bool classof (const IRRewrite *rewrite) {
444
464
return rewrite->getKind () == Kind::BlockTypeConversion;
@@ -458,6 +478,10 @@ class BlockTypeConversionRewrite : public BlockRewrite {
458
478
// / The original block that was requested to have its signature converted.
459
479
Block *origBlock;
460
480
481
+ // / The conversion information for each of the arguments. The information is
482
+ // / std::nullopt if the argument was dropped during conversion.
483
+ SmallVector<std::optional<ConvertedArgInfo>, 1 > argInfo;
484
+
461
485
// / The type converter used to convert the arguments.
462
486
const TypeConverter *converter;
463
487
};
@@ -666,16 +690,12 @@ class CreateOperationRewrite : public OperationRewrite {
666
690
// / The type of materialization.
667
691
enum MaterializationKind {
668
692
// / This materialization materializes a conversion for an illegal block
669
- // / argument type, to the original one.
693
+ // / argument type, to a legal one.
670
694
Argument,
671
695
672
696
// / This materialization materializes a conversion from an illegal type to a
673
697
// / legal one.
674
- Target,
675
-
676
- // / This materialization materializes a conversion from a legal type back to
677
- // / an illegal one.
678
- Source
698
+ Target
679
699
};
680
700
681
701
// / An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@@ -715,7 +735,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
715
735
private:
716
736
// / The corresponding type converter to use when resolving this
717
737
// / materialization, and the kind of this materialization.
718
- llvm::PointerIntPair<const TypeConverter *, 2 , MaterializationKind>
738
+ llvm::PointerIntPair<const TypeConverter *, 1 , MaterializationKind>
719
739
converterAndKind;
720
740
};
721
741
} // namespace
@@ -834,6 +854,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
834
854
ValueRange inputs, Type outputType,
835
855
const TypeConverter *converter);
836
856
857
+ Value buildUnresolvedArgumentMaterialization (Block *block, Location loc,
858
+ ValueRange inputs,
859
+ Type outputType,
860
+ const TypeConverter *converter);
861
+
837
862
Value buildUnresolvedTargetMaterialization (Location loc, Value input,
838
863
Type outputType,
839
864
const TypeConverter *converter);
@@ -963,6 +988,28 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
963
988
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener ()))
964
989
for (Operation *op : block->getUsers ())
965
990
listener->notifyOperationModified (op);
991
+
992
+ // Process the remapping for each of the original arguments.
993
+ for (auto [origArg, info] :
994
+ llvm::zip_equal (origBlock->getArguments (), argInfo)) {
995
+ // Handle the case of a 1->0 value mapping.
996
+ if (!info) {
997
+ if (Value newArg =
998
+ rewriterImpl.mapping .lookupOrNull (origArg, origArg.getType ()))
999
+ rewriter.replaceAllUsesWith (origArg, newArg);
1000
+ continue ;
1001
+ }
1002
+
1003
+ // Otherwise this is a 1->1+ value mapping.
1004
+ Value castValue = info->castValue ;
1005
+ assert (info->newArgSize >= 1 && castValue && " expected 1->1+ mapping" );
1006
+
1007
+ // If the argument is still used, replace it with the generated cast.
1008
+ if (!origArg.use_empty ()) {
1009
+ rewriter.replaceAllUsesWith (origArg, rewriterImpl.mapping .lookupOrDefault (
1010
+ castValue, origArg.getType ()));
1011
+ }
1012
+ }
966
1013
}
967
1014
968
1015
void BlockTypeConversionRewrite::rollback () {
@@ -987,12 +1034,14 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
987
1034
continue ;
988
1035
989
1036
Value replacementValue = rewriterImpl.mapping .lookupOrDefault (origArg);
990
- assert (replacementValue && " replacement value not found" );
1037
+ bool isDroppedArg = replacementValue == origArg;
1038
+ if (!isDroppedArg)
1039
+ builder.setInsertionPointAfterValue (replacementValue);
991
1040
Value newArg;
992
1041
if (converter) {
993
- builder.setInsertionPointAfterValue (replacementValue);
994
1042
newArg = converter->materializeSourceConversion (
995
- builder, origArg.getLoc (), origArg.getType (), replacementValue);
1043
+ builder, origArg.getLoc (), origArg.getType (),
1044
+ isDroppedArg ? ValueRange () : ValueRange (replacementValue));
996
1045
assert ((!newArg || newArg.getType () == origArg.getType ()) &&
997
1046
" materialization hook did not provide a value of the expected "
998
1047
" type" );
@@ -1003,6 +1052,8 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
1003
1052
<< " failed to materialize conversion for block argument #"
1004
1053
<< it.index () << " that remained live after conversion, type was "
1005
1054
<< origArg.getType ();
1055
+ if (!isDroppedArg)
1056
+ diag << " , with target type " << replacementValue.getType ();
1006
1057
diag.attachNote (liveUser->getLoc ())
1007
1058
<< " see existing live user here: " << *liveUser;
1008
1059
return failure ();
@@ -1288,64 +1339,73 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1288
1339
// Replace all uses of the old block with the new block.
1289
1340
block->replaceAllUsesWith (newBlock);
1290
1341
1342
+ // Remap each of the original arguments as determined by the signature
1343
+ // conversion.
1344
+ SmallVector<std::optional<ConvertedArgInfo>, 1 > argInfo;
1345
+ argInfo.resize (origArgCount);
1346
+
1291
1347
for (unsigned i = 0 ; i != origArgCount; ++i) {
1292
- BlockArgument origArg = block->getArgument (i);
1293
- Type origArgType = origArg.getType ();
1294
-
1295
- std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1296
- signatureConversion.getInputMapping (i);
1297
- if (!inputMap) {
1298
- // This block argument was dropped and no replacement value was provided.
1299
- // Materialize a replacement value "out of thin air".
1300
- Value repl = buildUnresolvedMaterialization (
1301
- MaterializationKind::Source, newBlock, newBlock->begin (),
1302
- origArg.getLoc (), /* inputs=*/ ValueRange (),
1303
- /* outputType=*/ origArgType, converter);
1304
- mapping.map (origArg, repl);
1305
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1348
+ auto inputMap = signatureConversion.getInputMapping (i);
1349
+ if (!inputMap)
1306
1350
continue ;
1307
- }
1351
+ BlockArgument origArg = block-> getArgument (i);
1308
1352
1309
- if (Value repl = inputMap->replacementValue ) {
1310
- // This block argument was dropped and a replacement value was provided.
1353
+ // If inputMap->replacementValue is not nullptr, then the argument is
1354
+ // dropped and a replacement value is provided to be the remappedValue.
1355
+ if (inputMap->replacementValue ) {
1311
1356
assert (inputMap->size == 0 &&
1312
1357
" invalid to provide a replacement value when the argument isn't "
1313
1358
" dropped" );
1314
- mapping.map (origArg, repl );
1359
+ mapping.map (origArg, inputMap-> replacementValue );
1315
1360
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1316
1361
continue ;
1317
1362
}
1318
1363
1319
- // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
1320
- // dialect conversion. Therefore, we need an argument materialization to
1321
- // turn the replacement block arguments into a single SSA value that can be
1322
- // used as a replacement.
1364
+ // Otherwise, this is a 1->1+ mapping.
1323
1365
auto replArgs =
1324
1366
newBlock->getArguments ().slice (inputMap->inputNo , inputMap->size );
1325
- Value argMat = buildUnresolvedMaterialization (
1326
- MaterializationKind::Argument, newBlock, newBlock->begin (),
1327
- origArg.getLoc (), /* inputs=*/ replArgs, origArgType, converter);
1328
- mapping.map (origArg, argMat);
1329
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1367
+ Value newArg;
1330
1368
1369
+ // If this is a 1->1 mapping and the types of new and replacement arguments
1370
+ // match (i.e. it's an identity map), then the argument is mapped to its
1371
+ // original type.
1331
1372
// FIXME: We simply pass through the replacement argument if there wasn't a
1332
1373
// converter, which isn't great as it allows implicit type conversions to
1333
1374
// appear. We should properly restructure this code to handle cases where a
1334
1375
// converter isn't provided and also to properly handle the case where an
1335
1376
// argument materialization is actually a temporary source materialization
1336
1377
// (e.g. in the case of 1->N).
1337
- Type legalOutputType;
1338
- if (converter)
1339
- legalOutputType = converter->convertType (origArgType);
1340
- if (legalOutputType && legalOutputType != origArgType) {
1341
- Value targetMat = buildUnresolvedTargetMaterialization (
1342
- origArg.getLoc (), argMat, legalOutputType, converter);
1343
- mapping.map (argMat, targetMat);
1378
+ if (replArgs.size () == 1 &&
1379
+ (!converter || replArgs[0 ].getType () == origArg.getType ())) {
1380
+ newArg = replArgs.front ();
1381
+ mapping.map (origArg, newArg);
1382
+ } else {
1383
+ // Build argument materialization: new block arguments -> old block
1384
+ // argument type.
1385
+ Value argMat = buildUnresolvedArgumentMaterialization (
1386
+ newBlock, origArg.getLoc (), replArgs, origArg.getType (), converter);
1387
+ mapping.map (origArg, argMat);
1388
+
1389
+ // Build target materialization: old block argument type -> legal type.
1390
+ // Note: This function returns an "empty" type if no valid conversion to
1391
+ // a legal type exists. In that case, we continue the conversion with the
1392
+ // original block argument type.
1393
+ Type legalOutputType = converter->convertType (origArg.getType ());
1394
+ if (legalOutputType && legalOutputType != origArg.getType ()) {
1395
+ newArg = buildUnresolvedTargetMaterialization (
1396
+ origArg.getLoc (), argMat, legalOutputType, converter);
1397
+ mapping.map (argMat, newArg);
1398
+ } else {
1399
+ newArg = argMat;
1400
+ }
1344
1401
}
1402
+
1345
1403
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1404
+ argInfo[i] = ConvertedArgInfo (inputMap->inputNo , inputMap->size , newArg);
1346
1405
}
1347
1406
1348
- appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
1407
+ appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
1408
+ converter);
1349
1409
1350
1410
// Erase the old block. (It is just unlinked for now and will be erased during
1351
1411
// cleanup.)
@@ -1376,6 +1436,13 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1376
1436
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1377
1437
return convertOp.getResult (0 );
1378
1438
}
1439
+ Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization (
1440
+ Block *block, Location loc, ValueRange inputs, Type outputType,
1441
+ const TypeConverter *converter) {
1442
+ return buildUnresolvedMaterialization (MaterializationKind::Argument, block,
1443
+ block->begin (), loc, inputs, outputType,
1444
+ converter);
1445
+ }
1379
1446
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization (
1380
1447
Location loc, Value input, Type outputType,
1381
1448
const TypeConverter *converter) {
@@ -2792,10 +2859,6 @@ static LogicalResult legalizeUnresolvedMaterialization(
2792
2859
newMaterialization = converter->materializeTargetConversion (
2793
2860
rewriter, op->getLoc (), outputType, inputOperands);
2794
2861
break ;
2795
- case MaterializationKind::Source:
2796
- newMaterialization = converter->materializeSourceConversion (
2797
- rewriter, op->getLoc (), outputType, inputOperands);
2798
- break ;
2799
2862
}
2800
2863
if (newMaterialization) {
2801
2864
assert (newMaterialization.getType () == outputType &&
@@ -2808,8 +2871,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
2808
2871
2809
2872
InFlightDiagnostic diag = op->emitError ()
2810
2873
<< " failed to legalize unresolved materialization "
2811
- " from ( "
2812
- << inputOperands.getTypes () << " ) to " << outputType
2874
+ " from "
2875
+ << inputOperands.getTypes () << " to " << outputType
2813
2876
<< " that remained live after conversion" ;
2814
2877
if (Operation *liveUser = findLiveUser (op->getUsers ())) {
2815
2878
diag.attachNote (liveUser->getLoc ())
0 commit comments