@@ -152,17 +152,12 @@ namespace {
152
152
// / This class contains a snapshot of the current conversion rewriter state.
153
153
// / This is useful when saving and undoing a set of rewrites.
154
154
struct RewriterState {
155
- RewriterState (unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
156
- unsigned numRewrites, unsigned numIgnoredOperations,
157
- unsigned numErased)
158
- : numCreatedOps(numCreatedOps),
159
- numUnresolvedMaterializations (numUnresolvedMaterializations),
155
+ RewriterState (unsigned numUnresolvedMaterializations, unsigned numRewrites,
156
+ unsigned numIgnoredOperations, unsigned numErased)
157
+ : numUnresolvedMaterializations(numUnresolvedMaterializations),
160
158
numRewrites (numRewrites), numIgnoredOperations(numIgnoredOperations),
161
159
numErased(numErased) {}
162
160
163
- // / The current number of created operations.
164
- unsigned numCreatedOps;
165
-
166
161
// / The current number of unresolved materializations.
167
162
unsigned numUnresolvedMaterializations;
168
163
@@ -299,7 +294,8 @@ class IRRewrite {
299
294
ReplaceBlockArg,
300
295
MoveOperation,
301
296
ModifyOperation,
302
- ReplaceOperation
297
+ ReplaceOperation,
298
+ CreateOperation
303
299
};
304
300
305
301
virtual ~IRRewrite () = default ;
@@ -372,7 +368,11 @@ class CreateBlockRewrite : public BlockRewrite {
372
368
auto &blockOps = block->getOperations ();
373
369
while (!blockOps.empty ())
374
370
blockOps.remove (blockOps.begin ());
375
- eraseBlock (block);
371
+ if (block->getParent ()) {
372
+ eraseBlock (block);
373
+ } else {
374
+ delete block;
375
+ }
376
376
}
377
377
};
378
378
@@ -602,7 +602,7 @@ class OperationRewrite : public IRRewrite {
602
602
603
603
static bool classof (const IRRewrite *rewrite) {
604
604
return rewrite->getKind () >= Kind::MoveOperation &&
605
- rewrite->getKind () <= Kind::ReplaceOperation ;
605
+ rewrite->getKind () <= Kind::CreateOperation ;
606
606
}
607
607
608
608
protected:
@@ -708,6 +708,19 @@ class ReplaceOperationRewrite : public OperationRewrite {
708
708
// / 1->N conversion of some kind.
709
709
bool changedResults;
710
710
};
711
+
712
+ class CreateOperationRewrite : public OperationRewrite {
713
+ public:
714
+ CreateOperationRewrite (ConversionPatternRewriterImpl &rewriterImpl,
715
+ Operation *op)
716
+ : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
717
+
718
+ static bool classof (const IRRewrite *rewrite) {
719
+ return rewrite->getKind () == Kind::CreateOperation;
720
+ }
721
+
722
+ void rollback () override ;
723
+ };
711
724
} // namespace
712
725
713
726
// / Return "true" if there is an operation rewrite that matches the specified
@@ -925,9 +938,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
925
938
// replacing a value with one of a different type.
926
939
ConversionValueMapping mapping;
927
940
928
- // / Ordered vector of all of the newly created operations during conversion.
929
- SmallVector<Operation *> createdOps;
930
-
931
941
// / Ordered vector of all unresolved type conversion materializations during
932
942
// / conversion.
933
943
SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
@@ -1110,7 +1120,18 @@ void ReplaceOperationRewrite::rollback() {
1110
1120
1111
1121
void ReplaceOperationRewrite::cleanup () { eraseOp (op); }
1112
1122
1123
+ void CreateOperationRewrite::rollback () {
1124
+ for (Region ®ion : op->getRegions ()) {
1125
+ while (!region.getBlocks ().empty ())
1126
+ region.getBlocks ().remove (region.getBlocks ().begin ());
1127
+ }
1128
+ op->dropAllUses ();
1129
+ eraseOp (op);
1130
+ }
1131
+
1113
1132
void ConversionPatternRewriterImpl::detachNestedAndErase (Operation *op) {
1133
+ // if (erasedIR.erasedOps.contains(op)) return;
1134
+
1114
1135
for (Region ®ion : op->getRegions ()) {
1115
1136
for (Block &block : region.getBlocks ()) {
1116
1137
while (!block.getOperations ().empty ())
@@ -1127,8 +1148,6 @@ void ConversionPatternRewriterImpl::discardRewrites() {
1127
1148
// Remove any newly created ops.
1128
1149
for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
1129
1150
detachNestedAndErase (materialization.getOp ());
1130
- for (auto *op : llvm::reverse (createdOps))
1131
- detachNestedAndErase (op);
1132
1151
}
1133
1152
1134
1153
void ConversionPatternRewriterImpl::applyRewrites () {
@@ -1148,9 +1167,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
1148
1167
// State Management
1149
1168
1150
1169
RewriterState ConversionPatternRewriterImpl::getCurrentState () {
1151
- return RewriterState (createdOps.size (), unresolvedMaterializations.size (),
1152
- rewrites.size (), ignoredOps.size (),
1153
- eraseRewriter.erased .size ());
1170
+ return RewriterState (unresolvedMaterializations.size (), rewrites.size (),
1171
+ ignoredOps.size (), eraseRewriter.erased .size ());
1154
1172
}
1155
1173
1156
1174
void ConversionPatternRewriterImpl::resetState (RewriterState state) {
@@ -1171,12 +1189,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
1171
1189
detachNestedAndErase (op);
1172
1190
}
1173
1191
1174
- // Pop all of the newly created operations.
1175
- while (createdOps.size () != state.numCreatedOps ) {
1176
- detachNestedAndErase (createdOps.back ());
1177
- createdOps.pop_back ();
1178
- }
1179
-
1180
1192
// Pop all of the recorded ignored operations that are no longer valid.
1181
1193
while (ignoredOps.size () != state.numIgnoredOperations )
1182
1194
ignoredOps.pop_back ();
@@ -1460,7 +1472,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
1460
1472
});
1461
1473
if (!previous.isSet ()) {
1462
1474
// This is a newly created op.
1463
- createdOps. push_back (op);
1475
+ appendRewrite<CreateOperationRewrite> (op);
1464
1476
return ;
1465
1477
}
1466
1478
Operation *prevOp = previous.getPoint () == previous.getBlock ()->end ()
@@ -1961,13 +1973,16 @@ OperationLegalizer::legalizeWithFold(Operation *op,
1961
1973
rewriter.replaceOp (op, replacementValues);
1962
1974
1963
1975
// Recursively legalize any new constant operations.
1964
- for (unsigned i = curState.numCreatedOps , e = rewriterImpl.createdOps .size ();
1976
+ for (unsigned i = curState.numRewrites , e = rewriterImpl.rewrites .size ();
1965
1977
i != e; ++i) {
1966
- Operation *cstOp = rewriterImpl.createdOps [i];
1967
- if (failed (legalize (cstOp, rewriter))) {
1978
+ auto *createOp =
1979
+ dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites [i].get ());
1980
+ if (!createOp)
1981
+ continue ;
1982
+ if (failed (legalize (createOp->getOperation (), rewriter))) {
1968
1983
LLVM_DEBUG (logFailure (rewriterImpl.logger ,
1969
1984
" failed to legalize generated constant '{0}'" ,
1970
- cstOp ->getName ()));
1985
+ createOp-> getOperation () ->getName ()));
1971
1986
rewriterImpl.resetState (curState);
1972
1987
return failure ();
1973
1988
}
@@ -2112,9 +2127,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2112
2127
// blocks in regions created by this pattern will already be legalized later
2113
2128
// on. If we haven't built the set yet, build it now.
2114
2129
if (operationsToIgnore.empty ()) {
2115
- auto createdOps = ArrayRef<Operation *>(impl.createdOps )
2116
- .drop_front (state.numCreatedOps );
2117
- operationsToIgnore.insert (createdOps.begin (), createdOps.end ());
2130
+ for (unsigned i = state.numRewrites , e = impl.rewrites .size (); i != e;
2131
+ ++i) {
2132
+ auto *createOp =
2133
+ dyn_cast<CreateOperationRewrite>(impl.rewrites [i].get ());
2134
+ if (!createOp)
2135
+ continue ;
2136
+ operationsToIgnore.insert (createOp->getOperation ());
2137
+ }
2118
2138
}
2119
2139
2120
2140
// If this operation should be considered for re-legalization, try it.
@@ -2132,8 +2152,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2132
2152
LogicalResult OperationLegalizer::legalizePatternCreatedOperations (
2133
2153
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2134
2154
RewriterState &state, RewriterState &newState) {
2135
- for (int i = state.numCreatedOps , e = newState.numCreatedOps ; i != e; ++i) {
2136
- Operation *op = impl.createdOps [i];
2155
+ for (int i = state.numRewrites , e = newState.numRewrites ; i != e; ++i) {
2156
+ auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites [i].get ());
2157
+ if (!createOp)
2158
+ continue ;
2159
+ Operation *op = createOp->getOperation ();
2137
2160
if (failed (legalize (op, rewriter))) {
2138
2161
LLVM_DEBUG (logFailure (impl.logger ,
2139
2162
" failed to legalize generated operation '{0}'({1})" ,
@@ -2563,10 +2586,15 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2563
2586
});
2564
2587
return liveUserIt == val.user_end () ? nullptr : *liveUserIt;
2565
2588
};
2566
- for (auto &r : rewriterImpl.rewrites )
2567
- if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get ()))
2568
- if (failed (rewrite->materializeLiveConversions (findLiveUser)))
2589
+ // Note: `rewrites` may be reallocated as the loop is running.
2590
+ for (int64_t i = 0 ; i < rewriterImpl.rewrites .size (); ++i) {
2591
+ auto &rewrite = rewriterImpl.rewrites [i];
2592
+ if (auto *blockTypeConversionRewrite =
2593
+ dyn_cast<BlockTypeConversionRewrite>(rewrite.get ()))
2594
+ if (failed (blockTypeConversionRewrite->materializeLiveConversions (
2595
+ findLiveUser)))
2569
2596
return failure ();
2597
+ }
2570
2598
return success ();
2571
2599
}
2572
2600
0 commit comments