Skip to content

Commit a00afa9

Browse files
committed
Add typehints to fsspec.utils
1 parent a8034d3 commit a00afa9

File tree

2 files changed

+88
-34
lines changed

2 files changed

+88
-34
lines changed

fsspec/spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def _strip_protocol(cls, path):
196196
# use of root_marker to make minimum required path, e.g., "/"
197197
return path or cls.root_marker
198198

199-
def unstrip_protocol(self, name):
199+
def unstrip_protocol(self, name: str) -> str:
200200
"""Format FS-specific path to generic, including protocol"""
201201
protos = (self.protocol,) if isinstance(self.protocol, str) else self.protocol
202202
for protocol in protos:

fsspec/utils.py

Lines changed: 87 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,32 @@
1111
from functools import partial
1212
from hashlib import md5
1313
from importlib.metadata import version
14+
from typing import (
15+
IO,
16+
TYPE_CHECKING,
17+
Any,
18+
Callable,
19+
Iterable,
20+
Iterator,
21+
Sequence,
22+
TypeVar,
23+
)
1424
from urllib.parse import urlsplit
1525

26+
if TYPE_CHECKING:
27+
from typing_extensions import TypeGuard
28+
29+
from fsspec.spec import AbstractFileSystem
30+
31+
1632
DEFAULT_BLOCK_SIZE = 5 * 2**20
1733

34+
T = TypeVar("T")
1835

19-
def infer_storage_options(urlpath, inherit_storage_options=None):
36+
37+
def infer_storage_options(
38+
urlpath: str, inherit_storage_options: dict[str, Any] | None = None
39+
) -> dict[str, Any]:
2040
"""Infer storage options from URL path and merge it with existing storage
2141
options.
2242
@@ -68,7 +88,7 @@ def infer_storage_options(urlpath, inherit_storage_options=None):
6888
# for HTTP, we don't want to parse, as requests will anyway
6989
return {"protocol": protocol, "path": urlpath}
7090

71-
options = {"protocol": protocol, "path": path}
91+
options: dict[str, Any] = {"protocol": protocol, "path": path}
7292

7393
if parsed_path.netloc:
7494
# Parse `hostname` from netloc manually because `parsed_path.hostname`
@@ -98,7 +118,9 @@ def infer_storage_options(urlpath, inherit_storage_options=None):
98118
return options
99119

100120

101-
def update_storage_options(options, inherited=None):
121+
def update_storage_options(
122+
options: dict[str, Any], inherited: dict[str, Any] | None = None
123+
) -> None:
102124
if not inherited:
103125
inherited = {}
104126
collisions = set(options) & set(inherited)
@@ -116,7 +138,7 @@ def update_storage_options(options, inherited=None):
116138
compressions: dict[str, str] = {}
117139

118140

119-
def infer_compression(filename):
141+
def infer_compression(filename: str) -> str | None:
120142
"""Infer compression, if available, from filename.
121143
122144
Infer a named compression type, if registered and available, from filename
@@ -126,9 +148,10 @@ def infer_compression(filename):
126148
extension = os.path.splitext(filename)[-1].strip(".").lower()
127149
if extension in compressions:
128150
return compressions[extension]
151+
return None
129152

130153

131-
def build_name_function(max_int):
154+
def build_name_function(max_int: float) -> Callable[[int], str]:
132155
"""Returns a function that receives a single integer
133156
and returns it as a string padded by enough zero characters
134157
to align with maximum possible integer
@@ -151,13 +174,13 @@ def build_name_function(max_int):
151174

152175
pad_length = int(math.ceil(math.log10(max_int)))
153176

154-
def name_function(i):
177+
def name_function(i: int) -> str:
155178
return str(i).zfill(pad_length)
156179

157180
return name_function
158181

159182

160-
def seek_delimiter(file, delimiter, blocksize):
183+
def seek_delimiter(file: IO[bytes], delimiter: bytes, blocksize: int) -> bool:
161184
r"""Seek current file to file start, file end, or byte after delimiter seq.
162185
163186
Seeks file to next chunk delimiter, where chunks are defined on file start,
@@ -186,7 +209,7 @@ def seek_delimiter(file, delimiter, blocksize):
186209

187210
# Interface is for binary IO, with delimiter as bytes, but initialize last
188211
# with result of file.read to preserve compatibility with text IO.
189-
last = None
212+
last: bytes | None = None
190213
while True:
191214
current = file.read(blocksize)
192215
if not current:
@@ -206,7 +229,13 @@ def seek_delimiter(file, delimiter, blocksize):
206229
last = full[-len(delimiter) :]
207230

208231

209-
def read_block(f, offset, length, delimiter=None, split_before=False):
232+
def read_block(
233+
f: IO[bytes],
234+
offset: int,
235+
length: int | None,
236+
delimiter: bytes | None = None,
237+
split_before: bool = False,
238+
) -> bytes:
210239
"""Read a block of bytes from a file
211240
212241
Parameters
@@ -267,11 +296,14 @@ def read_block(f, offset, length, delimiter=None, split_before=False):
267296
length = end - start
268297

269298
f.seek(offset)
270-
b = f.read(length)
299+
if length is not None:
300+
b = f.read(length)
301+
else:
302+
b = f.read()
271303
return b
272304

273305

274-
def tokenize(*args, **kwargs):
306+
def tokenize(*args: Any, **kwargs: Any) -> str:
275307
"""Deterministic token
276308
277309
(modified from dask.base)
@@ -285,13 +317,14 @@ def tokenize(*args, **kwargs):
285317
if kwargs:
286318
args += (kwargs,)
287319
try:
288-
return md5(str(args).encode()).hexdigest()
320+
h = md5(str(args).encode())
289321
except ValueError:
290322
# FIPS systems: https://github.com/fsspec/filesystem_spec/issues/380
291-
return md5(str(args).encode(), usedforsecurity=False).hexdigest()
323+
h = md5(str(args).encode(), usedforsecurity=False) # type: ignore[call-arg]
324+
return h.hexdigest()
292325

293326

294-
def stringify_path(filepath):
327+
def stringify_path(filepath: str | os.PathLike[str] | pathlib.Path) -> str:
295328
"""Attempt to convert a path-like object to a string.
296329
297330
Parameters
@@ -315,7 +348,7 @@ def stringify_path(filepath):
315348
"""
316349
if isinstance(filepath, str):
317350
return filepath
318-
elif hasattr(filepath, "__fspath__"):
351+
elif hasattr(filepath, "__fspath__") or isinstance(filepath, os.PathLike):
319352
return filepath.__fspath__()
320353
elif isinstance(filepath, pathlib.Path):
321354
return str(filepath)
@@ -325,13 +358,15 @@ def stringify_path(filepath):
325358
return filepath
326359

327360

328-
def make_instance(cls, args, kwargs):
361+
def make_instance(
362+
cls: Callable[..., T], args: Sequence[Any], kwargs: dict[str, Any]
363+
) -> T:
329364
inst = cls(*args, **kwargs)
330-
inst._determine_worker()
365+
inst._determine_worker() # type: ignore[attr-defined]
331366
return inst
332367

333368

334-
def common_prefix(paths):
369+
def common_prefix(paths: Iterable[str]) -> str:
335370
"""For a list of paths, find the shortest prefix common to all"""
336371
parts = [p.split("/") for p in paths]
337372
lmax = min(len(p) for p in parts)
@@ -344,7 +379,12 @@ def common_prefix(paths):
344379
return "/".join(parts[0][:i])
345380

346381

347-
def other_paths(paths, path2, exists=False, flatten=False):
382+
def other_paths(
383+
paths: list[str],
384+
path2: str | list[str],
385+
exists: bool = False,
386+
flatten: bool = False,
387+
) -> list[str]:
348388
"""In bulk file operations, construct a new file tree from a list of files
349389
350390
Parameters
@@ -384,25 +424,25 @@ def other_paths(paths, path2, exists=False, flatten=False):
384424
return path2
385425

386426

387-
def is_exception(obj):
427+
def is_exception(obj: Any) -> bool:
388428
return isinstance(obj, BaseException)
389429

390430

391-
def isfilelike(f):
431+
def isfilelike(f: Any) -> TypeGuard[IO[bytes]]:
392432
for attr in ["read", "close", "tell"]:
393433
if not hasattr(f, attr):
394434
return False
395435
return True
396436

397437

398-
def get_protocol(url):
438+
def get_protocol(url: str) -> str:
399439
parts = re.split(r"(\:\:|\://)", url, 1)
400440
if len(parts) > 1:
401441
return parts[0]
402442
return "file"
403443

404444

405-
def can_be_local(path):
445+
def can_be_local(path: str) -> bool:
406446
"""Can the given URL be used with open_local?"""
407447
from fsspec import get_filesystem_class
408448

@@ -413,7 +453,7 @@ def can_be_local(path):
413453
return False
414454

415455

416-
def get_package_version_without_import(name):
456+
def get_package_version_without_import(name: str) -> str | None:
417457
"""For given package name, try to find the version without importing it
418458
419459
Import and package.__version__ is still the backup here, so an import
@@ -439,7 +479,12 @@ def get_package_version_without_import(name):
439479
return None
440480

441481

442-
def setup_logging(logger=None, logger_name=None, level="DEBUG", clear=True):
482+
def setup_logging(
483+
logger: logging.Logger | None = None,
484+
logger_name: str | None = None,
485+
level: str = "DEBUG",
486+
clear: bool = True,
487+
) -> logging.Logger:
443488
if logger is None and logger_name is None:
444489
raise ValueError("Provide either logger object or logger name")
445490
logger = logger or logging.getLogger(logger_name)
@@ -455,20 +500,22 @@ def setup_logging(logger=None, logger_name=None, level="DEBUG", clear=True):
455500
return logger
456501

457502

458-
def _unstrip_protocol(name, fs):
503+
def _unstrip_protocol(name: str, fs: AbstractFileSystem) -> str:
459504
return fs.unstrip_protocol(name)
460505

461506

462-
def mirror_from(origin_name, methods):
507+
def mirror_from(
508+
origin_name: str, methods: Iterable[str]
509+
) -> Callable[[type[T]], type[T]]:
463510
"""Mirror attributes and methods from the given
464511
origin_name attribute of the instance to the
465512
decorated class"""
466513

467-
def origin_getter(method, self):
514+
def origin_getter(method: str, self: Any) -> Any:
468515
origin = getattr(self, origin_name)
469516
return getattr(origin, method)
470517

471-
def wrapper(cls):
518+
def wrapper(cls: type[T]) -> type[T]:
472519
for method in methods:
473520
wrapped_method = partial(origin_getter, method)
474521
setattr(cls, method, property(wrapped_method))
@@ -478,11 +525,18 @@ def wrapper(cls):
478525

479526

480527
@contextlib.contextmanager
481-
def nullcontext(obj):
528+
def nullcontext(obj: T) -> Iterator[T]:
482529
yield obj
483530

484531

485-
def merge_offset_ranges(paths, starts, ends, max_gap=0, max_block=None, sort=True):
532+
def merge_offset_ranges(
533+
paths: list[str],
534+
starts: list[int] | int,
535+
ends: list[int] | int,
536+
max_gap: int = 0,
537+
max_block: int | None = None,
538+
sort: bool = True,
539+
) -> tuple[list[str], list[int], list[int]]:
486540
"""Merge adjacent byte-offset ranges when the inter-range
487541
gap is <= `max_gap`, and when the merged byte range does not
488542
exceed `max_block` (if specified). By default, this function
@@ -496,7 +550,7 @@ def merge_offset_ranges(paths, starts, ends, max_gap=0, max_block=None, sort=Tru
496550
if not isinstance(starts, list):
497551
starts = [starts] * len(paths)
498552
if not isinstance(ends, list):
499-
ends = [starts] * len(paths)
553+
ends = [ends] * len(paths)
500554
if len(starts) != len(paths) or len(ends) != len(paths):
501555
raise ValueError
502556

@@ -545,7 +599,7 @@ def merge_offset_ranges(paths, starts, ends, max_gap=0, max_block=None, sort=Tru
545599
return paths, starts, ends
546600

547601

548-
def file_size(filelike):
602+
def file_size(filelike: IO[bytes]) -> int:
549603
"""Find length of any open read-mode file-like"""
550604
pos = filelike.tell()
551605
try:

0 commit comments

Comments
 (0)