Skip to content

Commit fb496f1

Browse files
committed
Fix test declarations, some impl bugs remain
Signed-off-by: Vasily Litvinov <[email protected]>
1 parent 1bd80f3 commit fb496f1

File tree

3 files changed

+210
-2
lines changed

3 files changed

+210
-2
lines changed

pandas/core/exchange/dataframe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ def __init__(
2828
self._nan_as_null = nan_as_null
2929
self._allow_copy = allow_copy
3030

31+
def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True):
32+
return PandasDataFrameXchg(self._df, nan_as_null, allow_copy)
33+
3134
@property
3235
def metadata(self):
3336
# `index` isn't a regular column, and the protocol doesn't support row

pandas/core/exchange/dataframe_protocol.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@ class ColumnBuffers(TypedDict):
9393
offsets: Optional[Tuple["Buffer", Any]]
9494

9595

96+
class CategoricalDescription(TypedDict):
97+
# whether the ordering of dictionary indices is semantically meaningful
98+
is_ordered: bool
99+
# whether a dictionary-style mapping of categorical values to other objects exists
100+
is_dictionary: bool
101+
# Python-level only (e.g. ``{int: str}``). None if not a dictionary-style categorical.
102+
mapping: Optional[dict]
103+
104+
96105
class Buffer(ABC):
97106
"""
98107
Data in the buffer is guaranteed to be contiguous in memory.
@@ -250,15 +259,15 @@ def dtype(self) -> Tuple[DtypeKind, int, str, str]:
250259

251260
@property
252261
@abstractmethod
253-
def describe_categorical(self) -> Tuple[bool, bool, Optional[dict]]:
262+
def describe_categorical(self) -> CategoricalDescription:
254263
"""
255264
If the dtype is categorical, there are two options:
256265
- There are only values in the data buffer.
257266
- There is a separate dictionary-style encoding for categorical values.
258267
259268
Raises TypeError if the dtype is not categorical
260269
261-
Returns the description on how to interpret the data buffer:
270+
Returns the dictionary with description on how to interpret the data buffer:
262271
- "is_ordered" : bool, whether the ordering of dictionary indices is
263272
semantically meaningful.
264273
- "is_dictionary" : bool, whether a dictionary-style mapping of
@@ -367,6 +376,11 @@ class DataFrame(ABC):
367376

368377
version = 0 # version of the protocol
369378

379+
@abstractmethod
380+
def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True):
381+
"""Construct a new exchange object, potentially changing the parameters."""
382+
pass
383+
370384
@property
371385
@abstractmethod
372386
def metadata(self) -> Dict[str, Any]:

pandas/tests/exchange/test_impl.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import pandas as pd
2+
import numpy as np
3+
import pytest
4+
import random
5+
6+
from pandas.testing import assert_frame_equal
7+
from pandas.core.exchange.dataframe_protocol import DtypeKind, ColumnNullType
8+
from pandas.core.exchange.from_dataframe import from_dataframe
9+
10+
test_data_categorical = {
11+
"ordered": pd.Categorical(list("testdata") * 30, ordered=True),
12+
"unordered": pd.Categorical(list("testdata") * 30, ordered=False),
13+
}
14+
15+
NCOLS, NROWS = 100, 200
16+
17+
int_data = {
18+
"col{}".format(int((i - NCOLS / 2) % NCOLS + 1)): [
19+
random.randint(0, 100) for _ in range(NROWS)
20+
]
21+
for i in range(NCOLS)
22+
}
23+
24+
bool_data = {
25+
"col{}".format(int((i - NCOLS / 2) % NCOLS + 1)): [
26+
random.choice([True, False]) for _ in range(NROWS)
27+
]
28+
for i in range(NCOLS)
29+
}
30+
31+
float_data = {
32+
"col{}".format(int((i - NCOLS / 2) % NCOLS + 1)): [
33+
random.random() for _ in range(NROWS)
34+
]
35+
for i in range(NCOLS)
36+
}
37+
38+
string_data = {
39+
"separator data": [
40+
"abC|DeF,Hik",
41+
"234,3245.67",
42+
"gSaf,qWer|Gre",
43+
"asd3,4sad|",
44+
np.NaN,
45+
]
46+
}
47+
48+
49+
@pytest.mark.parametrize("data", [("ordered", True), ("unordered", False)])
50+
def test_categorical_dtype(data):
51+
df = pd.DataFrame({"A": (test_data_categorical[data[0]])})
52+
53+
col = df.__dataframe__().get_column_by_name("A")
54+
assert col.dtype[0] == DtypeKind.CATEGORICAL
55+
assert col.null_count == 0
56+
assert col.describe_null == (ColumnNullType.USE_SENTINEL, -1)
57+
assert col.num_chunks() == 1
58+
assert col.describe_categorical == {
59+
"is_ordered": data[1],
60+
"is_dictionary": True,
61+
"mapping": {4: "s", 2: "d", 3: "e", 1: "t"},
62+
}
63+
64+
assert assert_frame_equal(df, from_dataframe(df.__dataframe__()))
65+
66+
67+
@pytest.mark.parametrize("data", [int_data, float_data, bool_data])
68+
def test_dataframe(data):
69+
df = pd.DataFrame(data)
70+
71+
df2 = df.__dataframe__()
72+
73+
assert df2._allow_copy is True
74+
assert df2.num_columns() == NCOLS
75+
assert df2.num_rows() == NROWS
76+
77+
assert list(df2.column_names()) == list(data.keys())
78+
79+
assert assert_frame_equal(
80+
from_dataframe(df2.select_columns((0, 2))),
81+
from_dataframe(df2.select_columns_by_name(("col33", "col35"))),
82+
)
83+
assert assert_frame_equal(
84+
from_dataframe(df2.select_columns((0, 2))),
85+
from_dataframe(df2.select_columns_by_name(("col33", "col35"))),
86+
)
87+
88+
89+
def test_missing_from_masked():
90+
df = pd.DataFrame(
91+
{
92+
"x": np.array([1, 2, 3, 4, 0]),
93+
"y": np.array([1.5, 2.5, 3.5, 4.5, 0]),
94+
"z": np.array([True, False, True, True, True]),
95+
}
96+
)
97+
98+
df2 = df.__dataframe__()
99+
100+
# for col_name in df.columns:
101+
# assert convert_column_to_array(df2.get_column_by_name(col_name) == df[col_name].tolist()
102+
# assert df[col_name].dtype == convert_column_to_array(df2.get_column_by_name(col_name)).dtype
103+
104+
rng = np.random.RandomState(42)
105+
dict_null = {col: rng.randint(low=0, high=len(df)) for col in df.columns}
106+
for col, num_nulls in dict_null.items():
107+
null_idx = df.index[
108+
rng.choice(np.arange(len(df)), size=num_nulls, replace=False)
109+
]
110+
df.loc[null_idx, col] = None
111+
112+
df2 = df.__dataframe__()
113+
114+
assert df2.get_column_by_name("x").null_count == dict_null["x"]
115+
assert df2.get_column_by_name("y").null_count == dict_null["y"]
116+
assert df2.get_column_by_name("z").null_count == dict_null["z"]
117+
118+
119+
@pytest.mark.parametrize(
120+
"data",
121+
[
122+
{"x": [1.5, 2.5, 3.5], "y": [9.2, 10.5, 11.8]},
123+
{"x": [1, 2, 0], "y": [9.2, 10.5, 11.8]},
124+
{
125+
"x": np.array([True, True, False]),
126+
"y": np.array([1, 2, 0]),
127+
"z": np.array([9.2, 10.5, 11.8]),
128+
},
129+
],
130+
)
131+
def test_mixed_data(data):
132+
df = pd.DataFrame(data)
133+
df2 = df.__dataframe__()
134+
135+
for col_name in df.columns:
136+
assert df2.get_column_by_name(col_name).null_count == 0
137+
138+
139+
def test_mixed_missing():
140+
df = pd.DataFrame(
141+
{
142+
"x": np.array([True, None, False, None, True]),
143+
"y": np.array([None, 2, None, 1, 2]),
144+
"z": np.array([9.2, 10.5, None, 11.8, None]),
145+
}
146+
)
147+
148+
df2 = df.__dataframe__()
149+
150+
for col_name in df.columns:
151+
assert df2.get_column_by_name(col_name).null_count == 2
152+
153+
154+
def test_select_columns_error():
155+
df = pd.DataFrame(int_data)
156+
157+
df2 = df.__dataframe__()
158+
159+
with pytest.raises(ValueError):
160+
assert from_dataframe(df2.select_columns(np.array([0, 2]))) == from_dataframe(
161+
df2.select_columns_by_name(("col33", "col35"))
162+
)
163+
164+
165+
def test_select_columns_by_name_error():
166+
df = pd.DataFrame(int_data)
167+
168+
df2 = df.__dataframe__()
169+
170+
with pytest.raises(ValueError):
171+
assert from_dataframe(
172+
df2.select_columns_by_name(np.array(["col33", "col35"]))
173+
) == from_dataframe(df2.select_columns((0, 2)))
174+
175+
176+
def test_string():
177+
test_str_data = string_data["separator data"] + [""]
178+
df = pd.DataFrame({"A": test_str_data})
179+
col = df.__dataframe__().get_column_by_name("A")
180+
181+
assert col.size == 6
182+
assert col.null_count == 1
183+
assert col.dtype[0] == DtypeKind.STRING
184+
assert col.describe_null == (ColumnNullType.USE_BYTEMASK, 0)
185+
186+
df_sliced = df[1:]
187+
col = df_sliced.__dataframe__().get_column_by_name("A")
188+
assert col.size == 5
189+
assert col.null_count == 1
190+
assert col.dtype[0] == DtypeKind.STRING
191+
assert col.describe_null == (ColumnNullType.USE_BYTEMASK, 0)

0 commit comments

Comments
 (0)