Skip to content

Commit 14a7788

Browse files
authored
Cached write transaction (#1434)
* Allow transaction writing with local cache * Writes cached and uploaded in single op
1 parent d1d2268 commit 14a7788

File tree

2 files changed

+80
-13
lines changed

2 files changed

+80
-13
lines changed

fsspec/implementations/cached.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from fsspec.implementations.cache_mapper import create_cache_mapper
1818
from fsspec.implementations.cache_metadata import CacheMetadata
1919
from fsspec.spec import AbstractBufferedFile
20+
from fsspec.transaction import Transaction
2021
from fsspec.utils import infer_compression
2122

2223
if TYPE_CHECKING:
@@ -25,6 +26,16 @@
2526
logger = logging.getLogger("fsspec.cached")
2627

2728

29+
class WriteCachedTransaction(Transaction):
30+
def complete(self, commit=True):
31+
rpaths = [f.path for f in self.files]
32+
lpaths = [f.fn for f in self.files]
33+
if commit:
34+
self.fs.put(lpaths, rpaths)
35+
# else remove?
36+
self.fs._intrans = False
37+
38+
2839
class CachingFileSystem(AbstractFileSystem):
2940
"""Locally caching filesystem, layer over any other FS
3041
@@ -415,6 +426,10 @@ def __getattribute__(self, item):
415426
"__eq__",
416427
"to_json",
417428
"cache_size",
429+
"pipe_file",
430+
"pipe",
431+
"start_transaction",
432+
"end_transaction",
418433
]:
419434
# all the methods defined in this class. Note `open` here, since
420435
# it calls `_open`, but is actually in superclass
@@ -423,7 +438,10 @@ def __getattribute__(self, item):
423438
)
424439
if item in ["__reduce_ex__"]:
425440
raise AttributeError
426-
if item in ["_cache"]:
441+
if item in ["transaction"]:
442+
# property
443+
return type(self).transaction.__get__(self)
444+
if item in ["_cache", "transaction_type"]:
427445
# class attributes
428446
return getattr(type(self), item)
429447
if item == "__class__":
@@ -512,7 +530,13 @@ def open_many(self, open_files):
512530
self._mkcache()
513531
else:
514532
return [
515-
LocalTempFile(self.fs, path, mode=open_files.mode) for path in paths
533+
LocalTempFile(
534+
self.fs,
535+
path,
536+
mode=open_files.mode,
537+
fn=os.path.join(self.storage[-1], self._mapper(path)),
538+
)
539+
for path in paths
516540
]
517541

518542
if self.compression:
@@ -625,7 +649,8 @@ def cat(
625649
def _open(self, path, mode="rb", **kwargs):
626650
path = self._strip_protocol(path)
627651
if "r" not in mode:
628-
return LocalTempFile(self, path, mode=mode)
652+
fn = self._make_local_details(path)
653+
return LocalTempFile(self, path, mode=mode, fn=fn)
629654
detail = self._check_file(path)
630655
if detail:
631656
detail, fn = detail
@@ -692,6 +717,7 @@ class SimpleCacheFileSystem(WholeFileCacheFileSystem):
692717

693718
protocol = "simplecache"
694719
local_file = True
720+
transaction_type = WriteCachedTransaction
695721

696722
def __init__(self, **kwargs):
697723
kw = kwargs.copy()
@@ -716,6 +742,22 @@ def save_cache(self):
716742
def load_cache(self):
717743
pass
718744

745+
def pipe_file(self, path, value=None, **kwargs):
746+
if self._intrans:
747+
with self.open(path, "wb") as f:
748+
f.write(value)
749+
else:
750+
super().pipe_file(path, value)
751+
752+
def pipe(self, path, value=None, **kwargs):
753+
if isinstance(path, str):
754+
self.pipe_file(self._strip_protocol(path), value, **kwargs)
755+
elif isinstance(path, dict):
756+
for k, v in path.items():
757+
self.pipe_file(self._strip_protocol(k), v, **kwargs)
758+
else:
759+
raise ValueError("path must be str or dict")
760+
719761
def cat_ranges(
720762
self, paths, starts, ends, max_gap=None, on_error="return", **kwargs
721763
):
@@ -729,14 +771,17 @@ def cat_ranges(
729771

730772
def _open(self, path, mode="rb", **kwargs):
731773
path = self._strip_protocol(path)
774+
sha = self._mapper(path)
732775

733776
if "r" not in mode:
734-
return LocalTempFile(self, path, mode=mode)
777+
fn = os.path.join(self.storage[-1], sha)
778+
return LocalTempFile(
779+
self, path, mode=mode, autocommit=not self._intrans, fn=fn
780+
)
735781
fn = self._check_file(path)
736782
if fn:
737783
return open(fn, mode)
738784

739-
sha = self._mapper(path)
740785
fn = os.path.join(self.storage[-1], sha)
741786
logger.debug("Copying %s to local cache", path)
742787
kwargs["mode"] = mode
@@ -767,13 +812,9 @@ def _open(self, path, mode="rb", **kwargs):
767812
class LocalTempFile:
768813
"""A temporary local file, which will be uploaded on commit"""
769814

770-
def __init__(self, fs, path, fn=None, mode="wb", autocommit=True, seek=0):
771-
if fn:
772-
self.fn = fn
773-
self.fh = open(fn, mode)
774-
else:
775-
fd, self.fn = tempfile.mkstemp()
776-
self.fh = open(fd, mode)
815+
def __init__(self, fs, path, fn, mode="wb", autocommit=True, seek=0):
816+
self.fn = fn
817+
self.fh = open(fn, mode)
777818
self.mode = mode
778819
if seek:
779820
self.fh.seek(seek)

fsspec/implementations/tests/test_cached.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,10 @@ def test_workflow(ftp_writable, impl):
289289
with fs.open("/out", "wb") as f:
290290
f.write(b"changed")
291291

292-
assert fs.cat("/out") == b"test" # old value
292+
if impl == "filecache":
293+
assert (
294+
fs.cat("/out") == b"changed"
295+
) # new value, because we overwrote the cached location
293296

294297

295298
@pytest.mark.parametrize("impl", ["simplecache", "blockcache"])
@@ -1272,3 +1275,26 @@ def test_spurious_directory_issue1410(tmpdir):
12721275
# would be created and the next assertion would fail.
12731276
assert len(os.listdir()) == 1
12741277
assert fs._parent("/any/path") == "any" # correct for ZIP, which has no leading /
1278+
1279+
1280+
def test_write_transaction(tmpdir, m, monkeypatch):
1281+
called = [0]
1282+
orig = m.put
1283+
1284+
def patched_put(*args, **kwargs):
1285+
called[0] += 1
1286+
orig(*args, **kwargs)
1287+
1288+
monkeypatch.setattr(m, "put", patched_put)
1289+
tmpdir = str(tmpdir)
1290+
fs, _ = fsspec.core.url_to_fs("simplecache::memory://", cache_storage=tmpdir)
1291+
with fs.transaction:
1292+
fs.pipe("myfile", b"1")
1293+
fs.pipe("otherfile", b"2")
1294+
with fs.open("blarh", "wb") as f:
1295+
f.write(b"ff")
1296+
assert not m.find("")
1297+
1298+
assert m.cat("myfile") == b"1"
1299+
assert m.cat("otherfile") == b"2"
1300+
assert called[0] == 1 # copy was done in one go

0 commit comments

Comments
 (0)