Skip to content

Fix AsyncGroup.create_dataset() dtype handling and optimize tests #3050 #3059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3050.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- Fixed potential error in `AsyncGroup.create_dataset()` where `dtype` argument could be missing when calling `create_array()`
192 changes: 135 additions & 57 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
NodeType,
ShapeLike,
ZarrFormat,
parse_shapelike,
)
from zarr.core.config import config
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
Expand Down Expand Up @@ -441,9 +440,8 @@

metadata: GroupMetadata
store_path: StorePath

# TODO: make this correct and work
# TODO: ensure that this can be bound properly to subclass of AsyncGroup
_sync: Any = field(default=None, init=False)
_async_group: Any = field(default=None, init=False)

@classmethod
async def from_store(
Expand Down Expand Up @@ -991,6 +989,53 @@
return ()
return tuple(await asyncio.gather(*(self.require_group(name) for name in names)))

async def _require_array_async(
self,
name: str,
*,
shape: ShapeLike,
dtype: npt.DTypeLike = None,
exact: bool = False,
**kwargs: Any,
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
"""Obtain an array, creating if it doesn't exist.

Other `kwargs` are as per :func:`zarr.AsyncGroup.create_array`.

Parameters
----------
name : str
Array name.
shape : int or tuple of ints
Array shape.
dtype : str or dtype, optional
NumPy dtype. If None, the dtype will be inferred from the existing array.
exact : bool, optional
If True, require `dtype` to match exactly. If False, require
`dtype` can be cast from array dtype.

Returns
-------
a : AsyncArray
"""
try:
item = await self.getitem(name)
except KeyError:

Check warning on line 1023 in src/zarr/core/group.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/group.py#L1021-L1023

Added lines #L1021 - L1023 were not covered by tests
# If it doesn't exist, create it
return await self.create_array(name, shape=shape, dtype=dtype, **kwargs)

Check warning on line 1025 in src/zarr/core/group.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/group.py#L1025

Added line #L1025 was not covered by tests
else:
# Existing item must be an AsyncArray with matching dtype/shape
if not isinstance(item, AsyncArray):
raise TypeError(f"Incompatible object ({item.__class__.__name__}) already exists")
assert isinstance(item, AsyncArray) # mypy
if exact and dtype is not None and item.dtype != np.dtype(dtype):
raise TypeError("Incompatible dtype")
if not exact and dtype is not None and not np.can_cast(item.dtype, dtype):
raise TypeError("Incompatible dtype")
if item.shape != shape:
raise TypeError("Incompatible shape")
return item

Check warning on line 1037 in src/zarr/core/group.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/group.py#L1028-L1037

Added lines #L1028 - L1037 were not covered by tests

async def create_array(
self,
name: str,
Expand Down Expand Up @@ -1155,32 +1200,35 @@
# create_dataset in zarr 2.x requires shape but not dtype if data is
# provided. Allow this configuration by inferring dtype from data if
# necessary and passing it to create_array
if "dtype" not in kwargs and data is not None:
kwargs["dtype"] = data.dtype
if "dtype" not in kwargs:
if data is not None:
kwargs["dtype"] = data.dtype

Check warning on line 1205 in src/zarr/core/group.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/group.py#L1203-L1205

Added lines #L1203 - L1205 were not covered by tests
else:
raise ValueError("dtype must be provided if data is None")

Check warning on line 1207 in src/zarr/core/group.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/group.py#L1207

Added line #L1207 was not covered by tests
array = await self.create_array(name, shape=shape, **kwargs)
if data is not None:
await array.setitem(slice(None), data)
return array

@deprecated("Use AsyncGroup.require_array instead.")
async def require_dataset(
@deprecated("Use Group.require_array instead.")
def require_dataset(
self,
name: str,
*,
shape: ChunkCoords,
shape: ShapeLike,
dtype: npt.DTypeLike = None,
exact: bool = False,
**kwargs: Any,
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
) -> Array:
"""Obtain an array, creating if it doesn't exist.

.. deprecated:: 3.0.0
The h5py compatibility methods will be removed in 3.1.0. Use `AsyncGroup.require_dataset` instead.
The h5py compatibility methods will be removed in 3.1.0. Use `Group.require_array` instead.

Arrays are known as "datasets" in HDF5 terminology. For compatibility
with h5py, Zarr groups also implement the :func:`zarr.AsyncGroup.create_dataset` method.

Other `kwargs` are as per :func:`zarr.AsyncGroup.create_dataset`.
Other `kwargs` are as per :func:`zarr.Group.create_array`.

Parameters
----------
Expand All @@ -1189,29 +1237,35 @@
shape : int or tuple of ints
Array shape.
dtype : str or dtype, optional
NumPy dtype.
NumPy dtype. If None, the dtype will be inferred from the existing array.
exact : bool, optional
If True, require `dtype` to match exactly. If false, require
If True, require `dtype` to match exactly. If False, require
`dtype` can be cast from array dtype.

Returns
-------
a : AsyncArray
a : Array
"""
return await self.require_array(name, shape=shape, dtype=dtype, exact=exact, **kwargs)
return Array(

Check warning on line 1249 in src/zarr/core/group.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/group.py#L1249

Added line #L1249 was not covered by tests
self._sync(
self._async_group._require_array_async(
name, shape=shape, dtype=dtype, exact=exact, **kwargs
)
)
)

async def require_array(
def require_array(
self,
name: str,
*,
shape: ShapeLike,
dtype: npt.DTypeLike = None,
exact: bool = False,
**kwargs: Any,
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
) -> Array:
"""Obtain an array, creating if it doesn't exist.

Other `kwargs` are as per :func:`zarr.AsyncGroup.create_dataset`.
Other `kwargs` are as per :func:`zarr.Group.create_array`.

Parameters
----------
Expand All @@ -1220,35 +1274,22 @@
shape : int or tuple of ints
Array shape.
dtype : str or dtype, optional
NumPy dtype.
NumPy dtype. If None, the dtype will be inferred from the existing array.
exact : bool, optional
If True, require `dtype` to match exactly. If false, require
If True, require `dtype` to match exactly. If False, require
`dtype` can be cast from array dtype.

Returns
-------
a : AsyncArray
a : Array
"""
try:
ds = await self.getitem(name)
if not isinstance(ds, AsyncArray):
raise TypeError(f"Incompatible object ({ds.__class__.__name__}) already exists")

shape = parse_shapelike(shape)
if shape != ds.shape:
raise TypeError(f"Incompatible shape ({ds.shape} vs {shape})")

dtype = np.dtype(dtype)
if exact:
if ds.dtype != dtype:
raise TypeError(f"Incompatible dtype ({ds.dtype} vs {dtype})")
else:
if not np.can_cast(ds.dtype, dtype):
raise TypeError(f"Incompatible dtype ({ds.dtype} vs {dtype})")
except KeyError:
ds = await self.create_array(name, shape=shape, dtype=dtype, **kwargs)

return ds
return Array(

Check warning on line 1286 in src/zarr/core/group.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/group.py#L1286

Added line #L1286 was not covered by tests
self._sync(
self._async_group._require_array_async(
name, shape=shape, dtype=dtype, exact=exact, **kwargs
)
)
)

async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup:
"""Update group attributes.
Expand Down Expand Up @@ -2511,49 +2552,75 @@
.. deprecated:: 3.0.0
The h5py compatibility methods will be removed in 3.1.0. Use `Group.create_array` instead.


Arrays are known as "datasets" in HDF5 terminology. For compatibility
with h5py, Zarr groups also implement the :func:`zarr.Group.require_dataset` method.
with h5py, Zarr groups also implement the :func:`zarr.AsyncGroup.require_dataset` method.

Parameters
----------
name : str
Array name.
**kwargs : dict
Additional arguments passed to :func:`zarr.Group.create_array`
Additional arguments passed to :func:`zarr.AsyncGroup.create_array`.

Returns
-------
a : Array
a : AsyncArray
"""
return Array(self._sync(self._async_group.create_dataset(name, **kwargs)))

@deprecated("Use Group.require_array instead.")
def require_dataset(self, name: str, *, shape: ShapeLike, **kwargs: Any) -> Array:
def require_dataset(
self,
name: str,
*,
shape: ShapeLike,
dtype: npt.DTypeLike = None,
exact: bool = False,
**kwargs: Any,
) -> Array:
"""Obtain an array, creating if it doesn't exist.

.. deprecated:: 3.0.0
The h5py compatibility methods will be removed in 3.1.0. Use `Group.require_array` instead.

Arrays are known as "datasets" in HDF5 terminology. For compatibility
with h5py, Zarr groups also implement the :func:`zarr.Group.create_dataset` method.
with h5py, Zarr groups also implement the :func:`zarr.AsyncGroup.create_dataset` method.

Other `kwargs` are as per :func:`zarr.Group.create_dataset`.
Other `kwargs` are as per :func:`zarr.Group.create_array`.

Parameters
----------
name : str
Array name.
**kwargs :
See :func:`zarr.Group.create_dataset`.
shape : int or tuple of ints
Array shape.
dtype : str or dtype, optional
NumPy dtype. If None, the dtype will be inferred from the existing array.
exact : bool, optional
If True, require `dtype` to match exactly. If False, require
`dtype` can be cast from array dtype.

Returns
-------
a : Array
"""
return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs)))
return Array(

Check warning on line 2607 in src/zarr/core/group.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/group.py#L2607

Added line #L2607 was not covered by tests
self._sync(
self._async_group._require_array_async(
name, shape=shape, dtype=dtype, exact=exact, **kwargs
)
)
)

def require_array(self, name: str, *, shape: ShapeLike, **kwargs: Any) -> Array:
def require_array(
self,
name: str,
*,
shape: ShapeLike,
dtype: npt.DTypeLike = None,
exact: bool = False,
**kwargs: Any,
) -> Array:
"""Obtain an array, creating if it doesn't exist.

Other `kwargs` are as per :func:`zarr.Group.create_array`.
Expand All @@ -2562,14 +2629,25 @@
----------
name : str
Array name.
**kwargs :
See :func:`zarr.Group.create_array`.
shape : int or tuple of ints
Array shape.
dtype : str or dtype, optional
NumPy dtype. If None, the dtype will be inferred from the existing array.
exact : bool, optional
If True, require `dtype` to match exactly. If False, require
`dtype` can be cast from array dtype.

Returns
-------
a : Array
"""
return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs)))
return Array(

Check warning on line 2644 in src/zarr/core/group.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/group.py#L2644

Added line #L2644 was not covered by tests
self._sync(
self._async_group._require_array_async(
name, shape=shape, dtype=dtype, exact=exact, **kwargs
)
)
)

@_deprecate_positional_args
def empty(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> Array:
Expand Down Expand Up @@ -2918,7 +2996,7 @@
This function will parse its input to ensure that the hierarchy is complete. Any implicit groups
will be inserted as needed. For example, an input like
```{'a/b': GroupMetadata}``` will be parsed to
```{'': GroupMetadata, 'a': GroupMetadata, 'b': Groupmetadata}```
```{'': GroupMetadata, 'a': GroupMetadata, 'b': Groupmetadata}```.

After input parsing, this function then creates all the nodes in the hierarchy concurrently.

Expand Down
Loading