Skip to content

Commit d572c11

Browse files
authored
type DataFrame.assign (#1176)
* type DataFrame.assign * include Scalar too
1 parent 0abb350 commit d572c11

File tree

3 files changed

+39
-1
lines changed

3 files changed

+39
-1
lines changed

pandas-stubs/_typing.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,4 +837,9 @@ ExcelWriteEngine: TypeAlias = Literal["openpyxl", "odf", "xlsxwriter"]
837837
# https://github.com/pandas-dev/pandas-stubs/pull/1151#issuecomment-2715130190
838838
TimeZones: TypeAlias = str | tzinfo | None | int
839839

840+
# Evaluates to a DataFrame column in DataFrame.assign context.
841+
IntoColumn: TypeAlias = (
842+
AnyArrayLike | Scalar | Callable[[DataFrame], AnyArrayLike | Scalar]
843+
)
844+
840845
__all__ = ["npt", "type_t"]

pandas-stubs/core/frame.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ from pandas._typing import (
100100
InterpolateOptions,
101101
IntervalClosedType,
102102
IntervalT,
103+
IntoColumn,
103104
JoinHow,
104105
JsonFrameOrient,
105106
Label,
@@ -742,7 +743,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
742743
value: Scalar | ListLikeU | None,
743744
allow_duplicates: _bool = ...,
744745
) -> None: ...
745-
def assign(self, **kwargs) -> Self: ...
746+
def assign(self, **kwargs: IntoColumn) -> Self: ...
746747
def align(
747748
self,
748749
other: NDFrameT,

tests/test_frame.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,38 @@ def test_types_assign() -> None:
308308
df["col3"] = df.sum(axis=1)
309309

310310

311+
def test_assign() -> None:
312+
df = pd.DataFrame({"a": [1, 2, 3], 1: [4, 5, 6]})
313+
314+
my_unnamed_func = lambda df: df["a"] * 2
315+
316+
def my_named_func_1(df: pd.DataFrame) -> pd.Series[str]:
317+
return df["a"]
318+
319+
def my_named_func_2(df: pd.DataFrame) -> pd.Series[Any]:
320+
return df["a"]
321+
322+
check(assert_type(df.assign(c=lambda df: df["a"] * 2), pd.DataFrame), pd.DataFrame)
323+
check(
324+
assert_type(df.assign(c=lambda df: df["a"].index), pd.DataFrame), pd.DataFrame
325+
)
326+
check(
327+
assert_type(df.assign(c=lambda df: df["a"].to_numpy()), pd.DataFrame),
328+
pd.DataFrame,
329+
)
330+
check(
331+
assert_type(df.assign(c=lambda df: df["a"].max()), pd.DataFrame),
332+
pd.DataFrame,
333+
)
334+
check(assert_type(df.assign(c=df["a"] * 2), pd.DataFrame), pd.DataFrame)
335+
check(assert_type(df.assign(c=df["a"].index), pd.DataFrame), pd.DataFrame)
336+
check(assert_type(df.assign(c=df["a"].to_numpy()), pd.DataFrame), pd.DataFrame)
337+
check(assert_type(df.assign(c=2), pd.DataFrame), pd.DataFrame)
338+
check(assert_type(df.assign(c=my_unnamed_func), pd.DataFrame), pd.DataFrame)
339+
check(assert_type(df.assign(c=my_named_func_1), pd.DataFrame), pd.DataFrame)
340+
check(assert_type(df.assign(c=my_named_func_2), pd.DataFrame), pd.DataFrame)
341+
342+
311343
def test_types_sample() -> None:
312344
df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
313345
# GH 67

0 commit comments

Comments
 (0)