Skip to content

Commit 83379a4

Browse files
sycaishobsi
authored andcommitted
feat: add closed parameter in rolling() (#1539)
* feat: add parameter in rolling() * beautify code
1 parent 7b70e97 commit 83379a4

File tree

6 files changed

+199
-55
lines changed

6 files changed

+199
-55
lines changed

bigframes/core/groupby/__init__.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,18 @@
1515
from __future__ import annotations
1616

1717
import typing
18-
from typing import Sequence, Tuple, Union
18+
from typing import Literal, Sequence, Tuple, Union
1919

2020
import bigframes_vendored.constants as constants
2121
import bigframes_vendored.pandas.core.groupby as vendored_pandas_groupby
2222
import jellyfish
2323
import pandas as pd
2424

2525
from bigframes import session
26+
from bigframes.core import expression as ex
2627
from bigframes.core import log_adapter
2728
import bigframes.core.block_transforms as block_ops
2829
import bigframes.core.blocks as blocks
29-
import bigframes.core.expression
3030
import bigframes.core.ordering as order
3131
import bigframes.core.utils as utils
3232
import bigframes.core.validations as validations
@@ -305,13 +305,16 @@ def diff(self, periods=1) -> series.Series:
305305
return self._apply_window_op(agg_ops.DiffOp(periods), window=window)
306306

307307
@validations.requires_ordering()
308-
def rolling(self, window: int, min_periods=None) -> windows.Window:
309-
# To get n size window, need current row and n-1 preceding rows.
310-
window_spec = window_specs.rows(
311-
grouping_keys=tuple(self._by_col_ids),
312-
start=-(window - 1),
313-
end=0,
314-
min_periods=min_periods or window,
308+
def rolling(
309+
self,
310+
window: int,
311+
min_periods=None,
312+
closed: Literal["right", "left", "both", "neither"] = "right",
313+
) -> windows.Window:
314+
window_spec = window_specs.WindowSpec(
315+
bounds=window_specs.RowsWindowBounds.from_window_size(window, closed),
316+
min_periods=min_periods if min_periods is not None else window,
317+
grouping_keys=tuple(ex.deref(col) for col in self._by_col_ids),
315318
)
316319
block = self._block.order_by(
317320
[order.ascending_over(col) for col in self._by_col_ids],
@@ -361,7 +364,7 @@ def _agg_string(self, func: str) -> df.DataFrame:
361364
return dataframe if self._as_index else self._convert_index(dataframe)
362365

363366
def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
364-
aggregations: typing.List[bigframes.core.expression.Aggregation] = []
367+
aggregations: typing.List[ex.Aggregation] = []
365368
column_labels = []
366369

367370
want_aggfunc_level = any(utils.is_list_like(aggs) for aggs in func.values())
@@ -738,13 +741,16 @@ def diff(self, periods=1) -> series.Series:
738741
return self._apply_window_op(agg_ops.DiffOp(periods), window=window)
739742

740743
@validations.requires_ordering()
741-
def rolling(self, window: int, min_periods=None) -> windows.Window:
742-
# To get n size window, need current row and n-1 preceding rows.
743-
window_spec = window_specs.rows(
744-
grouping_keys=tuple(self._by_col_ids),
745-
start=-(window - 1),
746-
end=0,
747-
min_periods=min_periods or window,
744+
def rolling(
745+
self,
746+
window: int,
747+
min_periods=None,
748+
closed: Literal["right", "left", "both", "neither"] = "right",
749+
) -> windows.Window:
750+
window_spec = window_specs.WindowSpec(
751+
bounds=window_specs.RowsWindowBounds.from_window_size(window, closed),
752+
min_periods=min_periods if min_periods is not None else window,
753+
grouping_keys=tuple(ex.deref(col) for col in self._by_col_ids),
748754
)
749755
block = self._block.order_by(
750756
[order.ascending_over(col) for col in self._by_col_ids],
@@ -806,11 +812,9 @@ def _apply_window_op(
806812
return series.Series(block.select_column(result_id))
807813

808814

809-
def agg(input: str, op: agg_ops.AggregateOp) -> bigframes.core.expression.Aggregation:
815+
def agg(input: str, op: agg_ops.AggregateOp) -> ex.Aggregation:
810816
if isinstance(op, agg_ops.UnaryAggregateOp):
811-
return bigframes.core.expression.UnaryAggregation(
812-
op, bigframes.core.expression.deref(input)
813-
)
817+
return ex.UnaryAggregation(op, ex.deref(input))
814818
else:
815819
assert isinstance(op, agg_ops.NullaryAggregateOp)
816-
return bigframes.core.expression.NullaryAggregation(op)
820+
return ex.NullaryAggregation(op)

bigframes/core/window_spec.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from dataclasses import dataclass, replace
1717
import itertools
18-
from typing import Mapping, Optional, Set, Tuple, Union
18+
from typing import Literal, Mapping, Optional, Set, Tuple, Union
1919

2020
import bigframes.core.expression as ex
2121
import bigframes.core.identifiers as ids
@@ -140,6 +140,21 @@ class RowsWindowBounds:
140140
start: Optional[int] = None
141141
end: Optional[int] = None
142142

143+
@classmethod
144+
def from_window_size(
145+
cls, window: int, closed: Literal["right", "left", "both", "neither"]
146+
) -> RowsWindowBounds:
147+
if closed == "right":
148+
return cls(-(window - 1), 0)
149+
elif closed == "left":
150+
return cls(-window, -1)
151+
elif closed == "both":
152+
return cls(-window, 0)
153+
elif closed == "neither":
154+
return cls(-(window - 1), -1)
155+
else:
156+
raise ValueError(f"Unsupported value for 'closed' parameter: {closed}")
157+
143158
def __post_init__(self):
144159
if self.start is None:
145160
return

bigframes/dataframe.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3308,10 +3308,15 @@ def _perform_join_by_index(
33083308
return DataFrame(block)
33093309

33103310
@validations.requires_ordering()
3311-
def rolling(self, window: int, min_periods=None) -> bigframes.core.window.Window:
3312-
# To get n size window, need current row and n-1 preceding rows.
3313-
window_def = windows.rows(
3314-
start=-(window - 1), end=0, min_periods=min_periods or window
3311+
def rolling(
3312+
self,
3313+
window: int,
3314+
min_periods=None,
3315+
closed: Literal["right", "left", "both", "neither"] = "right",
3316+
) -> bigframes.core.window.Window:
3317+
window_def = windows.WindowSpec(
3318+
bounds=windows.RowsWindowBounds.from_window_size(window, closed),
3319+
min_periods=min_periods if min_periods is not None else window,
33153320
)
33163321
return bigframes.core.window.Window(
33173322
self._block, window_def, self._block.value_columns

bigframes/series.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,10 @@
3535
import typing_extensions
3636

3737
import bigframes.core
38-
from bigframes.core import log_adapter
38+
from bigframes.core import groupby, log_adapter
3939
import bigframes.core.block_transforms as block_ops
4040
import bigframes.core.blocks as blocks
4141
import bigframes.core.expression as ex
42-
import bigframes.core.groupby as groupby
4342
import bigframes.core.indexers
4443
import bigframes.core.indexes as indexes
4544
import bigframes.core.ordering as order
@@ -1438,10 +1437,15 @@ def sort_index(self, *, axis=0, ascending=True, na_position="last") -> Series:
14381437
return Series(block)
14391438

14401439
@validations.requires_ordering()
1441-
def rolling(self, window: int, min_periods=None) -> bigframes.core.window.Window:
1442-
# To get n size window, need current row and n-1 preceding rows.
1443-
window_spec = windows.rows(
1444-
start=-(window - 1), end=0, min_periods=min_periods or window
1440+
def rolling(
1441+
self,
1442+
window: int,
1443+
min_periods=None,
1444+
closed: Literal["right", "left", "both", "neither"] = "right",
1445+
) -> bigframes.core.window.Window:
1446+
window_spec = windows.WindowSpec(
1447+
bounds=windows.RowsWindowBounds.from_window_size(window, closed),
1448+
min_periods=min_periods if min_periods is not None else window,
14451449
)
14461450
return bigframes.core.window.Window(
14471451
self._block, window_spec, self._block.value_columns, is_series=True

tests/system/small/test_window.py

Lines changed: 117 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,111 @@
1616
import pytest
1717

1818

19+
@pytest.fixture(scope="module")
20+
def rolling_dfs(scalars_dfs):
21+
bf_df, pd_df = scalars_dfs
22+
23+
target_cols = ["int64_too", "float64_col", "bool_col"]
24+
25+
bf_df = bf_df[target_cols].set_index("bool_col")
26+
pd_df = pd_df[target_cols].set_index("bool_col")
27+
28+
return bf_df, pd_df
29+
30+
31+
@pytest.fixture(scope="module")
32+
def rolling_series(scalars_dfs):
33+
bf_df, pd_df = scalars_dfs
34+
target_col = "int64_too"
35+
36+
return bf_df[target_col], pd_df[target_col]
37+
38+
39+
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
40+
def test_dataframe_rolling_closed_param(rolling_dfs, closed):
41+
bf_df, pd_df = rolling_dfs
42+
43+
actual_result = bf_df.rolling(window=3, closed=closed).sum().to_pandas()
44+
45+
expected_result = pd_df.rolling(window=3, closed=closed).sum()
46+
pd.testing.assert_frame_equal(actual_result, expected_result, check_dtype=False)
47+
48+
49+
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
50+
def test_dataframe_groupby_rolling_closed_param(rolling_dfs, closed):
51+
bf_df, pd_df = rolling_dfs
52+
53+
actual_result = (
54+
bf_df.groupby(level=0).rolling(window=3, closed=closed).sum().to_pandas()
55+
)
56+
57+
expected_result = pd_df.groupby(level=0).rolling(window=3, closed=closed).sum()
58+
pd.testing.assert_frame_equal(actual_result, expected_result, check_dtype=False)
59+
60+
61+
def test_dataframe_rolling_default_closed_param(rolling_dfs):
62+
bf_df, pd_df = rolling_dfs
63+
64+
actual_result = bf_df.rolling(window=3).sum().to_pandas()
65+
66+
expected_result = pd_df.rolling(window=3).sum()
67+
pd.testing.assert_frame_equal(actual_result, expected_result, check_dtype=False)
68+
69+
70+
def test_dataframe_groupby_rolling_default_closed_param(rolling_dfs):
71+
bf_df, pd_df = rolling_dfs
72+
73+
actual_result = bf_df.groupby(level=0).rolling(window=3).sum().to_pandas()
74+
75+
expected_result = pd_df.groupby(level=0).rolling(window=3).sum()
76+
pd.testing.assert_frame_equal(actual_result, expected_result, check_dtype=False)
77+
78+
79+
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
80+
def test_series_rolling_closed_param(rolling_series, closed):
81+
bf_series, df_series = rolling_series
82+
83+
actual_result = bf_series.rolling(window=3, closed=closed).sum().to_pandas()
84+
85+
expected_result = df_series.rolling(window=3, closed=closed).sum()
86+
pd.testing.assert_series_equal(actual_result, expected_result, check_dtype=False)
87+
88+
89+
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
90+
def test_series_groupby_rolling_closed_param(rolling_series, closed):
91+
bf_series, df_series = rolling_series
92+
93+
actual_result = (
94+
bf_series.groupby(bf_series % 2)
95+
.rolling(window=3, closed=closed)
96+
.sum()
97+
.to_pandas()
98+
)
99+
100+
expected_result = (
101+
df_series.groupby(df_series % 2).rolling(window=3, closed=closed).sum()
102+
)
103+
pd.testing.assert_series_equal(actual_result, expected_result, check_dtype=False)
104+
105+
106+
def test_series_rolling_default_closed_param(rolling_series):
107+
bf_series, df_series = rolling_series
108+
109+
actual_result = bf_series.rolling(window=3).sum().to_pandas()
110+
111+
expected_result = df_series.rolling(window=3).sum()
112+
pd.testing.assert_series_equal(actual_result, expected_result, check_dtype=False)
113+
114+
115+
def test_series_groupby_rolling_default_closed_param(rolling_series):
116+
bf_series, df_series = rolling_series
117+
118+
actual_result = bf_series.groupby(bf_series % 2).rolling(window=3).sum().to_pandas()
119+
120+
expected_result = df_series.groupby(df_series % 2).rolling(window=3).sum()
121+
pd.testing.assert_series_equal(actual_result, expected_result, check_dtype=False)
122+
123+
19124
@pytest.mark.parametrize(
20125
("windowing"),
21126
[
@@ -41,20 +146,13 @@
41146
pytest.param(lambda x: x.var(), id="var"),
42147
],
43148
)
44-
def test_series_window_agg_ops(
45-
scalars_df_index, scalars_pandas_df_index, windowing, agg_op
46-
):
47-
col_name = "int64_too"
48-
bf_series = agg_op(windowing(scalars_df_index[col_name])).to_pandas()
49-
pd_series = agg_op(windowing(scalars_pandas_df_index[col_name]))
50-
51-
# Pandas always converts to float64, even for min/max/count, which is not desired
52-
pd_series = pd_series.astype(bf_series.dtype)
53-
54-
pd.testing.assert_series_equal(
55-
pd_series,
56-
bf_series,
57-
)
149+
def test_series_window_agg_ops(rolling_series, windowing, agg_op):
150+
bf_series, pd_series = rolling_series
151+
152+
actual_result = agg_op(windowing(bf_series)).to_pandas()
153+
154+
expected_result = agg_op(windowing(pd_series))
155+
pd.testing.assert_series_equal(expected_result, actual_result, check_dtype=False)
58156

59157

60158
@pytest.mark.parametrize(
@@ -83,13 +181,10 @@ def test_series_window_agg_ops(
83181
pytest.param(lambda x: x.var(), id="var"),
84182
],
85183
)
86-
def test_dataframe_window_agg_ops(
87-
scalars_df_index, scalars_pandas_df_index, windowing, agg_op
88-
):
89-
scalars_df_index = scalars_df_index.set_index("bool_col")
90-
scalars_pandas_df_index = scalars_pandas_df_index.set_index("bool_col")
91-
col_names = ["int64_too", "float64_col"]
92-
bf_result = agg_op(windowing(scalars_df_index[col_names])).to_pandas()
93-
pd_result = agg_op(windowing(scalars_pandas_df_index[col_names]))
184+
def test_dataframe_window_agg_ops(rolling_dfs, windowing, agg_op):
185+
bf_df, pd_df = rolling_dfs
186+
187+
bf_result = agg_op(windowing(bf_df)).to_pandas()
94188

189+
pd_result = agg_op(windowing(pd_df))
95190
pd.testing.assert_frame_equal(pd_result, bf_result, check_dtype=False)

tests/unit/core/test_windowspec.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,24 @@ def test_invalid_rows_window_boundary_raise_error(start, end):
2727
def test_invalid_range_window_boundary_raise_error(start, end):
2828
with pytest.raises(ValueError):
2929
window_spec.RangeWindowBounds(start, end)
30+
31+
32+
@pytest.mark.parametrize(
33+
("window", "closed", "start", "end"),
34+
[
35+
pytest.param(3, "left", -3, -1, id="left"),
36+
pytest.param(3, "right", -2, 0, id="right"),
37+
pytest.param(3, "neither", -2, -1, id="neither"),
38+
pytest.param(3, "both", -3, 0, id="both"),
39+
],
40+
)
41+
def test_rows_window_bounds_from_window_size(window, closed, start, end):
42+
actual_result = window_spec.RowsWindowBounds.from_window_size(window, closed)
43+
44+
expected_result = window_spec.RowsWindowBounds(start, end)
45+
assert actual_result == expected_result
46+
47+
48+
def test_rows_window_bounds_from_window_size_invalid_closed_raise_error():
49+
with pytest.raises(ValueError):
50+
window_spec.RowsWindowBounds.from_window_size(3, "whatever") # type:ignore

0 commit comments

Comments
 (0)