Skip to content

Commit bc5f70a

Browse files
committed
[CSOptimizer] Allow generic operator overloads without associated type parameters
1 parent 7c1c46d commit bc5f70a

File tree

2 files changed

+59
-59
lines changed

2 files changed

+59
-59
lines changed

lib/Sema/CSOptimizer.cpp

Lines changed: 59 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
//
1515
//===----------------------------------------------------------------------===//
1616

17+
#include "TypeChecker.h"
18+
#include "swift/AST/GenericSignature.h"
1719
#include "swift/Sema/ConstraintGraph.h"
1820
#include "swift/Sema/ConstraintSystem.h"
1921
#include "llvm/ADT/BitVector.h"
@@ -86,34 +88,6 @@ void forEachDisjunctionChoice(
8688
}
8789
}
8890

89-
static bool isSIMDType(Type type) {
90-
auto *NTD = dyn_cast_or_null<StructDecl>(type->getAnyNominal());
91-
if (!NTD)
92-
return false;
93-
94-
auto typeName = NTD->getName().str();
95-
if (!typeName.startswith("SIMD"))
96-
return false;
97-
98-
return NTD->getParentModule()->getName().is("Swift");
99-
}
100-
101-
static bool isArithmeticOperatorOnSIMDProtocol(ValueDecl *decl) {
102-
if (!isSIMDOperator(decl))
103-
return false;
104-
105-
if (!decl->getBaseIdentifier().isArithmeticOperator())
106-
return false;
107-
108-
auto *DC = decl->getDeclContext();
109-
if (auto *P = DC->getSelfProtocolDecl()) {
110-
if (auto knownKind = P->getKnownProtocolKind())
111-
return *knownKind == KnownProtocolKind::SIMD;
112-
}
113-
114-
return false;
115-
}
116-
11791
} // end anonymous namespace
11892

11993
/// Given a set of disjunctions, attempt to determine
@@ -181,32 +155,28 @@ static void determineBestChoicesInContext(
181155
resultTypes.push_back(resultType);
182156
}
183157

184-
auto isViableOverload = [&](ValueDecl *decl) {
185-
// Allow standard arithmetic operator overloads on SIMD protocol
186-
// to be considered because we can favor them when then argument
187-
// is a known SIMD<N> type.
188-
if (isArithmeticOperatorOnSIMDProtocol(decl))
189-
return true;
190-
191-
// Don't consider generic overloads because we need conformance
192-
// checking functionality to determine best favoring, preferring
193-
// such overloads based only on concrete types leads to subpar
194-
// choices due to missed information.
195-
if (decl->getInterfaceType()->is<GenericFunctionType>())
196-
return false;
197-
198-
return true;
199-
};
200-
201158
// The choice with the best score.
202159
double bestScore = 0.0;
203160
SmallVector<std::pair<Constraint *, double>, 2> favoredChoices;
204161

205162
forEachDisjunctionChoice(
206163
cs, disjunction,
207164
[&](Constraint *choice, ValueDecl *decl, FunctionType *overloadType) {
208-
if (!isViableOverload(decl))
209-
return;
165+
GenericSignature genericSig;
166+
{
167+
if (auto *GF = dyn_cast<AbstractFunctionDecl>(decl)) {
168+
genericSig = GF->getGenericSignature();
169+
} else if (auto *SD = dyn_cast<SubscriptDecl>(decl)) {
170+
genericSig = SD->getGenericSignature();
171+
}
172+
173+
// Let's not consider non-operator generic overloads because we
174+
// need conformance checking functionality to determine best
175+
// favoring, preferring such overloads based on concrete types
176+
// alone leads to subpar choices due to missed information.
177+
if (genericSig && !decl->isOperator())
178+
return;
179+
}
210180

211181
ParameterListInfo paramListInfo(
212182
overloadType->getParams(), decl,
@@ -249,6 +219,23 @@ static void determineBestChoicesInContext(
249219
if (candidateArgumentTypes[i].empty())
250220
continue;
251221

222+
// Check protocol requirement(s) if this parameter is a
223+
// generic parameter type.
224+
GenericSignature::RequiredProtocols protocolRequirements;
225+
if (genericSig) {
226+
if (auto *GP = paramType->getAs<GenericTypeParamType>()) {
227+
protocolRequirements = genericSig->getRequiredProtocols(GP);
228+
// It's a generic parameter which might be connected via
229+
// same-type constraints to other generic parameters but
230+
// we cannot check that here, so let's ignore it.
231+
if (protocolRequirements.empty())
232+
continue;
233+
}
234+
235+
if (paramType->getAs<DependentMemberType>())
236+
return;
237+
}
238+
252239
// The idea here is to match the parameter type against
253240
// all of the argument candidate types and pick the best
254241
// match (i.e. exact equality one).
@@ -281,22 +268,36 @@ static void determineBestChoicesInContext(
281268

282269
// The specifier only matters for `inout` check.
283270
candidateType = candidateType->getWithoutSpecifierType();
284-
// Exact match on one of the candidate bindings.
285-
if (candidateType->isEqual(paramType)) {
271+
272+
// We don't check generic requirements against literal default
273+
// types because it creates more noise than signal for operators.
274+
if (!protocolRequirements.empty() && !isLiteralDefault) {
275+
if (llvm::all_of(
276+
protocolRequirements, [&](ProtocolDecl *protocol) {
277+
return TypeChecker::conformsToProtocol(
278+
candidateType, protocol, cs.DC->getParentModule(),
279+
/*allowMissing=*/false);
280+
})) {
281+
// Score is lower here because we still prefer concrete
282+
// overloads over the generic ones when possible.
283+
bestCandidateScore = std::max(bestCandidateScore, 0.7);
284+
continue;
285+
}
286+
} else if (paramType->hasTypeParameter()) {
287+
// i.e. Array<Element> or Optional<Wrapped> as a parameter.
288+
// This is slightly better than all of the conformances matching
289+
// because the parameter is concrete and could split the graph.
290+
if (paramType->getAnyNominal() == candidateType->getAnyNominal()) {
291+
bestCandidateScore = std::max(bestCandidateScore, 0.8);
292+
continue;
293+
}
294+
} else if (candidateType->isEqual(paramType)) {
295+
// Exact match on one of the candidate bindings.
286296
bestCandidateScore =
287297
std::max(bestCandidateScore, isLiteralDefault ? 0.3 : 1.0);
288298
continue;
289299
}
290300

291-
// If argument is SIMD<N> type i.e. SIMD1<...> it's appropriate
292-
// to favor of the overloads that are declared on SIMD protocol
293-
// and expect a particular `Scalar` if it's known.
294-
if (isSIMDType(candidateType) &&
295-
isArithmeticOperatorOnSIMDProtocol(decl)) {
296-
bestCandidateScore = 1.0;
297-
continue;
298-
}
299-
300301
// Only established arguments could be considered mismatches,
301302
// literal default types should be regarded as holes if they
302303
// didn't match.

validation-test/Sema/type_checker_perf/slow/rdar17170728.swift renamed to validation-test/Sema/type_checker_perf/fast/rdar17170728.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ let i: Int? = 1
55
let j: Int?
66
let k: Int? = 2
77

8-
// expected-error@+1 {{the compiler is unable to type-check this expression in reasonable time}}
98
let _ = [i, j, k].reduce(0 as Int?) {
109
$0 != nil && $1 != nil ? $0! + $1! : ($0 != nil ? $0! : ($1 != nil ? $1! : nil))
1110
}

0 commit comments

Comments
 (0)