Skip to content

Make TypeQuery more general, handling nonboolean queries. #3084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 12, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2375,24 +2375,24 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl
return c.copy_modified(ret_type=new_ret_type)


class ArgInferSecondPassQuery(types.TypeQuery):
class ArgInferSecondPassQuery(types.TypeQuery[bool]):
"""Query whether an argument type should be inferred in the second pass.

The result is True if the type has a type variable in a callable return
type anywhere. For example, the result for Callable[[], T] is True if t is
a type variable.
"""
def __init__(self) -> None:
super().__init__(False, types.ANY_TYPE_STRATEGY)
super().__init__(any)

def visit_callable_type(self, t: CallableType) -> bool:
return self.query_types(t.arg_types) or t.accept(HasTypeVarQuery())


class HasTypeVarQuery(types.TypeQuery):
class HasTypeVarQuery(types.TypeQuery[bool]):
"""Visitor for querying whether a type has a type variable component."""
def __init__(self) -> None:
super().__init__(False, types.ANY_TYPE_STRATEGY)
super().__init__(any)

def visit_type_var(self, t: TypeVarType) -> bool:
return True
Expand All @@ -2402,10 +2402,10 @@ def has_erased_component(t: Type) -> bool:
return t is not None and t.accept(HasErasedComponentsQuery())


class HasErasedComponentsQuery(types.TypeQuery):
class HasErasedComponentsQuery(types.TypeQuery[bool]):
"""Visitor for querying whether a type has an erased component."""
def __init__(self) -> None:
super().__init__(False, types.ANY_TYPE_STRATEGY)
super().__init__(any)

def visit_erased_type(self, t: ErasedType) -> bool:
return True
Expand Down
7 changes: 3 additions & 4 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from mypy.types import (
CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneTyp, TypeVarType,
Instance, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType,
DeletedType, UninhabitedType, TypeType, TypeVarId, TypeQuery, ALL_TYPES_STRATEGY,
is_named_instance
DeletedType, UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance
)
from mypy.maptype import map_instance_to_supertype
from mypy import nodes
Expand Down Expand Up @@ -250,9 +249,9 @@ def is_complete_type(typ: Type) -> bool:
return typ.accept(CompleteTypeVisitor())


class CompleteTypeVisitor(TypeQuery):
class CompleteTypeVisitor(TypeQuery[bool]):
def __init__(self) -> None:
super().__init__(default=True, strategy=ALL_TYPES_STRATEGY)
super().__init__(all)

def visit_none_type(self, t: NoneTyp) -> bool:
return experiments.STRICT_OPTIONAL
Expand Down
7 changes: 3 additions & 4 deletions mypy/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

from mypy.traverser import TraverserVisitor
from mypy.types import (
Type, AnyType, Instance, FunctionLike, TupleType, TypeVarType,
TypeQuery, ANY_TYPE_STRATEGY, CallableType
Type, AnyType, Instance, FunctionLike, TupleType, TypeVarType, TypeQuery, CallableType
)
from mypy import nodes
from mypy.nodes import (
Expand Down Expand Up @@ -226,9 +225,9 @@ def is_imprecise(t: Type) -> bool:
return t.accept(HasAnyQuery())


class HasAnyQuery(TypeQuery):
class HasAnyQuery(TypeQuery[bool]):
def __init__(self) -> None:
super().__init__(False, ANY_TYPE_STRATEGY)
super().__init__(any)

def visit_any(self, t: AnyType) -> bool:
return True
Expand Down
107 changes: 38 additions & 69 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import OrderedDict
from typing import (
Any, TypeVar, Dict, List, Tuple, cast, Generic, Set, Sequence, Optional, Union, Iterable,
NamedTuple,
NamedTuple, Callable,
)

import mypy.nodes
Expand Down Expand Up @@ -1500,109 +1500,78 @@ def keywords_str(self, a: Iterable[Tuple[str, Type]]) -> str:
])


# These constants define the method used by TypeQuery to combine multiple
# query results, e.g. for tuple types. The strategy is not used for empty
# result lists; in that case the default value takes precedence.
ANY_TYPE_STRATEGY = 0 # Return True if any of the results are True.
ALL_TYPES_STRATEGY = 1 # Return True if all of the results are True.
class TypeQuery(Generic[T], TypeVisitor[T]):
"""Visitor for performing queries of types.

strategy is used to combine results for a series of types

class TypeQuery(TypeVisitor[bool]):
"""Visitor for performing simple boolean queries of types.

This class allows defining the default value for leafs to simplify the
implementation of many queries.
Common use cases involve a boolean query using `any` or `all`
"""

default = False # Default result
strategy = 0 # Strategy for combining multiple values (ANY_TYPE_STRATEGY or ALL_TYPES_...).

def __init__(self, default: bool, strategy: int) -> None:
"""Construct a query visitor.

Use the given default result and strategy for combining
multiple results. The strategy must be either
ANY_TYPE_STRATEGY or ALL_TYPES_STRATEGY.
"""
self.default = default
def __init__(self, strategy: Callable[[Iterable[T]], T]) -> None:
self.strategy = strategy

def visit_unbound_type(self, t: UnboundType) -> bool:
return self.default
def visit_unbound_type(self, t: UnboundType) -> T:
return self.query_types(t.args)

def visit_type_list(self, t: TypeList) -> bool:
return self.default
def visit_type_list(self, t: TypeList) -> T:
return self.query_types(t.items)

def visit_any(self, t: AnyType) -> bool:
return self.default
def visit_any(self, t: AnyType) -> T:
return self.strategy([])

def visit_uninhabited_type(self, t: UninhabitedType) -> bool:
return self.default
def visit_uninhabited_type(self, t: UninhabitedType) -> T:
return self.strategy([])

def visit_none_type(self, t: NoneTyp) -> bool:
return self.default
def visit_none_type(self, t: NoneTyp) -> T:
return self.strategy([])

def visit_erased_type(self, t: ErasedType) -> bool:
return self.default
def visit_erased_type(self, t: ErasedType) -> T:
return self.strategy([])

def visit_deleted_type(self, t: DeletedType) -> bool:
return self.default
def visit_deleted_type(self, t: DeletedType) -> T:
return self.strategy([])

def visit_type_var(self, t: TypeVarType) -> bool:
return self.default
def visit_type_var(self, t: TypeVarType) -> T:
return self.strategy([])

def visit_partial_type(self, t: PartialType) -> bool:
return self.default
def visit_partial_type(self, t: PartialType) -> T:
return self.query_types(t.inner_types)

def visit_instance(self, t: Instance) -> bool:
def visit_instance(self, t: Instance) -> T:
return self.query_types(t.args)

def visit_callable_type(self, t: CallableType) -> bool:
def visit_callable_type(self, t: CallableType) -> T:
# FIX generics
return self.query_types(t.arg_types + [t.ret_type])

def visit_tuple_type(self, t: TupleType) -> bool:
def visit_tuple_type(self, t: TupleType) -> T:
return self.query_types(t.items)

def visit_typeddict_type(self, t: TypedDictType) -> bool:
def visit_typeddict_type(self, t: TypedDictType) -> T:
return self.query_types(t.items.values())

def visit_star_type(self, t: StarType) -> bool:
def visit_star_type(self, t: StarType) -> T:
return t.type.accept(self)

def visit_union_type(self, t: UnionType) -> bool:
def visit_union_type(self, t: UnionType) -> T:
return self.query_types(t.items)

def visit_overloaded(self, t: Overloaded) -> bool:
def visit_overloaded(self, t: Overloaded) -> T:
return self.query_types(t.items())

def visit_type_type(self, t: TypeType) -> bool:
def visit_type_type(self, t: TypeType) -> T:
return t.item.accept(self)

def query_types(self, types: Iterable[Type]) -> bool:
def visit_ellipsis_type(self, t: EllipsisType) -> T:
return self.strategy([])

def query_types(self, types: Iterable[Type]) -> T:
"""Perform a query for a list of types.

Use the strategy constant to combine the results.
Use the strategy to combine the results.
"""
if not types:
# Use default result for empty list.
return self.default
if self.strategy == ANY_TYPE_STRATEGY:
# Return True if at least one component is true.
res = False
for t in types:
res = res or t.accept(self)
if res:
break
return res
else:
# Return True if all components are true.
res = True
for t in types:
res = res and t.accept(self)
if not res:
break
return res
return self.strategy(t.accept(self) for t in types)


def strip_type(typ: Type) -> Type:
Expand Down