Skip to content

Commit e58d1ba

Browse files
authored
TYP: core.reshape (#52531)
1 parent d88f8bf commit e58d1ba

File tree

4 files changed

+51
-28
lines changed

4 files changed

+51
-28
lines changed

pandas/core/reshape/merge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1004,7 +1004,7 @@ def _maybe_add_join_keys(
10041004
result_dtype = find_common_type([lvals.dtype, rvals.dtype])
10051005

10061006
if result._is_label_reference(name):
1007-
result[name] = Series(
1007+
result[name] = result._constructor_sliced(
10081008
key_col, dtype=result_dtype, index=result.index
10091009
)
10101010
elif result._is_level_reference(name):

pandas/core/reshape/pivot.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def _add_margins(
263263
rows,
264264
cols,
265265
aggfunc,
266-
observed=None,
266+
observed: bool,
267267
margins_name: Hashable = "All",
268268
fill_value=None,
269269
):
@@ -292,7 +292,7 @@ def _add_margins(
292292
if not values and isinstance(table, ABCSeries):
293293
# If there are no values and the table is a series, then there is only
294294
# one column in the data. Compute grand margin and return it.
295-
return table._append(Series({key: grand_margin[margins_name]}))
295+
return table._append(table._constructor({key: grand_margin[margins_name]}))
296296

297297
elif values:
298298
marginal_result_set = _generate_marginal_results(
@@ -364,8 +364,16 @@ def _compute_grand_margin(
364364

365365

366366
def _generate_marginal_results(
367-
table, data, values, rows, cols, aggfunc, observed, margins_name: Hashable = "All"
367+
table,
368+
data: DataFrame,
369+
values,
370+
rows,
371+
cols,
372+
aggfunc,
373+
observed: bool,
374+
margins_name: Hashable = "All",
368375
):
376+
margin_keys: list | Index
369377
if len(cols) > 0:
370378
# need to "interleave" the margins
371379
table_pieces = []
@@ -433,23 +441,24 @@ def _all_key(key):
433441
new_order = [len(cols)] + list(range(len(cols)))
434442
row_margin.index = row_margin.index.reorder_levels(new_order)
435443
else:
436-
row_margin = Series(np.nan, index=result.columns)
444+
row_margin = data._constructor_sliced(np.nan, index=result.columns)
437445

438446
return result, margin_keys, row_margin
439447

440448

441449
def _generate_marginal_results_without_values(
442450
table: DataFrame,
443-
data,
451+
data: DataFrame,
444452
rows,
445453
cols,
446454
aggfunc,
447-
observed,
455+
observed: bool,
448456
margins_name: Hashable = "All",
449457
):
458+
margin_keys: list | Index
450459
if len(cols) > 0:
451460
# need to "interleave" the margins
452-
margin_keys: list | Index = []
461+
margin_keys = []
453462

454463
def _all_key():
455464
if len(cols) == 1:
@@ -535,7 +544,9 @@ def pivot(
535544
data.index.get_level_values(i) for i in range(data.index.nlevels)
536545
]
537546
else:
538-
index_list = [Series(data.index, name=data.index.name)]
547+
index_list = [
548+
data._constructor_sliced(data.index, name=data.index.name)
549+
]
539550
else:
540551
index_list = [data[idx] for idx in com.convert_to_list_like(index)]
541552

pandas/core/reshape/reshape.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@
4646
)
4747

4848
if TYPE_CHECKING:
49-
from pandas._typing import npt
49+
from pandas._typing import (
50+
Level,
51+
npt,
52+
)
5053

5154
from pandas.core.arrays import ExtensionArray
5255
from pandas.core.indexes.frozen import FrozenList
@@ -98,9 +101,7 @@ class _Unstacker:
98101
unstacked : DataFrame
99102
"""
100103

101-
def __init__(self, index: MultiIndex, level=-1, constructor=None) -> None:
102-
if constructor is None:
103-
constructor = DataFrame
104+
def __init__(self, index: MultiIndex, level: Level, constructor) -> None:
104105
self.constructor = constructor
105106

106107
self.index = index.remove_unused_levels()
@@ -374,13 +375,14 @@ def new_index(self) -> MultiIndex:
374375
)
375376

376377

377-
def _unstack_multiple(data, clocs, fill_value=None):
378+
def _unstack_multiple(data: Series | DataFrame, clocs, fill_value=None):
378379
if len(clocs) == 0:
379380
return data
380381

381382
# NOTE: This doesn't deal with hierarchical columns yet
382383

383384
index = data.index
385+
index = cast(MultiIndex, index) # caller is responsible for checking
384386

385387
# GH 19966 Make sure if MultiIndexed index has tuple name, they will be
386388
# recognised as a whole
@@ -433,10 +435,10 @@ def _unstack_multiple(data, clocs, fill_value=None):
433435
return result
434436

435437
# GH#42579 deep=False to avoid consolidating
436-
dummy = data.copy(deep=False)
437-
dummy.index = dummy_index
438+
dummy_df = data.copy(deep=False)
439+
dummy_df.index = dummy_index
438440

439-
unstacked = dummy.unstack("__placeholder__", fill_value=fill_value)
441+
unstacked = dummy_df.unstack("__placeholder__", fill_value=fill_value)
440442
if isinstance(unstacked, Series):
441443
unstcols = unstacked.index
442444
else:
@@ -497,7 +499,7 @@ def unstack(obj: Series | DataFrame, level, fill_value=None):
497499
)
498500

499501

500-
def _unstack_frame(obj: DataFrame, level, fill_value=None):
502+
def _unstack_frame(obj: DataFrame, level, fill_value=None) -> DataFrame:
501503
assert isinstance(obj.index, MultiIndex) # checked by caller
502504
unstacker = _Unstacker(obj.index, level=level, constructor=obj._constructor)
503505

@@ -617,7 +619,7 @@ def factorize(index):
617619
return frame._constructor_sliced(new_values, index=new_index)
618620

619621

620-
def stack_multiple(frame, level, dropna: bool = True):
622+
def stack_multiple(frame: DataFrame, level, dropna: bool = True):
621623
# If all passed levels match up to column names, no
622624
# ambiguity about what to do
623625
if all(lev in frame.columns.names for lev in level):

pandas/core/reshape/tile.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232
is_scalar,
3333
is_timedelta64_dtype,
3434
)
35-
from pandas.core.dtypes.dtypes import ExtensionDtype
35+
from pandas.core.dtypes.dtypes import (
36+
DatetimeTZDtype,
37+
ExtensionDtype,
38+
)
3639
from pandas.core.dtypes.generic import ABCSeries
3740
from pandas.core.dtypes.missing import isna
3841

@@ -47,7 +50,10 @@
4750
import pandas.core.algorithms as algos
4851

4952
if TYPE_CHECKING:
50-
from pandas._typing import IntervalLeftRight
53+
from pandas._typing import (
54+
DtypeObj,
55+
IntervalLeftRight,
56+
)
5157

5258

5359
def cut(
@@ -399,7 +405,7 @@ def _bins_to_cuts(
399405
labels=None,
400406
precision: int = 3,
401407
include_lowest: bool = False,
402-
dtype=None,
408+
dtype: DtypeObj | None = None,
403409
duplicates: str = "raise",
404410
ordered: bool = True,
405411
):
@@ -481,7 +487,7 @@ def _coerce_to_type(x):
481487
this method converts it to numeric so that cut or qcut method can
482488
handle it
483489
"""
484-
dtype = None
490+
dtype: DtypeObj | None = None
485491

486492
if is_datetime64tz_dtype(x.dtype):
487493
dtype = x.dtype
@@ -508,7 +514,7 @@ def _coerce_to_type(x):
508514
return x, dtype
509515

510516

511-
def _convert_bin_to_numeric_type(bins, dtype):
517+
def _convert_bin_to_numeric_type(bins, dtype: DtypeObj | None):
512518
"""
513519
if the passed bin is of datetime/timedelta type,
514520
this method converts it to integer
@@ -542,7 +548,7 @@ def _convert_bin_to_numeric_type(bins, dtype):
542548
return bins
543549

544550

545-
def _convert_bin_to_datelike_type(bins, dtype):
551+
def _convert_bin_to_datelike_type(bins, dtype: DtypeObj | None):
546552
"""
547553
Convert bins to a DatetimeIndex or TimedeltaIndex if the original dtype is
548554
datelike
@@ -557,22 +563,26 @@ def _convert_bin_to_datelike_type(bins, dtype):
557563
bins : Array-like of bins, DatetimeIndex or TimedeltaIndex if dtype is
558564
datelike
559565
"""
560-
if is_datetime64tz_dtype(dtype):
566+
if isinstance(dtype, DatetimeTZDtype):
561567
bins = to_datetime(bins.astype(np.int64), utc=True).tz_convert(dtype.tz)
562568
elif is_datetime_or_timedelta_dtype(dtype):
563569
bins = Index(bins.astype(np.int64), dtype=dtype)
564570
return bins
565571

566572

567573
def _format_labels(
568-
bins, precision: int, right: bool = True, include_lowest: bool = False, dtype=None
574+
bins,
575+
precision: int,
576+
right: bool = True,
577+
include_lowest: bool = False,
578+
dtype: DtypeObj | None = None,
569579
):
570580
"""based on the dtype, return our labels"""
571581
closed: IntervalLeftRight = "right" if right else "left"
572582

573583
formatter: Callable[[Any], Timestamp] | Callable[[Any], Timedelta]
574584

575-
if is_datetime64tz_dtype(dtype):
585+
if isinstance(dtype, DatetimeTZDtype):
576586
formatter = lambda x: Timestamp(x, tz=dtype.tz)
577587
adjust = lambda x: x - Timedelta("1ns")
578588
elif is_datetime64_dtype(dtype):

0 commit comments

Comments
 (0)