Skip to content

Commit e404ed7

Browse files
committed
[CSStep] Don't favor choices until the disjunction is picked
1 parent a094c3e commit e404ed7

File tree

4 files changed

+34
-22
lines changed

4 files changed

+34
-22
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5213,8 +5213,9 @@ class ConstraintSystem {
52135213

52145214
/// Pick a disjunction from the InactiveConstraints list.
52155215
///
5216-
/// \returns The selected disjunction.
5217-
Constraint *selectDisjunction();
5216+
/// \returns The selected disjunction and a set of it's favored choices.
5217+
Optional<std::pair<Constraint *, llvm::TinyPtrVector<Constraint *>>>
5218+
selectDisjunction();
52185219

52195220
/// Pick a conjunction from the InactiveConstraints list.
52205221
///
@@ -6143,7 +6144,8 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
61436144
public:
61446145
using Element = DisjunctionChoice;
61456146

6146-
DisjunctionChoiceProducer(ConstraintSystem &cs, Constraint *disjunction)
6147+
DisjunctionChoiceProducer(ConstraintSystem &cs, Constraint *disjunction,
6148+
llvm::TinyPtrVector<Constraint *> &favorites)
61476149
: BindingProducer(cs, disjunction->shouldRememberChoice()
61486150
? disjunction->getLocator()
61496151
: nullptr),
@@ -6153,6 +6155,11 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
61536155
assert(disjunction->getKind() == ConstraintKind::Disjunction);
61546156
assert(!disjunction->shouldRememberChoice() || disjunction->getLocator());
61556157

6158+
// Mark constraints as favored. This information
6159+
// is going to be used by partitioner.
6160+
for (auto *choice : favorites)
6161+
cs.favorConstraint(choice);
6162+
61566163
// Order and partition the disjunction choices.
61576164
partitionDisjunction(Ordering, PartitionBeginning);
61586165
}
@@ -6197,8 +6204,9 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
61976204
// Partition the choices in the disjunction into groups that we will
61986205
// iterate over in an order appropriate to attempt to stop before we
61996206
// have to visit all of the options.
6200-
void partitionDisjunction(SmallVectorImpl<unsigned> &Ordering,
6201-
SmallVectorImpl<unsigned> &PartitionBeginning);
6207+
void
6208+
partitionDisjunction(SmallVectorImpl<unsigned> &Ordering,
6209+
SmallVectorImpl<unsigned> &PartitionBeginning);
62026210

62036211
/// Partition the choices in the range \c first to \c last into groups and
62046212
/// order the groups in the best order to attempt based on the argument

lib/Sema/CSOptimizer.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -468,15 +468,16 @@ selectBestBindingDisjunction(ConstraintSystem &cs,
468468
return firstBindDisjunction;
469469
}
470470

471-
Constraint *ConstraintSystem::selectDisjunction() {
471+
Optional<std::pair<Constraint *, llvm::TinyPtrVector<Constraint *>>>
472+
ConstraintSystem::selectDisjunction() {
472473
SmallVector<Constraint *, 4> disjunctions;
473474

474475
collectDisjunctions(disjunctions);
475476
if (disjunctions.empty())
476-
return nullptr;
477+
return None;
477478

478479
if (auto *disjunction = selectBestBindingDisjunction(*this, disjunctions))
479-
return disjunction;
480+
return std::make_pair(disjunction, llvm::TinyPtrVector<Constraint *>());
480481

481482
llvm::DenseMap<Constraint *, llvm::TinyPtrVector<Constraint *>> favorings;
482483
determineBestChoicesInContext(*this, disjunctions, favorings);
@@ -508,14 +509,8 @@ Constraint *ConstraintSystem::selectDisjunction() {
508509
return firstFavored < secondFavored;
509510
});
510511

511-
if (bestDisjunction != disjunctions.end()) {
512-
// If selected disjunction has any choices that should be favored
513-
// let's record them now.
514-
for (auto *choice : favorings[*bestDisjunction])
515-
favorConstraint(choice);
516-
517-
return *bestDisjunction;
518-
}
512+
if (bestDisjunction != disjunctions.end())
513+
return std::make_pair(*bestDisjunction, favorings[*bestDisjunction]);
519514

520-
return nullptr;
515+
return None;
521516
}

lib/Sema/CSStep.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ StepResult ComponentStep::take(bool prevFailed) {
360360
}
361361
});
362362

363-
auto *disjunction = CS.selectDisjunction();
363+
auto disjunction = CS.selectDisjunction();
364364
auto *conjunction = CS.selectConjunction();
365365

366366
if (CS.isDebugMode()) {
@@ -403,7 +403,8 @@ StepResult ComponentStep::take(bool prevFailed) {
403403
// Bindings usually happen first, but sometimes we want to prioritize a
404404
// disjunction or conjunction.
405405
if (bestBindings) {
406-
if (disjunction && !bestBindings->favoredOverDisjunction(disjunction))
406+
if (disjunction &&
407+
!bestBindings->favoredOverDisjunction(disjunction->first))
407408
return StepKind::Disjunction;
408409

409410
if (conjunction && !bestBindings->favoredOverConjunction(conjunction))
@@ -426,9 +427,9 @@ StepResult ComponentStep::take(bool prevFailed) {
426427
return suspend(
427428
std::make_unique<TypeVariableStep>(*bestBindings, Solutions));
428429
case StepKind::Disjunction: {
429-
CS.retireConstraint(disjunction);
430+
CS.retireConstraint(disjunction->first);
430431
return suspend(
431-
std::make_unique<DisjunctionStep>(CS, disjunction, Solutions));
432+
std::make_unique<DisjunctionStep>(CS, *disjunction, Solutions));
432433
}
433434
case StepKind::Conjunction: {
434435
CS.retireConstraint(conjunction);

lib/Sema/CSStep.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,9 +677,17 @@ class DisjunctionStep final : public BindingStep<DisjunctionChoiceProducer> {
677677
std::optional<std::pair<Constraint *, Score>> LastSolvedChoice;
678678

679679
public:
680+
DisjunctionStep(
681+
ConstraintSystem &cs,
682+
std::pair<Constraint *, llvm::TinyPtrVector<Constraint *>> &disjunction,
683+
SmallVectorImpl<Solution> &solutions)
684+
: DisjunctionStep(cs, disjunction.first, disjunction.second, solutions) {}
685+
680686
DisjunctionStep(ConstraintSystem &cs, Constraint *disjunction,
687+
llvm::TinyPtrVector<Constraint *> &favoredChoices,
681688
SmallVectorImpl<Solution> &solutions)
682-
: BindingStep(cs, {cs, disjunction}, solutions), Disjunction(disjunction) {
689+
: BindingStep(cs, {cs, disjunction, favoredChoices}, solutions),
690+
Disjunction(disjunction) {
683691
assert(Disjunction->getKind() == ConstraintKind::Disjunction);
684692
pruneOverloadSet(Disjunction);
685693
++cs.solverState->NumDisjunctions;

0 commit comments

Comments
 (0)