@@ -99,11 +99,7 @@ def pd(self):
99
99
return pd
100
100
101
101
def __init__ (
102
- self ,
103
- root ,
104
- fs = None ,
105
- out_root = None ,
106
- cache_size = 128 ,
102
+ self , root , fs = None , out_root = None , cache_size = 128 , categorical_threshold = 10
107
103
):
108
104
"""
109
105
Parameters
@@ -128,6 +124,7 @@ def __init__(
128
124
self .zmetadata = met ["metadata" ]
129
125
self .url = self .root + "/{field}/refs.{record}.parq"
130
126
self .out_root = out_root or self .root
127
+ self .cat_thresh = categorical_threshold
131
128
132
129
# Define function to open and decompress refs
133
130
@lru_cache (maxsize = cache_size )
@@ -322,10 +319,10 @@ def __getitem__(self, key):
322
319
def __setitem__ (self , key , value ):
323
320
if "/" in key and not self ._is_meta (key ):
324
321
field , chunk = key .split ("/" )
325
- record , _ , _ = self ._key_to_record (key )
322
+ record , i , _ = self ._key_to_record (key )
326
323
subdict = self ._items .setdefault ((field , record ), {})
327
- subdict [chunk ] = value
328
- if len (subdict ) == self ._output_size ( field , record ) :
324
+ subdict [i ] = value
325
+ if len (subdict ) == self .record_size :
329
326
self .write (field , record )
330
327
else :
331
328
# metadata or top-level
@@ -349,24 +346,12 @@ def __delitem__(self, key):
349
346
record , _ , _ = self ._key_to_record (key )
350
347
subdict = self ._items .setdefault ((field , record ), {})
351
348
subdict [chunk ] = None
352
- if len (subdict ) == self ._output_size ( field , record ) :
349
+ if len (subdict ) == self .record_size :
353
350
self .write (field , record )
354
351
else :
355
352
# metadata or top-level
356
353
self ._items [key ] = None
357
354
358
- @lru_cache (4096 )
359
- def _output_size (self , field , record ):
360
- zarray = json .loads (self [f"{ field } /.zarray" ])
361
- nchunks = 1
362
- for s , ch in zip (zarray ["shape" ], zarray ["chunks" ]):
363
- nchunks *= math .ceil (s / ch )
364
- nrec = nchunks // self .record_size
365
- rem = nchunks % self .record_size
366
- if rem != 0 :
367
- nrec += 1
368
- return self .record_size if record < nrec - 1 else rem
369
-
370
355
def write (self , field , record , base_url = None , storage_options = None ):
371
356
# extra requirements if writing
372
357
import kerchunk .df
@@ -376,29 +361,15 @@ def write(self, field, record, base_url=None, storage_options=None):
376
361
# TODO: if the dict is incomplete, also load records and merge in
377
362
partition = self ._items [(field , record )]
378
363
fn = f"{ base_url or self .out_root } /{ field } /refs.{ record } .parq"
379
- output_size = self ._output_size (field , record )
380
364
381
365
####
382
366
paths = np .full (self .record_size , np .nan , dtype = "O" )
383
367
offsets = np .zeros (self .record_size , dtype = "int64" )
384
368
sizes = np .zeros (self .record_size , dtype = "int64" )
385
369
raws = np .full (self .record_size , np .nan , dtype = "O" )
386
- zarray = json .loads (self [f"{ field } /.zarray" ])
387
- shape = np .array (zarray ["shape" ])
388
- chunks = np .array (zarray ["chunks" ])
389
370
nraw = 0
390
371
npath = 0
391
- for key , data in partition .items ():
392
- chunk_id = key .rsplit ("/" , 1 )[- 1 ]
393
- chunk_ints = [int (ch ) for ch in chunk_id .split ("." )]
394
- i = 0
395
- mult = 1
396
- for chunk_int , sh , ch in zip (chunk_ints [::- 1 ], shape [::- 1 ], chunks [::- 1 ]):
397
- i += chunk_int * mult
398
- mult *= sh // ch
399
- j = i % self .record_size
400
- # Make note if expected number of chunks differs from actual
401
- # number found in references
372
+ for j , data in partition .items ():
402
373
if isinstance (data , list ):
403
374
npath += 1
404
375
paths [j ] = data [0 ]
@@ -409,7 +380,6 @@ def write(self, field, record, base_url=None, storage_options=None):
409
380
nraw += 1
410
381
raws [j ] = kerchunk .df ._proc_raw (data )
411
382
# TODO: only save needed columns
412
- # TODO: maybe categorize paths column
413
383
df = pd .DataFrame (
414
384
dict (
415
385
path = paths ,
@@ -418,7 +388,9 @@ def write(self, field, record, base_url=None, storage_options=None):
418
388
raw = raws ,
419
389
),
420
390
copy = False ,
421
- )[:output_size ]
391
+ )
392
+ if df .path .count () / (df .path .nunique () or 1 ) > self .cat_thresh :
393
+ df ["path" ] = df ["path" ].astype ("category" )
422
394
object_encoding = dict (raw = "bytes" , path = "utf8" )
423
395
has_nulls = ["path" , "raw" ]
424
396
@@ -447,16 +419,15 @@ def flush(self, base_url=None, storage_options=None):
447
419
Location of the output
448
420
"""
449
421
# write what we have so far and clear sub chunks
450
- for thing in self ._items :
422
+ for thing in list ( self ._items ) :
451
423
if isinstance (thing , tuple ):
452
424
field , record = thing
453
- if self ._items .get ((record , field )):
454
- self .write (
455
- field ,
456
- record ,
457
- base_url = base_url ,
458
- storage_options = storage_options ,
459
- )
425
+ self .write (
426
+ field ,
427
+ record ,
428
+ base_url = base_url ,
429
+ storage_options = storage_options ,
430
+ )
460
431
461
432
# gather .zmetadata from self._items and write that too
462
433
for k in list (self ._items ):
0 commit comments