@@ -153,9 +153,9 @@ namespace {
153
153
// / This is useful when saving and undoing a set of rewrites.
154
154
struct RewriterState {
155
155
RewriterState (unsigned numRewrites, unsigned numIgnoredOperations,
156
- unsigned numErased)
156
+ unsigned numErased, unsigned numReplacedOps )
157
157
: numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
158
- numErased (numErased) {}
158
+ numErased (numErased), numReplacedOps(numReplacedOps) {}
159
159
160
160
// / The current number of rewrites performed.
161
161
unsigned numRewrites;
@@ -165,6 +165,9 @@ struct RewriterState {
165
165
166
166
// / The current number of erased operations/blocks.
167
167
unsigned numErased;
168
+
169
+ // / The current number of replaced ops that are scheduled for erasure.
170
+ unsigned numReplacedOps;
168
171
};
169
172
170
173
// ===----------------------------------------------------------------------===//
@@ -954,6 +957,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
954
957
// / operation was ignored.
955
958
SetVector<Operation *> ignoredOps;
956
959
960
+ // A set of operations that were erased.
961
+ SetVector<Operation *> replacedOps;
962
+
957
963
// / The current type converter, or nullptr if no type converter is currently
958
964
// / active.
959
965
const TypeConverter *currentTypeConverter = nullptr ;
@@ -1152,7 +1158,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
1152
1158
1153
1159
RewriterState ConversionPatternRewriterImpl::getCurrentState () {
1154
1160
return RewriterState (rewrites.size (), ignoredOps.size (),
1155
- eraseRewriter.erased .size ());
1161
+ eraseRewriter.erased .size (), replacedOps. size () );
1156
1162
}
1157
1163
1158
1164
void ConversionPatternRewriterImpl::resetState (RewriterState state) {
@@ -1165,6 +1171,9 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
1165
1171
1166
1172
while (eraseRewriter.erased .size () != state.numErased )
1167
1173
eraseRewriter.erased .pop_back ();
1174
+
1175
+ while (replacedOps.size () != state.numReplacedOps )
1176
+ replacedOps.pop_back ();
1168
1177
}
1169
1178
1170
1179
void ConversionPatternRewriterImpl::undoRewrites (unsigned numRewritesToKeep) {
@@ -1228,9 +1237,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
1228
1237
return success ();
1229
1238
}
1230
1239
1240
+ // TODO: This function is a misnomer. It does not actually check if `op` is in
1241
+ // `ignoredOps`.
1231
1242
bool ConversionPatternRewriterImpl::isOpIgnored (Operation *op) const {
1232
1243
// Check to see if this operation or the parent operation is ignored.
1233
- return ignoredOps.count (op->getParentOp ()) || ignoredOps .count (op);
1244
+ return ignoredOps.count (op->getParentOp ()) || replacedOps .count (op);
1234
1245
}
1235
1246
1236
1247
void ConversionPatternRewriterImpl::markNestedOpsIgnored (Operation *op) {
@@ -1479,7 +1490,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
1479
1490
void ConversionPatternRewriterImpl::notifyOpReplaced (Operation *op,
1480
1491
ValueRange newValues) {
1481
1492
assert (newValues.size () == op->getNumResults ());
1482
- assert (!ignoredOps .contains (op) && " operation was already replaced" );
1493
+ assert (!replacedOps .contains (op) && " operation was already replaced" );
1483
1494
1484
1495
// Track if any of the results changed, e.g. erased and replaced with null.
1485
1496
bool resultChanged = false ;
@@ -1500,7 +1511,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
1500
1511
1501
1512
// Mark this operation as recursively ignored so that we don't need to
1502
1513
// convert any nested operations.
1503
- ignoredOps .insert (op);
1514
+ replacedOps .insert (op);
1504
1515
markNestedOpsIgnored (op);
1505
1516
}
1506
1517
0 commit comments