Skip to content

Add inference for 'Promise' based on call to 'resolve' #40466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 94 additions & 20 deletions src/compiler/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19458,7 +19458,7 @@ namespace ts {
source = getUnionType(sources);
}
else if (target.flags & TypeFlags.Intersection && some((<IntersectionType>target).types,
t => !!getInferenceInfoForType(t) || (isGenericMappedType(t) && !!getInferenceInfoForType(getHomomorphicTypeVariable(t) || neverType)))) {
t => !!getInferenceInfoForType(inferences, t) || (isGenericMappedType(t) && !!getInferenceInfoForType(inferences, getHomomorphicTypeVariable(t) || neverType)))) {
// We reduce intersection types only when they contain naked type parameters. For example, when
// inferring from 'string[] & { extra: any }' to 'string[] & T' we want to remove string[] and
// infer { extra: any } for T. But when inferring to 'string[] & Iterable<T>' we want to keep the
Expand Down Expand Up @@ -19490,7 +19490,7 @@ namespace ts {
(priority & InferencePriority.ReturnType && (source === autoType || source === autoArrayType)) || isFromInferenceBlockedSource(source)) {
return;
}
const inference = getInferenceInfoForType(target);
const inference = getInferenceInfoForType(inferences, target);
if (inference) {
if (!inference.isFixed) {
if (inference.priority === undefined || priority < inference.priority) {
Expand Down Expand Up @@ -19687,21 +19687,10 @@ namespace ts {
}
}

function getInferenceInfoForType(type: Type) {
if (type.flags & TypeFlags.TypeVariable) {
for (const inference of inferences) {
if (type === inference.typeParameter) {
return inference;
}
}
}
return undefined;
}

function getSingleTypeVariableFromIntersectionTypes(types: Type[]) {
let typeVariable: Type | undefined;
for (const type of types) {
const t = type.flags & TypeFlags.Intersection && find((<IntersectionType>type).types, t => !!getInferenceInfoForType(t));
const t = type.flags & TypeFlags.Intersection && find((<IntersectionType>type).types, t => !!getInferenceInfoForType(inferences, t));
if (!t || typeVariable && t !== typeVariable) {
return undefined;
}
Expand All @@ -19722,7 +19711,7 @@ namespace ts {
// equal priority (i.e. of equal quality) to what we would infer for a naked type
// parameter.
for (const t of targets) {
if (getInferenceInfoForType(t)) {
if (getInferenceInfoForType(inferences, t)) {
nakedTypeVariable = t;
typeVariableCount++;
}
Expand Down Expand Up @@ -19764,7 +19753,7 @@ namespace ts {
// make from nested naked type variables and given slightly higher priority by virtue
// of being first in the candidates array.
for (const t of targets) {
if (getInferenceInfoForType(t)) {
if (getInferenceInfoForType(inferences, t)) {
typeVariableCount++;
}
else {
Expand All @@ -19778,7 +19767,7 @@ namespace ts {
// we only infer to single naked type variables.
if (targetFlags & TypeFlags.Intersection ? typeVariableCount === 1 : typeVariableCount > 0) {
for (const t of targets) {
if (getInferenceInfoForType(t)) {
if (getInferenceInfoForType(inferences, t)) {
inferWithPriority(source, t, InferencePriority.NakedTypeVariable);
}
}
Expand All @@ -19798,7 +19787,7 @@ namespace ts {
// where T is a type variable. Use inferTypeForHomomorphicMappedType to infer a suitable source
// type and then make a secondary inference from that type to T. We make a secondary inference
// such that direct inferences to T get priority over inferences to Partial<T>, for example.
const inference = getInferenceInfoForType((<IndexType>constraintType).type);
const inference = getInferenceInfoForType(inferences, (<IndexType>constraintType).type);
if (inference && !inference.isFixed && !isFromInferenceBlockedSource(source)) {
const inferredType = inferTypeForHomomorphicMappedType(source, target, <IndexType>constraintType);
if (inferredType) {
Expand Down Expand Up @@ -19909,7 +19898,7 @@ namespace ts {
const middleLength = targetArity - startLength - endLength;
if (middleLength === 2 && elementFlags[startLength] & elementFlags[startLength + 1] & ElementFlags.Variadic && isTupleType(source)) {
// Middle of target is [...T, ...U] and source is tuple type
const targetInfo = getInferenceInfoForType(elementTypes[startLength]);
const targetInfo = getInferenceInfoForType(inferences, elementTypes[startLength]);
if (targetInfo && targetInfo.impliedArity !== undefined) {
// Infer slices from source based on implied arity of T.
inferFromTypes(sliceTupleType(source, startLength, sourceEndLength + sourceArity - targetInfo.impliedArity), elementTypes[startLength]);
Expand Down Expand Up @@ -20004,6 +19993,22 @@ namespace ts {
}
}

function getInferenceInfoForType(inferences: InferenceInfo[], type: Type) {
if (type.flags & TypeFlags.TypeVariable) {
for (const inference of inferences) {
if (type === inference.typeParameter) {
return inference;
}
}
}
return undefined;
}

function hasHigherPriorityInference(inferences: InferenceInfo[], type: Type, priority: InferencePriority) {
const inference = getInferenceInfoForType(inferences, type);
return !!inference && (inference.isFixed || inference.priority !== undefined && inference.priority < priority);
}

function isTypeOrBaseIdenticalTo(s: Type, t: Type) {
return isTypeIdenticalTo(s, t) || !!(t.flags & TypeFlags.String && s.flags & TypeFlags.StringLiteral || t.flags & TypeFlags.Number && s.flags & TypeFlags.NumberLiteral);
}
Expand Down Expand Up @@ -20661,7 +20666,7 @@ namespace ts {
}

function isTypeSubsetOf(source: Type, target: Type) {
return source === target || target.flags & TypeFlags.Union && isTypeSubsetOfUnion(source, <UnionType>target);
return source === target || !!(target.flags & TypeFlags.Union) && isTypeSubsetOfUnion(source, <UnionType>target);
}

function isTypeSubsetOfUnion(source: Type, target: UnionType) {
Expand Down Expand Up @@ -26020,6 +26025,75 @@ namespace ts {
inferTypes(context.inferences, spreadType, restType);
}

// Attempt to solve for `T` in `new Promise<T>(resolve => resolve(t))` (also known as the "revealing constructor" pattern).
// To avoid too much complexity, we use a very restrictive heuristic:
// - Restrict to NewExpression to reduce overhead.
// - `signature` has a single parameter (`callbackType`)
// - `callbackType` has a single call signature (`callbackSignature`) (i.e., `executor: (resolve: (value: T | PromiseLike<T>) => void) => void`)
// - `callbackSignature` has at least one parameter (`innerCallbackType`)
// - `innerCallbackType` has a single call signature (`innerCallbackSignature`) (i.e., `resolve: (value: T | PromiseLike<T>) => void`)
// - `innerCallbackSignature` has a single parameter (`innerCallbackValueType`)
// - `innerCallbackValueType` contains type variable for which we are gathering inferences (i.e. `value: T | PromiseLike<T>`)
// - The function (`callbackFunc`) passed as the argument to the parameter `callbackType` must be inline (i.e., an arrow function or function expression)
// - `callbackFunc` must have one parameter (`innerCallbackParam`) that is untyped (and thus would be contextually typed by `innerCallbackType`)
// If the above conditions are met then:
// - Determine the name in function `callbackFunc` given to the parameter `innerCallbackParam`
// - Find all references to that name in the body of the function `callbackFunc`
// - If `innerCallbackParam` is called directly, collect inferences for the type of the argument passed to the parameter (`innerCallbackValueType`)
// - If `innerCallbackParam` is passed as the argument to another function, we can attempt to use the contextual type of that parameter for inference.
if (isNewExpression(node) && argCount === 1) {
const callbackType = getTypeAtPosition(signature, 0); // executor: ...
const callbackSignature = getSingleCallSignature(callbackType); // (resolve: (...) => ...) => ...
const callbackFunc = skipParentheses(args[0]);
if (callbackSignature && isFunctionExpressionOrArrowFunction(callbackFunc)) {
const sourceFile = getSourceFileOfNode(callbackFunc);
for (let callbackParamIndex = 0; callbackParamIndex < callbackFunc.parameters.length; callbackParamIndex++) {
const innerCallbackType = tryGetTypeAtPosition(callbackSignature, callbackParamIndex); // resolve: ...
const innerCallbackSignature = innerCallbackType && getSingleCallSignature(innerCallbackType); // (value: T | PromiseLike<T>) => ...
const innerCallbackParam = callbackFunc.parameters[callbackParamIndex];
if (innerCallbackSignature && getParameterCount(innerCallbackSignature) === 1 && isIdentifier(innerCallbackParam.name) && !getEffectiveTypeAnnotationNode(innerCallbackParam)) {
const innerCallbackValueType = getTypeAtPosition(innerCallbackSignature, 0); // value: ...
// Don't do the work if we already have a higher-priority inference.
if (some(signature.typeParameters, typeParam => isTypeSubsetOf(typeParam, innerCallbackValueType) && !hasHigherPriorityInference(context.inferences, typeParam, InferencePriority.RevealingConstructor))) {
const innerCallbackSymbol = getSymbolOfNode(innerCallbackParam);
const positions = getPossibleSymbolReferencePositions(sourceFile, idText(innerCallbackParam.name), callbackFunc);
if (positions.length) {
const candidateReferences = findNodesAtPositions(callbackFunc, positions, sourceFile);
if (candidateReferences.length) {
// The callback will not have a type associated with it, so we temporarily assign it `anyFunctionType` so that
// we do not trigger implicit `any` errors and so that we do not create inferences from it.
const links = getSymbolLinks(innerCallbackSymbol);
const savedType = links.type;
links.type = anyFunctionType;
// collect types for inferences to ppB
for (const candidateReference of candidateReferences) {
if (!isIdentifier(candidateReference) || candidateReference === innerCallbackParam.name) continue;
const candidateReferenceSymbol = resolveName(candidateReference, candidateReference.escapedText, SymbolFlags.Value, /*nameNotFoundMessage*/ undefined, /*nameArg*/ undefined, /*isUse*/ false);
if (candidateReferenceSymbol !== innerCallbackSymbol) continue;
if (isCallExpression(candidateReference.parent) && candidateReference === candidateReference.parent.expression) {
const argType =
candidateReference.parent.arguments.length >= 1 ? checkExpression(candidateReference.parent.arguments[0]) :
voidType;
inferTypes(context.inferences, argType, innerCallbackValueType, InferencePriority.RevealingConstructor);
}
else if (isCallOrNewExpression(candidateReference.parent) && contains(candidateReference.parent.arguments, candidateReference)) {
const callbackType = getContextualType(candidateReference);
const callbackSignature = callbackType && getSingleCallSignature(callbackType);
const callbackParamType = callbackSignature && tryGetTypeAtPosition(callbackSignature, 0);
if (callbackParamType) {
inferTypes(context.inferences, callbackParamType, innerCallbackValueType, InferencePriority.RevealingConstructor);
}
}
}
links.type = savedType;
}
}
}
}
}
}
}

return getInferredTypes(context);
}

Expand Down
7 changes: 4 additions & 3 deletions src/compiler/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5437,10 +5437,11 @@ namespace ts {
ReturnType = 1 << 6, // Inference made from return type of generic function
LiteralKeyof = 1 << 7, // Inference made from a string literal to a keyof T
NoConstraints = 1 << 8, // Don't infer from constraints of instantiable types
AlwaysStrict = 1 << 9, // Always use strict rules for contravariant inferences
MaxValue = 1 << 10, // Seed for inference priority tracking
RevealingConstructor = 1 << 9, // Inference made to a callback in a "revealing constructor" (i.e., `new Promise(resolve => resolve(1))`)
AlwaysStrict = 1 << 10, // Always use strict rules for contravariant inferences
MaxValue = 1 << 11, // Seed for inference priority tracking

PriorityImpliesCombination = ReturnType | MappedTypeConstraint | LiteralKeyof, // These priorities imply that the resulting type should be a combination of all candidates
PriorityImpliesCombination = ReturnType | MappedTypeConstraint | LiteralKeyof | RevealingConstructor, // These priorities imply that the resulting type should be a combination of all candidates
Circularity = -1, // Inference circularity (value less than all other priorities)
}

Expand Down
73 changes: 73 additions & 0 deletions src/compiler/utilities.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6903,4 +6903,77 @@ namespace ts {
return bindParentToChildIgnoringJSDoc(child, parent) || bindJSDoc(child);
}
}

export function getPossibleSymbolReferencePositions(sourceFile: SourceFile, symbolName: string, container: Node = sourceFile) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work on something with Unicode escapes. It's not a blocker, but it'd be very strange.

new Promise(resolve => {
    \u0072\u0065\u0073\u006f\u006c\u0076\u0065(100);
})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, and that's documented in the find-all-references code where this function originally came from. Its unfortunate, but this heuristic is intended to improve inference in a very specific case. It might be feasible to rewrite the function to be a bit smarter (though possibly slower), by scanning the text instead of using .indexOf. That would likely help with find-all-references as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Though I'm not sure a reasonable change based on scanning could be made prior to 4.1-beta)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, scratch that. Its feasible to do. The implementation that supports unicode escapes depends on using a regular expression, so is costlier. I have a version I can push that actually tracks on the source file whether it contains identifiers with unicode escapes, and if it doesn't then it can use the existing indexOf logic, which should be faster:

image

I might wait on that until after this PR, however, as I'd need to also add fourslash tests for find-all-references which could be time consuming and could possibly wait until after 4.1-beta.

const positions: number[] = [];

/// TODO: Cache symbol existence for files to save text search
// Also, need to make this work for unicode escapes.

// Be resilient in the face of a symbol with no name or zero length name
if (!symbolName || !symbolName.length) {
return positions as readonly number[] as SortedReadonlyArray<number>;
}

const text = sourceFile.text;
const sourceLength = text.length;
const symbolNameLength = symbolName.length;

let position = text.indexOf(symbolName, container.pos);
while (position >= 0) {
// If we are past the end, stop looking
if (position > container.end) break;

// We found a match. Make sure it's not part of a larger word (i.e. the char
// before and after it have to be a non-identifier char).
const endPosition = position + symbolNameLength;

if ((position === 0 || !isIdentifierPart(text.charCodeAt(position - 1), ScriptTarget.Latest)) &&
(endPosition === sourceLength || !isIdentifierPart(text.charCodeAt(endPosition), ScriptTarget.Latest))) {
// Found a real match. Keep searching.
positions.push(position);
}
position = text.indexOf(symbolName, position + symbolNameLength + 1);
}

return positions as readonly number[] as SortedReadonlyArray<number>;
}

export function findNodesAtPositions(container: Node, positions: SortedReadonlyArray<number>, sourceFile = getSourceFileOfNode(container)) {
let i = 0;
const results: Node[] = [];
visit(container);
return results;
function visit(node: Node) {
const startPos = skipTrivia(sourceFile.text, node.pos);
while (i < positions.length) {
const pos = positions[i];
const startOffset = i;
if (pos >= node.pos && pos < node.end) {
if (pos < startPos) {
// The position exists in the node's trivia, so we should skip it and
// move on to the next position
i++;
}
else {
const length = results.length;
forEachChild(node, visit);
if (length === results.length) {
// no children were added, so add this node
results.push(node);
// advance to the next position
i++;
}
}
}
else {
// If we've advanced past the end of our parent we should break out of
// the containing `forEachChild`. Otherwise, the position is not contained
// within this node so we should skip to the next node
return !!node.parent && pos > node.parent.end;
}
Debug.assert(i !== startOffset, "Position did not advance");
}
}
}
}
Loading