Skip to content

Commit 03d8fd7

Browse files
authored
Merge pull request #80445 from slavapestov/fix-issue-80288
Sema: Fix handling of getter typed throws in witness matching
2 parents 6c0b778 + b26c599 commit 03d8fd7

File tree

6 files changed

+187
-75
lines changed

6 files changed

+187
-75
lines changed

include/swift/AST/IRGenOptions.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#define SWIFT_AST_IRGENOPTIONS_H
2020

2121
#include "swift/AST/LinkLibrary.h"
22+
#include "swift/Basic/Assertions.h"
2223
#include "swift/Basic/PathRemapper.h"
2324
#include "swift/Basic/Sanitizers.h"
2425
#include "swift/Basic/OptionSet.h"
@@ -655,11 +656,7 @@ class IRGenOptions {
655656
TypeInfoFilter(TypeInfoDumpFilter::All),
656657
PlatformCCallingConvention(llvm::CallingConv::C), UseCASBackend(false),
657658
CASObjMode(llvm::CASBackendMode::Native) {
658-
#ifndef NDEBUG
659-
DisableRoundTripDebugTypes = false;
660-
#else
661-
DisableRoundTripDebugTypes = true;
662-
#endif
659+
DisableRoundTripDebugTypes = !CONDITIONAL_ASSERT_enabled();
663660
}
664661

665662
/// Appends to \p os an arbitrary string representing all options which

lib/Sema/AssociatedTypeInference.cpp

Lines changed: 72 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1908,27 +1908,16 @@ static Type getWithoutProtocolTypeAliases(Type type) {
19081908
///
19091909
/// Also see simplifyCurrentTypeWitnesses().
19101910
static Type getWitnessTypeForMatching(NormalProtocolConformance *conformance,
1911-
ValueDecl *witness) {
1912-
if (witness->isRecursiveValidation()) {
1913-
LLVM_DEBUG(llvm::dbgs() << "Recursive validation\n";);
1914-
return Type();
1915-
}
1916-
1917-
if (witness->isInvalid()) {
1918-
LLVM_DEBUG(llvm::dbgs() << "Invalid witness\n";);
1919-
return Type();
1920-
}
1921-
1911+
ValueDecl *witness, Type type) {
19221912
if (!witness->getDeclContext()->isTypeContext()) {
19231913
// FIXME: Could we infer from 'Self' to make these work?
1924-
return witness->getInterfaceType();
1914+
return type;
19251915
}
19261916

19271917
// Retrieve the set of substitutions to be applied to the witness.
19281918
Type model =
19291919
conformance->getDeclContext()->mapTypeIntoContext(conformance->getType());
19301920
TypeSubstitutionMap substitutions = model->getMemberSubstitutions(witness);
1931-
Type type = witness->getInterfaceType()->getReferenceStorageReferent();
19321921

19331922
type = getWithoutProtocolTypeAliases(type);
19341923

@@ -2082,14 +2071,20 @@ AssociatedTypeInference::inferTypeWitnessesViaAssociatedType(
20822071
!witnessHasImplementsAttrForRequiredName(typeDecl, assocType))
20832072
continue;
20842073

2085-
// Determine the witness type.
2086-
Type witnessType = getWitnessTypeForMatching(conformance, typeDecl);
2087-
if (!witnessType) continue;
2074+
if (typeDecl->isInvalid()) {
2075+
LLVM_DEBUG(llvm::dbgs() << "Recursive validation\n";);
2076+
continue;
2077+
}
20882078

2089-
if (auto witnessMetaType = witnessType->getAs<AnyMetatypeType>())
2090-
witnessType = witnessMetaType->getInstanceType();
2091-
else
2079+
if (typeDecl->isRecursiveValidation()) {
2080+
LLVM_DEBUG(llvm::dbgs() << "Recursive validation\n";);
20922081
continue;
2082+
}
2083+
2084+
// Determine the witness type.
2085+
Type witnessType = getWitnessTypeForMatching(conformance, typeDecl,
2086+
typeDecl->getDeclaredInterfaceType());
2087+
if (!witnessType) continue;
20932088

20942089
if (result.empty()) {
20952090
// If we found at least one default candidate, we must allow for the
@@ -2177,21 +2172,60 @@ AssociatedTypeInference::getPotentialTypeWitnessesByMatchingTypes(ValueDecl *req
21772172
InferredAssociatedTypesByWitness inferred;
21782173
inferred.Witness = witness;
21792174

2180-
// Compute the requirement and witness types we'll use for matching.
2181-
Type fullWitnessType = getWitnessTypeForMatching(conformance, witness);
2182-
if (!fullWitnessType) {
2175+
auto reqType = removeSelfParam(req, req->getInterfaceType());
2176+
Type witnessType;
2177+
2178+
if (witness->isRecursiveValidation()) {
2179+
LLVM_DEBUG(llvm::dbgs() << "Recursive validation\n";);
21832180
return inferred;
21842181
}
21852182

2186-
LLVM_DEBUG(llvm::dbgs() << "Witness type for matching is "
2187-
<< fullWitnessType << "\n";);
2183+
if (witness->isInvalid()) {
2184+
LLVM_DEBUG(llvm::dbgs() << "Invalid witness\n";);
2185+
return inferred;
2186+
}
21882187

21892188
auto setup =
2190-
[&]() -> std::tuple<std::optional<RequirementMatch>, Type, Type> {
2191-
fullWitnessType = removeSelfParam(witness, fullWitnessType);
2189+
[&]() -> std::tuple<std::optional<RequirementMatch>, Type, Type, Type, Type> {
2190+
// Compute the requirement and witness types we'll use for matching.
2191+
witnessType = witness->getInterfaceType()->getReferenceStorageReferent();
2192+
witnessType = getWitnessTypeForMatching(conformance, witness, witnessType);
2193+
2194+
LLVM_DEBUG(llvm::dbgs() << "Witness type for matching is "
2195+
<< witnessType << "\n";);
2196+
2197+
witnessType = removeSelfParam(witness, witnessType);
2198+
2199+
Type reqThrownError;
2200+
Type witnessThrownError;
2201+
2202+
if (auto *witnessASD = dyn_cast<AbstractStorageDecl>(witness)) {
2203+
auto *reqASD = cast<AbstractStorageDecl>(req);
2204+
2205+
// Dig out the thrown error types from the getter so we can compare them
2206+
// later.
2207+
auto getThrownErrorType = [](AbstractStorageDecl *asd) -> Type {
2208+
if (auto getter = asd->getEffectfulGetAccessor()) {
2209+
if (Type thrownErrorType = getter->getThrownInterfaceType()) {
2210+
return thrownErrorType;
2211+
} else if (getter->hasThrows()) {
2212+
return asd->getASTContext().getErrorExistentialType();
2213+
}
2214+
}
2215+
2216+
return asd->getASTContext().getNeverType();
2217+
};
2218+
2219+
reqThrownError = getThrownErrorType(reqASD);
2220+
2221+
witnessThrownError = getThrownErrorType(witnessASD);
2222+
witnessThrownError = getWitnessTypeForMatching(conformance, witness,
2223+
witnessThrownError);
2224+
}
2225+
21922226
return std::make_tuple(std::nullopt,
2193-
removeSelfParam(req, req->getInterfaceType()),
2194-
fullWitnessType);
2227+
reqType, witnessType,
2228+
reqThrownError, witnessThrownError);
21952229
};
21962230

21972231
/// Visits a requirement type to match it to a potential witness for
@@ -2327,7 +2361,7 @@ AssociatedTypeInference::getPotentialTypeWitnessesByMatchingTypes(ValueDecl *req
23272361
Type witnessType) -> std::optional<RequirementMatch> {
23282362
if (!matchVisitor.match(reqType, witnessType)) {
23292363
return RequirementMatch(witness, MatchKind::TypeConflict,
2330-
fullWitnessType);
2364+
witnessType);
23312365
}
23322366

23332367
return std::nullopt;
@@ -2340,7 +2374,7 @@ AssociatedTypeInference::getPotentialTypeWitnessesByMatchingTypes(ValueDecl *req
23402374
return RequirementMatch(witness,
23412375
anyRenaming ? MatchKind::RenamedMatch
23422376
: MatchKind::ExactMatch,
2343-
fullWitnessType);
2377+
witnessType);
23442378

23452379
};
23462380

@@ -4587,5 +4621,13 @@ ReferencedAssociatedTypesRequest::evaluate(Evaluator &eval,
45874621
reqTy->getCanonicalType().walk(walker);
45884622
}
45894623

4624+
if (auto *asd = dyn_cast<AbstractStorageDecl>(req)) {
4625+
if (auto getter = asd->getEffectfulGetAccessor()) {
4626+
if (Type thrownErrorType = getter->getThrownInterfaceType()) {
4627+
thrownErrorType->getCanonicalType().walk(walker);
4628+
}
4629+
}
4630+
}
4631+
45904632
return assocTypes;
45914633
}

lib/Sema/TypeCheckEffects.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4897,7 +4897,11 @@ static ThrownErrorClassification classifyThrownErrorType(Type type) {
48974897
return ThrownErrorClassification::AnyError;
48984898
}
48994899

4900-
if (type->hasTypeVariable() || type->hasTypeParameter())
4900+
// All three cases come up. The first one from the "real" witness matcher,
4901+
// and the other two from associated type inference.
4902+
if (type->hasTypeVariable() ||
4903+
type->hasTypeParameter() ||
4904+
type->hasPrimaryArchetype())
49014905
return ThrownErrorClassification::Dependent;
49024906

49034907
return ThrownErrorClassification::Specific;

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -512,8 +512,7 @@ checkEffects(AbstractStorageDecl *witness, AbstractStorageDecl *req) {
512512
/// to be used by `matchWitness`.
513513
static std::optional<RequirementMatch>
514514
matchWitnessStructureImpl(ValueDecl *req, ValueDecl *witness,
515-
bool &decomposeFunctionType, bool &ignoreReturnType,
516-
Type &reqThrownError, Type &witnessThrownError) {
515+
bool &decomposeFunctionType, bool &ignoreReturnType) {
517516
assert(!req->isInvalid() && "Cannot have an invalid requirement here");
518517

519518
/// Make sure the witness is of the same kind as the requirement.
@@ -658,23 +657,6 @@ matchWitnessStructureImpl(ValueDecl *req, ValueDecl *witness,
658657

659658
// Decompose the parameters for subscript declarations.
660659
decomposeFunctionType = isa<SubscriptDecl>(req);
661-
662-
// Dig out the thrown error types from the getter so we can compare them
663-
// later.
664-
auto getThrownErrorType = [](AbstractStorageDecl *asd) -> Type {
665-
if (auto getter = asd->getEffectfulGetAccessor()) {
666-
if (Type thrownErrorType = getter->getThrownInterfaceType()) {
667-
return thrownErrorType;
668-
} else if (getter->hasThrows()) {
669-
return asd->getASTContext().getAnyExistentialType();
670-
}
671-
}
672-
673-
return asd->getASTContext().getNeverType();
674-
};
675-
676-
reqThrownError = getThrownErrorType(reqASD);
677-
witnessThrownError = getThrownErrorType(witnessASD);
678660
} else if (isa<ConstructorDecl>(witness)) {
679661
decomposeFunctionType = true;
680662
ignoreReturnType = true;
@@ -713,39 +695,33 @@ bool swift::TypeChecker::witnessStructureMatches(ValueDecl *req,
713695
const ValueDecl *witness) {
714696
bool decomposeFunctionType = false;
715697
bool ignoreReturnType = false;
716-
Type reqThrownError;
717-
Type witnessThrownError;
718698
return matchWitnessStructureImpl(req, const_cast<ValueDecl *>(witness),
719-
decomposeFunctionType, ignoreReturnType,
720-
reqThrownError,
721-
witnessThrownError) == std::nullopt;
699+
decomposeFunctionType, ignoreReturnType)
700+
== std::nullopt;
722701
}
723702

724703
RequirementMatch swift::matchWitness(
725704
DeclContext *dc, ValueDecl *req, ValueDecl *witness,
726705
llvm::function_ref<
727-
std::tuple<std::optional<RequirementMatch>, Type, Type>(void)>
706+
std::tuple<std::optional<RequirementMatch>, Type, Type, Type, Type>(void)>
728707
setup,
729708
llvm::function_ref<std::optional<RequirementMatch>(Type, Type)> matchTypes,
730709
llvm::function_ref<RequirementMatch(bool, ArrayRef<OptionalAdjustment>)>
731710
finalize) {
732711
bool decomposeFunctionType = false;
733712
bool ignoreReturnType = false;
734-
Type reqThrownError;
735-
Type witnessThrownError;
736713

737714
if (auto StructuralMismatch = matchWitnessStructureImpl(
738-
req, witness, decomposeFunctionType, ignoreReturnType, reqThrownError,
739-
witnessThrownError)) {
715+
req, witness, decomposeFunctionType, ignoreReturnType)) {
740716
return *StructuralMismatch;
741717
}
742718

743719
// Set up the match, determining the requirement and witness types
744720
// in the process.
745-
Type reqType, witnessType;
721+
Type reqType, witnessType, reqThrownError, witnessThrownError;
746722
{
747723
std::optional<RequirementMatch> result;
748-
std::tie(result, reqType, witnessType) = setup();
724+
std::tie(result, reqType, witnessType, reqThrownError, witnessThrownError) = setup();
749725
if (result) {
750726
return std::move(result.value());
751727
}
@@ -936,7 +912,8 @@ RequirementMatch swift::matchWitness(
936912

937913
case ThrownErrorSubtyping::Subtype:
938914
// If there were no type parameters, we're done.
939-
if (!reqThrownError->hasTypeParameter())
915+
if (!reqThrownError->hasTypeVariable() &&
916+
!reqThrownError->hasTypeParameter())
940917
break;
941918

942919
LLVM_FALLTHROUGH;
@@ -1186,7 +1163,7 @@ swift::matchWitness(WitnessChecker::RequirementEnvironmentCache &reqEnvCache,
11861163

11871164
// Set up the constraint system for matching.
11881165
auto setup =
1189-
[&]() -> std::tuple<std::optional<RequirementMatch>, Type, Type> {
1166+
[&]() -> std::tuple<std::optional<RequirementMatch>, Type, Type, Type, Type> {
11901167
// Construct a constraint system to use to solve the equality between
11911168
// the required type and the witness type.
11921169
cs.emplace(dc, ConstraintSystemFlags::AllowFixes);
@@ -1199,10 +1176,12 @@ swift::matchWitness(WitnessChecker::RequirementEnvironmentCache &reqEnvCache,
11991176
if (syntheticEnv)
12001177
selfTy = syntheticEnv->mapTypeIntoContext(selfTy);
12011178

1179+
12021180
// Open up the type of the requirement.
1181+
SmallVector<OpenedType, 4> reqReplacements;
1182+
12031183
reqLocator =
12041184
cs->getConstraintLocator(req, ConstraintLocator::ProtocolRequirement);
1205-
SmallVector<OpenedType, 4> reqReplacements;
12061185
reqType =
12071186
cs->getTypeOfMemberReference(selfTy, req, dc,
12081187
/*isDynamicResult=*/false,
@@ -1235,14 +1214,17 @@ swift::matchWitness(WitnessChecker::RequirementEnvironmentCache &reqEnvCache,
12351214
}
12361215

12371216
// Open up the witness type.
1217+
SmallVector<OpenedType, 4> witnessReplacements;
1218+
12381219
witnessType = witness->getInterfaceType();
12391220
witnessLocator =
12401221
cs->getConstraintLocator(req, LocatorPathElt::Witness(witness));
12411222
if (witness->getDeclContext()->isTypeContext()) {
12421223
openWitnessType =
1243-
cs->getTypeOfMemberReference(
1244-
selfTy, witness, dc, /*isDynamicResult=*/false,
1245-
FunctionRefInfo::doubleBaseNameApply(), witnessLocator)
1224+
cs->getTypeOfMemberReference(selfTy, witness, dc,
1225+
/*isDynamicResult=*/false,
1226+
FunctionRefInfo::doubleBaseNameApply(),
1227+
witnessLocator, &witnessReplacements)
12461228
.adjustedReferenceType;
12471229
} else {
12481230
openWitnessType =
@@ -1253,7 +1235,37 @@ swift::matchWitness(WitnessChecker::RequirementEnvironmentCache &reqEnvCache,
12531235
}
12541236
openWitnessType = openWitnessType->getRValueType();
12551237

1256-
return std::make_tuple(std::nullopt, reqType, openWitnessType);
1238+
Type reqThrownError;
1239+
Type witnessThrownError;
1240+
1241+
if (auto *witnessASD = dyn_cast<AbstractStorageDecl>(witness)) {
1242+
auto *reqASD = cast<AbstractStorageDecl>(req);
1243+
1244+
// Dig out the thrown error types from the getter so we can compare them
1245+
// later.
1246+
auto getThrownErrorType = [](AbstractStorageDecl *asd) -> Type {
1247+
if (auto getter = asd->getEffectfulGetAccessor()) {
1248+
if (Type thrownErrorType = getter->getThrownInterfaceType()) {
1249+
return thrownErrorType;
1250+
} else if (getter->hasThrows()) {
1251+
return asd->getASTContext().getErrorExistentialType();
1252+
}
1253+
}
1254+
1255+
return asd->getASTContext().getNeverType();
1256+
};
1257+
1258+
reqThrownError = getThrownErrorType(reqASD);
1259+
reqThrownError = cs->openType(reqThrownError, reqReplacements,
1260+
reqLocator);
1261+
1262+
witnessThrownError = getThrownErrorType(witnessASD);
1263+
witnessThrownError = cs->openType(witnessThrownError, witnessReplacements,
1264+
witnessLocator);
1265+
}
1266+
1267+
return std::make_tuple(std::nullopt, reqType, openWitnessType,
1268+
reqThrownError, witnessThrownError);
12571269
};
12581270

12591271
// Match a type in the requirement to a type in the witness.

lib/Sema/TypeCheckProtocol.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ class ConformanceChecker : public WitnessChecker {
180180
RequirementMatch matchWitness(
181181
DeclContext *dc, ValueDecl *req, ValueDecl *witness,
182182
llvm::function_ref<
183-
std::tuple<std::optional<RequirementMatch>, Type, Type>(void)>
183+
std::tuple<std::optional<RequirementMatch>, Type, Type, Type, Type>(void)>
184184
setup,
185185
llvm::function_ref<std::optional<RequirementMatch>(Type, Type)> matchTypes,
186186
llvm::function_ref<RequirementMatch(bool, ArrayRef<OptionalAdjustment>)>

0 commit comments

Comments
 (0)