-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: Add online operations for EWM.mean #41888
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 23 commits
e195c58
3d95167
9354bd0
0ce197d
8096cc6
a5273b9
bab78cc
d72a03e
0b7e773
8444b42
7847373
df13b55
57db06e
329dbd2
9594afe
28be18a
85025ff
3345271
2186ea0
8024a7b
80c8b7f
8a5b0b9
e790947
916e68b
175c4ca
f799a0f
c8b09b6
2cb4019
fea8b0b
04ea064
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10892,7 +10892,7 @@ def ewm( | |
span: float | None = None, | ||
halflife: float | TimedeltaConvertibleTypes | None = None, | ||
alpha: float | None = None, | ||
min_periods: int = 0, | ||
min_periods: int | None = 0, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this (and the changed annotation below) orthogonal to the rest of the PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A bit.
It's parent class typed Since I was calling |
||
adjust: bool_t = True, | ||
ignore_na: bool_t = False, | ||
axis: Axis = 0, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,6 +41,10 @@ | |
GroupbyIndexer, | ||
) | ||
from pandas.core.window.numba_ import generate_numba_ewma_func | ||
from pandas.core.window.online import ( | ||
EWMMeanState, | ||
generate_online_numba_ewma_func, | ||
) | ||
from pandas.core.window.rolling import ( | ||
BaseWindow, | ||
BaseWindowGroupby, | ||
|
@@ -263,7 +267,7 @@ def __init__( | |
span: float | None = None, | ||
halflife: float | TimedeltaConvertibleTypes | None = None, | ||
alpha: float | None = None, | ||
min_periods: int = 0, | ||
min_periods: int | None = 0, | ||
adjust: bool = True, | ||
ignore_na: bool = False, | ||
axis: Axis = 0, | ||
|
@@ -273,7 +277,7 @@ def __init__( | |
): | ||
super().__init__( | ||
obj=obj, | ||
min_periods=max(int(min_periods), 1), | ||
min_periods=1 if min_periods is None else max(int(min_periods), 1), | ||
on=None, | ||
center=False, | ||
closed=None, | ||
|
@@ -338,6 +342,48 @@ def _get_window_indexer(self) -> BaseIndexer: | |
""" | ||
return ExponentialMovingWindowIndexer() | ||
|
||
def online(self, engine="numba", engine_kwargs=None): | ||
""" | ||
Return an ``OnlineExponentialMovingWindow`` object to calculate | ||
exponentially moving window aggregations in an online method. | ||
|
||
.. versionadded:: 1.3.0 | ||
|
||
Parameters | ||
---------- | ||
engine: str, default ``'numba'`` | ||
Execution engine to calculate online aggregations. | ||
Applies to all supported aggregation methods. | ||
|
||
engine_kwargs : dict, default None | ||
Applies to all supported aggregation methods. | ||
|
||
* For ``'numba'`` engine, the engine can accept ``nopython``, ``nogil`` | ||
and ``parallel`` dictionary keys. The values must either be ``True`` or | ||
``False``. The default ``engine_kwargs`` for the ``'numba'`` engine is | ||
``{{'nopython': True, 'nogil': False, 'parallel': False}}`` and will be | ||
applied to the function | ||
|
||
Returns | ||
------- | ||
OnlineExponentialMovingWindow | ||
""" | ||
return OnlineExponentialMovingWindow( | ||
obj=self.obj, | ||
com=self.com, | ||
span=self.span, | ||
halflife=self.halflife, | ||
alpha=self.alpha, | ||
min_periods=self.min_periods, | ||
adjust=self.adjust, | ||
ignore_na=self.ignore_na, | ||
axis=self.axis, | ||
times=self.times, | ||
engine=engine, | ||
engine_kwargs=engine_kwargs, | ||
selection=self._selection, | ||
) | ||
|
||
@doc( | ||
_shared_docs["aggregate"], | ||
see_also=dedent( | ||
|
@@ -655,3 +701,167 @@ def _get_window_indexer(self) -> GroupbyIndexer: | |
window_indexer=ExponentialMovingWindowIndexer, | ||
) | ||
return window_indexer | ||
|
||
|
||
class OnlineExponentialMovingWindow(ExponentialMovingWindow): | ||
def __init__( | ||
self, | ||
obj: FrameOrSeries, | ||
com: float | None = None, | ||
span: float | None = None, | ||
halflife: float | TimedeltaConvertibleTypes | None = None, | ||
alpha: float | None = None, | ||
min_periods: int | None = 0, | ||
adjust: bool = True, | ||
ignore_na: bool = False, | ||
axis: Axis = 0, | ||
times: str | np.ndarray | FrameOrSeries | None = None, | ||
engine: str = "numba", | ||
engine_kwargs: dict[str, bool] | None = None, | ||
*, | ||
selection=None, | ||
): | ||
if times is not None: | ||
raise NotImplementedError( | ||
"times is not implemented with online operations." | ||
) | ||
super().__init__( | ||
obj=obj, | ||
com=com, | ||
span=span, | ||
halflife=halflife, | ||
alpha=alpha, | ||
min_periods=min_periods, | ||
adjust=adjust, | ||
ignore_na=ignore_na, | ||
axis=axis, | ||
times=times, | ||
selection=selection, | ||
) | ||
self._mean = EWMMeanState( | ||
self._com, self.adjust, self.ignore_na, self.axis, obj.shape | ||
) | ||
if maybe_use_numba(engine): | ||
self.engine = engine | ||
self.engine_kwargs = engine_kwargs | ||
else: | ||
raise ValueError("'numba' is the only supported engine") | ||
|
||
def reset(self): | ||
""" | ||
Reset the state captured by `update` calls. | ||
""" | ||
self._mean.reset() | ||
|
||
def aggregate(self, func, *args, **kwargs): | ||
return NotImplementedError | ||
|
||
def std(self, bias: bool = False, *args, **kwargs): | ||
return NotImplementedError | ||
|
||
def corr( | ||
self, | ||
other: FrameOrSeriesUnion | None = None, | ||
pairwise: bool | None = None, | ||
**kwargs, | ||
): | ||
return NotImplementedError | ||
|
||
def cov( | ||
self, | ||
other: FrameOrSeriesUnion | None = None, | ||
pairwise: bool | None = None, | ||
bias: bool = False, | ||
**kwargs, | ||
): | ||
return NotImplementedError | ||
|
||
def var(self, bias: bool = False, *args, **kwargs): | ||
return NotImplementedError | ||
|
||
def mean(self, *args, update=None, update_times=None, **kwargs): | ||
""" | ||
Calculate an online exponentially weighted mean. | ||
|
||
Parameters | ||
---------- | ||
update: DataFrame or Series, default None | ||
New values to continue calculating the | ||
exponentially weighted mean from the last values and weights. | ||
Values should be float64 dtype. | ||
|
||
``update`` needs to be ``None`` the first time the | ||
exponentially weighted mean is calculated. | ||
|
||
update_times: Series or 1-D np.ndarray, default None | ||
New times to continue calculating the | ||
exponentially weighted mean from the last values and weights. | ||
If ``None``, values are assumed to be evenly spaced | ||
in time. | ||
This feature is currently unsupported. | ||
|
||
Returns | ||
------- | ||
DataFrame or Series | ||
|
||
Examples | ||
-------- | ||
>>> df = pd.DataFrame({"a": range(5), "b": range(5, 10)}) | ||
>>> online_ewm = df.head(2).ewm(0.5).online() | ||
>>> online_ewm.mean() | ||
a b | ||
0 0.00 5.00 | ||
1 0.75 5.75 | ||
>>> online_ewm.mean(update=df.tail(3)) | ||
a b | ||
1 1.615385 6.615385 | ||
2 2.550000 7.550000 | ||
3 3.520661 8.520661 | ||
>>> online_ewm.reset() | ||
>>> online_ewm.mean() | ||
a b | ||
0 0.00 5.00 | ||
1 0.75 5.75 | ||
""" | ||
result_kwargs = {} | ||
is_frame = True if self._selected_obj.ndim == 2 else False | ||
if update_times is not None: | ||
raise NotImplementedError("update_times is not implemented.") | ||
else: | ||
update_deltas = np.ones( | ||
max(self._selected_obj.shape[self.axis - 1] - 1, 0), dtype=np.float64 | ||
) | ||
if update is not None: | ||
if self._mean.last_ewm is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A user needs to call This checks that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, and test for this? (with good error message) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
raise ValueError( | ||
"Must call mean with update=None first before passing update" | ||
) | ||
result_from = 1 | ||
result_kwargs["index"] = update.index | ||
if is_frame: | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
last_value = self._mean.last_ewm[np.newaxis, :] | ||
result_kwargs["columns"] = update.columns | ||
else: | ||
last_value = self._mean.last_ewm | ||
result_kwargs["name"] = update.name | ||
np_array = np.concatenate((last_value, update.to_numpy())) | ||
else: | ||
result_from = 0 | ||
result_kwargs["index"] = self._selected_obj.index | ||
if is_frame: | ||
result_kwargs["columns"] = self._selected_obj.columns | ||
else: | ||
result_kwargs["name"] = self._selected_obj.name | ||
np_array = self._selected_obj.astype(np.float64).to_numpy() | ||
ewma_func = generate_online_numba_ewma_func(self.engine_kwargs) | ||
result = self._mean.run_ewm( | ||
np_array if is_frame else np_array[:, np.newaxis], | ||
update_deltas, | ||
self.min_periods, | ||
ewma_func, | ||
) | ||
if not is_frame: | ||
result = result.squeeze() | ||
result = result[result_from:] | ||
result = self._selected_obj._constructor(result, **result_kwargs) | ||
return result |
Uh oh!
There was an error while loading. Please reload this page.