17
17
from fsspec .implementations .cache_mapper import create_cache_mapper
18
18
from fsspec .implementations .cache_metadata import CacheMetadata
19
19
from fsspec .spec import AbstractBufferedFile
20
+ from fsspec .transaction import Transaction
20
21
from fsspec .utils import infer_compression
21
22
22
23
if TYPE_CHECKING :
25
26
logger = logging .getLogger ("fsspec.cached" )
26
27
27
28
29
+ class WriteCachedTransaction (Transaction ):
30
+ def complete (self , commit = True ):
31
+ rpaths = [f .path for f in self .files ]
32
+ lpaths = [f .fn for f in self .files ]
33
+ if commit :
34
+ self .fs .put (lpaths , rpaths )
35
+ # else remove?
36
+ self .fs ._intrans = False
37
+
38
+
28
39
class CachingFileSystem (AbstractFileSystem ):
29
40
"""Locally caching filesystem, layer over any other FS
30
41
@@ -415,6 +426,10 @@ def __getattribute__(self, item):
415
426
"__eq__" ,
416
427
"to_json" ,
417
428
"cache_size" ,
429
+ "pipe_file" ,
430
+ "pipe" ,
431
+ "start_transaction" ,
432
+ "end_transaction" ,
418
433
]:
419
434
# all the methods defined in this class. Note `open` here, since
420
435
# it calls `_open`, but is actually in superclass
@@ -423,7 +438,10 @@ def __getattribute__(self, item):
423
438
)
424
439
if item in ["__reduce_ex__" ]:
425
440
raise AttributeError
426
- if item in ["_cache" ]:
441
+ if item in ["transaction" ]:
442
+ # property
443
+ return type (self ).transaction .__get__ (self )
444
+ if item in ["_cache" , "transaction_type" ]:
427
445
# class attributes
428
446
return getattr (type (self ), item )
429
447
if item == "__class__" :
@@ -512,7 +530,13 @@ def open_many(self, open_files):
512
530
self ._mkcache ()
513
531
else :
514
532
return [
515
- LocalTempFile (self .fs , path , mode = open_files .mode ) for path in paths
533
+ LocalTempFile (
534
+ self .fs ,
535
+ path ,
536
+ mode = open_files .mode ,
537
+ fn = os .path .join (self .storage [- 1 ], self ._mapper (path )),
538
+ )
539
+ for path in paths
516
540
]
517
541
518
542
if self .compression :
@@ -625,7 +649,8 @@ def cat(
625
649
def _open (self , path , mode = "rb" , ** kwargs ):
626
650
path = self ._strip_protocol (path )
627
651
if "r" not in mode :
628
- return LocalTempFile (self , path , mode = mode )
652
+ fn = self ._make_local_details (path )
653
+ return LocalTempFile (self , path , mode = mode , fn = fn )
629
654
detail = self ._check_file (path )
630
655
if detail :
631
656
detail , fn = detail
@@ -692,6 +717,7 @@ class SimpleCacheFileSystem(WholeFileCacheFileSystem):
692
717
693
718
protocol = "simplecache"
694
719
local_file = True
720
+ transaction_type = WriteCachedTransaction
695
721
696
722
def __init__ (self , ** kwargs ):
697
723
kw = kwargs .copy ()
@@ -716,6 +742,22 @@ def save_cache(self):
716
742
def load_cache (self ):
717
743
pass
718
744
745
+ def pipe_file (self , path , value = None , ** kwargs ):
746
+ if self ._intrans :
747
+ with self .open (path , "wb" ) as f :
748
+ f .write (value )
749
+ else :
750
+ super ().pipe_file (path , value )
751
+
752
+ def pipe (self , path , value = None , ** kwargs ):
753
+ if isinstance (path , str ):
754
+ self .pipe_file (self ._strip_protocol (path ), value , ** kwargs )
755
+ elif isinstance (path , dict ):
756
+ for k , v in path .items ():
757
+ self .pipe_file (self ._strip_protocol (k ), v , ** kwargs )
758
+ else :
759
+ raise ValueError ("path must be str or dict" )
760
+
719
761
def cat_ranges (
720
762
self , paths , starts , ends , max_gap = None , on_error = "return" , ** kwargs
721
763
):
@@ -729,14 +771,17 @@ def cat_ranges(
729
771
730
772
def _open (self , path , mode = "rb" , ** kwargs ):
731
773
path = self ._strip_protocol (path )
774
+ sha = self ._mapper (path )
732
775
733
776
if "r" not in mode :
734
- return LocalTempFile (self , path , mode = mode )
777
+ fn = os .path .join (self .storage [- 1 ], sha )
778
+ return LocalTempFile (
779
+ self , path , mode = mode , autocommit = not self ._intrans , fn = fn
780
+ )
735
781
fn = self ._check_file (path )
736
782
if fn :
737
783
return open (fn , mode )
738
784
739
- sha = self ._mapper (path )
740
785
fn = os .path .join (self .storage [- 1 ], sha )
741
786
logger .debug ("Copying %s to local cache" , path )
742
787
kwargs ["mode" ] = mode
@@ -767,13 +812,9 @@ def _open(self, path, mode="rb", **kwargs):
767
812
class LocalTempFile :
768
813
"""A temporary local file, which will be uploaded on commit"""
769
814
770
- def __init__ (self , fs , path , fn = None , mode = "wb" , autocommit = True , seek = 0 ):
771
- if fn :
772
- self .fn = fn
773
- self .fh = open (fn , mode )
774
- else :
775
- fd , self .fn = tempfile .mkstemp ()
776
- self .fh = open (fd , mode )
815
+ def __init__ (self , fs , path , fn , mode = "wb" , autocommit = True , seek = 0 ):
816
+ self .fn = fn
817
+ self .fh = open (fn , mode )
777
818
self .mode = mode
778
819
if seek :
779
820
self .fh .seek (seek )
0 commit comments