Skip to content

Commit b7a2bb0

Browse files
authored
Merge pull request #184 from pandas-dev/master
Sync Fork from Upstream Repo
2 parents a4e1651 + 5525561 commit b7a2bb0

File tree

25 files changed

+202
-243
lines changed

25 files changed

+202
-243
lines changed

pandas/core/algorithms.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1876,29 +1876,33 @@ def _sort_tuples(values: np.ndarray) -> np.ndarray:
18761876
return values[indexer]
18771877

18781878

1879-
def union_with_duplicates(lvals: np.ndarray, rvals: np.ndarray) -> np.ndarray:
1879+
def union_with_duplicates(lvals: ArrayLike, rvals: ArrayLike) -> ArrayLike:
18801880
"""
18811881
Extracts the union from lvals and rvals with respect to duplicates and nans in
18821882
both arrays.
18831883
18841884
Parameters
18851885
----------
1886-
lvals: np.ndarray
1886+
lvals: np.ndarray or ExtensionArray
18871887
left values which is ordered in front.
1888-
rvals: np.ndarray
1888+
rvals: np.ndarray or ExtensionArray
18891889
right values ordered after lvals.
18901890
18911891
Returns
18921892
-------
1893-
np.ndarray containing the unsorted union of both arrays
1893+
np.ndarray or ExtensionArray
1894+
Containing the unsorted union of both arrays.
18941895
"""
18951896
indexer = []
18961897
l_count = value_counts(lvals, dropna=False)
18971898
r_count = value_counts(rvals, dropna=False)
18981899
l_count, r_count = l_count.align(r_count, fill_value=0)
18991900
unique_array = unique(np.append(lvals, rvals))
1900-
if is_extension_array_dtype(lvals) or is_extension_array_dtype(rvals):
1901-
unique_array = pd_array(unique_array)
1901+
if not isinstance(lvals, np.ndarray):
1902+
# i.e. ExtensionArray
1903+
# Note: we only get here with lvals.dtype == rvals.dtype
1904+
# TODO: are there any cases where union won't be type/dtype preserving?
1905+
unique_array = type(lvals)._from_sequence(unique_array, dtype=lvals.dtype)
19021906
for i, value in enumerate(unique_array):
19031907
indexer += [i] * int(max(l_count[value], r_count[value]))
19041908
return unique_array.take(indexer)

pandas/core/arrays/string_arrow.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -675,13 +675,18 @@ def value_counts(self, dropna: bool = True) -> Series:
675675

676676
vc = self._data.value_counts()
677677

678-
# Index cannot hold ExtensionArrays yet
679-
index = Index(type(self)(vc.field(0)).astype(object))
678+
values = vc.field(0)
679+
counts = vc.field(1)
680+
if dropna and self._data.null_count > 0:
681+
mask = values.is_valid()
682+
values = values.filter(mask)
683+
counts = counts.filter(mask)
684+
680685
# No missing values so we can adhere to the interface and return a numpy array.
681-
counts = np.array(vc.field(1))
686+
counts = np.array(counts)
682687

683-
if dropna and self._data.null_count > 0:
684-
raise NotImplementedError("yo")
688+
# Index cannot hold ExtensionArrays yet
689+
index = Index(type(self)(values)).astype(object)
685690

686691
return Series(counts, index=index).astype("Int64")
687692

pandas/core/computation/engines.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
align_terms,
1313
reconstruct_object,
1414
)
15+
from pandas.core.computation.expr import Expr
1516
from pandas.core.computation.ops import (
1617
MATHOPS,
1718
REDUCTIONS,
@@ -26,13 +27,13 @@ class NumExprClobberingError(NameError):
2627
pass
2728

2829

29-
def _check_ne_builtin_clash(expr):
30+
def _check_ne_builtin_clash(expr: Expr) -> None:
3031
"""
3132
Attempt to prevent foot-shooting in a helpful way.
3233
3334
Parameters
3435
----------
35-
terms : Term
36+
expr : Expr
3637
Terms can contain
3738
"""
3839
names = expr.names

pandas/core/computation/eval.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
Top level ``eval`` module.
33
"""
4+
from __future__ import annotations
45

56
import tokenize
6-
from typing import Optional
77
import warnings
88

99
from pandas._libs.lib import no_default
@@ -14,13 +14,14 @@
1414
PARSERS,
1515
Expr,
1616
)
17+
from pandas.core.computation.ops import BinOp
1718
from pandas.core.computation.parsing import tokenize_string
1819
from pandas.core.computation.scope import ensure_scope
1920

2021
from pandas.io.formats.printing import pprint_thing
2122

2223

23-
def _check_engine(engine: Optional[str]) -> str:
24+
def _check_engine(engine: str | None) -> str:
2425
"""
2526
Make sure a valid engine is passed.
2627
@@ -161,9 +162,9 @@ def _check_for_locals(expr: str, stack_level: int, parser: str):
161162

162163

163164
def eval(
164-
expr,
165-
parser="pandas",
166-
engine: Optional[str] = None,
165+
expr: str | BinOp, # we leave BinOp out of the docstr bc it isn't for users
166+
parser: str = "pandas",
167+
engine: str | None = None,
167168
truediv=no_default,
168169
local_dict=None,
169170
global_dict=None,
@@ -309,10 +310,12 @@ def eval(
309310
stacklevel=2,
310311
)
311312

313+
exprs: list[str | BinOp]
312314
if isinstance(expr, str):
313315
_check_expression(expr)
314316
exprs = [e.strip() for e in expr.splitlines() if e.strip() != ""]
315317
else:
318+
# ops.BinOp; for internal compat, not intended to be passed by users
316319
exprs = [expr]
317320
multi_line = len(exprs) > 1
318321

pandas/core/computation/pytables.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ class PyTablesExpr(expr.Expr):
546546

547547
_visitor: PyTablesExprVisitor | None
548548
env: PyTablesScope
549+
expr: str
549550

550551
def __init__(
551552
self,
@@ -570,7 +571,7 @@ def __init__(
570571
local_dict = where.env.scope
571572
_where = where.expr
572573

573-
elif isinstance(where, (list, tuple)):
574+
elif is_list_like(where):
574575
where = list(where)
575576
for idx, w in enumerate(where):
576577
if isinstance(w, PyTablesExpr):
@@ -580,6 +581,7 @@ def __init__(
580581
where[idx] = w
581582
_where = " & ".join(f"({w})" for w in com.flatten(where))
582583
else:
584+
# _validate_where ensures we otherwise have a string
583585
_where = where
584586

585587
self.expr = _where

pandas/core/computation/scope.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,13 @@ class Scope:
106106
"""
107107

108108
__slots__ = ["level", "scope", "target", "resolvers", "temps"]
109+
level: int
110+
scope: DeepChainMap
111+
resolvers: DeepChainMap
112+
temps: dict
109113

110114
def __init__(
111-
self, level, global_dict=None, local_dict=None, resolvers=(), target=None
115+
self, level: int, global_dict=None, local_dict=None, resolvers=(), target=None
112116
):
113117
self.level = level + 1
114118

@@ -146,8 +150,7 @@ def __init__(
146150

147151
# assumes that resolvers are going from outermost scope to inner
148152
if isinstance(local_dict, Scope):
149-
# error: Cannot determine type of 'resolvers'
150-
resolvers += tuple(local_dict.resolvers.maps) # type: ignore[has-type]
153+
resolvers += tuple(local_dict.resolvers.maps)
151154
self.resolvers = DeepChainMap(*resolvers)
152155
self.temps = {}
153156

@@ -212,7 +215,7 @@ def resolve(self, key: str, is_local: bool):
212215

213216
raise UndefinedVariableError(key, is_local) from err
214217

215-
def swapkey(self, old_key: str, new_key: str, new_value=None):
218+
def swapkey(self, old_key: str, new_key: str, new_value=None) -> None:
216219
"""
217220
Replace a variable name, with a potentially new value.
218221
@@ -238,7 +241,7 @@ def swapkey(self, old_key: str, new_key: str, new_value=None):
238241
mapping[new_key] = new_value # type: ignore[index]
239242
return
240243

241-
def _get_vars(self, stack, scopes: list[str]):
244+
def _get_vars(self, stack, scopes: list[str]) -> None:
242245
"""
243246
Get specifically scoped variables from a list of stack frames.
244247
@@ -263,7 +266,7 @@ def _get_vars(self, stack, scopes: list[str]):
263266
# scope after the loop
264267
del frame
265268

266-
def _update(self, level: int):
269+
def _update(self, level: int) -> None:
267270
"""
268271
Update the current scope by going back `level` levels.
269272
@@ -313,7 +316,7 @@ def ntemps(self) -> int:
313316
return len(self.temps)
314317

315318
@property
316-
def full_scope(self):
319+
def full_scope(self) -> DeepChainMap:
317320
"""
318321
Return the full scope for use with passing to engines transparently
319322
as a mapping.

pandas/core/dtypes/cast.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
ensure_str,
5656
is_bool,
5757
is_bool_dtype,
58-
is_categorical_dtype,
5958
is_complex,
6059
is_complex_dtype,
6160
is_datetime64_dtype,
@@ -79,6 +78,7 @@
7978
pandas_dtype,
8079
)
8180
from pandas.core.dtypes.dtypes import (
81+
CategoricalDtype,
8282
DatetimeTZDtype,
8383
ExtensionDtype,
8484
IntervalDtype,
@@ -359,15 +359,15 @@ def trans(x):
359359
return result
360360

361361

362-
def maybe_cast_result(
362+
def maybe_cast_pointwise_result(
363363
result: ArrayLike,
364364
dtype: DtypeObj,
365365
numeric_only: bool = False,
366-
how: str = "",
367366
same_dtype: bool = True,
368367
) -> ArrayLike:
369368
"""
370-
Try casting result to a different type if appropriate
369+
Try casting result of a pointwise operation back to the original dtype if
370+
appropriate.
371371
372372
Parameters
373373
----------
@@ -377,8 +377,6 @@ def maybe_cast_result(
377377
Input Series from which result was calculated.
378378
numeric_only : bool, default False
379379
Whether to cast only numerics or datetimes as well.
380-
how : str, default ""
381-
How the result was computed.
382380
same_dtype : bool, default True
383381
Specify dtype when calling _from_sequence
384382
@@ -387,12 +385,12 @@ def maybe_cast_result(
387385
result : array-like
388386
result maybe casted to the dtype.
389387
"""
390-
dtype = maybe_cast_result_dtype(dtype, how)
391388

392389
assert not is_scalar(result)
393390

394391
if isinstance(dtype, ExtensionDtype):
395-
if not is_categorical_dtype(dtype) and dtype.kind != "M":
392+
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)):
393+
# TODO: avoid this special-casing
396394
# We have to special case categorical so as not to upcast
397395
# things like counts back to categorical
398396

pandas/core/dtypes/common.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -128,51 +128,6 @@ def ensure_str(value: Union[bytes, Any]) -> str:
128128
return value
129129

130130

131-
def ensure_int_or_float(arr: ArrayLike, copy: bool = False) -> np.ndarray:
132-
"""
133-
Ensure that an dtype array of some integer dtype
134-
has an int64 dtype if possible.
135-
If it's not possible, potentially because of overflow,
136-
convert the array to float64 instead.
137-
138-
Parameters
139-
----------
140-
arr : array-like
141-
The array whose data type we want to enforce.
142-
copy: bool
143-
Whether to copy the original array or reuse
144-
it in place, if possible.
145-
146-
Returns
147-
-------
148-
out_arr : The input array cast as int64 if
149-
possible without overflow.
150-
Otherwise the input array cast to float64.
151-
152-
Notes
153-
-----
154-
If the array is explicitly of type uint64 the type
155-
will remain unchanged.
156-
"""
157-
# TODO: GH27506 potential bug with ExtensionArrays
158-
try:
159-
# error: Unexpected keyword argument "casting" for "astype"
160-
return arr.astype("int64", copy=copy, casting="safe") # type: ignore[call-arg]
161-
except TypeError:
162-
pass
163-
try:
164-
# error: Unexpected keyword argument "casting" for "astype"
165-
return arr.astype("uint64", copy=copy, casting="safe") # type: ignore[call-arg]
166-
except TypeError:
167-
if is_extension_array_dtype(arr.dtype):
168-
# pandas/core/dtypes/common.py:168: error: Item "ndarray" of
169-
# "Union[ExtensionArray, ndarray]" has no attribute "to_numpy" [union-attr]
170-
return arr.to_numpy( # type: ignore[union-attr]
171-
dtype="float64", na_value=np.nan
172-
)
173-
return arr.astype("float64", copy=copy)
174-
175-
176131
def ensure_python_int(value: Union[int, np.integer]) -> int:
177132
"""
178133
Ensure that a value is a python int.

pandas/core/groupby/ops.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,13 @@
3636
from pandas.util._decorators import cache_readonly
3737

3838
from pandas.core.dtypes.cast import (
39-
maybe_cast_result,
39+
maybe_cast_pointwise_result,
4040
maybe_cast_result_dtype,
4141
maybe_downcast_to_dtype,
4242
)
4343
from pandas.core.dtypes.common import (
4444
ensure_float64,
4545
ensure_int64,
46-
ensure_int_or_float,
4746
ensure_platform_int,
4847
is_bool_dtype,
4948
is_categorical_dtype,
@@ -582,7 +581,7 @@ def _ea_wrap_cython_operation(
582581

583582
elif is_integer_dtype(values.dtype) or is_bool_dtype(values.dtype):
584583
# IntegerArray or BooleanArray
585-
values = ensure_int_or_float(values)
584+
values = values.to_numpy("float64", na_value=np.nan)
586585
res_values = self._cython_operation(
587586
kind, values, how, axis, min_count, **kwargs
588587
)
@@ -660,9 +659,11 @@ def _cython_operation(
660659
values = values.view("int64")
661660
is_numeric = True
662661
elif is_bool_dtype(dtype):
663-
values = ensure_int_or_float(values)
662+
values = values.astype("int64")
664663
elif is_integer_dtype(dtype):
665-
values = ensure_int_or_float(values)
664+
# e.g. uint8 -> uint64, int16 -> int64
665+
dtype = dtype.kind + "8"
666+
values = values.astype(dtype, copy=False)
666667
elif is_numeric:
667668
if not is_complex_dtype(dtype):
668669
values = ensure_float64(values)
@@ -797,7 +798,7 @@ def _aggregate_series_pure_python(self, obj: Series, func: F):
797798
result[label] = res
798799

799800
out = lib.maybe_convert_objects(result, try_float=False)
800-
out = maybe_cast_result(out, obj.dtype, numeric_only=True)
801+
out = maybe_cast_pointwise_result(out, obj.dtype, numeric_only=True)
801802

802803
return out, counts
803804

0 commit comments

Comments
 (0)