@@ -152,17 +152,12 @@ namespace {
152
152
// / This class contains a snapshot of the current conversion rewriter state.
153
153
// / This is useful when saving and undoing a set of rewrites.
154
154
struct RewriterState {
155
- RewriterState (unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
156
- unsigned numRewrites, unsigned numIgnoredOperations,
157
- unsigned numErased)
158
- : numCreatedOps(numCreatedOps),
159
- numUnresolvedMaterializations (numUnresolvedMaterializations),
155
+ RewriterState (unsigned numUnresolvedMaterializations, unsigned numRewrites,
156
+ unsigned numIgnoredOperations, unsigned numErased)
157
+ : numUnresolvedMaterializations(numUnresolvedMaterializations),
160
158
numRewrites (numRewrites), numIgnoredOperations(numIgnoredOperations),
161
159
numErased(numErased) {}
162
160
163
- // / The current number of created operations.
164
- unsigned numCreatedOps;
165
-
166
161
// / The current number of unresolved materializations.
167
162
unsigned numUnresolvedMaterializations;
168
163
@@ -299,7 +294,8 @@ class IRRewrite {
299
294
ReplaceBlockArg,
300
295
MoveOperation,
301
296
ModifyOperation,
302
- ReplaceOperation
297
+ ReplaceOperation,
298
+ CreateOperation
303
299
};
304
300
305
301
virtual ~IRRewrite () = default ;
@@ -372,7 +368,11 @@ class CreateBlockRewrite : public BlockRewrite {
372
368
auto &blockOps = block->getOperations ();
373
369
while (!blockOps.empty ())
374
370
blockOps.remove (blockOps.begin ());
375
- eraseBlock (block);
371
+ if (block->getParent ()) {
372
+ eraseBlock (block);
373
+ } else {
374
+ delete block;
375
+ }
376
376
}
377
377
};
378
378
@@ -602,7 +602,7 @@ class OperationRewrite : public IRRewrite {
602
602
603
603
static bool classof (const IRRewrite *rewrite) {
604
604
return rewrite->getKind () >= Kind::MoveOperation &&
605
- rewrite->getKind () <= Kind::ReplaceOperation ;
605
+ rewrite->getKind () <= Kind::CreateOperation ;
606
606
}
607
607
608
608
protected:
@@ -708,6 +708,19 @@ class ReplaceOperationRewrite : public OperationRewrite {
708
708
// / 1->N conversion of some kind.
709
709
bool changedResults;
710
710
};
711
+
712
+ class CreateOperationRewrite : public OperationRewrite {
713
+ public:
714
+ CreateOperationRewrite (ConversionPatternRewriterImpl &rewriterImpl,
715
+ Operation *op)
716
+ : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
717
+
718
+ static bool classof (const IRRewrite *rewrite) {
719
+ return rewrite->getKind () == Kind::CreateOperation;
720
+ }
721
+
722
+ void rollback () override ;
723
+ };
711
724
} // namespace
712
725
713
726
// / Return "true" if there is an operation rewrite that matches the specified
@@ -925,9 +938,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
925
938
// replacing a value with one of a different type.
926
939
ConversionValueMapping mapping;
927
940
928
- // / Ordered vector of all of the newly created operations during conversion.
929
- SmallVector<Operation *> createdOps;
930
-
931
941
// / Ordered vector of all unresolved type conversion materializations during
932
942
// / conversion.
933
943
SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
@@ -1110,7 +1120,18 @@ void ReplaceOperationRewrite::rollback() {
1110
1120
1111
1121
void ReplaceOperationRewrite::cleanup () { eraseOp (op); }
1112
1122
1123
+ void CreateOperationRewrite::rollback () {
1124
+ for (Region ®ion : op->getRegions ()) {
1125
+ while (!region.getBlocks ().empty ())
1126
+ region.getBlocks ().remove (region.getBlocks ().begin ());
1127
+ }
1128
+ op->dropAllUses ();
1129
+ eraseOp (op);
1130
+ }
1131
+
1113
1132
void ConversionPatternRewriterImpl::detachNestedAndErase (Operation *op) {
1133
+ // if (erasedIR.erasedOps.contains(op)) return;
1134
+
1114
1135
for (Region ®ion : op->getRegions ()) {
1115
1136
for (Block &block : region.getBlocks ()) {
1116
1137
while (!block.getOperations ().empty ())
@@ -1127,8 +1148,6 @@ void ConversionPatternRewriterImpl::discardRewrites() {
1127
1148
// Remove any newly created ops.
1128
1149
for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
1129
1150
detachNestedAndErase (materialization.getOp ());
1130
- for (auto *op : llvm::reverse (createdOps))
1131
- detachNestedAndErase (op);
1132
1151
}
1133
1152
1134
1153
void ConversionPatternRewriterImpl::applyRewrites () {
@@ -1148,9 +1167,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
1148
1167
// State Management
1149
1168
1150
1169
RewriterState ConversionPatternRewriterImpl::getCurrentState () {
1151
- return RewriterState (createdOps.size (), unresolvedMaterializations.size (),
1152
- rewrites.size (), ignoredOps.size (),
1153
- eraseRewriter.erased .size ());
1170
+ return RewriterState (unresolvedMaterializations.size (), rewrites.size (),
1171
+ ignoredOps.size (), eraseRewriter.erased .size ());
1154
1172
}
1155
1173
1156
1174
void ConversionPatternRewriterImpl::resetState (RewriterState state) {
@@ -1171,12 +1189,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
1171
1189
detachNestedAndErase (op);
1172
1190
}
1173
1191
1174
- // Pop all of the newly created operations.
1175
- while (createdOps.size () != state.numCreatedOps ) {
1176
- detachNestedAndErase (createdOps.back ());
1177
- createdOps.pop_back ();
1178
- }
1179
-
1180
1192
// Pop all of the recorded ignored operations that are no longer valid.
1181
1193
while (ignoredOps.size () != state.numIgnoredOperations )
1182
1194
ignoredOps.pop_back ();
@@ -1444,7 +1456,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
1444
1456
});
1445
1457
if (!previous.isSet ()) {
1446
1458
// This is a newly created op.
1447
- createdOps. push_back (op);
1459
+ appendRewrite<CreateOperationRewrite> (op);
1448
1460
return ;
1449
1461
}
1450
1462
Operation *prevOp = previous.getPoint () == previous.getBlock ()->end ()
@@ -1945,13 +1957,16 @@ OperationLegalizer::legalizeWithFold(Operation *op,
1945
1957
rewriter.replaceOp (op, replacementValues);
1946
1958
1947
1959
// Recursively legalize any new constant operations.
1948
- for (unsigned i = curState.numCreatedOps , e = rewriterImpl.createdOps .size ();
1960
+ for (unsigned i = curState.numRewrites , e = rewriterImpl.rewrites .size ();
1949
1961
i != e; ++i) {
1950
- Operation *cstOp = rewriterImpl.createdOps [i];
1951
- if (failed (legalize (cstOp, rewriter))) {
1962
+ auto *createOp =
1963
+ dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites [i].get ());
1964
+ if (!createOp)
1965
+ continue ;
1966
+ if (failed (legalize (createOp->getOperation (), rewriter))) {
1952
1967
LLVM_DEBUG (logFailure (rewriterImpl.logger ,
1953
1968
" failed to legalize generated constant '{0}'" ,
1954
- cstOp ->getName ()));
1969
+ createOp-> getOperation () ->getName ()));
1955
1970
rewriterImpl.resetState (curState);
1956
1971
return failure ();
1957
1972
}
@@ -2098,9 +2113,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2098
2113
// blocks in regions created by this pattern will already be legalized later
2099
2114
// on. If we haven't built the set yet, build it now.
2100
2115
if (operationsToIgnore.empty ()) {
2101
- auto createdOps = ArrayRef<Operation *>(impl.createdOps )
2102
- .drop_front (state.numCreatedOps );
2103
- operationsToIgnore.insert (createdOps.begin (), createdOps.end ());
2116
+ for (unsigned i = state.numRewrites , e = impl.rewrites .size (); i != e;
2117
+ ++i) {
2118
+ auto *createOp =
2119
+ dyn_cast<CreateOperationRewrite>(impl.rewrites [i].get ());
2120
+ if (!createOp)
2121
+ continue ;
2122
+ operationsToIgnore.insert (createOp->getOperation ());
2123
+ }
2104
2124
}
2105
2125
2106
2126
// If this operation should be considered for re-legalization, try it.
@@ -2118,8 +2138,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2118
2138
LogicalResult OperationLegalizer::legalizePatternCreatedOperations (
2119
2139
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2120
2140
RewriterState &state, RewriterState &newState) {
2121
- for (int i = state.numCreatedOps , e = newState.numCreatedOps ; i != e; ++i) {
2122
- Operation *op = impl.createdOps [i];
2141
+ for (int i = state.numRewrites , e = newState.numRewrites ; i != e; ++i) {
2142
+ auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites [i].get ());
2143
+ if (!createOp)
2144
+ continue ;
2145
+ Operation *op = createOp->getOperation ();
2123
2146
if (failed (legalize (op, rewriter))) {
2124
2147
LLVM_DEBUG (logFailure (impl.logger ,
2125
2148
" failed to legalize generated operation '{0}'({1})" ,
@@ -2549,10 +2572,15 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2549
2572
});
2550
2573
return liveUserIt == val.user_end () ? nullptr : *liveUserIt;
2551
2574
};
2552
- for (auto &r : rewriterImpl.rewrites )
2553
- if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get ()))
2554
- if (failed (rewrite->materializeLiveConversions (findLiveUser)))
2575
+ // Note: `rewrites` may be reallocated as the loop is running.
2576
+ for (int64_t i = 0 ; i < rewriterImpl.rewrites .size (); ++i) {
2577
+ auto &rewrite = rewriterImpl.rewrites [i];
2578
+ if (auto *blockTypeConversionRewrite =
2579
+ dyn_cast<BlockTypeConversionRewrite>(rewrite.get ()))
2580
+ if (failed (blockTypeConversionRewrite->materializeLiveConversions (
2581
+ findLiveUser)))
2555
2582
return failure ();
2583
+ }
2556
2584
return success ();
2557
2585
}
2558
2586
0 commit comments