|
14 | 14 | //
|
15 | 15 | //===----------------------------------------------------------------------===//
|
16 | 16 |
|
| 17 | +#include "TypeChecker.h" |
| 18 | +#include "swift/AST/GenericSignature.h" |
17 | 19 | #include "swift/Sema/ConstraintGraph.h"
|
18 | 20 | #include "swift/Sema/ConstraintSystem.h"
|
19 | 21 | #include "llvm/ADT/BitVector.h"
|
@@ -86,34 +88,6 @@ void forEachDisjunctionChoice(
|
86 | 88 | }
|
87 | 89 | }
|
88 | 90 |
|
89 |
| -static bool isSIMDType(Type type) { |
90 |
| - auto *NTD = dyn_cast_or_null<StructDecl>(type->getAnyNominal()); |
91 |
| - if (!NTD) |
92 |
| - return false; |
93 |
| - |
94 |
| - auto typeName = NTD->getName().str(); |
95 |
| - if (!typeName.startswith("SIMD")) |
96 |
| - return false; |
97 |
| - |
98 |
| - return NTD->getParentModule()->getName().is("Swift"); |
99 |
| -} |
100 |
| - |
101 |
| -static bool isArithmeticOperatorOnSIMDProtocol(ValueDecl *decl) { |
102 |
| - if (!isSIMDOperator(decl)) |
103 |
| - return false; |
104 |
| - |
105 |
| - if (!decl->getBaseIdentifier().isArithmeticOperator()) |
106 |
| - return false; |
107 |
| - |
108 |
| - auto *DC = decl->getDeclContext(); |
109 |
| - if (auto *P = DC->getSelfProtocolDecl()) { |
110 |
| - if (auto knownKind = P->getKnownProtocolKind()) |
111 |
| - return *knownKind == KnownProtocolKind::SIMD; |
112 |
| - } |
113 |
| - |
114 |
| - return false; |
115 |
| -} |
116 |
| - |
117 | 91 | } // end anonymous namespace
|
118 | 92 |
|
119 | 93 | /// Given a set of disjunctions, attempt to determine
|
@@ -181,32 +155,28 @@ static void determineBestChoicesInContext(
|
181 | 155 | resultTypes.push_back(resultType);
|
182 | 156 | }
|
183 | 157 |
|
184 |
| - auto isViableOverload = [&](ValueDecl *decl) { |
185 |
| - // Allow standard arithmetic operator overloads on SIMD protocol |
186 |
| - // to be considered because we can favor them when then argument |
187 |
| - // is a known SIMD<N> type. |
188 |
| - if (isArithmeticOperatorOnSIMDProtocol(decl)) |
189 |
| - return true; |
190 |
| - |
191 |
| - // Don't consider generic overloads because we need conformance |
192 |
| - // checking functionality to determine best favoring, preferring |
193 |
| - // such overloads based only on concrete types leads to subpar |
194 |
| - // choices due to missed information. |
195 |
| - if (decl->getInterfaceType()->is<GenericFunctionType>()) |
196 |
| - return false; |
197 |
| - |
198 |
| - return true; |
199 |
| - }; |
200 |
| - |
201 | 158 | // The choice with the best score.
|
202 | 159 | double bestScore = 0.0;
|
203 | 160 | SmallVector<std::pair<Constraint *, double>, 2> favoredChoices;
|
204 | 161 |
|
205 | 162 | forEachDisjunctionChoice(
|
206 | 163 | cs, disjunction,
|
207 | 164 | [&](Constraint *choice, ValueDecl *decl, FunctionType *overloadType) {
|
208 |
| - if (!isViableOverload(decl)) |
209 |
| - return; |
| 165 | + GenericSignature genericSig; |
| 166 | + { |
| 167 | + if (auto *GF = dyn_cast<AbstractFunctionDecl>(decl)) { |
| 168 | + genericSig = GF->getGenericSignature(); |
| 169 | + } else if (auto *SD = dyn_cast<SubscriptDecl>(decl)) { |
| 170 | + genericSig = SD->getGenericSignature(); |
| 171 | + } |
| 172 | + |
| 173 | + // Let's not consider non-operator generic overloads because we |
| 174 | + // need conformance checking functionality to determine best |
| 175 | + // favoring, preferring such overloads based on concrete types |
| 176 | + // alone leads to subpar choices due to missed information. |
| 177 | + if (genericSig && !decl->isOperator()) |
| 178 | + return; |
| 179 | + } |
210 | 180 |
|
211 | 181 | ParameterListInfo paramListInfo(
|
212 | 182 | overloadType->getParams(), decl,
|
@@ -249,6 +219,23 @@ static void determineBestChoicesInContext(
|
249 | 219 | if (candidateArgumentTypes[i].empty())
|
250 | 220 | continue;
|
251 | 221 |
|
| 222 | + // Check protocol requirement(s) if this parameter is a |
| 223 | + // generic parameter type. |
| 224 | + GenericSignature::RequiredProtocols protocolRequirements; |
| 225 | + if (genericSig) { |
| 226 | + if (auto *GP = paramType->getAs<GenericTypeParamType>()) { |
| 227 | + protocolRequirements = genericSig->getRequiredProtocols(GP); |
| 228 | + // It's a generic parameter which might be connected via |
| 229 | + // same-type constraints to other generic parameters but |
| 230 | + // we cannot check that here, so let's ignore it. |
| 231 | + if (protocolRequirements.empty()) |
| 232 | + continue; |
| 233 | + } |
| 234 | + |
| 235 | + if (paramType->getAs<DependentMemberType>()) |
| 236 | + return; |
| 237 | + } |
| 238 | + |
252 | 239 | // The idea here is to match the parameter type against
|
253 | 240 | // all of the argument candidate types and pick the best
|
254 | 241 | // match (i.e. exact equality one).
|
@@ -281,22 +268,36 @@ static void determineBestChoicesInContext(
|
281 | 268 |
|
282 | 269 | // The specifier only matters for `inout` check.
|
283 | 270 | candidateType = candidateType->getWithoutSpecifierType();
|
284 |
| - // Exact match on one of the candidate bindings. |
285 |
| - if (candidateType->isEqual(paramType)) { |
| 271 | + |
| 272 | + // We don't check generic requirements against literal default |
| 273 | + // types because it creates more noise than signal for operators. |
| 274 | + if (!protocolRequirements.empty() && !isLiteralDefault) { |
| 275 | + if (llvm::all_of( |
| 276 | + protocolRequirements, [&](ProtocolDecl *protocol) { |
| 277 | + return TypeChecker::conformsToProtocol( |
| 278 | + candidateType, protocol, cs.DC->getParentModule(), |
| 279 | + /*allowMissing=*/false); |
| 280 | + })) { |
| 281 | + // Score is lower here because we still prefer concrete |
| 282 | + // overloads over the generic ones when possible. |
| 283 | + bestCandidateScore = std::max(bestCandidateScore, 0.7); |
| 284 | + continue; |
| 285 | + } |
| 286 | + } else if (paramType->hasTypeParameter()) { |
| 287 | + // i.e. Array<Element> or Optional<Wrapped> as a parameter. |
| 288 | + // This is slightly better than all of the conformances matching |
| 289 | + // because the parameter is concrete and could split the graph. |
| 290 | + if (paramType->getAnyNominal() == candidateType->getAnyNominal()) { |
| 291 | + bestCandidateScore = std::max(bestCandidateScore, 0.8); |
| 292 | + continue; |
| 293 | + } |
| 294 | + } else if (candidateType->isEqual(paramType)) { |
| 295 | + // Exact match on one of the candidate bindings. |
286 | 296 | bestCandidateScore =
|
287 | 297 | std::max(bestCandidateScore, isLiteralDefault ? 0.3 : 1.0);
|
288 | 298 | continue;
|
289 | 299 | }
|
290 | 300 |
|
291 |
| - // If argument is SIMD<N> type i.e. SIMD1<...> it's appropriate |
292 |
| - // to favor of the overloads that are declared on SIMD protocol |
293 |
| - // and expect a particular `Scalar` if it's known. |
294 |
| - if (isSIMDType(candidateType) && |
295 |
| - isArithmeticOperatorOnSIMDProtocol(decl)) { |
296 |
| - bestCandidateScore = 1.0; |
297 |
| - continue; |
298 |
| - } |
299 |
| - |
300 | 301 | // Only established arguments could be considered mismatches,
|
301 | 302 | // literal default types should be regarded as holes if they
|
302 | 303 | // didn't match.
|
|
0 commit comments