@@ -512,8 +512,7 @@ checkEffects(AbstractStorageDecl *witness, AbstractStorageDecl *req) {
512
512
// / to be used by `matchWitness`.
513
513
static std::optional<RequirementMatch>
514
514
matchWitnessStructureImpl (ValueDecl *req, ValueDecl *witness,
515
- bool &decomposeFunctionType, bool &ignoreReturnType,
516
- Type &reqThrownError, Type &witnessThrownError) {
515
+ bool &decomposeFunctionType, bool &ignoreReturnType) {
517
516
assert (!req->isInvalid () && " Cannot have an invalid requirement here" );
518
517
519
518
// / Make sure the witness is of the same kind as the requirement.
@@ -658,23 +657,6 @@ matchWitnessStructureImpl(ValueDecl *req, ValueDecl *witness,
658
657
659
658
// Decompose the parameters for subscript declarations.
660
659
decomposeFunctionType = isa<SubscriptDecl>(req);
661
-
662
- // Dig out the thrown error types from the getter so we can compare them
663
- // later.
664
- auto getThrownErrorType = [](AbstractStorageDecl *asd) -> Type {
665
- if (auto getter = asd->getEffectfulGetAccessor ()) {
666
- if (Type thrownErrorType = getter->getThrownInterfaceType ()) {
667
- return thrownErrorType;
668
- } else if (getter->hasThrows ()) {
669
- return asd->getASTContext ().getAnyExistentialType ();
670
- }
671
- }
672
-
673
- return asd->getASTContext ().getNeverType ();
674
- };
675
-
676
- reqThrownError = getThrownErrorType (reqASD);
677
- witnessThrownError = getThrownErrorType (witnessASD);
678
660
} else if (isa<ConstructorDecl>(witness)) {
679
661
decomposeFunctionType = true ;
680
662
ignoreReturnType = true ;
@@ -713,39 +695,33 @@ bool swift::TypeChecker::witnessStructureMatches(ValueDecl *req,
713
695
const ValueDecl *witness) {
714
696
bool decomposeFunctionType = false ;
715
697
bool ignoreReturnType = false ;
716
- Type reqThrownError;
717
- Type witnessThrownError;
718
698
return matchWitnessStructureImpl (req, const_cast <ValueDecl *>(witness),
719
- decomposeFunctionType, ignoreReturnType,
720
- reqThrownError,
721
- witnessThrownError) == std::nullopt;
699
+ decomposeFunctionType, ignoreReturnType)
700
+ == std::nullopt;
722
701
}
723
702
724
703
RequirementMatch swift::matchWitness (
725
704
DeclContext *dc, ValueDecl *req, ValueDecl *witness,
726
705
llvm::function_ref<
727
- std::tuple<std::optional<RequirementMatch>, Type, Type>(void )>
706
+ std::tuple<std::optional<RequirementMatch>, Type, Type, Type, Type >(void )>
728
707
setup,
729
708
llvm::function_ref<std::optional<RequirementMatch>(Type, Type)> matchTypes,
730
709
llvm::function_ref<RequirementMatch(bool , ArrayRef<OptionalAdjustment>)>
731
710
finalize) {
732
711
bool decomposeFunctionType = false ;
733
712
bool ignoreReturnType = false ;
734
- Type reqThrownError;
735
- Type witnessThrownError;
736
713
737
714
if (auto StructuralMismatch = matchWitnessStructureImpl (
738
- req, witness, decomposeFunctionType, ignoreReturnType, reqThrownError,
739
- witnessThrownError)) {
715
+ req, witness, decomposeFunctionType, ignoreReturnType)) {
740
716
return *StructuralMismatch;
741
717
}
742
718
743
719
// Set up the match, determining the requirement and witness types
744
720
// in the process.
745
- Type reqType, witnessType;
721
+ Type reqType, witnessType, reqThrownError, witnessThrownError ;
746
722
{
747
723
std::optional<RequirementMatch> result;
748
- std::tie (result, reqType, witnessType) = setup ();
724
+ std::tie (result, reqType, witnessType, reqThrownError, witnessThrownError ) = setup ();
749
725
if (result) {
750
726
return std::move (result.value ());
751
727
}
@@ -936,7 +912,8 @@ RequirementMatch swift::matchWitness(
936
912
937
913
case ThrownErrorSubtyping::Subtype:
938
914
// If there were no type parameters, we're done.
939
- if (!reqThrownError->hasTypeParameter ())
915
+ if (!reqThrownError->hasTypeVariable () &&
916
+ !reqThrownError->hasTypeParameter ())
940
917
break ;
941
918
942
919
LLVM_FALLTHROUGH;
@@ -1186,7 +1163,7 @@ swift::matchWitness(WitnessChecker::RequirementEnvironmentCache &reqEnvCache,
1186
1163
1187
1164
// Set up the constraint system for matching.
1188
1165
auto setup =
1189
- [&]() -> std::tuple<std::optional<RequirementMatch>, Type, Type> {
1166
+ [&]() -> std::tuple<std::optional<RequirementMatch>, Type, Type, Type, Type > {
1190
1167
// Construct a constraint system to use to solve the equality between
1191
1168
// the required type and the witness type.
1192
1169
cs.emplace (dc, ConstraintSystemFlags::AllowFixes);
@@ -1199,10 +1176,12 @@ swift::matchWitness(WitnessChecker::RequirementEnvironmentCache &reqEnvCache,
1199
1176
if (syntheticEnv)
1200
1177
selfTy = syntheticEnv->mapTypeIntoContext (selfTy);
1201
1178
1179
+
1202
1180
// Open up the type of the requirement.
1181
+ SmallVector<OpenedType, 4 > reqReplacements;
1182
+
1203
1183
reqLocator =
1204
1184
cs->getConstraintLocator (req, ConstraintLocator::ProtocolRequirement);
1205
- SmallVector<OpenedType, 4 > reqReplacements;
1206
1185
reqType =
1207
1186
cs->getTypeOfMemberReference (selfTy, req, dc,
1208
1187
/* isDynamicResult=*/ false ,
@@ -1235,14 +1214,17 @@ swift::matchWitness(WitnessChecker::RequirementEnvironmentCache &reqEnvCache,
1235
1214
}
1236
1215
1237
1216
// Open up the witness type.
1217
+ SmallVector<OpenedType, 4 > witnessReplacements;
1218
+
1238
1219
witnessType = witness->getInterfaceType ();
1239
1220
witnessLocator =
1240
1221
cs->getConstraintLocator (req, LocatorPathElt::Witness (witness));
1241
1222
if (witness->getDeclContext ()->isTypeContext ()) {
1242
1223
openWitnessType =
1243
- cs->getTypeOfMemberReference (
1244
- selfTy, witness, dc, /* isDynamicResult=*/ false ,
1245
- FunctionRefInfo::doubleBaseNameApply (), witnessLocator)
1224
+ cs->getTypeOfMemberReference (selfTy, witness, dc,
1225
+ /* isDynamicResult=*/ false ,
1226
+ FunctionRefInfo::doubleBaseNameApply (),
1227
+ witnessLocator, &witnessReplacements)
1246
1228
.adjustedReferenceType ;
1247
1229
} else {
1248
1230
openWitnessType =
@@ -1253,7 +1235,37 @@ swift::matchWitness(WitnessChecker::RequirementEnvironmentCache &reqEnvCache,
1253
1235
}
1254
1236
openWitnessType = openWitnessType->getRValueType ();
1255
1237
1256
- return std::make_tuple (std::nullopt, reqType, openWitnessType);
1238
+ Type reqThrownError;
1239
+ Type witnessThrownError;
1240
+
1241
+ if (auto *witnessASD = dyn_cast<AbstractStorageDecl>(witness)) {
1242
+ auto *reqASD = cast<AbstractStorageDecl>(req);
1243
+
1244
+ // Dig out the thrown error types from the getter so we can compare them
1245
+ // later.
1246
+ auto getThrownErrorType = [](AbstractStorageDecl *asd) -> Type {
1247
+ if (auto getter = asd->getEffectfulGetAccessor ()) {
1248
+ if (Type thrownErrorType = getter->getThrownInterfaceType ()) {
1249
+ return thrownErrorType;
1250
+ } else if (getter->hasThrows ()) {
1251
+ return asd->getASTContext ().getErrorExistentialType ();
1252
+ }
1253
+ }
1254
+
1255
+ return asd->getASTContext ().getNeverType ();
1256
+ };
1257
+
1258
+ reqThrownError = getThrownErrorType (reqASD);
1259
+ reqThrownError = cs->openType (reqThrownError, reqReplacements,
1260
+ reqLocator);
1261
+
1262
+ witnessThrownError = getThrownErrorType (witnessASD);
1263
+ witnessThrownError = cs->openType (witnessThrownError, witnessReplacements,
1264
+ witnessLocator);
1265
+ }
1266
+
1267
+ return std::make_tuple (std::nullopt, reqType, openWitnessType,
1268
+ reqThrownError, witnessThrownError);
1257
1269
};
1258
1270
1259
1271
// Match a type in the requirement to a type in the witness.
0 commit comments