Skip to content

Commit 734cc65

Browse files
authored
feat: support inplace=True in rename and rename_axis (#1744)
* feat: support inplace=True in rename and rename_axis * fix typing issues * do not consider tuples single labels in rename
1 parent acba032 commit 734cc65

File tree

13 files changed

+391
-45
lines changed

13 files changed

+391
-45
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def _(
9292
return args[0] < args[1]
9393
if isinstance(op, ops.eq_op.__class__):
9494
return args[0] == args[1]
95+
if isinstance(op, ops.ne_op.__class__):
96+
return args[0] != args[1]
9597
if isinstance(op, ops.mod_op.__class__):
9698
return args[0] % args[1]
9799
if isinstance(op, ops.coalesce_op.__class__):
@@ -101,6 +103,9 @@ def _(
101103
for pred, result in zip(args[2::2], args[3::2]):
102104
return expr.when(pred).then(result)
103105
return expr
106+
if isinstance(op, ops.where_op.__class__):
107+
original, condition, otherwise = args
108+
return pl.when(condition).then(original).otherwise(otherwise)
104109
raise NotImplementedError(f"Polars compiler hasn't implemented {op}")
105110

106111
@dataclasses.dataclass(frozen=True)

bigframes/core/indexes/base.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,7 @@ def names(self) -> typing.Sequence[blocks.Label]:
145145

146146
@names.setter
147147
def names(self, values: typing.Sequence[blocks.Label]):
148-
new_block = self._block.with_index_labels(values)
149-
if self._linked_frame is not None:
150-
self._linked_frame._set_block(
151-
self._linked_frame._block.with_index_labels(values)
152-
)
153-
self._block = new_block
148+
self.rename(values, inplace=True)
154149

155150
@property
156151
def nlevels(self) -> int:
@@ -411,11 +406,62 @@ def fillna(self, value=None) -> Index:
411406
ops.fillna_op.as_expr(ex.free_var("arg"), ex.const(value))
412407
)
413408

414-
def rename(self, name: Union[str, Sequence[str]]) -> Index:
415-
names = [name] if isinstance(name, str) else list(name)
409+
@overload
410+
def rename(
411+
self,
412+
name: Union[blocks.Label, Sequence[blocks.Label]],
413+
) -> Index:
414+
...
415+
416+
@overload
417+
def rename(
418+
self,
419+
name: Union[blocks.Label, Sequence[blocks.Label]],
420+
*,
421+
inplace: Literal[False],
422+
) -> Index:
423+
...
424+
425+
@overload
426+
def rename(
427+
self,
428+
name: Union[blocks.Label, Sequence[blocks.Label]],
429+
*,
430+
inplace: Literal[True],
431+
) -> None:
432+
...
433+
434+
def rename(
435+
self,
436+
name: Union[blocks.Label, Sequence[blocks.Label]],
437+
*,
438+
inplace: bool = False,
439+
) -> Optional[Index]:
440+
# Tuples are allowed as a label, but we specifically exclude them here.
441+
# This is because tuples are hashable, but we want to treat them as a
442+
# sequence. If name is iterable, we want to assume we're working with a
443+
# MultiIndex. Unfortunately, strings are iterable and we don't want a
444+
# list of all the characters, so specifically exclude the non-tuple
445+
# hashables.
446+
if isinstance(name, blocks.Label) and not isinstance(name, tuple):
447+
names = [name]
448+
else:
449+
names = list(name)
450+
416451
if len(names) != self.nlevels:
417452
raise ValueError("'name' must be same length as levels")
418-
return Index(self._block.with_index_labels(names))
453+
454+
new_block = self._block.with_index_labels(names)
455+
456+
if inplace:
457+
if self._linked_frame is not None:
458+
self._linked_frame._set_block(
459+
self._linked_frame._block.with_index_labels(names)
460+
)
461+
self._block = new_block
462+
return None
463+
else:
464+
return Index(new_block)
419465

420466
def drop(
421467
self,

bigframes/dataframe.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2082,15 +2082,67 @@ def reorder_levels(self, order: LevelsType, axis: int | str = 0):
20822082
def _resolve_levels(self, level: LevelsType) -> typing.Sequence[str]:
20832083
return self._block.index.resolve_level(level)
20842084

2085+
@overload
20852086
def rename(self, *, columns: Mapping[blocks.Label, blocks.Label]) -> DataFrame:
2087+
...
2088+
2089+
@overload
2090+
def rename(
2091+
self, *, columns: Mapping[blocks.Label, blocks.Label], inplace: Literal[False]
2092+
) -> DataFrame:
2093+
...
2094+
2095+
@overload
2096+
def rename(
2097+
self, *, columns: Mapping[blocks.Label, blocks.Label], inplace: Literal[True]
2098+
) -> None:
2099+
...
2100+
2101+
def rename(
2102+
self, *, columns: Mapping[blocks.Label, blocks.Label], inplace: bool = False
2103+
) -> Optional[DataFrame]:
20862104
block = self._block.rename(columns=columns)
2087-
return DataFrame(block)
20882105

2106+
if inplace:
2107+
self._block = block
2108+
return None
2109+
else:
2110+
return DataFrame(block)
2111+
2112+
@overload
2113+
def rename_axis(
2114+
self,
2115+
mapper: typing.Union[blocks.Label, typing.Sequence[blocks.Label]],
2116+
) -> DataFrame:
2117+
...
2118+
2119+
@overload
20892120
def rename_axis(
20902121
self,
20912122
mapper: typing.Union[blocks.Label, typing.Sequence[blocks.Label]],
2123+
*,
2124+
inplace: Literal[False],
20922125
**kwargs,
20932126
) -> DataFrame:
2127+
...
2128+
2129+
@overload
2130+
def rename_axis(
2131+
self,
2132+
mapper: typing.Union[blocks.Label, typing.Sequence[blocks.Label]],
2133+
*,
2134+
inplace: Literal[True],
2135+
**kwargs,
2136+
) -> None:
2137+
...
2138+
2139+
def rename_axis(
2140+
self,
2141+
mapper: typing.Union[blocks.Label, typing.Sequence[blocks.Label]],
2142+
*,
2143+
inplace: bool = False,
2144+
**kwargs,
2145+
) -> Optional[DataFrame]:
20942146
if len(kwargs) != 0:
20952147
raise NotImplementedError(
20962148
f"rename_axis does not currently support any keyword arguments. {constants.FEEDBACK_LINK}"
@@ -2100,7 +2152,14 @@ def rename_axis(
21002152
labels = mapper
21012153
else:
21022154
labels = [mapper]
2103-
return DataFrame(self._block.with_index_labels(labels))
2155+
2156+
block = self._block.with_index_labels(labels)
2157+
2158+
if inplace:
2159+
self._block = block
2160+
return None
2161+
else:
2162+
return DataFrame(block)
21042163

21052164
@validations.requires_ordering()
21062165
def equals(self, other: typing.Union[bigframes.series.Series, DataFrame]) -> bool:

bigframes/series.py

Lines changed: 86 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Literal,
3232
Mapping,
3333
Optional,
34+
overload,
3435
Sequence,
3536
Tuple,
3637
Union,
@@ -95,6 +96,10 @@ class Series(bigframes.operations.base.SeriesMethods, vendored_pandas_series.Ser
9596
# Must be above 5000 for pandas to delegate to bigframes for binops
9697
__pandas_priority__ = 13000
9798

99+
# Ensure mypy can more robustly determine the type of self._block since it
100+
# gets set in various places.
101+
_block: blocks.Block
102+
98103
def __init__(self, *args, **kwargs):
99104
self._query_job: Optional[bigquery.QueryJob] = None
100105
super().__init__(*args, **kwargs)
@@ -254,22 +259,45 @@ def __iter__(self) -> typing.Iterator:
254259
def copy(self) -> Series:
255260
return Series(self._block)
256261

262+
@overload
257263
def rename(
258-
self, index: Union[blocks.Label, Mapping[Any, Any]] = None, **kwargs
264+
self,
265+
index: Union[blocks.Label, Mapping[Any, Any]] = None,
266+
) -> Series:
267+
...
268+
269+
@overload
270+
def rename(
271+
self,
272+
index: Union[blocks.Label, Mapping[Any, Any]] = None,
273+
*,
274+
inplace: Literal[False],
275+
**kwargs,
259276
) -> Series:
277+
...
278+
279+
@overload
280+
def rename(
281+
self,
282+
index: Union[blocks.Label, Mapping[Any, Any]] = None,
283+
*,
284+
inplace: Literal[True],
285+
**kwargs,
286+
) -> None:
287+
...
288+
289+
def rename(
290+
self,
291+
index: Union[blocks.Label, Mapping[Any, Any]] = None,
292+
*,
293+
inplace: bool = False,
294+
**kwargs,
295+
) -> Optional[Series]:
260296
if len(kwargs) != 0:
261297
raise NotImplementedError(
262298
f"rename does not currently support any keyword arguments. {constants.FEEDBACK_LINK}"
263299
)
264300

265-
# rename the Series name
266-
if index is None or isinstance(
267-
index, str
268-
): # Python 3.9 doesn't allow isinstance of Optional
269-
index = typing.cast(Optional[str], index)
270-
block = self._block.with_column_labels([index])
271-
return Series(block)
272-
273301
# rename the index
274302
if isinstance(index, Mapping):
275303
index = typing.cast(Mapping[Any, Any], index)
@@ -294,22 +322,61 @@ def rename(
294322

295323
block = block.set_index(new_idx_ids, index_labels=block.index.names)
296324

297-
return Series(block)
325+
if inplace:
326+
self._block = block
327+
return None
328+
else:
329+
return Series(block)
298330

299331
# rename the Series name
300332
if isinstance(index, typing.Hashable):
333+
# Python 3.9 doesn't allow isinstance of Optional
301334
index = typing.cast(Optional[str], index)
302335
block = self._block.with_column_labels([index])
303-
return Series(block)
336+
337+
if inplace:
338+
self._block = block
339+
return None
340+
else:
341+
return Series(block)
304342

305343
raise ValueError(f"Unsupported type of parameter index: {type(index)}")
306344

307-
@validations.requires_index
345+
@overload
346+
def rename_axis(
347+
self,
348+
mapper: typing.Union[blocks.Label, typing.Sequence[blocks.Label]],
349+
) -> Series:
350+
...
351+
352+
@overload
308353
def rename_axis(
309354
self,
310355
mapper: typing.Union[blocks.Label, typing.Sequence[blocks.Label]],
356+
*,
357+
inplace: Literal[False],
311358
**kwargs,
312359
) -> Series:
360+
...
361+
362+
@overload
363+
def rename_axis(
364+
self,
365+
mapper: typing.Union[blocks.Label, typing.Sequence[blocks.Label]],
366+
*,
367+
inplace: Literal[True],
368+
**kwargs,
369+
) -> None:
370+
...
371+
372+
@validations.requires_index
373+
def rename_axis(
374+
self,
375+
mapper: typing.Union[blocks.Label, typing.Sequence[blocks.Label]],
376+
*,
377+
inplace: bool = False,
378+
**kwargs,
379+
) -> Optional[Series]:
313380
if len(kwargs) != 0:
314381
raise NotImplementedError(
315382
f"rename_axis does not currently support any keyword arguments. {constants.FEEDBACK_LINK}"
@@ -319,7 +386,13 @@ def rename_axis(
319386
labels = mapper
320387
else:
321388
labels = [mapper]
322-
return Series(self._block.with_index_labels(labels))
389+
390+
block = self._block.with_index_labels(labels)
391+
if inplace:
392+
self._block = block
393+
return None
394+
else:
395+
return Series(block)
323396

324397
def equals(
325398
self, other: typing.Union[Series, bigframes.dataframe.DataFrame]

bigframes/testing/mocks.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import copy
1616
import datetime
17-
from typing import Optional, Sequence
17+
from typing import Any, Dict, Optional, Sequence
1818
import unittest.mock as mock
1919

2020
import google.auth.credentials
@@ -23,12 +23,9 @@
2323

2424
import bigframes
2525
import bigframes.clients
26-
import bigframes.core.ordering
26+
import bigframes.core.global_session
2727
import bigframes.dataframe
28-
import bigframes.series
2928
import bigframes.session.clients
30-
import bigframes.session.executor
31-
import bigframes.session.metrics
3229

3330
"""Utilities for creating test resources."""
3431

@@ -129,7 +126,10 @@ def query_and_wait_mock(query, *args, job_config=None, **kwargs):
129126

130127

131128
def create_dataframe(
132-
monkeypatch: pytest.MonkeyPatch, *, session: Optional[bigframes.Session] = None
129+
monkeypatch: pytest.MonkeyPatch,
130+
*,
131+
session: Optional[bigframes.Session] = None,
132+
data: Optional[Dict[str, Sequence[Any]]] = None,
133133
) -> bigframes.dataframe.DataFrame:
134134
"""[Experimental] Create a mock DataFrame that avoids making Google Cloud API calls.
135135
@@ -138,8 +138,11 @@ def create_dataframe(
138138
if session is None:
139139
session = create_bigquery_session()
140140

141+
if data is None:
142+
data = {"col": []}
143+
141144
# Since this may create a ReadLocalNode, the session we explicitly pass in
142145
# might not actually be used. Mock out the global session, too.
143146
monkeypatch.setattr(bigframes.core.global_session, "_global_session", session)
144147
bigframes.options.bigquery._session_started = True
145-
return bigframes.dataframe.DataFrame({"col": []}, session=session)
148+
return bigframes.dataframe.DataFrame(data, session=session)

0 commit comments

Comments
 (0)