Skip to content

Commit 936f4d2

Browse files
TomNicholaschuckwondopre-commit-ci[bot]d-v-b
authored
Make all codecs pickleable (#745)
* make Zlib codec pickleable * add test * show __init_subclass__ can work for Zlib * refactor the BytesBytes Codecs to use __init_subclass__ * remove snake_case function * Chuck's suggestions Co-authored-by: Chuck Daniels <[email protected]> * redefine Bitround * remove debugging prints from test * redefine Shuffle * redefine Delta * redefine FixedScaleOffset * Quantize * PackBits * AsType * redefine checksum codecs * array to bytes codecs * remove todo * remove dynamic constructors * release note * remove unneeded imports * style: pre-commit fixes * remove unused type ignore --------- Co-authored-by: Chuck Daniels <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Davis Bennett <[email protected]>
1 parent 65e16c3 commit 936f4d2

File tree

3 files changed

+110
-125
lines changed

3 files changed

+110
-125
lines changed

docs/release.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ Improvements
2525

2626
* In ``vlen``, define and use ``const`` ``HEADER_LENGTH``.
2727
By :user:`John Kirkham <jakirkham>`, :issue:`723`
28+
* All codecs are now pickleable.
29+
By :user:`Tom Nicholas <TomNicholas>`, :issue:`744`
2830

2931
Fixes
3032
~~~~~

numcodecs/tests/test_zarr3.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

3+
import pickle
34
from typing import TYPE_CHECKING
45

56
import numpy as np
67
import pytest
78

9+
import numcodecs.bitround
10+
811
if TYPE_CHECKING: # pragma: no cover
912
import zarr
1013
else:
@@ -260,7 +263,7 @@ def test_delta_astype(store: StorePath):
260263
dtype=data.dtype,
261264
fill_value=0,
262265
filters=[
263-
numcodecs.zarr3.Delta(dtype="i8", astype="i2"), # type: ignore[arg-type]
266+
numcodecs.zarr3.Delta(dtype="i8", astype="i2"),
264267
],
265268
)
266269

@@ -277,3 +280,39 @@ def test_repr():
277280
def test_to_dict():
278281
codec = numcodecs.zarr3.LZ4(level=5)
279282
assert codec.to_dict() == {"name": "numcodecs.lz4", "configuration": {"level": 5}}
283+
284+
285+
@pytest.mark.parametrize(
286+
"codec_cls",
287+
[
288+
numcodecs.zarr3.Blosc,
289+
numcodecs.zarr3.LZ4,
290+
numcodecs.zarr3.Zstd,
291+
numcodecs.zarr3.Zlib,
292+
numcodecs.zarr3.GZip,
293+
numcodecs.zarr3.BZ2,
294+
numcodecs.zarr3.LZMA,
295+
numcodecs.zarr3.Shuffle,
296+
numcodecs.zarr3.BitRound,
297+
numcodecs.zarr3.Delta,
298+
numcodecs.zarr3.FixedScaleOffset,
299+
numcodecs.zarr3.Quantize,
300+
numcodecs.zarr3.PackBits,
301+
numcodecs.zarr3.AsType,
302+
numcodecs.zarr3.CRC32,
303+
numcodecs.zarr3.CRC32C,
304+
numcodecs.zarr3.Adler32,
305+
numcodecs.zarr3.Fletcher32,
306+
numcodecs.zarr3.JenkinsLookup3,
307+
numcodecs.zarr3.PCodec,
308+
numcodecs.zarr3.ZFPY,
309+
],
310+
)
311+
def test_codecs_pickleable(codec_cls):
312+
codec = codec_cls()
313+
314+
expected = codec
315+
316+
p = pickle.dumps(codec)
317+
actual = pickle.loads(p)
318+
assert actual == expected

numcodecs/zarr3.py

Lines changed: 68 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
import asyncio
2929
import math
3030
from dataclasses import dataclass, replace
31-
from functools import cached_property, partial
32-
from typing import Any, Self, TypeVar
31+
from functools import cached_property
32+
from typing import Any, Self
3333
from warnings import warn
3434

3535
import numpy as np
@@ -79,6 +79,18 @@ class _NumcodecsCodec(Metadata):
7979
codec_name: str
8080
codec_config: dict[str, JSON]
8181

82+
def __init_subclass__(cls, *, codec_name: str | None = None, **kwargs):
83+
"""To be used only when creating the actual public-facing codec class."""
84+
super().__init_subclass__(**kwargs)
85+
if codec_name is not None:
86+
namespace = codec_name
87+
88+
cls_name = f"{CODEC_PREFIX}{namespace}.{cls.__name__}"
89+
cls.codec_name = f"{CODEC_PREFIX}{namespace}"
90+
cls.__doc__ = f"""
91+
See :class:`{cls_name}` for more details and parameters.
92+
"""
93+
8294
def __init__(self, **codec_config: JSON) -> None:
8395
if not self.codec_name:
8496
raise ValueError(
@@ -180,128 +192,55 @@ async def _encode_single(self, chunk_ndbuffer: NDBuffer, chunk_spec: ArraySpec)
180192
return chunk_spec.prototype.buffer.from_bytes(out)
181193

182194

183-
T = TypeVar("T", bound=_NumcodecsCodec)
184-
185-
186-
def _add_docstring(cls: type[T], ref_class_name: str) -> type[T]:
187-
cls.__doc__ = f"""
188-
See :class:`{ref_class_name}` for more details and parameters.
189-
"""
190-
return cls
191-
192-
193-
def _add_docstring_wrapper(ref_class_name: str) -> partial:
194-
return partial(_add_docstring, ref_class_name=ref_class_name)
195-
196-
197-
def _make_bytes_bytes_codec(codec_name: str, cls_name: str) -> type[_NumcodecsBytesBytesCodec]:
198-
# rename for class scope
199-
_codec_name = CODEC_PREFIX + codec_name
200-
201-
class _Codec(_NumcodecsBytesBytesCodec):
202-
codec_name = _codec_name
203-
204-
def __init__(self, **codec_config: JSON) -> None:
205-
super().__init__(**codec_config)
206-
207-
_Codec.__name__ = cls_name
208-
return _Codec
209-
210-
211-
def _make_array_array_codec(codec_name: str, cls_name: str) -> type[_NumcodecsArrayArrayCodec]:
212-
# rename for class scope
213-
_codec_name = CODEC_PREFIX + codec_name
214-
215-
class _Codec(_NumcodecsArrayArrayCodec):
216-
codec_name = _codec_name
217-
218-
def __init__(self, **codec_config: JSON) -> None:
219-
super().__init__(**codec_config)
220-
221-
_Codec.__name__ = cls_name
222-
return _Codec
223-
224-
225-
def _make_array_bytes_codec(codec_name: str, cls_name: str) -> type[_NumcodecsArrayBytesCodec]:
226-
# rename for class scope
227-
_codec_name = CODEC_PREFIX + codec_name
195+
# bytes-to-bytes codecs
196+
class Blosc(_NumcodecsBytesBytesCodec, codec_name="blosc"):
197+
pass
228198

229-
class _Codec(_NumcodecsArrayBytesCodec):
230-
codec_name = _codec_name
231199

232-
def __init__(self, **codec_config: JSON) -> None:
233-
super().__init__(**codec_config)
200+
class LZ4(_NumcodecsBytesBytesCodec, codec_name="lz4"):
201+
pass
234202

235-
_Codec.__name__ = cls_name
236-
return _Codec
237203

204+
class Zstd(_NumcodecsBytesBytesCodec, codec_name="zstd"):
205+
pass
238206

239-
def _make_checksum_codec(codec_name: str, cls_name: str) -> type[_NumcodecsBytesBytesCodec]:
240-
# rename for class scope
241-
_codec_name = CODEC_PREFIX + codec_name
242207

243-
class _ChecksumCodec(_NumcodecsBytesBytesCodec):
244-
codec_name = _codec_name
208+
class Zlib(_NumcodecsBytesBytesCodec, codec_name="zlib"):
209+
pass
245210

246-
def __init__(self, **codec_config: JSON) -> None:
247-
super().__init__(**codec_config)
248211

249-
def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int:
250-
return input_byte_length + 4 # pragma: no cover
212+
class GZip(_NumcodecsBytesBytesCodec, codec_name="gzip"):
213+
pass
251214

252-
_ChecksumCodec.__name__ = cls_name
253-
return _ChecksumCodec
254215

216+
class BZ2(_NumcodecsBytesBytesCodec, codec_name="bz2"):
217+
pass
255218

256-
# bytes-to-bytes codecs
257-
Blosc = _add_docstring(_make_bytes_bytes_codec("blosc", "Blosc"), "numcodecs.blosc.Blosc")
258-
LZ4 = _add_docstring(_make_bytes_bytes_codec("lz4", "LZ4"), "numcodecs.lz4.LZ4")
259-
Zstd = _add_docstring(_make_bytes_bytes_codec("zstd", "Zstd"), "numcodecs.zstd.Zstd")
260-
Zlib = _add_docstring(_make_bytes_bytes_codec("zlib", "Zlib"), "numcodecs.zlib.Zlib")
261-
GZip = _add_docstring(_make_bytes_bytes_codec("gzip", "GZip"), "numcodecs.gzip.GZip")
262-
BZ2 = _add_docstring(_make_bytes_bytes_codec("bz2", "BZ2"), "numcodecs.bz2.BZ2")
263-
LZMA = _add_docstring(_make_bytes_bytes_codec("lzma", "LZMA"), "numcodecs.lzma.LZMA")
264219

220+
class LZMA(_NumcodecsBytesBytesCodec, codec_name="lzma"):
221+
pass
265222

266-
@_add_docstring_wrapper("numcodecs.shuffle.Shuffle")
267-
class Shuffle(_NumcodecsBytesBytesCodec):
268-
codec_name = f"{CODEC_PREFIX}shuffle"
269-
270-
def __init__(self, **codec_config: JSON) -> None:
271-
super().__init__(**codec_config)
272223

224+
class Shuffle(_NumcodecsBytesBytesCodec, codec_name="shuffle"):
273225
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Shuffle:
274226
if self.codec_config.get("elementsize", None) is None:
275227
return Shuffle(**{**self.codec_config, "elementsize": array_spec.dtype.itemsize})
276228
return self # pragma: no cover
277229

278230

279231
# array-to-array codecs ("filters")
280-
@_add_docstring_wrapper("numcodecs.delta.Delta")
281-
class Delta(_NumcodecsArrayArrayCodec):
282-
codec_name = f"{CODEC_PREFIX}delta"
283-
284-
def __init__(self, **codec_config: dict[str, JSON]) -> None:
285-
super().__init__(**codec_config)
286-
232+
class Delta(_NumcodecsArrayArrayCodec, codec_name="delta"):
287233
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
288234
if astype := self.codec_config.get("astype"):
289235
return replace(chunk_spec, dtype=np.dtype(astype)) # type: ignore[call-overload]
290236
return chunk_spec
291237

292238

293-
BitRound = _add_docstring(
294-
_make_array_array_codec("bitround", "BitRound"), "numcodecs.bitround.BitRound"
295-
)
296-
239+
class BitRound(_NumcodecsArrayArrayCodec, codec_name="bitround"):
240+
pass
297241

298-
@_add_docstring_wrapper("numcodecs.fixedscaleoffset.FixedScaleOffset")
299-
class FixedScaleOffset(_NumcodecsArrayArrayCodec):
300-
codec_name = f"{CODEC_PREFIX}fixedscaleoffset"
301-
302-
def __init__(self, **codec_config: JSON) -> None:
303-
super().__init__(**codec_config)
304242

243+
class FixedScaleOffset(_NumcodecsArrayArrayCodec, codec_name="fixedscaleoffset"):
305244
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
306245
if astype := self.codec_config.get("astype"):
307246
return replace(chunk_spec, dtype=np.dtype(astype)) # type: ignore[call-overload]
@@ -313,10 +252,7 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> FixedScaleOffset:
313252
return self
314253

315254

316-
@_add_docstring_wrapper("numcodecs.quantize.Quantize")
317-
class Quantize(_NumcodecsArrayArrayCodec):
318-
codec_name = f"{CODEC_PREFIX}quantize"
319-
255+
class Quantize(_NumcodecsArrayArrayCodec, codec_name="quantize"):
320256
def __init__(self, **codec_config: JSON) -> None:
321257
super().__init__(**codec_config)
322258

@@ -326,13 +262,7 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Quantize:
326262
return self
327263

328264

329-
@_add_docstring_wrapper("numcodecs.packbits.PackBits")
330-
class PackBits(_NumcodecsArrayArrayCodec):
331-
codec_name = f"{CODEC_PREFIX}packbits"
332-
333-
def __init__(self, **codec_config: JSON) -> None:
334-
super().__init__(**codec_config)
335-
265+
class PackBits(_NumcodecsArrayArrayCodec, codec_name="packbits"):
336266
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
337267
return replace(
338268
chunk_spec,
@@ -345,13 +275,7 @@ def validate(self, *, dtype: np.dtype[Any], **_kwargs) -> None:
345275
raise ValueError(f"Packbits filter requires bool dtype. Got {dtype}.")
346276

347277

348-
@_add_docstring_wrapper("numcodecs.astype.AsType")
349-
class AsType(_NumcodecsArrayArrayCodec):
350-
codec_name = f"{CODEC_PREFIX}astype"
351-
352-
def __init__(self, **codec_config: JSON) -> None:
353-
super().__init__(**codec_config)
354-
278+
class AsType(_NumcodecsArrayArrayCodec, codec_name="astype"):
355279
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
356280
return replace(chunk_spec, dtype=np.dtype(self.codec_config["encode_dtype"])) # type: ignore[arg-type]
357281

@@ -362,19 +286,39 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> AsType:
362286

363287

364288
# bytes-to-bytes checksum codecs
365-
CRC32 = _add_docstring(_make_checksum_codec("crc32", "CRC32"), "numcodecs.checksum32.CRC32")
366-
CRC32C = _add_docstring(_make_checksum_codec("crc32c", "CRC32C"), "numcodecs.checksum32.CRC32C")
367-
Adler32 = _add_docstring(_make_checksum_codec("adler32", "Adler32"), "numcodecs.checksum32.Adler32")
368-
Fletcher32 = _add_docstring(
369-
_make_checksum_codec("fletcher32", "Fletcher32"), "numcodecs.fletcher32.Fletcher32"
370-
)
371-
JenkinsLookup3 = _add_docstring(
372-
_make_checksum_codec("jenkins_lookup3", "JenkinsLookup3"), "numcodecs.checksum32.JenkinsLookup3"
373-
)
289+
class _NumcodecsChecksumCodec(_NumcodecsBytesBytesCodec):
290+
def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int:
291+
return input_byte_length + 4 # pragma: no cover
292+
293+
294+
class CRC32(_NumcodecsChecksumCodec, codec_name="crc32"):
295+
pass
296+
297+
298+
class CRC32C(_NumcodecsChecksumCodec, codec_name="crc32c"):
299+
pass
300+
301+
302+
class Adler32(_NumcodecsChecksumCodec, codec_name="adler32"):
303+
pass
304+
305+
306+
class Fletcher32(_NumcodecsChecksumCodec, codec_name="fletcher32"):
307+
pass
308+
309+
310+
class JenkinsLookup3(_NumcodecsChecksumCodec, codec_name="jenkins_lookup3"):
311+
pass
312+
374313

375314
# array-to-bytes codecs
376-
PCodec = _add_docstring(_make_array_bytes_codec("pcodec", "PCodec"), "numcodecs.pcodec.PCodec")
377-
ZFPY = _add_docstring(_make_array_bytes_codec("zfpy", "ZFPY"), "numcodecs.zfpy.ZFPY")
315+
class PCodec(_NumcodecsArrayBytesCodec, codec_name="pcodec"):
316+
pass
317+
318+
319+
class ZFPY(_NumcodecsArrayBytesCodec, codec_name="zfpy"):
320+
pass
321+
378322

379323
__all__ = [
380324
"BZ2",

0 commit comments

Comments
 (0)