Skip to content

Commit 51f715b

Browse files
committed
Add interpolation options to moving quantile
1 parent 0aae181 commit 51f715b

File tree

3 files changed

+86
-16
lines changed

3 files changed

+86
-16
lines changed

pandas/_libs/window.pyx

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,26 +1356,53 @@ cdef _roll_min_max(ndarray[numeric] input, int64_t win, int64_t minp,
13561356
# print("output: {0}".format(output))
13571357
return output
13581358

1359+
def _get_interpolation_id(str interpolation):
1360+
"""
1361+
Converts string to interpolation id
1362+
1363+
Parameters
1364+
----------
1365+
interpolation: 'linear', 'lower', 'higher', 'nearest', 'midpoint'
1366+
"""
1367+
if interpolation == 'linear':
1368+
return 0
1369+
elif interpolation == 'lower':
1370+
return 1
1371+
elif interpolation == 'higher':
1372+
return 2
1373+
elif interpolation == 'nearest':
1374+
return 3
1375+
elif interpolation == 'midpoint':
1376+
return 4
1377+
else:
1378+
raise ValueError("Interpolation {} is not supported"
1379+
.format(interpolation))
1380+
13591381

13601382
def roll_quantile(ndarray[float64_t, cast=True] input, int64_t win,
13611383
int64_t minp, object index, object closed,
1362-
double quantile):
1384+
double quantile, str interpolation):
13631385
"""
13641386
O(N log(window)) implementation using skip list
13651387
"""
13661388
cdef:
1367-
double val, prev, midpoint
1389+
double val, prev, midpoint, idx_with_fraction
13681390
IndexableSkiplist skiplist
13691391
int64_t nobs = 0, i, j, s, e, N
13701392
Py_ssize_t idx
13711393
bint is_variable
13721394
ndarray[int64_t] start, end
13731395
ndarray[double_t] output
13741396
double vlow, vhigh
1397+
int interpolation_id
13751398

13761399
if quantile <= 0.0 or quantile >= 1.0:
13771400
raise ValueError("quantile value {0} not in [0, 1]".format(quantile))
13781401

1402+
# interpolation_id is needed to avoid string comparisons inside the loop
1403+
# I tried to use callback but it resulted in worse performance
1404+
interpolation_id = _get_interpolation_id(interpolation)
1405+
13791406
# we use the Fixed/Variable Indexer here as the
13801407
# actual skiplist ops outweigh any window computation costs
13811408
start, end, N, win, minp, is_variable = get_window_indexer(
@@ -1414,18 +1441,31 @@ def roll_quantile(ndarray[float64_t, cast=True] input, int64_t win,
14141441
skiplist.insert(val)
14151442

14161443
if nobs >= minp:
1417-
idx = int(quantile * <double>(nobs - 1))
1418-
1419-
# Single value in skip list
14201444
if nobs == 1:
1445+
# Single value in skip list
14211446
output[i] = skiplist.get(0)
1422-
1423-
# Interpolated quantile
14241447
else:
1425-
vlow = skiplist.get(idx)
1426-
vhigh = skiplist.get(idx + 1)
1427-
output[i] = ((vlow + (vhigh - vlow) *
1428-
(quantile * (nobs - 1) - idx)))
1448+
idx_with_fraction = quantile * <double> (nobs - 1)
1449+
idx = int(idx_with_fraction)
1450+
1451+
if interpolation_id == 0: # linear
1452+
vlow = skiplist.get(idx)
1453+
vhigh = skiplist.get(idx + 1)
1454+
output[i] = ((vlow + (vhigh - vlow) *
1455+
(idx_with_fraction - idx)))
1456+
elif interpolation_id == 1: # lower
1457+
output[i] = skiplist.get(idx)
1458+
elif interpolation_id == 2: # higher
1459+
output[i] = skiplist.get(idx + 1)
1460+
elif interpolation_id == 3: # nearest
1461+
if idx_with_fraction - idx < 0.5:
1462+
output[i] = skiplist.get(idx)
1463+
else:
1464+
output[i] = skiplist.get(idx + 1)
1465+
elif interpolation_id == 4: # midpoint
1466+
vlow = skiplist.get(idx)
1467+
vhigh = skiplist.get(idx + 1)
1468+
output[i] = <double> (vlow + vhigh) / 2
14291469
else:
14301470
output[i] = NaN
14311471

pandas/core/window.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,9 +1246,21 @@ def kurt(self, **kwargs):
12461246
Parameters
12471247
----------
12481248
quantile : float
1249-
0 <= quantile <= 1""")
1249+
0 <= quantile <= 1
1250+
interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'}
1251+
.. versionadded:: 0.23.0
12501252
1251-
def quantile(self, quantile, **kwargs):
1253+
This optional parameter specifies the interpolation method to use,
1254+
when the desired quantile lies between two data points `i` and `j`:
1255+
1256+
* linear: `i + (j - i) * fraction`, where `fraction` is the
1257+
fractional part of the index surrounded by `i` and `j`.
1258+
* lower: `i`.
1259+
* higher: `j`.
1260+
* nearest: `i` or `j` whichever is nearest.
1261+
* midpoint: (`i` + `j`) / 2.""")
1262+
1263+
def quantile(self, quantile, interpolation='linear', **kwargs):
12521264
window = self._get_window()
12531265
index, indexi = self._get_index()
12541266

@@ -1262,7 +1274,8 @@ def f(arg, *args, **kwargs):
12621274
self.closed)
12631275
else:
12641276
return _window.roll_quantile(arg, window, minp, indexi,
1265-
self.closed, quantile)
1277+
self.closed, quantile,
1278+
interpolation)
12661279

12671280
return self._apply(f, 'quantile', quantile=quantile,
12681281
**kwargs)
@@ -1582,8 +1595,10 @@ def kurt(self, **kwargs):
15821595
@Substitution(name='rolling')
15831596
@Appender(_doc_template)
15841597
@Appender(_shared_docs['quantile'])
1585-
def quantile(self, quantile, **kwargs):
1586-
return super(Rolling, self).quantile(quantile=quantile, **kwargs)
1598+
def quantile(self, quantile, interpolation='linear', **kwargs): # here
1599+
return super(Rolling, self).quantile(quantile=quantile,
1600+
interpolation=interpolation,
1601+
**kwargs)
15871602

15881603
@Substitution(name='rolling')
15891604
@Appender(_doc_template)

pandas/tests/test_window.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,7 +1135,22 @@ def test_rolling_quantile_series(self):
11351135
s = Series(arr)
11361136
q1 = s.quantile(0.1)
11371137
q2 = s.rolling(100).quantile(0.1).iloc[-1]
1138+
tm.assert_almost_equal(q1, q2)
1139+
1140+
q1 = s.quantile(0.1, interpolation='lower')
1141+
q2 = s.rolling(100).quantile(0.1, interpolation='lower').iloc[-1]
1142+
tm.assert_almost_equal(q1, q2)
1143+
1144+
q1 = s.quantile(0.1, interpolation='higher')
1145+
q2 = s.rolling(100).quantile(0.1, interpolation='higher').iloc[-1]
1146+
tm.assert_almost_equal(q1, q2)
1147+
1148+
q1 = s.quantile(0.1, interpolation='nearest')
1149+
q2 = s.rolling(100).quantile(0.1, interpolation='nearest').iloc[-1]
1150+
tm.assert_almost_equal(q1, q2)
11381151

1152+
q1 = s.quantile(0.1, interpolation='midpoint')
1153+
q2 = s.rolling(100).quantile(0.1, interpolation='midpoint').iloc[-1]
11391154
tm.assert_almost_equal(q1, q2)
11401155

11411156
def test_rolling_quantile_param(self):

0 commit comments

Comments
 (0)