Skip to content

Commit c704d04

Browse files
authored
REF: give rank1d/2d same nan filling (#41916)
1 parent b2b2baa commit c704d04

File tree

1 file changed

+50
-52
lines changed

1 file changed

+50
-52
lines changed

pandas/_libs/algos.pyx

Lines changed: 50 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,32 @@ ctypedef fused rank_t:
931931
int64_t
932932

933933

934+
cdef rank_t get_rank_nan_fill_val(bint rank_nans_highest, rank_t[:] _=None):
935+
"""
936+
Return the value we'll use to represent missing values when sorting depending
937+
on if we'd like missing values to end up at the top/bottom. (The second parameter
938+
is unused, but needed for fused type specialization)
939+
"""
940+
if rank_nans_highest:
941+
if rank_t is object:
942+
return Infinity()
943+
elif rank_t is int64_t:
944+
return util.INT64_MAX
945+
elif rank_t is uint64_t:
946+
return util.UINT64_MAX
947+
else:
948+
return np.inf
949+
else:
950+
if rank_t is object:
951+
return NegInfinity()
952+
elif rank_t is int64_t:
953+
return NPY_NAT
954+
elif rank_t is uint64_t:
955+
return 0
956+
else:
957+
return -np.inf
958+
959+
934960
@cython.wraparound(False)
935961
@cython.boundscheck(False)
936962
def rank_1d(
@@ -980,7 +1006,7 @@ def rank_1d(
9801006
ndarray[rank_t, ndim=1] masked_vals
9811007
rank_t[:] masked_vals_memview
9821008
uint8_t[:] mask
983-
bint keep_na, check_labels, check_mask
1009+
bint keep_na, nans_rank_highest, check_labels, check_mask
9841010
rank_t nan_fill_val
9851011

9861012
tiebreak = tiebreakers[ties_method]
@@ -1026,27 +1052,12 @@ def rank_1d(
10261052
# If descending, fill with highest value since descending
10271053
# will flip the ordering to still end up with lowest rank.
10281054
# Symmetric logic applies to `na_option == 'bottom'`
1029-
if ascending ^ (na_option == 'top'):
1030-
if rank_t is object:
1031-
nan_fill_val = Infinity()
1032-
elif rank_t is int64_t:
1033-
nan_fill_val = util.INT64_MAX
1034-
elif rank_t is uint64_t:
1035-
nan_fill_val = util.UINT64_MAX
1036-
else:
1037-
nan_fill_val = np.inf
1055+
nans_rank_highest = ascending ^ (na_option == 'top')
1056+
nan_fill_val = get_rank_nan_fill_val[rank_t](nans_rank_highest)
1057+
if nans_rank_highest:
10381058
order = (masked_vals, mask, labels)
10391059
else:
1040-
if rank_t is object:
1041-
nan_fill_val = NegInfinity()
1042-
elif rank_t is int64_t:
1043-
nan_fill_val = NPY_NAT
1044-
elif rank_t is uint64_t:
1045-
nan_fill_val = 0
1046-
else:
1047-
nan_fill_val = -np.inf
1048-
1049-
order = (masked_vals, ~(np.array(mask, copy=False)), labels)
1060+
order = (masked_vals, ~(np.asarray(mask)), labels)
10501061

10511062
np.putmask(masked_vals, mask, nan_fill_val)
10521063
# putmask doesn't accept a memoryview, so we assign as a separate step
@@ -1073,14 +1084,11 @@ def rank_1d(
10731084
check_mask,
10741085
check_labels,
10751086
keep_na,
1087+
pct,
10761088
N,
10771089
)
1078-
if pct:
1079-
for i in range(N):
1080-
if grp_sizes[i] != 0:
1081-
out[i] = out[i] / grp_sizes[i]
10821090

1083-
return np.array(out)
1091+
return np.asarray(out)
10841092

10851093

10861094
@cython.wraparound(False)
@@ -1097,6 +1105,7 @@ cdef void rank_sorted_1d(
10971105
bint check_mask,
10981106
bint check_labels,
10991107
bint keep_na,
1108+
bint pct,
11001109
Py_ssize_t N,
11011110
) nogil:
11021111
"""
@@ -1108,7 +1117,7 @@ cdef void rank_sorted_1d(
11081117
out : float64_t[::1]
11091118
Array to store computed ranks
11101119
grp_sizes : int64_t[::1]
1111-
Array to store group counts.
1120+
Array to store group counts, only used if pct=True
11121121
labels : See rank_1d.__doc__
11131122
sort_indexer : intp_t[:]
11141123
Array of indices which sorts masked_vals
@@ -1118,12 +1127,14 @@ cdef void rank_sorted_1d(
11181127
Array where entries are True if the value is missing, False otherwise
11191128
tiebreak : TiebreakEnumType
11201129
See rank_1d.__doc__ for the different modes
1121-
check_mask : bint
1130+
check_mask : bool
11221131
If False, assumes the mask is all False to skip mask indexing
1123-
check_labels : bint
1132+
check_labels : bool
11241133
If False, assumes all labels are the same to skip group handling logic
1125-
keep_na : bint
1134+
keep_na : bool
11261135
Whether or not to keep nulls
1136+
pct : bool
1137+
Compute percentage rank of data within each group
11271138
N : Py_ssize_t
11281139
The number of elements to rank. Note: it is not always true that
11291140
N == len(out) or N == len(masked_vals) (see `nancorr_spearman` usage for why)
@@ -1342,6 +1353,11 @@ cdef void rank_sorted_1d(
13421353
grp_start = i + 1
13431354
grp_vals_seen = 1
13441355

1356+
if pct:
1357+
for i in range(N):
1358+
if grp_sizes[i] != 0:
1359+
out[i] = out[i] / grp_sizes[i]
1360+
13451361

13461362
def rank_2d(
13471363
ndarray[rank_t, ndim=2] in_arr,
@@ -1362,11 +1378,11 @@ def rank_2d(
13621378
ndarray[rank_t, ndim=2] values
13631379
ndarray[intp_t, ndim=2] argsort_indexer
13641380
ndarray[uint8_t, ndim=2] mask
1365-
rank_t val, nan_value
1381+
rank_t val, nan_fill_val
13661382
float64_t count, sum_ranks = 0.0
13671383
int tiebreak = 0
13681384
int64_t idx
1369-
bint check_mask, condition, keep_na
1385+
bint check_mask, condition, keep_na, nans_rank_highest
13701386

13711387
tiebreak = tiebreakers[ties_method]
13721388

@@ -1384,27 +1400,9 @@ def rank_2d(
13841400
if values.dtype != np.object_:
13851401
values = values.astype('O')
13861402

1403+
nans_rank_highest = ascending ^ (na_option == 'top')
13871404
if check_mask:
1388-
if ascending ^ (na_option == 'top'):
1389-
if rank_t is object:
1390-
nan_value = Infinity()
1391-
elif rank_t is float64_t:
1392-
nan_value = np.inf
1393-
1394-
# int64 and datetimelike
1395-
else:
1396-
nan_value = util.INT64_MAX
1397-
1398-
else:
1399-
if rank_t is object:
1400-
nan_value = NegInfinity()
1401-
elif rank_t is float64_t:
1402-
nan_value = -np.inf
1403-
1404-
# int64 and datetimelike
1405-
else:
1406-
nan_value = NPY_NAT
1407-
1405+
nan_fill_val = get_rank_nan_fill_val[rank_t](nans_rank_highest)
14081406
if rank_t is object:
14091407
mask = missing.isnaobj2d(values)
14101408
elif rank_t is float64_t:
@@ -1414,7 +1412,7 @@ def rank_2d(
14141412
else:
14151413
mask = values == NPY_NAT
14161414

1417-
np.putmask(values, mask, nan_value)
1415+
np.putmask(values, mask, nan_fill_val)
14181416
else:
14191417
mask = np.zeros_like(values, dtype=bool)
14201418

0 commit comments

Comments
 (0)