Skip to content

Commit 33bf271

Browse files
authored
REF: simplify Index.join dispatch/wrapping (#40793)
1 parent 094f630 commit 33bf271

File tree

2 files changed

+74
-86
lines changed

2 files changed

+74
-86
lines changed

pandas/core/indexes/base.py

Lines changed: 73 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from copy import copy as copy_func
44
from datetime import datetime
5+
import functools
56
from itertools import zip_longest
67
import operator
78
from typing import (
@@ -39,6 +40,7 @@
3940
ArrayLike,
4041
Dtype,
4142
DtypeObj,
43+
F,
4244
Shape,
4345
T,
4446
final,
@@ -186,6 +188,33 @@
186188
_o_dtype = np.dtype("object")
187189

188190

191+
def _maybe_return_indexers(meth: F) -> F:
192+
"""
193+
Decorator to simplify 'return_indexers' checks in Index.join.
194+
"""
195+
196+
@functools.wraps(meth)
197+
def join(
198+
self,
199+
other,
200+
how: str_t = "left",
201+
level=None,
202+
return_indexers: bool = False,
203+
sort: bool = False,
204+
):
205+
join_index, lidx, ridx = meth(self, other, how=how, level=level, sort=sort)
206+
if not return_indexers:
207+
return join_index
208+
209+
if lidx is not None:
210+
lidx = ensure_platform_int(lidx)
211+
if ridx is not None:
212+
ridx = ensure_platform_int(ridx)
213+
return join_index, lidx, ridx
214+
215+
return cast(F, join)
216+
217+
189218
def disallow_kwargs(kwargs: dict[str, Any]):
190219
if kwargs:
191220
raise TypeError(f"Unexpected keyword arguments {repr(set(kwargs))}")
@@ -3761,9 +3790,7 @@ def reindex(self, target, method=None, level=None, limit=None, tolerance=None):
37613790
if level is not None:
37623791
if method is not None:
37633792
raise TypeError("Fill method not supported if level passed")
3764-
_, indexer, _ = self._join_level(
3765-
target, level, how="right", return_indexers=True
3766-
)
3793+
_, indexer, _ = self._join_level(target, level, how="right")
37673794
else:
37683795
if self.equals(target):
37693796
indexer = None
@@ -3859,6 +3886,7 @@ def _reindex_non_unique(self, target):
38593886
# --------------------------------------------------------------------
38603887
# Join Methods
38613888

3889+
@_maybe_return_indexers
38623890
def join(
38633891
self,
38643892
other,
@@ -3900,60 +3928,44 @@ def join(
39003928
if self.names == other.names:
39013929
pass
39023930
else:
3903-
return self._join_multi(other, how=how, return_indexers=return_indexers)
3931+
return self._join_multi(other, how=how)
39043932

39053933
# join on the level
39063934
if level is not None and (self_is_mi or other_is_mi):
3907-
return self._join_level(
3908-
other, level, how=how, return_indexers=return_indexers
3909-
)
3935+
return self._join_level(other, level, how=how)
39103936

39113937
if len(other) == 0 and how in ("left", "outer"):
39123938
join_index = self._view()
3913-
if return_indexers:
3914-
rindexer = np.repeat(np.intp(-1), len(join_index))
3915-
return join_index, None, rindexer
3916-
else:
3917-
return join_index
3939+
rindexer = np.repeat(np.intp(-1), len(join_index))
3940+
return join_index, None, rindexer
39183941

39193942
if len(self) == 0 and how in ("right", "outer"):
39203943
join_index = other._view()
3921-
if return_indexers:
3922-
lindexer = np.repeat(np.intp(-1), len(join_index))
3923-
return join_index, lindexer, None
3924-
else:
3925-
return join_index
3944+
lindexer = np.repeat(np.intp(-1), len(join_index))
3945+
return join_index, lindexer, None
39263946

39273947
if self._join_precedence < other._join_precedence:
39283948
how = {"right": "left", "left": "right"}.get(how, how)
3929-
result = other.join(
3930-
self, how=how, level=level, return_indexers=return_indexers
3949+
join_index, lidx, ridx = other.join(
3950+
self, how=how, level=level, return_indexers=True
39313951
)
3932-
if return_indexers:
3933-
x, y, z = result
3934-
result = x, z, y
3935-
return result
3952+
lidx, ridx = ridx, lidx
3953+
return join_index, lidx, ridx
39363954

39373955
if not is_dtype_equal(self.dtype, other.dtype):
39383956
this = self.astype("O")
39393957
other = other.astype("O")
3940-
return this.join(other, how=how, return_indexers=return_indexers)
3958+
return this.join(other, how=how, return_indexers=True)
39413959

39423960
_validate_join_method(how)
39433961

39443962
if not self.is_unique and not other.is_unique:
3945-
return self._join_non_unique(
3946-
other, how=how, return_indexers=return_indexers
3947-
)
3963+
return self._join_non_unique(other, how=how)
39483964
elif not self.is_unique or not other.is_unique:
39493965
if self.is_monotonic and other.is_monotonic:
3950-
return self._join_monotonic(
3951-
other, how=how, return_indexers=return_indexers
3952-
)
3966+
return self._join_monotonic(other, how=how)
39533967
else:
3954-
return self._join_non_unique(
3955-
other, how=how, return_indexers=return_indexers
3956-
)
3968+
return self._join_non_unique(other, how=how)
39573969
elif (
39583970
self.is_monotonic
39593971
and other.is_monotonic
@@ -3965,9 +3977,7 @@ def join(
39653977
# Categorical is monotonic if data are ordered as categories, but join can
39663978
# not handle this in case of not lexicographically monotonic GH#38502
39673979
try:
3968-
return self._join_monotonic(
3969-
other, how=how, return_indexers=return_indexers
3970-
)
3980+
return self._join_monotonic(other, how=how)
39713981
except TypeError:
39723982
pass
39733983

@@ -3987,21 +3997,18 @@ def join(
39873997
if sort:
39883998
join_index = join_index.sort_values()
39893999

3990-
if return_indexers:
3991-
if join_index is self:
3992-
lindexer = None
3993-
else:
3994-
lindexer = self.get_indexer(join_index)
3995-
if join_index is other:
3996-
rindexer = None
3997-
else:
3998-
rindexer = other.get_indexer(join_index)
3999-
return join_index, lindexer, rindexer
4000+
if join_index is self:
4001+
lindexer = None
40004002
else:
4001-
return join_index
4003+
lindexer = self.get_indexer(join_index)
4004+
if join_index is other:
4005+
rindexer = None
4006+
else:
4007+
rindexer = other.get_indexer(join_index)
4008+
return join_index, lindexer, rindexer
40024009

40034010
@final
4004-
def _join_multi(self, other, how, return_indexers=True):
4011+
def _join_multi(self, other, how):
40054012
from pandas.core.indexes.multi import MultiIndex
40064013
from pandas.core.reshape.merge import restore_dropped_levels_multijoin
40074014

@@ -4054,10 +4061,7 @@ def _join_multi(self, other, how, return_indexers=True):
40544061

40554062
multi_join_idx = multi_join_idx.remove_unused_levels()
40564063

4057-
if return_indexers:
4058-
return multi_join_idx, lidx, ridx
4059-
else:
4060-
return multi_join_idx
4064+
return multi_join_idx, lidx, ridx
40614065

40624066
jl = list(overlap)[0]
40634067

@@ -4071,16 +4075,14 @@ def _join_multi(self, other, how, return_indexers=True):
40714075
how = {"right": "left", "left": "right"}.get(how, how)
40724076

40734077
level = other.names.index(jl)
4074-
result = self._join_level(
4075-
other, level, how=how, return_indexers=return_indexers
4076-
)
4078+
result = self._join_level(other, level, how=how)
40774079

4078-
if flip_order and isinstance(result, tuple):
4080+
if flip_order:
40794081
return result[0], result[2], result[1]
40804082
return result
40814083

40824084
@final
4083-
def _join_non_unique(self, other, how="left", return_indexers=False):
4085+
def _join_non_unique(self, other, how="left"):
40844086
from pandas.core.reshape.merge import get_join_indexers
40854087

40864088
# We only get here if dtypes match
@@ -4102,15 +4104,10 @@ def _join_non_unique(self, other, how="left", return_indexers=False):
41024104

41034105
join_index = self._wrap_joined_index(join_array, other)
41044106

4105-
if return_indexers:
4106-
return join_index, left_idx, right_idx
4107-
else:
4108-
return join_index
4107+
return join_index, left_idx, right_idx
41094108

41104109
@final
4111-
def _join_level(
4112-
self, other, level, how="left", return_indexers=False, keep_order=True
4113-
):
4110+
def _join_level(self, other, level, how="left", keep_order=True):
41144111
"""
41154112
The join method *only* affects the level of the resulting
41164113
MultiIndex. Otherwise it just exactly aligns the Index data to the
@@ -4249,28 +4246,22 @@ def _get_leaf_sorter(labels: list[np.ndarray]) -> np.ndarray:
42494246
if flip_order:
42504247
left_indexer, right_indexer = right_indexer, left_indexer
42514248

4252-
if return_indexers:
4253-
left_indexer = (
4254-
None if left_indexer is None else ensure_platform_int(left_indexer)
4255-
)
4256-
right_indexer = (
4257-
None if right_indexer is None else ensure_platform_int(right_indexer)
4258-
)
4259-
return join_index, left_indexer, right_indexer
4260-
else:
4261-
return join_index
4249+
left_indexer = (
4250+
None if left_indexer is None else ensure_platform_int(left_indexer)
4251+
)
4252+
right_indexer = (
4253+
None if right_indexer is None else ensure_platform_int(right_indexer)
4254+
)
4255+
return join_index, left_indexer, right_indexer
42624256

42634257
@final
4264-
def _join_monotonic(self, other, how="left", return_indexers=False):
4258+
def _join_monotonic(self, other: Index, how="left"):
42654259
# We only get here with matching dtypes
42664260
assert other.dtype == self.dtype
42674261

42684262
if self.equals(other):
42694263
ret_index = other if how == "right" else self
4270-
if return_indexers:
4271-
return ret_index, None, None
4272-
else:
4273-
return ret_index
4264+
return ret_index, None, None
42744265

42754266
sv = self._get_engine_target()
42764267
ov = other._get_engine_target()
@@ -4306,12 +4297,9 @@ def _join_monotonic(self, other, how="left", return_indexers=False):
43064297

43074298
join_index = self._wrap_joined_index(join_array, other)
43084299

4309-
if return_indexers:
4310-
lidx = None if lidx is None else ensure_platform_int(lidx)
4311-
ridx = None if ridx is None else ensure_platform_int(ridx)
4312-
return join_index, lidx, ridx
4313-
else:
4314-
return join_index
4300+
lidx = None if lidx is None else ensure_platform_int(lidx)
4301+
ridx = None if ridx is None else ensure_platform_int(ridx)
4302+
return join_index, lidx, ridx
43154303

43164304
def _wrap_joined_index(
43174305
self: _IndexT, joined: np.ndarray, other: _IndexT

pandas/core/indexes/multi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2532,7 +2532,7 @@ def reindex(self, target, method=None, level=None, limit=None, tolerance=None):
25322532
else:
25332533
target = ensure_index(target)
25342534
target, indexer, _ = self._join_level(
2535-
target, level, how="right", return_indexers=True, keep_order=False
2535+
target, level, how="right", keep_order=False
25362536
)
25372537
else:
25382538
target = ensure_index(target)

0 commit comments

Comments
 (0)