Skip to content

Commit 3d24a9a

Browse files
authored
use Union for ListLike in DataFrame __new__ (#198)
1 parent 45a628c commit 3d24a9a

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

pandas-stubs/_typing.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ AxisType = Literal["columns", "index", 0, 1]
9999
DtypeNp = TypeVar("DtypeNp", bound=np.dtype[np.generic])
100100
KeysArgType = Any
101101
ListLike = TypeVar("ListLike", Sequence, np.ndarray, "Series", "Index")
102+
ListLikeU = Union[Sequence, np.ndarray, Series, Index]
102103
StrLike = Union[str, np.str_]
103104
Scalar = Union[
104105
str,

pandas-stubs/core/frame.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ from pandas._typing import (
6161
Label,
6262
Level,
6363
ListLike,
64+
ListLikeU,
6465
MaskType,
6566
Renamer,
6667
Scalar,
@@ -157,10 +158,10 @@ class DataFrame(NDFrame, OpsMixin):
157158

158159
def __new__(
159160
cls,
160-
data: ListLike
161+
data: ListLikeU
161162
| DataFrame
162163
| dict[Any, Any]
163-
| Iterable[tuple[Hashable, ListLike]]
164+
| Iterable[tuple[Hashable, ListLikeU]]
164165
| None = ...,
165166
index: Axes | None = ...,
166167
columns: Axes | None = ...,

tests/test_frame.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
TYPE_CHECKING,
99
Any,
1010
Callable,
11+
Generic,
1112
Hashable,
1213
Iterable,
1314
Iterator,
1415
Tuple,
16+
TypeVar,
1517
Union,
1618
)
1719

@@ -1614,3 +1616,16 @@ def test_dict_items() -> None:
16141616
# GH 180
16151617
x = {"a": [1]}
16161618
check(assert_type(pd.DataFrame(x.items()), pd.DataFrame), pd.DataFrame)
1619+
1620+
1621+
def test_generic() -> None:
1622+
# GH 197
1623+
T = TypeVar("T")
1624+
1625+
class MyDataFrame(pd.DataFrame, Generic[T]):
1626+
...
1627+
1628+
def func() -> MyDataFrame[int]:
1629+
return MyDataFrame[int]({"foo": [1, 2, 3]})
1630+
1631+
func()

0 commit comments

Comments
 (0)