Skip to content

Commit ad298a7

Browse files
committed
Declare enums explicitly, fix hints
Signed-off-by: Vasily Litvinov <[email protected]>
1 parent 81be345 commit ad298a7

File tree

1 file changed

+39
-37
lines changed

1 file changed

+39
-37
lines changed

protocol/dataframe_protocol.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,32 @@
1+
from typing import Tuple, Optional, Dict, Any, Iterable, Sequence
2+
import enum
3+
4+
class DlpackDeviceType(enum.IntEnum):
5+
CPU = 1
6+
CUDA = 2
7+
CPU_PINNED = 3
8+
OPENCL = 4
9+
VULKAN = 7
10+
METAL = 8
11+
VPI = 9
12+
ROCM = 10
13+
14+
class DtypeKind(enum.IntEnum):
15+
INT = 0
16+
UINT = 1
17+
FLOAT = 2
18+
BOOL = 20
19+
STRING = 21 # UTF-8
20+
DATETIME = 22
21+
CATEGORICAL = 23
22+
23+
class ColumnNullType:
24+
NON_NULLABLE = 0
25+
USE_NAN = 1
26+
USE_SENTINEL = 2
27+
USE_BITMASK = 3
28+
USE_BYTEMASK = 4
29+
130
class Buffer:
231
"""
332
Data in the buffer is guaranteed to be contiguous in memory.
@@ -41,20 +70,11 @@ def __dlpack__(self):
4170
"""
4271
raise NotImplementedError("__dlpack__")
4372

44-
def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
73+
def __dlpack_device__(self) -> Tuple[DlpackDeviceType, int]:
4574
"""
4675
Device type and device ID for where the data in the buffer resides.
4776
48-
Uses device type codes matching DLPack. Enum members are::
49-
50-
- CPU = 1
51-
- CUDA = 2
52-
- CPU_PINNED = 3
53-
- OPENCL = 4
54-
- VULKAN = 7
55-
- METAL = 8
56-
- VPI = 9
57-
- ROCM = 10
77+
Uses device type codes matching DLPack.
5878
5979
Note: must be implemented even if ``__dlpack__`` is not.
6080
"""
@@ -128,20 +148,10 @@ def offset(self) -> int:
128148
pass
129149

130150
@property
131-
def dtype(self) -> Tuple[enum.IntEnum, int, str, str]:
151+
def dtype(self) -> Tuple[DtypeKind, int, str, str]:
132152
"""
133153
Dtype description as a tuple ``(kind, bit-width, format string, endianness)``.
134154
135-
Kind :
136-
137-
- INT = 0
138-
- UINT = 1
139-
- FLOAT = 2
140-
- BOOL = 20
141-
- STRING = 21 # UTF-8
142-
- DATETIME = 22
143-
- CATEGORICAL = 23
144-
145155
Bit-width : the number of bits as an integer
146156
Format string : data type description format string in Apache Arrow C
147157
Data Interface format.
@@ -170,7 +180,7 @@ def dtype(self) -> Tuple[enum.IntEnum, int, str, str]:
170180
pass
171181

172182
@property
173-
def describe_categorical(self) -> dict[bool, bool, Optional[dict]]:
183+
def describe_categorical(self) -> Dict[bool, bool, Optional[dict]]:
174184
"""
175185
If the dtype is categorical, there are two options:
176186
@@ -193,19 +203,11 @@ def describe_categorical(self) -> dict[bool, bool, Optional[dict]]:
193203
pass
194204

195205
@property
196-
def describe_null(self) -> Tuple[int, Any]:
206+
def describe_null(self) -> Tuple[ColumnNullType, Any]:
197207
"""
198208
Return the missing value (or "null") representation the column dtype
199209
uses, as a tuple ``(kind, value)``.
200210
201-
Kind:
202-
203-
- 0 : non-nullable
204-
- 1 : NaN/NaT
205-
- 2 : sentinel value
206-
- 3 : bit mask
207-
- 4 : byte mask
208-
209211
Value : if kind is "sentinel value", the actual value. If kind is a bit
210212
mask or a byte mask, the value (0 or 1) indicating a missing value. None
211213
otherwise.
@@ -234,15 +236,15 @@ def num_chunks(self) -> int:
234236
"""
235237
pass
236238

237-
def get_chunks(self, n_chunks : Optional[int] = None) -> Iterable[Column]:
239+
def get_chunks(self, n_chunks : Optional[int] = None) -> Iterable["Column"]:
238240
"""
239241
Return an iterator yielding the chunks.
240242
241243
See `DataFrame.get_chunks` for details on ``n_chunks``.
242244
"""
243245
pass
244246

245-
def get_buffers(self) -> dict[Tuple[Buffer, Any], Optional[Tuple[Buffer, Any]], Optional[Tuple[Buffer, Any]]]:
247+
def get_buffers(self) -> Dict[Tuple[Buffer, Any], Optional[Tuple[Buffer, Any]], Optional[Tuple[Buffer, Any]]]:
246248
"""
247249
Return a dictionary containing the underlying buffers.
248250
@@ -367,19 +369,19 @@ def get_columns(self) -> Iterable[Column]:
367369
"""
368370
pass
369371

370-
def select_columns(self, indices: Sequence[int]) -> DataFrame:
372+
def select_columns(self, indices: Sequence[int]) -> "DataFrame":
371373
"""
372374
Create a new DataFrame by selecting a subset of columns by index.
373375
"""
374376
pass
375377

376-
def select_columns_by_name(self, names: Sequence[str]) -> DataFrame:
378+
def select_columns_by_name(self, names: Sequence[str]) -> "DataFrame":
377379
"""
378380
Create a new DataFrame by selecting a subset of columns by name.
379381
"""
380382
pass
381383

382-
def get_chunks(self, n_chunks : Optional[int] = None) -> Iterable[DataFrame]:
384+
def get_chunks(self, n_chunks : Optional[int] = None) -> Iterable["DataFrame"]:
383385
"""
384386
Return an iterator yielding the chunks.
385387

0 commit comments

Comments
 (0)