@@ -152,17 +152,12 @@ namespace {
152
152
// / This class contains a snapshot of the current conversion rewriter state.
153
153
// / This is useful when saving and undoing a set of rewrites.
154
154
struct RewriterState {
155
- RewriterState (unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
156
- unsigned numRewrites, unsigned numIgnoredOperations,
157
- unsigned numErased)
158
- : numCreatedOps(numCreatedOps),
159
- numUnresolvedMaterializations (numUnresolvedMaterializations),
155
+ RewriterState (unsigned numUnresolvedMaterializations, unsigned numRewrites,
156
+ unsigned numIgnoredOperations, unsigned numErased)
157
+ : numUnresolvedMaterializations(numUnresolvedMaterializations),
160
158
numRewrites (numRewrites), numIgnoredOperations(numIgnoredOperations),
161
159
numErased(numErased) {}
162
160
163
- // / The current number of created operations.
164
- unsigned numCreatedOps;
165
-
166
161
// / The current number of unresolved materializations.
167
162
unsigned numUnresolvedMaterializations;
168
163
@@ -303,7 +298,8 @@ class IRRewrite {
303
298
// Operation rewrites
304
299
MoveOperation,
305
300
ModifyOperation,
306
- ReplaceOperation
301
+ ReplaceOperation,
302
+ CreateOperation
307
303
};
308
304
309
305
virtual ~IRRewrite () = default ;
@@ -376,7 +372,10 @@ class CreateBlockRewrite : public BlockRewrite {
376
372
auto &blockOps = block->getOperations ();
377
373
while (!blockOps.empty ())
378
374
blockOps.remove (blockOps.begin ());
379
- eraseBlock (block);
375
+ if (block->getParent ())
376
+ eraseBlock (block);
377
+ else
378
+ delete block;
380
379
}
381
380
};
382
381
@@ -606,7 +605,7 @@ class OperationRewrite : public IRRewrite {
606
605
607
606
static bool classof (const IRRewrite *rewrite) {
608
607
return rewrite->getKind () >= Kind::MoveOperation &&
609
- rewrite->getKind () <= Kind::ReplaceOperation ;
608
+ rewrite->getKind () <= Kind::CreateOperation ;
610
609
}
611
610
612
611
protected:
@@ -740,6 +739,19 @@ class ReplaceOperationRewrite : public OperationRewrite {
740
739
// / A boolean flag that indicates whether result types have changed or not.
741
740
bool changedResults;
742
741
};
742
+
743
+ class CreateOperationRewrite : public OperationRewrite {
744
+ public:
745
+ CreateOperationRewrite (ConversionPatternRewriterImpl &rewriterImpl,
746
+ Operation *op)
747
+ : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
748
+
749
+ static bool classof (const IRRewrite *rewrite) {
750
+ return rewrite->getKind () == Kind::CreateOperation;
751
+ }
752
+
753
+ void rollback () override ;
754
+ };
743
755
} // namespace
744
756
745
757
// / Return "true" if there is an operation rewrite that matches the specified
@@ -957,9 +969,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
957
969
// replacing a value with one of a different type.
958
970
ConversionValueMapping mapping;
959
971
960
- // / Ordered vector of all of the newly created operations during conversion.
961
- SmallVector<Operation *> createdOps;
962
-
963
972
// / Ordered vector of all unresolved type conversion materializations during
964
973
// / conversion.
965
974
SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
@@ -1144,6 +1153,15 @@ void ReplaceOperationRewrite::rollback() {
1144
1153
1145
1154
void ReplaceOperationRewrite::cleanup () { eraseOp (op); }
1146
1155
1156
+ void CreateOperationRewrite::rollback () {
1157
+ for (Region ®ion : op->getRegions ()) {
1158
+ while (!region.getBlocks ().empty ())
1159
+ region.getBlocks ().remove (region.getBlocks ().begin ());
1160
+ }
1161
+ op->dropAllUses ();
1162
+ eraseOp (op);
1163
+ }
1164
+
1147
1165
void ConversionPatternRewriterImpl::detachNestedAndErase (Operation *op) {
1148
1166
for (Region ®ion : op->getRegions ()) {
1149
1167
for (Block &block : region.getBlocks ()) {
@@ -1161,8 +1179,6 @@ void ConversionPatternRewriterImpl::discardRewrites() {
1161
1179
// Remove any newly created ops.
1162
1180
for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
1163
1181
detachNestedAndErase (materialization.getOp ());
1164
- for (auto *op : llvm::reverse (createdOps))
1165
- detachNestedAndErase (op);
1166
1182
}
1167
1183
1168
1184
void ConversionPatternRewriterImpl::applyRewrites () {
@@ -1182,9 +1198,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
1182
1198
// State Management
1183
1199
1184
1200
RewriterState ConversionPatternRewriterImpl::getCurrentState () {
1185
- return RewriterState (createdOps.size (), unresolvedMaterializations.size (),
1186
- rewrites.size (), ignoredOps.size (),
1187
- eraseRewriter.erased .size ());
1201
+ return RewriterState (unresolvedMaterializations.size (), rewrites.size (),
1202
+ ignoredOps.size (), eraseRewriter.erased .size ());
1188
1203
}
1189
1204
1190
1205
void ConversionPatternRewriterImpl::resetState (RewriterState state) {
@@ -1205,12 +1220,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
1205
1220
detachNestedAndErase (op);
1206
1221
}
1207
1222
1208
- // Pop all of the newly created operations.
1209
- while (createdOps.size () != state.numCreatedOps ) {
1210
- detachNestedAndErase (createdOps.back ());
1211
- createdOps.pop_back ();
1212
- }
1213
-
1214
1223
// Pop all of the recorded ignored operations that are no longer valid.
1215
1224
while (ignoredOps.size () != state.numIgnoredOperations )
1216
1225
ignoredOps.pop_back ();
@@ -1478,7 +1487,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
1478
1487
});
1479
1488
if (!previous.isSet ()) {
1480
1489
// This is a newly created op.
1481
- createdOps. push_back (op);
1490
+ appendRewrite<CreateOperationRewrite> (op);
1482
1491
return ;
1483
1492
}
1484
1493
Operation *prevOp = previous.getPoint () == previous.getBlock ()->end ()
@@ -1979,13 +1988,16 @@ OperationLegalizer::legalizeWithFold(Operation *op,
1979
1988
rewriter.replaceOp (op, replacementValues);
1980
1989
1981
1990
// Recursively legalize any new constant operations.
1982
- for (unsigned i = curState.numCreatedOps , e = rewriterImpl.createdOps .size ();
1991
+ for (unsigned i = curState.numRewrites , e = rewriterImpl.rewrites .size ();
1983
1992
i != e; ++i) {
1984
- Operation *cstOp = rewriterImpl.createdOps [i];
1985
- if (failed (legalize (cstOp, rewriter))) {
1993
+ auto *createOp =
1994
+ dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites [i].get ());
1995
+ if (!createOp)
1996
+ continue ;
1997
+ if (failed (legalize (createOp->getOperation (), rewriter))) {
1986
1998
LLVM_DEBUG (logFailure (rewriterImpl.logger ,
1987
1999
" failed to legalize generated constant '{0}'" ,
1988
- cstOp ->getName ()));
2000
+ createOp-> getOperation () ->getName ()));
1989
2001
rewriterImpl.resetState (curState);
1990
2002
return failure ();
1991
2003
}
@@ -2132,9 +2144,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2132
2144
// blocks in regions created by this pattern will already be legalized later
2133
2145
// on. If we haven't built the set yet, build it now.
2134
2146
if (operationsToIgnore.empty ()) {
2135
- auto createdOps = ArrayRef<Operation *>(impl.createdOps )
2136
- .drop_front (state.numCreatedOps );
2137
- operationsToIgnore.insert (createdOps.begin (), createdOps.end ());
2147
+ for (unsigned i = state.numRewrites , e = impl.rewrites .size (); i != e;
2148
+ ++i) {
2149
+ auto *createOp =
2150
+ dyn_cast<CreateOperationRewrite>(impl.rewrites [i].get ());
2151
+ if (!createOp)
2152
+ continue ;
2153
+ operationsToIgnore.insert (createOp->getOperation ());
2154
+ }
2138
2155
}
2139
2156
2140
2157
// If this operation should be considered for re-legalization, try it.
@@ -2152,8 +2169,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2152
2169
LogicalResult OperationLegalizer::legalizePatternCreatedOperations (
2153
2170
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2154
2171
RewriterState &state, RewriterState &newState) {
2155
- for (int i = state.numCreatedOps , e = newState.numCreatedOps ; i != e; ++i) {
2156
- Operation *op = impl.createdOps [i];
2172
+ for (int i = state.numRewrites , e = newState.numRewrites ; i != e; ++i) {
2173
+ auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites [i].get ());
2174
+ if (!createOp)
2175
+ continue ;
2176
+ Operation *op = createOp->getOperation ();
2157
2177
if (failed (legalize (op, rewriter))) {
2158
2178
LLVM_DEBUG (logFailure (impl.logger ,
2159
2179
" failed to legalize generated operation '{0}'({1})" ,
@@ -2583,10 +2603,16 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2583
2603
});
2584
2604
return liveUserIt == val.user_end () ? nullptr : *liveUserIt;
2585
2605
};
2586
- for (auto &r : rewriterImpl.rewrites )
2587
- if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get ()))
2588
- if (failed (rewrite->materializeLiveConversions (findLiveUser)))
2606
+ // Note: `rewrites` may be reallocated as the loop is running.
2607
+ for (int64_t i = 0 ; i < static_cast <int64_t >(rewriterImpl.rewrites .size ());
2608
+ ++i) {
2609
+ auto &rewrite = rewriterImpl.rewrites [i];
2610
+ if (auto *blockTypeConversionRewrite =
2611
+ dyn_cast<BlockTypeConversionRewrite>(rewrite.get ()))
2612
+ if (failed (blockTypeConversionRewrite->materializeLiveConversions (
2613
+ findLiveUser)))
2589
2614
return failure ();
2615
+ }
2590
2616
return success ();
2591
2617
}
2592
2618
0 commit comments