Skip to content

AST: Use weighted reduction order for opaque return types #81171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion include/swift/AST/SubstitutionMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ class LookUpConformanceInSubstitutionMap {
};

struct OverrideSubsInfo {
ASTContext &Ctx;
unsigned BaseDepth;
unsigned OrigDepth;
SubstitutionMap BaseSubMap;
Expand Down
42 changes: 33 additions & 9 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -7254,9 +7254,10 @@ class GenericTypeParamType : public SubstitutableType,
Identifier Name;
};

unsigned Depth : 15;
unsigned IsDecl : 1;
unsigned Index : 16;
unsigned Depth : 15;
unsigned Weight : 1;
unsigned Index : 15;

/// The kind of generic type parameter this is.
GenericTypeParamKind ParamKind;
Expand All @@ -7281,15 +7282,21 @@ class GenericTypeParamType : public SubstitutableType,
Type valueType, const ASTContext &ctx);

/// Retrieve a canonical generic type parameter with the given kind, depth,
/// index, and optional value type.
/// index, weight, and optional value type.
static GenericTypeParamType *get(GenericTypeParamKind paramKind,
unsigned depth, unsigned index,
unsigned depth, unsigned index, unsigned weight,
Type valueType, const ASTContext &ctx);

/// Retrieve a canonical generic type parameter at the given depth and index.
/// Retrieve a canonical generic type parameter at the given depth and index,
/// with weight 0.
static GenericTypeParamType *getType(unsigned depth, unsigned index,
const ASTContext &ctx);

/// Retrieve a canonical generic type parameter at the given depth and index
/// for an opaque result type, so with weight 1.
static GenericTypeParamType *getOpaqueResultType(unsigned depth, unsigned index,
const ASTContext &ctx);

/// Retrieve a canonical generic parameter pack at the given depth and index.
static GenericTypeParamType *getPack(unsigned depth, unsigned index,
const ASTContext &ctx);
Expand Down Expand Up @@ -7345,6 +7352,14 @@ class GenericTypeParamType : public SubstitutableType,
return Index;
}

/// The weight of this generic parameter in the type parameter order.
///
/// Opaque result types have weight 1, while all other generic parameters
/// have weight 0.
unsigned getWeight() const {
return Weight;
}

/// Returns \c true if this type parameter is declared as a pack.
///
/// \code
Expand All @@ -7366,20 +7381,24 @@ class GenericTypeParamType : public SubstitutableType,

Type getValueType() const;

GenericTypeParamType *withDepth(unsigned depth) const;

void Profile(llvm::FoldingSetNodeID &ID) {
// Note: We explicitly don't use 'getName()' because for canonical forms
// which don't store an identifier we'll go create a tau based form. We
// really want to just plumb down the null Identifier because that's what's
// inside the cache.
Profile(ID, getParamKind(), getDepth(), getIndex(), getValueType(),
Name);
Profile(ID, getParamKind(), getDepth(), getIndex(), getWeight(),
getValueType(), Name);
}
static void Profile(llvm::FoldingSetNodeID &ID,
GenericTypeParamKind paramKind, unsigned depth,
unsigned index, Type valueType, Identifier name) {
unsigned index, unsigned weight, Type valueType,
Identifier name) {
ID.AddInteger((uint8_t)paramKind);
ID.AddInteger(depth);
ID.AddInteger(index);
ID.AddInteger(weight);
ID.AddPointer(valueType.getPointer());
ID.AddPointer(name.get());
}
Expand All @@ -7402,7 +7421,7 @@ class GenericTypeParamType : public SubstitutableType,
const ASTContext &ctx);

explicit GenericTypeParamType(GenericTypeParamKind paramKind, unsigned depth,
unsigned index, Type valueType,
unsigned index, unsigned weight, Type valueType,
RecursiveTypeProperties props,
const ASTContext &ctx);
};
Expand All @@ -7412,6 +7431,11 @@ static CanGenericTypeParamType getType(unsigned depth, unsigned index,
return CanGenericTypeParamType(
GenericTypeParamType::getType(depth, index, C));
}
static CanGenericTypeParamType getOpaqueResultType(unsigned depth, unsigned index,
const ASTContext &C) {
return CanGenericTypeParamType(
GenericTypeParamType::getOpaqueResultType(depth, index, C));
}
END_CAN_TYPE_WRAPPER(GenericTypeParamType, SubstitutableType)

/// A type that refers to a member type of some type that is dependent on a
Expand Down
27 changes: 17 additions & 10 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5099,8 +5099,8 @@ GenericTypeParamType *GenericTypeParamType::get(Identifier name,
Type valueType,
const ASTContext &ctx) {
llvm::FoldingSetNodeID id;
GenericTypeParamType::Profile(id, paramKind, depth, index, valueType,
name);
GenericTypeParamType::Profile(id, paramKind, depth, index, /*weight=*/0,
valueType, name);

void *insertPos;
if (auto gpTy = ctx.getImpl().GenericParamTypes.FindNodeOrInsertPos(id, insertPos))
Expand All @@ -5110,8 +5110,8 @@ GenericTypeParamType *GenericTypeParamType::get(Identifier name,
if (paramKind == GenericTypeParamKind::Pack)
props |= RecursiveTypeProperties::HasParameterPack;

auto canType = GenericTypeParamType::get(paramKind, depth, index, valueType,
ctx);
auto canType = GenericTypeParamType::get(paramKind, depth, index, /*weight=*/0,
valueType, ctx);

auto result = new (ctx, AllocationArena::Permanent)
GenericTypeParamType(name, canType, ctx);
Expand All @@ -5130,10 +5130,10 @@ GenericTypeParamType *GenericTypeParamType::get(GenericTypeParamDecl *param) {

GenericTypeParamType *GenericTypeParamType::get(GenericTypeParamKind paramKind,
unsigned depth, unsigned index,
Type valueType,
unsigned weight, Type valueType,
const ASTContext &ctx) {
llvm::FoldingSetNodeID id;
GenericTypeParamType::Profile(id, paramKind, depth, index, valueType,
GenericTypeParamType::Profile(id, paramKind, depth, index, weight, valueType,
Identifier());

void *insertPos;
Expand All @@ -5145,7 +5145,7 @@ GenericTypeParamType *GenericTypeParamType::get(GenericTypeParamKind paramKind,
props |= RecursiveTypeProperties::HasParameterPack;

auto result = new (ctx, AllocationArena::Permanent)
GenericTypeParamType(paramKind, depth, index, valueType, props, ctx);
GenericTypeParamType(paramKind, depth, index, weight, valueType, props, ctx);
ctx.getImpl().GenericParamTypes.InsertNode(result, insertPos);
return result;
}
Expand All @@ -5154,22 +5154,29 @@ GenericTypeParamType *GenericTypeParamType::getType(unsigned depth,
unsigned index,
const ASTContext &ctx) {
return GenericTypeParamType::get(GenericTypeParamKind::Type, depth, index,
/*valueType*/ Type(), ctx);
/*weight=*/0, /*valueType=*/Type(), ctx);
}

GenericTypeParamType *GenericTypeParamType::getOpaqueResultType(unsigned depth,
unsigned index,
const ASTContext &ctx) {
return GenericTypeParamType::get(GenericTypeParamKind::Type, depth, index,
/*weight=*/1, /*valueType=*/Type(), ctx);
}

GenericTypeParamType *GenericTypeParamType::getPack(unsigned depth,
unsigned index,
const ASTContext &ctx) {
return GenericTypeParamType::get(GenericTypeParamKind::Pack, depth, index,
/*valueType*/ Type(), ctx);
/*weight=*/0, /*valueType=*/Type(), ctx);
}

GenericTypeParamType *GenericTypeParamType::getValue(unsigned depth,
unsigned index,
Type valueType,
const ASTContext &ctx) {
return GenericTypeParamType::get(GenericTypeParamKind::Value, depth, index,
valueType, ctx);
/*weight=*/0, valueType, ctx);
}

ArrayRef<GenericTypeParamType *>
Expand Down
16 changes: 13 additions & 3 deletions lib/AST/GenericSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -831,8 +831,7 @@ int swift::compareAssociatedTypes(AssociatedTypeDecl *assocType1,
return 0;
}

/// Canonical ordering for type parameters.
int swift::compareDependentTypes(Type type1, Type type2) {
static int compareDependentTypesRec(Type type1, Type type2) {
// Fast-path check for equality.
if (type1->isEqual(type2)) return 0;

Expand All @@ -853,7 +852,7 @@ int swift::compareDependentTypes(Type type1, Type type2) {

// - by base, so t_0_n.`P.T` < t_1_m.`P.T`
if (int compareBases =
compareDependentTypes(depMemTy1->getBase(), depMemTy2->getBase()))
compareDependentTypesRec(depMemTy1->getBase(), depMemTy2->getBase()))
return compareBases;

// - by name, so t_n_m.`P.T` < t_n_m.`P.U`
Expand All @@ -869,6 +868,17 @@ int swift::compareDependentTypes(Type type1, Type type2) {
return 0;
}

/// Canonical ordering for type parameters.
int swift::compareDependentTypes(Type type1, Type type2) {
auto *root1 = type1->getRootGenericParam();
auto *root2 = type2->getRootGenericParam();
if (root1->getWeight() != root2->getWeight()) {
return root2->getWeight() ? -1 : +1;
}

return compareDependentTypesRec(type1, type2);
}

#pragma mark Generic signature verification

void GenericSignature::verify() const {
Expand Down
15 changes: 4 additions & 11 deletions lib/AST/RequirementEnvironment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,8 @@ RequirementEnvironment::RequirementEnvironment(

auto conformanceToWitnessThunkGenericParamFn = [&](GenericTypeParamType *genericParam)
-> GenericTypeParamType * {
return GenericTypeParamType::get(genericParam->getParamKind(),
genericParam->getDepth() + (covariantSelf ? 1 : 0),
genericParam->getIndex(),
genericParam->getValueType(), ctx);
return genericParam->withDepth(
genericParam->getDepth() + (covariantSelf ? 1 : 0));
};

// This is a substitution function from the generic parameters of the
Expand Down Expand Up @@ -109,9 +107,7 @@ RequirementEnvironment::RequirementEnvironment(
// invalid code.
if (genericParam->getDepth() != 1)
return Type();
Type substGenericParam = GenericTypeParamType::get(
genericParam->getParamKind(), depth, genericParam->getIndex(),
genericParam->getValueType(), ctx);
Type substGenericParam = genericParam->withDepth(depth);
if (genericParam->isParameterPack()) {
substGenericParam = PackType::getSingletonPackExpansion(
substGenericParam);
Expand Down Expand Up @@ -210,10 +206,7 @@ RequirementEnvironment::RequirementEnvironment(
}

// Create an equivalent generic parameter at the next depth.
auto substGenericParam = GenericTypeParamType::get(
genericParam->getParamKind(), depth, genericParam->getIndex(),
genericParam->getValueType(), ctx);

auto substGenericParam = genericParam->withDepth(depth);
genericParamTypes.push_back(substGenericParam);
}

Expand Down
53 changes: 38 additions & 15 deletions lib/AST/RequirementMachine/Term.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,42 @@ bool Term::containsNameSymbols() const {
return false;
}

/// Shortlex order on symbol ranges.
/// Weighted shortlex order on symbol ranges, used for implementing
/// Term::compare() and MutableTerm::compare().
///
/// First we compare length, then perform a lexicographic comparison
/// on symbols if the two ranges have the same length.
/// We first compute a weight vector for both terms and compare the
/// vectors lexicographically:
/// - Weight of generic param symbols
/// - Number of name symbols
/// - Number of element symbols
///
/// This is used to implement Term::compare() and MutableTerm::compare()
/// below.
static std::optional<int> shortlexCompare(const Symbol *lhsBegin,
const Symbol *lhsEnd,
const Symbol *rhsBegin,
const Symbol *rhsEnd,
RewriteContext &ctx) {
// First, compare the number of name and pack element symbols.
/// If the terms have the same weight, we compare length.
///
/// If the terms have the same weight and length, we perform a
/// lexicographic comparison on symbols.
///
static std::optional<int> compareImpl(const Symbol *lhsBegin,
const Symbol *lhsEnd,
const Symbol *rhsBegin,
const Symbol *rhsEnd,
RewriteContext &ctx) {
ASSERT(lhsBegin != lhsEnd);
ASSERT(rhsBegin != rhsEnd);

// First compare weights on generic parameters. The implicit
// assumption here is we don't form terms with generic parameter
// symbols in the middle, which is true. Otherwise, we'd need
// to add up their weights like we do below for name symbols,
// of course.
if (lhsBegin->getKind() == Symbol::Kind::GenericParam &&
rhsBegin->getKind() == Symbol::Kind::GenericParam) {
unsigned lhsWeight = lhsBegin->getGenericParam()->getWeight();
unsigned rhsWeight = rhsBegin->getGenericParam()->getWeight();
if (lhsWeight != rhsWeight)
return lhsWeight > rhsWeight ? 1 : -1;
}

// Compare the number of name and pack element symbols.
unsigned lhsNameCount = 0;
unsigned lhsPackElementCount = 0;
for (auto *iter = lhsBegin; iter != lhsEnd; ++iter) {
Expand Down Expand Up @@ -192,17 +215,17 @@ static std::optional<int> shortlexCompare(const Symbol *lhsBegin,
return 0;
}

/// Shortlex order on terms. Returns None if the terms are identical except
/// Reduction order on terms. Returns None if the terms are identical except
/// for an incomparable superclass or concrete type symbol at the end.
std::optional<int> Term::compare(Term other, RewriteContext &ctx) const {
return shortlexCompare(begin(), end(), other.begin(), other.end(), ctx);
return compareImpl(begin(), end(), other.begin(), other.end(), ctx);
}

/// Shortlex order on mutable terms. Returns None if the terms are identical
/// Reduction order on mutable terms. Returns None if the terms are identical
/// except for an incomparable superclass or concrete type symbol at the end.
std::optional<int> MutableTerm::compare(const MutableTerm &other,
RewriteContext &ctx) const {
return shortlexCompare(begin(), end(), other.begin(), other.end(), ctx);
return compareImpl(begin(), end(), other.begin(), other.end(), ctx);
}

/// Replace the subterm in the range [from,to) of this term with \p rhs.
Expand Down
8 changes: 2 additions & 6 deletions lib/AST/SubstitutionMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,7 @@ OverrideSubsInfo::OverrideSubsInfo(const NominalTypeDecl *baseNominal,
const NominalTypeDecl *derivedNominal,
GenericSignature baseSig,
const GenericParamList *derivedParams)
: Ctx(baseSig->getASTContext()),
BaseDepth(0),
: BaseDepth(0),
OrigDepth(0),
DerivedParams(derivedParams) {

Expand Down Expand Up @@ -468,10 +467,7 @@ Type QueryOverrideSubs::operator()(SubstitutableType *type) const {
->getDeclaredInterfaceType();
}

return GenericTypeParamType::get(
gp->getParamKind(),
gp->getDepth() + info.OrigDepth - info.BaseDepth,
gp->getIndex(), gp->getValueType(), info.Ctx);
return gp->withDepth(gp->getDepth() + info.OrigDepth - info.BaseDepth);
}
}

Expand Down
Loading