10
10
from functools import partial
11
11
from hashlib import md5
12
12
from importlib .metadata import version
13
+ from typing import (
14
+ IO ,
15
+ TYPE_CHECKING ,
16
+ Any ,
17
+ Callable ,
18
+ Iterable ,
19
+ Iterator ,
20
+ Sequence ,
21
+ TypeVar ,
22
+ )
13
23
from urllib .parse import urlsplit
14
24
25
+ if TYPE_CHECKING :
26
+ from typing_extensions import TypeGuard
27
+
28
+ from fsspec .spec import AbstractFileSystem
29
+
30
+
15
31
DEFAULT_BLOCK_SIZE = 5 * 2 ** 20
16
32
33
+ T = TypeVar ("T" )
17
34
18
- def infer_storage_options (urlpath , inherit_storage_options = None ):
35
+
36
+ def infer_storage_options (
37
+ urlpath : str , inherit_storage_options : dict [str , Any ] | None = None
38
+ ) -> dict [str , Any ]:
19
39
"""Infer storage options from URL path and merge it with existing storage
20
40
options.
21
41
@@ -67,7 +87,7 @@ def infer_storage_options(urlpath, inherit_storage_options=None):
67
87
# for HTTP, we don't want to parse, as requests will anyway
68
88
return {"protocol" : protocol , "path" : urlpath }
69
89
70
- options = {"protocol" : protocol , "path" : path }
90
+ options : dict [ str , Any ] = {"protocol" : protocol , "path" : path }
71
91
72
92
if parsed_path .netloc :
73
93
# Parse `hostname` from netloc manually because `parsed_path.hostname`
@@ -97,7 +117,9 @@ def infer_storage_options(urlpath, inherit_storage_options=None):
97
117
return options
98
118
99
119
100
- def update_storage_options (options , inherited = None ):
120
+ def update_storage_options (
121
+ options : dict [str , Any ], inherited : dict [str , Any ] | None = None
122
+ ) -> None :
101
123
if not inherited :
102
124
inherited = {}
103
125
collisions = set (options ) & set (inherited )
@@ -115,7 +137,7 @@ def update_storage_options(options, inherited=None):
115
137
compressions : dict [str , str ] = {}
116
138
117
139
118
- def infer_compression (filename ) :
140
+ def infer_compression (filename : str ) -> str | None :
119
141
"""Infer compression, if available, from filename.
120
142
121
143
Infer a named compression type, if registered and available, from filename
@@ -125,9 +147,10 @@ def infer_compression(filename):
125
147
extension = os .path .splitext (filename )[- 1 ].strip ("." ).lower ()
126
148
if extension in compressions :
127
149
return compressions [extension ]
150
+ return None
128
151
129
152
130
- def build_name_function (max_int ) :
153
+ def build_name_function (max_int : float ) -> Callable [[ int ], str ] :
131
154
"""Returns a function that receives a single integer
132
155
and returns it as a string padded by enough zero characters
133
156
to align with maximum possible integer
@@ -150,13 +173,13 @@ def build_name_function(max_int):
150
173
151
174
pad_length = int (math .ceil (math .log10 (max_int )))
152
175
153
- def name_function (i ) :
176
+ def name_function (i : int ) -> str :
154
177
return str (i ).zfill (pad_length )
155
178
156
179
return name_function
157
180
158
181
159
- def seek_delimiter (file , delimiter , blocksize ) :
182
+ def seek_delimiter (file : IO [ bytes ] , delimiter : bytes , blocksize : int ) -> bool :
160
183
r"""Seek current file to file start, file end, or byte after delimiter seq.
161
184
162
185
Seeks file to next chunk delimiter, where chunks are defined on file start,
@@ -185,7 +208,7 @@ def seek_delimiter(file, delimiter, blocksize):
185
208
186
209
# Interface is for binary IO, with delimiter as bytes, but initialize last
187
210
# with result of file.read to preserve compatibility with text IO.
188
- last = None
211
+ last : bytes | None = None
189
212
while True :
190
213
current = file .read (blocksize )
191
214
if not current :
@@ -205,7 +228,13 @@ def seek_delimiter(file, delimiter, blocksize):
205
228
last = full [- len (delimiter ) :]
206
229
207
230
208
- def read_block (f , offset , length , delimiter = None , split_before = False ):
231
+ def read_block (
232
+ f : IO [bytes ],
233
+ offset : int ,
234
+ length : int | None ,
235
+ delimiter : bytes | None = None ,
236
+ split_before : bool = False ,
237
+ ) -> bytes :
209
238
"""Read a block of bytes from a file
210
239
211
240
Parameters
@@ -266,11 +295,14 @@ def read_block(f, offset, length, delimiter=None, split_before=False):
266
295
length = end - start
267
296
268
297
f .seek (offset )
269
- b = f .read (length )
298
+ if length is not None :
299
+ b = f .read (length )
300
+ else :
301
+ b = f .read ()
270
302
return b
271
303
272
304
273
- def tokenize (* args , ** kwargs ) :
305
+ def tokenize (* args : Any , ** kwargs : Any ) -> str :
274
306
"""Deterministic token
275
307
276
308
(modified from dask.base)
@@ -287,10 +319,10 @@ def tokenize(*args, **kwargs):
287
319
return md5 (str (args ).encode ()).hexdigest ()
288
320
except ValueError :
289
321
# FIPS systems: https://github.com/fsspec/filesystem_spec/issues/380
290
- return md5 (str (args ).encode (), usedforsecurity = False ).hexdigest ()
322
+ return md5 (str (args ).encode (), usedforsecurity = False ).hexdigest () # type: ignore[call-arg]
291
323
292
324
293
- def stringify_path (filepath ) :
325
+ def stringify_path (filepath : str | os . PathLike [ str ] | pathlib . Path ) -> str :
294
326
"""Attempt to convert a path-like object to a string.
295
327
296
328
Parameters
@@ -314,7 +346,7 @@ def stringify_path(filepath):
314
346
"""
315
347
if isinstance (filepath , str ):
316
348
return filepath
317
- elif hasattr (filepath , "__fspath__" ):
349
+ elif hasattr (filepath , "__fspath__" ) or isinstance ( filepath , os . PathLike ) :
318
350
return filepath .__fspath__ ()
319
351
elif isinstance (filepath , pathlib .Path ):
320
352
return str (filepath )
@@ -324,13 +356,15 @@ def stringify_path(filepath):
324
356
return filepath
325
357
326
358
327
- def make_instance (cls , args , kwargs ):
359
+ def make_instance (
360
+ cls : Callable [..., T ], args : Sequence [Any ], kwargs : dict [str , Any ]
361
+ ) -> T :
328
362
inst = cls (* args , ** kwargs )
329
- inst ._determine_worker ()
363
+ inst ._determine_worker () # type: ignore[attr-defined]
330
364
return inst
331
365
332
366
333
- def common_prefix (paths ) :
367
+ def common_prefix (paths : Iterable [ str ]) -> str :
334
368
"""For a list of paths, find the shortest prefix common to all"""
335
369
parts = [p .split ("/" ) for p in paths ]
336
370
lmax = min (len (p ) for p in parts )
@@ -343,7 +377,13 @@ def common_prefix(paths):
343
377
return "/" .join (parts [0 ][:i ])
344
378
345
379
346
- def other_paths (paths , path2 , is_dir = None , exists = False , flatten = False ):
380
+ def other_paths (
381
+ paths : list [str ],
382
+ path2 : str | list [str ],
383
+ is_dir : bool | None = None ,
384
+ exists : bool = False ,
385
+ flatten : bool = False ,
386
+ ) -> list [str ]:
347
387
"""In bulk file operations, construct a new file tree from a list of files
348
388
349
389
Parameters
@@ -388,25 +428,25 @@ def other_paths(paths, path2, is_dir=None, exists=False, flatten=False):
388
428
return path2
389
429
390
430
391
- def is_exception (obj ) :
431
+ def is_exception (obj : Any ) -> bool :
392
432
return isinstance (obj , BaseException )
393
433
394
434
395
- def isfilelike (f ) :
435
+ def isfilelike (f : Any ) -> TypeGuard [ IO [ bytes ]] :
396
436
for attr in ["read" , "close" , "tell" ]:
397
437
if not hasattr (f , attr ):
398
438
return False
399
439
return True
400
440
401
441
402
- def get_protocol (url ) :
442
+ def get_protocol (url : str ) -> str :
403
443
parts = re .split (r"(\:\:|\://)" , url , 1 )
404
444
if len (parts ) > 1 :
405
445
return parts [0 ]
406
446
return "file"
407
447
408
448
409
- def can_be_local (path ) :
449
+ def can_be_local (path : str ) -> bool :
410
450
"""Can the given URL be used with open_local?"""
411
451
from fsspec import get_filesystem_class
412
452
@@ -417,7 +457,7 @@ def can_be_local(path):
417
457
return False
418
458
419
459
420
- def get_package_version_without_import (name ) :
460
+ def get_package_version_without_import (name : str ) -> str | None :
421
461
"""For given package name, try to find the version without importing it
422
462
423
463
Import and package.__version__ is still the backup here, so an import
@@ -443,7 +483,12 @@ def get_package_version_without_import(name):
443
483
return None
444
484
445
485
446
- def setup_logging (logger = None , logger_name = None , level = "DEBUG" , clear = True ):
486
+ def setup_logging (
487
+ logger : logging .Logger | None = None ,
488
+ logger_name : str | None = None ,
489
+ level : str = "DEBUG" ,
490
+ clear : bool = True ,
491
+ ) -> logging .Logger :
447
492
if logger is None and logger_name is None :
448
493
raise ValueError ("Provide either logger object or logger name" )
449
494
logger = logger or logging .getLogger (logger_name )
@@ -459,20 +504,22 @@ def setup_logging(logger=None, logger_name=None, level="DEBUG", clear=True):
459
504
return logger
460
505
461
506
462
- def _unstrip_protocol (name , fs ) :
507
+ def _unstrip_protocol (name : str , fs : AbstractFileSystem ) -> str :
463
508
return fs .unstrip_protocol (name )
464
509
465
510
466
- def mirror_from (origin_name , methods ):
511
+ def mirror_from (
512
+ origin_name : str , methods : Iterable [str ]
513
+ ) -> Callable [[type [T ]], type [T ]]:
467
514
"""Mirror attributes and methods from the given
468
515
origin_name attribute of the instance to the
469
516
decorated class"""
470
517
471
- def origin_getter (method , self ) :
518
+ def origin_getter (method : str , self : Any ) -> Any :
472
519
origin = getattr (self , origin_name )
473
520
return getattr (origin , method )
474
521
475
- def wrapper (cls ) :
522
+ def wrapper (cls : type [ T ]) -> type [ T ] :
476
523
for method in methods :
477
524
wrapped_method = partial (origin_getter , method )
478
525
setattr (cls , method , property (wrapped_method ))
@@ -482,11 +529,18 @@ def wrapper(cls):
482
529
483
530
484
531
@contextmanager
485
- def nullcontext (obj ) :
532
+ def nullcontext (obj : T ) -> Iterator [ T ] :
486
533
yield obj
487
534
488
535
489
- def merge_offset_ranges (paths , starts , ends , max_gap = 0 , max_block = None , sort = True ):
536
+ def merge_offset_ranges (
537
+ paths : list [str ],
538
+ starts : list [int ] | int ,
539
+ ends : list [int ] | int ,
540
+ max_gap : int = 0 ,
541
+ max_block : int | None = None ,
542
+ sort : bool = True ,
543
+ ) -> tuple [list [str ], list [int ], list [int ]]:
490
544
"""Merge adjacent byte-offset ranges when the inter-range
491
545
gap is <= `max_gap`, and when the merged byte range does not
492
546
exceed `max_block` (if specified). By default, this function
@@ -500,7 +554,7 @@ def merge_offset_ranges(paths, starts, ends, max_gap=0, max_block=None, sort=Tru
500
554
if not isinstance (starts , list ):
501
555
starts = [starts ] * len (paths )
502
556
if not isinstance (ends , list ):
503
- ends = [starts ] * len (paths )
557
+ ends = [ends ] * len (paths )
504
558
if len (starts ) != len (paths ) or len (ends ) != len (paths ):
505
559
raise ValueError
506
560
@@ -549,7 +603,7 @@ def merge_offset_ranges(paths, starts, ends, max_gap=0, max_block=None, sort=Tru
549
603
return paths , starts , ends
550
604
551
605
552
- def file_size (filelike ) :
606
+ def file_size (filelike : IO [ bytes ]) -> int :
553
607
"""Find length of any open read-mode file-like"""
554
608
pos = filelike .tell ()
555
609
try :
0 commit comments