Skip to content

Commit 651e0af

Browse files
committed
RequirementMachine: Compare weight before length in Term/MutableTerm::compare()
1 parent fa3d8c5 commit 651e0af

File tree

1 file changed

+38
-15
lines changed

1 file changed

+38
-15
lines changed

lib/AST/RequirementMachine/Term.cpp

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,42 @@ bool Term::containsNameSymbols() const {
122122
return false;
123123
}
124124

125-
/// Shortlex order on symbol ranges.
125+
/// Weighted shortlex order on symbol ranges, used for implementing
126+
/// Term::compare() and MutableTerm::compare().
126127
///
127-
/// First we compare length, then perform a lexicographic comparison
128-
/// on symbols if the two ranges have the same length.
128+
/// We first compute a weight vector for both terms and compare the
129+
/// vectors lexicographically:
130+
/// - Weight of generic param symbols
131+
/// - Number of name symbols
132+
/// - Number of element symbols
129133
///
130-
/// This is used to implement Term::compare() and MutableTerm::compare()
131-
/// below.
132-
static std::optional<int> shortlexCompare(const Symbol *lhsBegin,
133-
const Symbol *lhsEnd,
134-
const Symbol *rhsBegin,
135-
const Symbol *rhsEnd,
136-
RewriteContext &ctx) {
137-
// First, compare the number of name and pack element symbols.
134+
/// If the terms have the same weight, we compare length.
135+
///
136+
/// If the terms have the same weight and length, we perform a
137+
/// lexicographic comparison on symbols.
138+
///
139+
static std::optional<int> compareImpl(const Symbol *lhsBegin,
140+
const Symbol *lhsEnd,
141+
const Symbol *rhsBegin,
142+
const Symbol *rhsEnd,
143+
RewriteContext &ctx) {
144+
ASSERT(lhsBegin != lhsEnd);
145+
ASSERT(rhsBegin != rhsEnd);
146+
147+
// First compare weights on generic parameters. The implicit
148+
// assumption here is we don't form terms with generic parameter
149+
// symbols in the middle, which is true. Otherwise, we'd need
150+
// to add up their weights like we do below for name symbols,
151+
// of course.
152+
if (lhsBegin->getKind() == Symbol::Kind::GenericParam &&
153+
rhsBegin->getKind() == Symbol::Kind::GenericParam) {
154+
unsigned lhsWeight = lhsBegin->getGenericParam()->getWeight();
155+
unsigned rhsWeight = rhsBegin->getGenericParam()->getWeight();
156+
if (lhsWeight != rhsWeight)
157+
return lhsWeight > rhsWeight ? 1 : -1;
158+
}
159+
160+
// Compare the number of name and pack element symbols.
138161
unsigned lhsNameCount = 0;
139162
unsigned lhsPackElementCount = 0;
140163
for (auto *iter = lhsBegin; iter != lhsEnd; ++iter) {
@@ -192,17 +215,17 @@ static std::optional<int> shortlexCompare(const Symbol *lhsBegin,
192215
return 0;
193216
}
194217

195-
/// Shortlex order on terms. Returns None if the terms are identical except
218+
/// Reduction order on terms. Returns None if the terms are identical except
196219
/// for an incomparable superclass or concrete type symbol at the end.
197220
std::optional<int> Term::compare(Term other, RewriteContext &ctx) const {
198-
return shortlexCompare(begin(), end(), other.begin(), other.end(), ctx);
221+
return compareImpl(begin(), end(), other.begin(), other.end(), ctx);
199222
}
200223

201-
/// Shortlex order on mutable terms. Returns None if the terms are identical
224+
/// Reduction order on mutable terms. Returns None if the terms are identical
202225
/// except for an incomparable superclass or concrete type symbol at the end.
203226
std::optional<int> MutableTerm::compare(const MutableTerm &other,
204227
RewriteContext &ctx) const {
205-
return shortlexCompare(begin(), end(), other.begin(), other.end(), ctx);
228+
return compareImpl(begin(), end(), other.begin(), other.end(), ctx);
206229
}
207230

208231
/// Replace the subterm in the range [from,to) of this term with \p rhs.

0 commit comments

Comments
 (0)