Skip to content

Commit 8eab8a2

Browse files
vnlitvinovrgommers
authored andcommitted
Declare enums explicitly, fix hints
Signed-off-by: Vasily Litvinov <[email protected]>
1 parent ee9be04 commit 8eab8a2

File tree

1 file changed

+38
-36
lines changed

1 file changed

+38
-36
lines changed

protocol/dataframe_protocol.py

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,32 @@
1+
from typing import Tuple, Optional, Dict, Any, Iterable, Sequence
2+
import enum
3+
4+
class DlpackDeviceType(enum.IntEnum):
5+
CPU = 1
6+
CUDA = 2
7+
CPU_PINNED = 3
8+
OPENCL = 4
9+
VULKAN = 7
10+
METAL = 8
11+
VPI = 9
12+
ROCM = 10
13+
14+
class DtypeKind(enum.IntEnum):
15+
INT = 0
16+
UINT = 1
17+
FLOAT = 2
18+
BOOL = 20
19+
STRING = 21 # UTF-8
20+
DATETIME = 22
21+
CATEGORICAL = 23
22+
23+
class ColumnNullType:
24+
NON_NULLABLE = 0
25+
USE_NAN = 1
26+
USE_SENTINEL = 2
27+
USE_BITMASK = 3
28+
USE_BYTEMASK = 4
29+
130
class Buffer:
231
"""
332
Data in the buffer is guaranteed to be contiguous in memory.
@@ -41,20 +70,11 @@ def __dlpack__(self):
4170
"""
4271
raise NotImplementedError("__dlpack__")
4372

44-
def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
73+
def __dlpack_device__(self) -> Tuple[DlpackDeviceType, int]:
4574
"""
4675
Device type and device ID for where the data in the buffer resides.
4776
48-
Uses device type codes matching DLPack. Enum members are::
49-
50-
- CPU = 1
51-
- CUDA = 2
52-
- CPU_PINNED = 3
53-
- OPENCL = 4
54-
- VULKAN = 7
55-
- METAL = 8
56-
- VPI = 9
57-
- ROCM = 10
77+
Uses device type codes matching DLPack.
5878
5979
Note: must be implemented even if ``__dlpack__`` is not.
6080
"""
@@ -128,20 +148,10 @@ def offset(self) -> int:
128148
pass
129149

130150
@property
131-
def dtype(self) -> Tuple[enum.IntEnum, int, str, str]:
151+
def dtype(self) -> Tuple[DtypeKind, int, str, str]:
132152
"""
133153
Dtype description as a tuple ``(kind, bit-width, format string, endianness)``.
134154
135-
Kind :
136-
137-
- INT = 0
138-
- UINT = 1
139-
- FLOAT = 2
140-
- BOOL = 20
141-
- STRING = 21 # UTF-8
142-
- DATETIME = 22
143-
- CATEGORICAL = 23
144-
145155
Bit-width : the number of bits as an integer
146156
Format string : data type description format string in Apache Arrow C
147157
Data Interface format.
@@ -194,19 +204,11 @@ def describe_categorical(self) -> dict[bool, bool, Optional[Column]]:
194204
pass
195205

196206
@property
197-
def describe_null(self) -> Tuple[int, Any]:
207+
def describe_null(self) -> Tuple[ColumnNullType, Any]:
198208
"""
199209
Return the missing value (or "null") representation the column dtype
200210
uses, as a tuple ``(kind, value)``.
201211
202-
Kind:
203-
204-
- 0 : non-nullable
205-
- 1 : NaN/NaT
206-
- 2 : sentinel value
207-
- 3 : bit mask
208-
- 4 : byte mask
209-
210212
Value : if kind is "sentinel value", the actual value. If kind is a bit
211213
mask or a byte mask, the value (0 or 1) indicating a missing value. None
212214
otherwise.
@@ -235,15 +237,15 @@ def num_chunks(self) -> int:
235237
"""
236238
pass
237239

238-
def get_chunks(self, n_chunks : Optional[int] = None) -> Iterable[Column]:
240+
def get_chunks(self, n_chunks : Optional[int] = None) -> Iterable["Column"]:
239241
"""
240242
Return an iterator yielding the chunks.
241243
242244
See `DataFrame.get_chunks` for details on ``n_chunks``.
243245
"""
244246
pass
245247

246-
def get_buffers(self) -> dict[Tuple[Buffer, Any], Optional[Tuple[Buffer, Any]], Optional[Tuple[Buffer, Any]]]:
248+
def get_buffers(self) -> Dict[Tuple[Buffer, Any], Optional[Tuple[Buffer, Any]], Optional[Tuple[Buffer, Any]]]:
247249
"""
248250
Return a dictionary containing the underlying buffers.
249251
@@ -368,19 +370,19 @@ def get_columns(self) -> Iterable[Column]:
368370
"""
369371
pass
370372

371-
def select_columns(self, indices: Sequence[int]) -> DataFrame:
373+
def select_columns(self, indices: Sequence[int]) -> "DataFrame":
372374
"""
373375
Create a new DataFrame by selecting a subset of columns by index.
374376
"""
375377
pass
376378

377-
def select_columns_by_name(self, names: Sequence[str]) -> DataFrame:
379+
def select_columns_by_name(self, names: Sequence[str]) -> "DataFrame":
378380
"""
379381
Create a new DataFrame by selecting a subset of columns by name.
380382
"""
381383
pass
382384

383-
def get_chunks(self, n_chunks : Optional[int] = None) -> Iterable[DataFrame]:
385+
def get_chunks(self, n_chunks : Optional[int] = None) -> Iterable["DataFrame"]:
384386
"""
385387
Return an iterator yielding the chunks.
386388

0 commit comments

Comments
 (0)