Skip to content

Commit d69899d

Browse files
authored
When writing references, don't cut end off (#1283)
1 parent 697d0f8 commit d69899d

File tree

1 file changed

+17
-46
lines changed

1 file changed

+17
-46
lines changed

fsspec/implementations/reference.py

Lines changed: 17 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,7 @@ def pd(self):
9999
return pd
100100

101101
def __init__(
102-
self,
103-
root,
104-
fs=None,
105-
out_root=None,
106-
cache_size=128,
102+
self, root, fs=None, out_root=None, cache_size=128, categorical_threshold=10
107103
):
108104
"""
109105
Parameters
@@ -128,6 +124,7 @@ def __init__(
128124
self.zmetadata = met["metadata"]
129125
self.url = self.root + "/{field}/refs.{record}.parq"
130126
self.out_root = out_root or self.root
127+
self.cat_thresh = categorical_threshold
131128

132129
# Define function to open and decompress refs
133130
@lru_cache(maxsize=cache_size)
@@ -322,10 +319,10 @@ def __getitem__(self, key):
322319
def __setitem__(self, key, value):
323320
if "/" in key and not self._is_meta(key):
324321
field, chunk = key.split("/")
325-
record, _, _ = self._key_to_record(key)
322+
record, i, _ = self._key_to_record(key)
326323
subdict = self._items.setdefault((field, record), {})
327-
subdict[chunk] = value
328-
if len(subdict) == self._output_size(field, record):
324+
subdict[i] = value
325+
if len(subdict) == self.record_size:
329326
self.write(field, record)
330327
else:
331328
# metadata or top-level
@@ -349,24 +346,12 @@ def __delitem__(self, key):
349346
record, _, _ = self._key_to_record(key)
350347
subdict = self._items.setdefault((field, record), {})
351348
subdict[chunk] = None
352-
if len(subdict) == self._output_size(field, record):
349+
if len(subdict) == self.record_size:
353350
self.write(field, record)
354351
else:
355352
# metadata or top-level
356353
self._items[key] = None
357354

358-
@lru_cache(4096)
359-
def _output_size(self, field, record):
360-
zarray = json.loads(self[f"{field}/.zarray"])
361-
nchunks = 1
362-
for s, ch in zip(zarray["shape"], zarray["chunks"]):
363-
nchunks *= math.ceil(s / ch)
364-
nrec = nchunks // self.record_size
365-
rem = nchunks % self.record_size
366-
if rem != 0:
367-
nrec += 1
368-
return self.record_size if record < nrec - 1 else rem
369-
370355
def write(self, field, record, base_url=None, storage_options=None):
371356
# extra requirements if writing
372357
import kerchunk.df
@@ -376,29 +361,15 @@ def write(self, field, record, base_url=None, storage_options=None):
376361
# TODO: if the dict is incomplete, also load records and merge in
377362
partition = self._items[(field, record)]
378363
fn = f"{base_url or self.out_root}/{field}/refs.{record}.parq"
379-
output_size = self._output_size(field, record)
380364

381365
####
382366
paths = np.full(self.record_size, np.nan, dtype="O")
383367
offsets = np.zeros(self.record_size, dtype="int64")
384368
sizes = np.zeros(self.record_size, dtype="int64")
385369
raws = np.full(self.record_size, np.nan, dtype="O")
386-
zarray = json.loads(self[f"{field}/.zarray"])
387-
shape = np.array(zarray["shape"])
388-
chunks = np.array(zarray["chunks"])
389370
nraw = 0
390371
npath = 0
391-
for key, data in partition.items():
392-
chunk_id = key.rsplit("/", 1)[-1]
393-
chunk_ints = [int(ch) for ch in chunk_id.split(".")]
394-
i = 0
395-
mult = 1
396-
for chunk_int, sh, ch in zip(chunk_ints[::-1], shape[::-1], chunks[::-1]):
397-
i += chunk_int * mult
398-
mult *= sh // ch
399-
j = i % self.record_size
400-
# Make note if expected number of chunks differs from actual
401-
# number found in references
372+
for j, data in partition.items():
402373
if isinstance(data, list):
403374
npath += 1
404375
paths[j] = data[0]
@@ -409,7 +380,6 @@ def write(self, field, record, base_url=None, storage_options=None):
409380
nraw += 1
410381
raws[j] = kerchunk.df._proc_raw(data)
411382
# TODO: only save needed columns
412-
# TODO: maybe categorize paths column
413383
df = pd.DataFrame(
414384
dict(
415385
path=paths,
@@ -418,7 +388,9 @@ def write(self, field, record, base_url=None, storage_options=None):
418388
raw=raws,
419389
),
420390
copy=False,
421-
)[:output_size]
391+
)
392+
if df.path.count() / (df.path.nunique() or 1) > self.cat_thresh:
393+
df["path"] = df["path"].astype("category")
422394
object_encoding = dict(raw="bytes", path="utf8")
423395
has_nulls = ["path", "raw"]
424396

@@ -447,16 +419,15 @@ def flush(self, base_url=None, storage_options=None):
447419
Location of the output
448420
"""
449421
# write what we have so far and clear sub chunks
450-
for thing in self._items:
422+
for thing in list(self._items):
451423
if isinstance(thing, tuple):
452424
field, record = thing
453-
if self._items.get((record, field)):
454-
self.write(
455-
field,
456-
record,
457-
base_url=base_url,
458-
storage_options=storage_options,
459-
)
425+
self.write(
426+
field,
427+
record,
428+
base_url=base_url,
429+
storage_options=storage_options,
430+
)
460431

461432
# gather .zmetadata from self._items and write that too
462433
for k in list(self._items):

0 commit comments

Comments
 (0)