2
2
3
3
from copy import copy as copy_func
4
4
from datetime import datetime
5
+ import functools
5
6
from itertools import zip_longest
6
7
import operator
7
8
from typing import (
39
40
ArrayLike ,
40
41
Dtype ,
41
42
DtypeObj ,
43
+ F ,
42
44
Shape ,
43
45
T ,
44
46
final ,
186
188
_o_dtype = np .dtype ("object" )
187
189
188
190
191
+ def _maybe_return_indexers (meth : F ) -> F :
192
+ """
193
+ Decorator to simplify 'return_indexers' checks in Index.join.
194
+ """
195
+
196
+ @functools .wraps (meth )
197
+ def join (
198
+ self ,
199
+ other ,
200
+ how : str_t = "left" ,
201
+ level = None ,
202
+ return_indexers : bool = False ,
203
+ sort : bool = False ,
204
+ ):
205
+ join_index , lidx , ridx = meth (self , other , how = how , level = level , sort = sort )
206
+ if not return_indexers :
207
+ return join_index
208
+
209
+ if lidx is not None :
210
+ lidx = ensure_platform_int (lidx )
211
+ if ridx is not None :
212
+ ridx = ensure_platform_int (ridx )
213
+ return join_index , lidx , ridx
214
+
215
+ return cast (F , join )
216
+
217
+
189
218
def disallow_kwargs (kwargs : dict [str , Any ]):
190
219
if kwargs :
191
220
raise TypeError (f"Unexpected keyword arguments { repr (set (kwargs ))} " )
@@ -3761,9 +3790,7 @@ def reindex(self, target, method=None, level=None, limit=None, tolerance=None):
3761
3790
if level is not None :
3762
3791
if method is not None :
3763
3792
raise TypeError ("Fill method not supported if level passed" )
3764
- _ , indexer , _ = self ._join_level (
3765
- target , level , how = "right" , return_indexers = True
3766
- )
3793
+ _ , indexer , _ = self ._join_level (target , level , how = "right" )
3767
3794
else :
3768
3795
if self .equals (target ):
3769
3796
indexer = None
@@ -3859,6 +3886,7 @@ def _reindex_non_unique(self, target):
3859
3886
# --------------------------------------------------------------------
3860
3887
# Join Methods
3861
3888
3889
+ @_maybe_return_indexers
3862
3890
def join (
3863
3891
self ,
3864
3892
other ,
@@ -3900,60 +3928,44 @@ def join(
3900
3928
if self .names == other .names :
3901
3929
pass
3902
3930
else :
3903
- return self ._join_multi (other , how = how , return_indexers = return_indexers )
3931
+ return self ._join_multi (other , how = how )
3904
3932
3905
3933
# join on the level
3906
3934
if level is not None and (self_is_mi or other_is_mi ):
3907
- return self ._join_level (
3908
- other , level , how = how , return_indexers = return_indexers
3909
- )
3935
+ return self ._join_level (other , level , how = how )
3910
3936
3911
3937
if len (other ) == 0 and how in ("left" , "outer" ):
3912
3938
join_index = self ._view ()
3913
- if return_indexers :
3914
- rindexer = np .repeat (np .intp (- 1 ), len (join_index ))
3915
- return join_index , None , rindexer
3916
- else :
3917
- return join_index
3939
+ rindexer = np .repeat (np .intp (- 1 ), len (join_index ))
3940
+ return join_index , None , rindexer
3918
3941
3919
3942
if len (self ) == 0 and how in ("right" , "outer" ):
3920
3943
join_index = other ._view ()
3921
- if return_indexers :
3922
- lindexer = np .repeat (np .intp (- 1 ), len (join_index ))
3923
- return join_index , lindexer , None
3924
- else :
3925
- return join_index
3944
+ lindexer = np .repeat (np .intp (- 1 ), len (join_index ))
3945
+ return join_index , lindexer , None
3926
3946
3927
3947
if self ._join_precedence < other ._join_precedence :
3928
3948
how = {"right" : "left" , "left" : "right" }.get (how , how )
3929
- result = other .join (
3930
- self , how = how , level = level , return_indexers = return_indexers
3949
+ join_index , lidx , ridx = other .join (
3950
+ self , how = how , level = level , return_indexers = True
3931
3951
)
3932
- if return_indexers :
3933
- x , y , z = result
3934
- result = x , z , y
3935
- return result
3952
+ lidx , ridx = ridx , lidx
3953
+ return join_index , lidx , ridx
3936
3954
3937
3955
if not is_dtype_equal (self .dtype , other .dtype ):
3938
3956
this = self .astype ("O" )
3939
3957
other = other .astype ("O" )
3940
- return this .join (other , how = how , return_indexers = return_indexers )
3958
+ return this .join (other , how = how , return_indexers = True )
3941
3959
3942
3960
_validate_join_method (how )
3943
3961
3944
3962
if not self .is_unique and not other .is_unique :
3945
- return self ._join_non_unique (
3946
- other , how = how , return_indexers = return_indexers
3947
- )
3963
+ return self ._join_non_unique (other , how = how )
3948
3964
elif not self .is_unique or not other .is_unique :
3949
3965
if self .is_monotonic and other .is_monotonic :
3950
- return self ._join_monotonic (
3951
- other , how = how , return_indexers = return_indexers
3952
- )
3966
+ return self ._join_monotonic (other , how = how )
3953
3967
else :
3954
- return self ._join_non_unique (
3955
- other , how = how , return_indexers = return_indexers
3956
- )
3968
+ return self ._join_non_unique (other , how = how )
3957
3969
elif (
3958
3970
self .is_monotonic
3959
3971
and other .is_monotonic
@@ -3965,9 +3977,7 @@ def join(
3965
3977
# Categorical is monotonic if data are ordered as categories, but join can
3966
3978
# not handle this in case of not lexicographically monotonic GH#38502
3967
3979
try :
3968
- return self ._join_monotonic (
3969
- other , how = how , return_indexers = return_indexers
3970
- )
3980
+ return self ._join_monotonic (other , how = how )
3971
3981
except TypeError :
3972
3982
pass
3973
3983
@@ -3987,21 +3997,18 @@ def join(
3987
3997
if sort :
3988
3998
join_index = join_index .sort_values ()
3989
3999
3990
- if return_indexers :
3991
- if join_index is self :
3992
- lindexer = None
3993
- else :
3994
- lindexer = self .get_indexer (join_index )
3995
- if join_index is other :
3996
- rindexer = None
3997
- else :
3998
- rindexer = other .get_indexer (join_index )
3999
- return join_index , lindexer , rindexer
4000
+ if join_index is self :
4001
+ lindexer = None
4000
4002
else :
4001
- return join_index
4003
+ lindexer = self .get_indexer (join_index )
4004
+ if join_index is other :
4005
+ rindexer = None
4006
+ else :
4007
+ rindexer = other .get_indexer (join_index )
4008
+ return join_index , lindexer , rindexer
4002
4009
4003
4010
@final
4004
- def _join_multi (self , other , how , return_indexers = True ):
4011
+ def _join_multi (self , other , how ):
4005
4012
from pandas .core .indexes .multi import MultiIndex
4006
4013
from pandas .core .reshape .merge import restore_dropped_levels_multijoin
4007
4014
@@ -4054,10 +4061,7 @@ def _join_multi(self, other, how, return_indexers=True):
4054
4061
4055
4062
multi_join_idx = multi_join_idx .remove_unused_levels ()
4056
4063
4057
- if return_indexers :
4058
- return multi_join_idx , lidx , ridx
4059
- else :
4060
- return multi_join_idx
4064
+ return multi_join_idx , lidx , ridx
4061
4065
4062
4066
jl = list (overlap )[0 ]
4063
4067
@@ -4071,16 +4075,14 @@ def _join_multi(self, other, how, return_indexers=True):
4071
4075
how = {"right" : "left" , "left" : "right" }.get (how , how )
4072
4076
4073
4077
level = other .names .index (jl )
4074
- result = self ._join_level (
4075
- other , level , how = how , return_indexers = return_indexers
4076
- )
4078
+ result = self ._join_level (other , level , how = how )
4077
4079
4078
- if flip_order and isinstance ( result , tuple ) :
4080
+ if flip_order :
4079
4081
return result [0 ], result [2 ], result [1 ]
4080
4082
return result
4081
4083
4082
4084
@final
4083
- def _join_non_unique (self , other , how = "left" , return_indexers = False ):
4085
+ def _join_non_unique (self , other , how = "left" ):
4084
4086
from pandas .core .reshape .merge import get_join_indexers
4085
4087
4086
4088
# We only get here if dtypes match
@@ -4102,15 +4104,10 @@ def _join_non_unique(self, other, how="left", return_indexers=False):
4102
4104
4103
4105
join_index = self ._wrap_joined_index (join_array , other )
4104
4106
4105
- if return_indexers :
4106
- return join_index , left_idx , right_idx
4107
- else :
4108
- return join_index
4107
+ return join_index , left_idx , right_idx
4109
4108
4110
4109
@final
4111
- def _join_level (
4112
- self , other , level , how = "left" , return_indexers = False , keep_order = True
4113
- ):
4110
+ def _join_level (self , other , level , how = "left" , keep_order = True ):
4114
4111
"""
4115
4112
The join method *only* affects the level of the resulting
4116
4113
MultiIndex. Otherwise it just exactly aligns the Index data to the
@@ -4249,28 +4246,22 @@ def _get_leaf_sorter(labels: list[np.ndarray]) -> np.ndarray:
4249
4246
if flip_order :
4250
4247
left_indexer , right_indexer = right_indexer , left_indexer
4251
4248
4252
- if return_indexers :
4253
- left_indexer = (
4254
- None if left_indexer is None else ensure_platform_int (left_indexer )
4255
- )
4256
- right_indexer = (
4257
- None if right_indexer is None else ensure_platform_int (right_indexer )
4258
- )
4259
- return join_index , left_indexer , right_indexer
4260
- else :
4261
- return join_index
4249
+ left_indexer = (
4250
+ None if left_indexer is None else ensure_platform_int (left_indexer )
4251
+ )
4252
+ right_indexer = (
4253
+ None if right_indexer is None else ensure_platform_int (right_indexer )
4254
+ )
4255
+ return join_index , left_indexer , right_indexer
4262
4256
4263
4257
@final
4264
- def _join_monotonic (self , other , how = "left" , return_indexers = False ):
4258
+ def _join_monotonic (self , other : Index , how = "left" ):
4265
4259
# We only get here with matching dtypes
4266
4260
assert other .dtype == self .dtype
4267
4261
4268
4262
if self .equals (other ):
4269
4263
ret_index = other if how == "right" else self
4270
- if return_indexers :
4271
- return ret_index , None , None
4272
- else :
4273
- return ret_index
4264
+ return ret_index , None , None
4274
4265
4275
4266
sv = self ._get_engine_target ()
4276
4267
ov = other ._get_engine_target ()
@@ -4306,12 +4297,9 @@ def _join_monotonic(self, other, how="left", return_indexers=False):
4306
4297
4307
4298
join_index = self ._wrap_joined_index (join_array , other )
4308
4299
4309
- if return_indexers :
4310
- lidx = None if lidx is None else ensure_platform_int (lidx )
4311
- ridx = None if ridx is None else ensure_platform_int (ridx )
4312
- return join_index , lidx , ridx
4313
- else :
4314
- return join_index
4300
+ lidx = None if lidx is None else ensure_platform_int (lidx )
4301
+ ridx = None if ridx is None else ensure_platform_int (ridx )
4302
+ return join_index , lidx , ridx
4315
4303
4316
4304
def _wrap_joined_index (
4317
4305
self : _IndexT , joined : np .ndarray , other : _IndexT
0 commit comments