Skip to content

Commit f796d67

Browse files
committed
Add typehints to fsspec.utils
1 parent 43df953 commit f796d67

File tree

2 files changed

+87
-33
lines changed

2 files changed

+87
-33
lines changed

fsspec/spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def _strip_protocol(cls, path):
196196
# use of root_marker to make minimum required path, e.g., "/"
197197
return path or cls.root_marker
198198

199-
def unstrip_protocol(self, name):
199+
def unstrip_protocol(self, name: str) -> str:
200200
"""Format FS-specific path to generic, including protocol"""
201201
protos = (self.protocol,) if isinstance(self.protocol, str) else self.protocol
202202
for protocol in protos:

fsspec/utils.py

Lines changed: 86 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,32 @@
1010
from functools import partial
1111
from hashlib import md5
1212
from importlib.metadata import version
13+
from typing import (
14+
IO,
15+
TYPE_CHECKING,
16+
Any,
17+
Callable,
18+
Iterable,
19+
Iterator,
20+
Sequence,
21+
TypeVar,
22+
)
1323
from urllib.parse import urlsplit
1424

25+
if TYPE_CHECKING:
26+
from typing_extensions import TypeGuard
27+
28+
from fsspec.spec import AbstractFileSystem
29+
30+
1531
DEFAULT_BLOCK_SIZE = 5 * 2**20
1632

33+
T = TypeVar("T")
1734

18-
def infer_storage_options(urlpath, inherit_storage_options=None):
35+
36+
def infer_storage_options(
37+
urlpath: str, inherit_storage_options: dict[str, Any] | None = None
38+
) -> dict[str, Any]:
1939
"""Infer storage options from URL path and merge it with existing storage
2040
options.
2141
@@ -67,7 +87,7 @@ def infer_storage_options(urlpath, inherit_storage_options=None):
6787
# for HTTP, we don't want to parse, as requests will anyway
6888
return {"protocol": protocol, "path": urlpath}
6989

70-
options = {"protocol": protocol, "path": path}
90+
options: dict[str, Any] = {"protocol": protocol, "path": path}
7191

7292
if parsed_path.netloc:
7393
# Parse `hostname` from netloc manually because `parsed_path.hostname`
@@ -97,7 +117,9 @@ def infer_storage_options(urlpath, inherit_storage_options=None):
97117
return options
98118

99119

100-
def update_storage_options(options, inherited=None):
120+
def update_storage_options(
121+
options: dict[str, Any], inherited: dict[str, Any] | None = None
122+
) -> None:
101123
if not inherited:
102124
inherited = {}
103125
collisions = set(options) & set(inherited)
@@ -115,7 +137,7 @@ def update_storage_options(options, inherited=None):
115137
compressions: dict[str, str] = {}
116138

117139

118-
def infer_compression(filename):
140+
def infer_compression(filename: str) -> str | None:
119141
"""Infer compression, if available, from filename.
120142
121143
Infer a named compression type, if registered and available, from filename
@@ -125,9 +147,10 @@ def infer_compression(filename):
125147
extension = os.path.splitext(filename)[-1].strip(".").lower()
126148
if extension in compressions:
127149
return compressions[extension]
150+
return None
128151

129152

130-
def build_name_function(max_int):
153+
def build_name_function(max_int: float) -> Callable[[int], str]:
131154
"""Returns a function that receives a single integer
132155
and returns it as a string padded by enough zero characters
133156
to align with maximum possible integer
@@ -150,13 +173,13 @@ def build_name_function(max_int):
150173

151174
pad_length = int(math.ceil(math.log10(max_int)))
152175

153-
def name_function(i):
176+
def name_function(i: int) -> str:
154177
return str(i).zfill(pad_length)
155178

156179
return name_function
157180

158181

159-
def seek_delimiter(file, delimiter, blocksize):
182+
def seek_delimiter(file: IO[bytes], delimiter: bytes, blocksize: int) -> bool:
160183
r"""Seek current file to file start, file end, or byte after delimiter seq.
161184
162185
Seeks file to next chunk delimiter, where chunks are defined on file start,
@@ -185,7 +208,7 @@ def seek_delimiter(file, delimiter, blocksize):
185208

186209
# Interface is for binary IO, with delimiter as bytes, but initialize last
187210
# with result of file.read to preserve compatibility with text IO.
188-
last = None
211+
last: bytes | None = None
189212
while True:
190213
current = file.read(blocksize)
191214
if not current:
@@ -205,7 +228,13 @@ def seek_delimiter(file, delimiter, blocksize):
205228
last = full[-len(delimiter) :]
206229

207230

208-
def read_block(f, offset, length, delimiter=None, split_before=False):
231+
def read_block(
232+
f: IO[bytes],
233+
offset: int,
234+
length: int | None,
235+
delimiter: bytes | None = None,
236+
split_before: bool = False,
237+
) -> bytes:
209238
"""Read a block of bytes from a file
210239
211240
Parameters
@@ -266,11 +295,14 @@ def read_block(f, offset, length, delimiter=None, split_before=False):
266295
length = end - start
267296

268297
f.seek(offset)
269-
b = f.read(length)
298+
if length is not None:
299+
b = f.read(length)
300+
else:
301+
b = f.read()
270302
return b
271303

272304

273-
def tokenize(*args, **kwargs):
305+
def tokenize(*args: Any, **kwargs: Any) -> str:
274306
"""Deterministic token
275307
276308
(modified from dask.base)
@@ -287,10 +319,10 @@ def tokenize(*args, **kwargs):
287319
return md5(str(args).encode()).hexdigest()
288320
except ValueError:
289321
# FIPS systems: https://github.com/fsspec/filesystem_spec/issues/380
290-
return md5(str(args).encode(), usedforsecurity=False).hexdigest()
322+
return md5(str(args).encode(), usedforsecurity=False).hexdigest() # type: ignore[call-arg]
291323

292324

293-
def stringify_path(filepath):
325+
def stringify_path(filepath: str | os.PathLike[str] | pathlib.Path) -> str:
294326
"""Attempt to convert a path-like object to a string.
295327
296328
Parameters
@@ -314,7 +346,7 @@ def stringify_path(filepath):
314346
"""
315347
if isinstance(filepath, str):
316348
return filepath
317-
elif hasattr(filepath, "__fspath__"):
349+
elif hasattr(filepath, "__fspath__") or isinstance(filepath, os.PathLike):
318350
return filepath.__fspath__()
319351
elif isinstance(filepath, pathlib.Path):
320352
return str(filepath)
@@ -324,13 +356,15 @@ def stringify_path(filepath):
324356
return filepath
325357

326358

327-
def make_instance(cls, args, kwargs):
359+
def make_instance(
360+
cls: Callable[..., T], args: Sequence[Any], kwargs: dict[str, Any]
361+
) -> T:
328362
inst = cls(*args, **kwargs)
329-
inst._determine_worker()
363+
inst._determine_worker() # type: ignore[attr-defined]
330364
return inst
331365

332366

333-
def common_prefix(paths):
367+
def common_prefix(paths: Iterable[str]) -> str:
334368
"""For a list of paths, find the shortest prefix common to all"""
335369
parts = [p.split("/") for p in paths]
336370
lmax = min(len(p) for p in parts)
@@ -343,7 +377,13 @@ def common_prefix(paths):
343377
return "/".join(parts[0][:i])
344378

345379

346-
def other_paths(paths, path2, is_dir=None, exists=False, flatten=False):
380+
def other_paths(
381+
paths: list[str],
382+
path2: str | list[str],
383+
is_dir: bool | None = None,
384+
exists: bool = False,
385+
flatten: bool = False,
386+
) -> list[str]:
347387
"""In bulk file operations, construct a new file tree from a list of files
348388
349389
Parameters
@@ -388,25 +428,25 @@ def other_paths(paths, path2, is_dir=None, exists=False, flatten=False):
388428
return path2
389429

390430

391-
def is_exception(obj):
431+
def is_exception(obj: Any) -> bool:
392432
return isinstance(obj, BaseException)
393433

394434

395-
def isfilelike(f):
435+
def isfilelike(f: Any) -> TypeGuard[IO[bytes]]:
396436
for attr in ["read", "close", "tell"]:
397437
if not hasattr(f, attr):
398438
return False
399439
return True
400440

401441

402-
def get_protocol(url):
442+
def get_protocol(url: str) -> str:
403443
parts = re.split(r"(\:\:|\://)", url, 1)
404444
if len(parts) > 1:
405445
return parts[0]
406446
return "file"
407447

408448

409-
def can_be_local(path):
449+
def can_be_local(path: str) -> bool:
410450
"""Can the given URL be used with open_local?"""
411451
from fsspec import get_filesystem_class
412452

@@ -417,7 +457,7 @@ def can_be_local(path):
417457
return False
418458

419459

420-
def get_package_version_without_import(name):
460+
def get_package_version_without_import(name: str) -> str | None:
421461
"""For given package name, try to find the version without importing it
422462
423463
Import and package.__version__ is still the backup here, so an import
@@ -443,7 +483,12 @@ def get_package_version_without_import(name):
443483
return None
444484

445485

446-
def setup_logging(logger=None, logger_name=None, level="DEBUG", clear=True):
486+
def setup_logging(
487+
logger: logging.Logger | None = None,
488+
logger_name: str | None = None,
489+
level: str = "DEBUG",
490+
clear: bool = True,
491+
) -> logging.Logger:
447492
if logger is None and logger_name is None:
448493
raise ValueError("Provide either logger object or logger name")
449494
logger = logger or logging.getLogger(logger_name)
@@ -459,20 +504,22 @@ def setup_logging(logger=None, logger_name=None, level="DEBUG", clear=True):
459504
return logger
460505

461506

462-
def _unstrip_protocol(name, fs):
507+
def _unstrip_protocol(name: str, fs: AbstractFileSystem) -> str:
463508
return fs.unstrip_protocol(name)
464509

465510

466-
def mirror_from(origin_name, methods):
511+
def mirror_from(
512+
origin_name: str, methods: Iterable[str]
513+
) -> Callable[[type[T]], type[T]]:
467514
"""Mirror attributes and methods from the given
468515
origin_name attribute of the instance to the
469516
decorated class"""
470517

471-
def origin_getter(method, self):
518+
def origin_getter(method: str, self: Any) -> Any:
472519
origin = getattr(self, origin_name)
473520
return getattr(origin, method)
474521

475-
def wrapper(cls):
522+
def wrapper(cls: type[T]) -> type[T]:
476523
for method in methods:
477524
wrapped_method = partial(origin_getter, method)
478525
setattr(cls, method, property(wrapped_method))
@@ -482,11 +529,18 @@ def wrapper(cls):
482529

483530

484531
@contextmanager
485-
def nullcontext(obj):
532+
def nullcontext(obj: T) -> Iterator[T]:
486533
yield obj
487534

488535

489-
def merge_offset_ranges(paths, starts, ends, max_gap=0, max_block=None, sort=True):
536+
def merge_offset_ranges(
537+
paths: list[str],
538+
starts: list[int] | int,
539+
ends: list[int] | int,
540+
max_gap: int = 0,
541+
max_block: int | None = None,
542+
sort: bool = True,
543+
) -> tuple[list[str], list[int], list[int]]:
490544
"""Merge adjacent byte-offset ranges when the inter-range
491545
gap is <= `max_gap`, and when the merged byte range does not
492546
exceed `max_block` (if specified). By default, this function
@@ -500,7 +554,7 @@ def merge_offset_ranges(paths, starts, ends, max_gap=0, max_block=None, sort=Tru
500554
if not isinstance(starts, list):
501555
starts = [starts] * len(paths)
502556
if not isinstance(ends, list):
503-
ends = [starts] * len(paths)
557+
ends = [ends] * len(paths)
504558
if len(starts) != len(paths) or len(ends) != len(paths):
505559
raise ValueError
506560

@@ -549,7 +603,7 @@ def merge_offset_ranges(paths, starts, ends, max_gap=0, max_block=None, sort=Tru
549603
return paths, starts, ends
550604

551605

552-
def file_size(filelike):
606+
def file_size(filelike: IO[bytes]) -> int:
553607
"""Find length of any open read-mode file-like"""
554608
pos = filelike.tell()
555609
try:

0 commit comments

Comments
 (0)