11
11
from functools import partial
12
12
from hashlib import md5
13
13
from importlib .metadata import version
14
+ from typing import (
15
+ IO ,
16
+ TYPE_CHECKING ,
17
+ Any ,
18
+ Callable ,
19
+ Iterable ,
20
+ Iterator ,
21
+ Sequence ,
22
+ TypeVar ,
23
+ )
14
24
from urllib .parse import urlsplit
15
25
26
+ if TYPE_CHECKING :
27
+ from typing_extensions import TypeGuard
28
+
29
+ from fsspec .spec import AbstractFileSystem
30
+
31
+
16
32
DEFAULT_BLOCK_SIZE = 5 * 2 ** 20
17
33
34
+ T = TypeVar ("T" )
18
35
19
- def infer_storage_options (urlpath , inherit_storage_options = None ):
36
+
37
+ def infer_storage_options (
38
+ urlpath : str , inherit_storage_options : dict [str , Any ] | None = None
39
+ ) -> dict [str , Any ]:
20
40
"""Infer storage options from URL path and merge it with existing storage
21
41
options.
22
42
@@ -68,7 +88,7 @@ def infer_storage_options(urlpath, inherit_storage_options=None):
68
88
# for HTTP, we don't want to parse, as requests will anyway
69
89
return {"protocol" : protocol , "path" : urlpath }
70
90
71
- options = {"protocol" : protocol , "path" : path }
91
+ options : dict [ str , Any ] = {"protocol" : protocol , "path" : path }
72
92
73
93
if parsed_path .netloc :
74
94
# Parse `hostname` from netloc manually because `parsed_path.hostname`
@@ -98,7 +118,9 @@ def infer_storage_options(urlpath, inherit_storage_options=None):
98
118
return options
99
119
100
120
101
- def update_storage_options (options , inherited = None ):
121
+ def update_storage_options (
122
+ options : dict [str , Any ], inherited : dict [str , Any ] | None = None
123
+ ) -> None :
102
124
if not inherited :
103
125
inherited = {}
104
126
collisions = set (options ) & set (inherited )
@@ -116,7 +138,7 @@ def update_storage_options(options, inherited=None):
116
138
compressions : dict [str , str ] = {}
117
139
118
140
119
- def infer_compression (filename ) :
141
+ def infer_compression (filename : str ) -> str | None :
120
142
"""Infer compression, if available, from filename.
121
143
122
144
Infer a named compression type, if registered and available, from filename
@@ -126,9 +148,10 @@ def infer_compression(filename):
126
148
extension = os .path .splitext (filename )[- 1 ].strip ("." ).lower ()
127
149
if extension in compressions :
128
150
return compressions [extension ]
151
+ return None
129
152
130
153
131
- def build_name_function (max_int ) :
154
+ def build_name_function (max_int : float ) -> Callable [[ int ], str ] :
132
155
"""Returns a function that receives a single integer
133
156
and returns it as a string padded by enough zero characters
134
157
to align with maximum possible integer
@@ -151,13 +174,13 @@ def build_name_function(max_int):
151
174
152
175
pad_length = int (math .ceil (math .log10 (max_int )))
153
176
154
- def name_function (i ) :
177
+ def name_function (i : int ) -> str :
155
178
return str (i ).zfill (pad_length )
156
179
157
180
return name_function
158
181
159
182
160
- def seek_delimiter (file , delimiter , blocksize ) :
183
+ def seek_delimiter (file : IO [ bytes ] , delimiter : bytes , blocksize : int ) -> bool :
161
184
r"""Seek current file to file start, file end, or byte after delimiter seq.
162
185
163
186
Seeks file to next chunk delimiter, where chunks are defined on file start,
@@ -186,7 +209,7 @@ def seek_delimiter(file, delimiter, blocksize):
186
209
187
210
# Interface is for binary IO, with delimiter as bytes, but initialize last
188
211
# with result of file.read to preserve compatibility with text IO.
189
- last = None
212
+ last : bytes | None = None
190
213
while True :
191
214
current = file .read (blocksize )
192
215
if not current :
@@ -206,7 +229,13 @@ def seek_delimiter(file, delimiter, blocksize):
206
229
last = full [- len (delimiter ) :]
207
230
208
231
209
- def read_block (f , offset , length , delimiter = None , split_before = False ):
232
+ def read_block (
233
+ f : IO [bytes ],
234
+ offset : int ,
235
+ length : int | None ,
236
+ delimiter : bytes | None = None ,
237
+ split_before : bool = False ,
238
+ ) -> bytes :
210
239
"""Read a block of bytes from a file
211
240
212
241
Parameters
@@ -267,11 +296,14 @@ def read_block(f, offset, length, delimiter=None, split_before=False):
267
296
length = end - start
268
297
269
298
f .seek (offset )
270
- b = f .read (length )
299
+ if length is not None :
300
+ b = f .read (length )
301
+ else :
302
+ b = f .read ()
271
303
return b
272
304
273
305
274
- def tokenize (* args , ** kwargs ) :
306
+ def tokenize (* args : Any , ** kwargs : Any ) -> str :
275
307
"""Deterministic token
276
308
277
309
(modified from dask.base)
@@ -285,13 +317,14 @@ def tokenize(*args, **kwargs):
285
317
if kwargs :
286
318
args += (kwargs ,)
287
319
try :
288
- return md5 (str (args ).encode ()). hexdigest ( )
320
+ h = md5 (str (args ).encode ())
289
321
except ValueError :
290
322
# FIPS systems: https://github.com/fsspec/filesystem_spec/issues/380
291
- return md5 (str (args ).encode (), usedforsecurity = False ).hexdigest ()
323
+ h = md5 (str (args ).encode (), usedforsecurity = False ) # type: ignore[call-arg]
324
+ return h .hexdigest ()
292
325
293
326
294
- def stringify_path (filepath ) :
327
+ def stringify_path (filepath : str | os . PathLike [ str ] | pathlib . Path ) -> str :
295
328
"""Attempt to convert a path-like object to a string.
296
329
297
330
Parameters
@@ -315,7 +348,7 @@ def stringify_path(filepath):
315
348
"""
316
349
if isinstance (filepath , str ):
317
350
return filepath
318
- elif hasattr (filepath , "__fspath__" ):
351
+ elif hasattr (filepath , "__fspath__" ) or isinstance ( filepath , os . PathLike ) :
319
352
return filepath .__fspath__ ()
320
353
elif isinstance (filepath , pathlib .Path ):
321
354
return str (filepath )
@@ -325,13 +358,15 @@ def stringify_path(filepath):
325
358
return filepath
326
359
327
360
328
- def make_instance (cls , args , kwargs ):
361
+ def make_instance (
362
+ cls : Callable [..., T ], args : Sequence [Any ], kwargs : dict [str , Any ]
363
+ ) -> T :
329
364
inst = cls (* args , ** kwargs )
330
- inst ._determine_worker ()
365
+ inst ._determine_worker () # type: ignore[attr-defined]
331
366
return inst
332
367
333
368
334
- def common_prefix (paths ) :
369
+ def common_prefix (paths : Iterable [ str ]) -> str :
335
370
"""For a list of paths, find the shortest prefix common to all"""
336
371
parts = [p .split ("/" ) for p in paths ]
337
372
lmax = min (len (p ) for p in parts )
@@ -344,7 +379,12 @@ def common_prefix(paths):
344
379
return "/" .join (parts [0 ][:i ])
345
380
346
381
347
- def other_paths (paths , path2 , exists = False , flatten = False ):
382
+ def other_paths (
383
+ paths : list [str ],
384
+ path2 : str | list [str ],
385
+ exists : bool = False ,
386
+ flatten : bool = False ,
387
+ ) -> list [str ]:
348
388
"""In bulk file operations, construct a new file tree from a list of files
349
389
350
390
Parameters
@@ -384,25 +424,25 @@ def other_paths(paths, path2, exists=False, flatten=False):
384
424
return path2
385
425
386
426
387
- def is_exception (obj ) :
427
+ def is_exception (obj : Any ) -> bool :
388
428
return isinstance (obj , BaseException )
389
429
390
430
391
- def isfilelike (f ) :
431
+ def isfilelike (f : Any ) -> TypeGuard [ IO [ bytes ]] :
392
432
for attr in ["read" , "close" , "tell" ]:
393
433
if not hasattr (f , attr ):
394
434
return False
395
435
return True
396
436
397
437
398
- def get_protocol (url ) :
438
+ def get_protocol (url : str ) -> str :
399
439
parts = re .split (r"(\:\:|\://)" , url , 1 )
400
440
if len (parts ) > 1 :
401
441
return parts [0 ]
402
442
return "file"
403
443
404
444
405
- def can_be_local (path ) :
445
+ def can_be_local (path : str ) -> bool :
406
446
"""Can the given URL be used with open_local?"""
407
447
from fsspec import get_filesystem_class
408
448
@@ -413,7 +453,7 @@ def can_be_local(path):
413
453
return False
414
454
415
455
416
- def get_package_version_without_import (name ) :
456
+ def get_package_version_without_import (name : str ) -> str | None :
417
457
"""For given package name, try to find the version without importing it
418
458
419
459
Import and package.__version__ is still the backup here, so an import
@@ -439,7 +479,12 @@ def get_package_version_without_import(name):
439
479
return None
440
480
441
481
442
- def setup_logging (logger = None , logger_name = None , level = "DEBUG" , clear = True ):
482
+ def setup_logging (
483
+ logger : logging .Logger | None = None ,
484
+ logger_name : str | None = None ,
485
+ level : str = "DEBUG" ,
486
+ clear : bool = True ,
487
+ ) -> logging .Logger :
443
488
if logger is None and logger_name is None :
444
489
raise ValueError ("Provide either logger object or logger name" )
445
490
logger = logger or logging .getLogger (logger_name )
@@ -455,20 +500,22 @@ def setup_logging(logger=None, logger_name=None, level="DEBUG", clear=True):
455
500
return logger
456
501
457
502
458
- def _unstrip_protocol (name , fs ) :
503
+ def _unstrip_protocol (name : str , fs : AbstractFileSystem ) -> str :
459
504
return fs .unstrip_protocol (name )
460
505
461
506
462
- def mirror_from (origin_name , methods ):
507
+ def mirror_from (
508
+ origin_name : str , methods : Iterable [str ]
509
+ ) -> Callable [[type [T ]], type [T ]]:
463
510
"""Mirror attributes and methods from the given
464
511
origin_name attribute of the instance to the
465
512
decorated class"""
466
513
467
- def origin_getter (method , self ) :
514
+ def origin_getter (method : str , self : Any ) -> Any :
468
515
origin = getattr (self , origin_name )
469
516
return getattr (origin , method )
470
517
471
- def wrapper (cls ) :
518
+ def wrapper (cls : type [ T ]) -> type [ T ] :
472
519
for method in methods :
473
520
wrapped_method = partial (origin_getter , method )
474
521
setattr (cls , method , property (wrapped_method ))
@@ -478,11 +525,18 @@ def wrapper(cls):
478
525
479
526
480
527
@contextlib .contextmanager
481
- def nullcontext (obj ) :
528
+ def nullcontext (obj : T ) -> Iterator [ T ] :
482
529
yield obj
483
530
484
531
485
- def merge_offset_ranges (paths , starts , ends , max_gap = 0 , max_block = None , sort = True ):
532
+ def merge_offset_ranges (
533
+ paths : list [str ],
534
+ starts : list [int ] | int ,
535
+ ends : list [int ] | int ,
536
+ max_gap : int = 0 ,
537
+ max_block : int | None = None ,
538
+ sort : bool = True ,
539
+ ) -> tuple [list [str ], list [int ], list [int ]]:
486
540
"""Merge adjacent byte-offset ranges when the inter-range
487
541
gap is <= `max_gap`, and when the merged byte range does not
488
542
exceed `max_block` (if specified). By default, this function
@@ -496,7 +550,7 @@ def merge_offset_ranges(paths, starts, ends, max_gap=0, max_block=None, sort=Tru
496
550
if not isinstance (starts , list ):
497
551
starts = [starts ] * len (paths )
498
552
if not isinstance (ends , list ):
499
- ends = [starts ] * len (paths )
553
+ ends = [ends ] * len (paths )
500
554
if len (starts ) != len (paths ) or len (ends ) != len (paths ):
501
555
raise ValueError
502
556
@@ -545,7 +599,7 @@ def merge_offset_ranges(paths, starts, ends, max_gap=0, max_block=None, sort=Tru
545
599
return paths , starts , ends
546
600
547
601
548
- def file_size (filelike ) :
602
+ def file_size (filelike : IO [ bytes ]) -> int :
549
603
"""Find length of any open read-mode file-like"""
550
604
pos = filelike .tell ()
551
605
try :
0 commit comments