Skip to content

Commit 89948ff

Browse files
committed
RequirementMachine: Compare weight before length in Term/MutableTerm::compare()
1 parent 118b9c2 commit 89948ff

File tree

2 files changed

+40
-17
lines changed

2 files changed

+40
-17
lines changed

lib/AST/RequirementMachine/Term.cpp

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,42 @@ bool Term::containsNameSymbols() const {
122122
return false;
123123
}
124124

125-
/// Shortlex order on symbol ranges.
125+
/// Weighted shortlex order on symbol ranges, used for implementing
126+
/// Term::compare() and MutableTerm::compare().
126127
///
127-
/// First we compare length, then perform a lexicographic comparison
128-
/// on symbols if the two ranges have the same length.
128+
/// We first compute a weight vector for both terms and compare the
129+
/// vectors lexicographically:
130+
/// - Weight of generic param symbols
131+
/// - Number of name symbols
132+
/// - Number of element symbols
129133
///
130-
/// This is used to implement Term::compare() and MutableTerm::compare()
131-
/// below.
132-
static std::optional<int> shortlexCompare(const Symbol *lhsBegin,
133-
const Symbol *lhsEnd,
134-
const Symbol *rhsBegin,
135-
const Symbol *rhsEnd,
136-
RewriteContext &ctx) {
137-
// First, compare the number of name and pack element symbols.
134+
/// If the terms have the same weight, we compare length.
135+
///
136+
/// If the terms have the same weight and length, we perform a
137+
/// lexicographic comparison on symbols.
138+
///
139+
static std::optional<int> compareImpl(const Symbol *lhsBegin,
140+
const Symbol *lhsEnd,
141+
const Symbol *rhsBegin,
142+
const Symbol *rhsEnd,
143+
RewriteContext &ctx) {
144+
ASSERT(lhsBegin != lhsEnd);
145+
ASSERT(rhsBegin != rhsEnd);
146+
147+
// First compare weights on generic parameters. The implicit
148+
// assumption here is we don't form terms with generic parameter
149+
// symbols in the middle, which is true. Otherwise, we'd need
150+
// to add up their weights like we do below for name symbols,
151+
// of course.
152+
if (lhsBegin->getKind() == Symbol::Kind::GenericParam &&
153+
rhsBegin->getKind() == Symbol::Kind::GenericParam) {
154+
unsigned lhsWeight = lhsBegin->getGenericParam()->getWeight();
155+
unsigned rhsWeight = rhsBegin->getGenericParam()->getWeight();
156+
if (lhsWeight != rhsWeight)
157+
return lhsWeight > rhsWeight ? 1 : -1;
158+
}
159+
160+
// Compare the number of name and pack element symbols.
138161
unsigned lhsNameCount = 0;
139162
unsigned lhsPackElementCount = 0;
140163
for (auto *iter = lhsBegin; iter != lhsEnd; ++iter) {
@@ -192,17 +215,17 @@ static std::optional<int> shortlexCompare(const Symbol *lhsBegin,
192215
return 0;
193216
}
194217

195-
/// Shortlex order on terms. Returns None if the terms are identical except
218+
/// Reduction order on terms. Returns None if the terms are identical except
196219
/// for an incomparable superclass or concrete type symbol at the end.
197220
std::optional<int> Term::compare(Term other, RewriteContext &ctx) const {
198-
return shortlexCompare(begin(), end(), other.begin(), other.end(), ctx);
221+
return compareImpl(begin(), end(), other.begin(), other.end(), ctx);
199222
}
200223

201-
/// Shortlex order on mutable terms. Returns None if the terms are identical
224+
/// Reduction order on mutable terms. Returns None if the terms are identical
202225
/// except for an incomparable superclass or concrete type symbol at the end.
203226
std::optional<int> MutableTerm::compare(const MutableTerm &other,
204227
RewriteContext &ctx) const {
205-
return shortlexCompare(begin(), end(), other.begin(), other.end(), ctx);
228+
return compareImpl(begin(), end(), other.begin(), other.end(), ctx);
206229
}
207230

208231
/// Replace the subterm in the range [from,to) of this term with \p rhs.

lib/Sema/TypeCheckGeneric.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ OpaqueResultTypeRequest::evaluate(Evaluator &evaluator,
142142
for (unsigned i = 0; i < opaqueReprs.size(); ++i) {
143143
auto *currentRepr = opaqueReprs[i];
144144

145-
if( auto opaqueReturn = dyn_cast<OpaqueReturnTypeRepr>(currentRepr) ) {
145+
if (auto opaqueReturn = dyn_cast<OpaqueReturnTypeRepr>(currentRepr)) {
146146
// Usually, we resolve the opaque constraint and bail if it isn't a class
147147
// or existential type (see below). However, in this case we know we will
148148
// fail, so we can bail early and provide a better diagnostic.
@@ -169,7 +169,7 @@ OpaqueResultTypeRequest::evaluate(Evaluator &evaluator,
169169

170170
TypeRepr *constraint = currentRepr;
171171

172-
if (auto opaqueReturn = dyn_cast<OpaqueReturnTypeRepr>(currentRepr)){
172+
if (auto opaqueReturn = dyn_cast<OpaqueReturnTypeRepr>(currentRepr)) {
173173
constraint = opaqueReturn->getConstraint();
174174
}
175175
// Try to resolve the constraint repr in the parent decl context. It

0 commit comments

Comments
 (0)