From 7f758b3919d014b82339a24d066a570590c4b968 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 1 Mar 2026 11:05:34 -0500 Subject: [PATCH] define a module for chunk grids, and a registry --- src/zarr/core/chunk_grids/__init__.py | 60 ++++++ .../{chunk_grids.py => chunk_grids/common.py} | 199 ++++++++++++------ src/zarr/core/chunk_grids/regular.py | 98 +++++++++ src/zarr/core/metadata/v3.py | 4 +- src/zarr/registry.py | 24 +++ 5 files changed, 320 insertions(+), 65 deletions(-) create mode 100644 src/zarr/core/chunk_grids/__init__.py rename src/zarr/core/{chunk_grids.py => chunk_grids/common.py} (72%) create mode 100644 src/zarr/core/chunk_grids/regular.py diff --git a/src/zarr/core/chunk_grids/__init__.py b/src/zarr/core/chunk_grids/__init__.py new file mode 100644 index 0000000000..c8e1f24d76 --- /dev/null +++ b/src/zarr/core/chunk_grids/__init__.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from typing import Any + +from zarr.core.chunk_grids.common import ( + ChunkGrid, + _auto_partition, + _guess_chunks, + _guess_num_chunks_per_axis_shard, + normalize_chunks, +) +from zarr.core.chunk_grids.regular import RegularChunkGrid +from zarr.core.common import JSON, NamedConfig, parse_named_configuration +from zarr.registry import get_chunk_grid_class, register_chunk_grid + +register_chunk_grid("regular", RegularChunkGrid) + + +def parse_chunk_grid( + data: dict[str, JSON] | ChunkGrid | NamedConfig[str, Any], +) -> ChunkGrid: + """Parse a chunk grid from a dictionary, returning existing ChunkGrid instances as-is. + + Uses the chunk grid registry to look up the appropriate class by name. + + Parameters + ---------- + data : dict[str, JSON] | ChunkGrid | NamedConfig[str, Any] + Either a ChunkGrid instance (returned as-is) or a dictionary with + 'name' and 'configuration' keys. + + Returns + ------- + ChunkGrid + + Raises + ------ + ValueError + If the chunk grid name is not found in the registry. + """ + if isinstance(data, ChunkGrid): + return data + + name_parsed, _ = parse_named_configuration(data) + try: + chunk_grid_cls = get_chunk_grid_class(name_parsed) + except KeyError as e: + raise ValueError(f"Unknown chunk grid. Got {name_parsed}.") from e + return chunk_grid_cls.from_dict(data) # type: ignore[arg-type] + + +__all__ = [ + "ChunkGrid", + "RegularChunkGrid", + "_auto_partition", + "_guess_chunks", + "_guess_num_chunks_per_axis_shard", + "normalize_chunks", + "parse_chunk_grid", +] diff --git a/src/zarr/core/chunk_grids.py b/src/zarr/core/chunk_grids/common.py similarity index 72% rename from src/zarr/core/chunk_grids.py rename to src/zarr/core/chunk_grids/common.py index 2c7945fa64..22a9db0708 100644 --- a/src/zarr/core/chunk_grids.py +++ b/src/zarr/core/chunk_grids/common.py @@ -1,27 +1,17 @@ from __future__ import annotations -import itertools import math import numbers -import operator import warnings from abc import abstractmethod from dataclasses import dataclass -from functools import reduce from typing import TYPE_CHECKING, Any, Literal import numpy as np +import numpy.typing as npt import zarr from zarr.abc.metadata import Metadata -from zarr.core.common import ( - JSON, - NamedConfig, - ShapeLike, - ceildiv, - parse_named_configuration, - parse_shapelike, -) from zarr.errors import ZarrUserWarning if TYPE_CHECKING: @@ -29,6 +19,141 @@ from typing import Self from zarr.core.array import ShardsLike + from zarr.core.common import JSON + + +@dataclass(frozen=True) +class ChunkGrid(Metadata): + @abstractmethod + def to_dict(self) -> dict[str, JSON]: ... + + @abstractmethod + def update_shape(self, new_shape: tuple[int, ...]) -> Self: + pass + + @abstractmethod + def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: + pass + + @abstractmethod + def get_nchunks(self, array_shape: tuple[int, ...]) -> int: + pass + + @abstractmethod + def get_chunk_shape( + self, array_shape: tuple[int, ...], chunk_coord: tuple[int, ...] + ) -> tuple[int, ...]: + """ + Get the shape of a specific chunk. + + Parameters + ---------- + array_shape : tuple[int, ...] + Shape of the full array. + chunk_coord : tuple[int, ...] + Coordinates of the chunk in the chunk grid. + + Returns + ------- + tuple[int, ...] + Shape of the chunk at the given coordinates. + """ + + @abstractmethod + def get_chunk_start( + self, array_shape: tuple[int, ...], chunk_coord: tuple[int, ...] + ) -> tuple[int, ...]: + """ + Get the starting position of a chunk in the array. + + Parameters + ---------- + array_shape : tuple[int, ...] + Shape of the full array. + chunk_coord : tuple[int, ...] + Coordinates of the chunk in the chunk grid. + + Returns + ------- + tuple[int, ...] + Starting position (offset) of the chunk in the array. + """ + + @abstractmethod + def array_index_to_chunk_coord( + self, array_shape: tuple[int, ...], array_index: tuple[int, ...] + ) -> tuple[int, ...]: + """ + Map an array index to the chunk coordinates that contain it. + + Parameters + ---------- + array_shape : tuple[int, ...] + Shape of the full array. + array_index : tuple[int, ...] + Index in the array. + + Returns + ------- + tuple[int, ...] + Coordinates of the chunk containing the array index. + """ + + @abstractmethod + def array_indices_to_chunk_dim( + self, array_shape: tuple[int, ...], dim: int, indices: npt.NDArray[np.intp] + ) -> npt.NDArray[np.intp]: + """ + Map an array of indices along one dimension to chunk coordinates (vectorized). + + Parameters + ---------- + array_shape : tuple[int, ...] + Shape of the full array. + dim : int + Dimension index. + indices : np.ndarray + Array of indices along the given dimension. + + Returns + ------- + np.ndarray + Array of chunk coordinates, same shape as indices. + """ + + @abstractmethod + def chunks_per_dim(self, array_shape: tuple[int, ...], dim: int) -> int: + """ + Get the number of chunks along a specific dimension. + + Parameters + ---------- + array_shape : tuple[int, ...] + Shape of the full array. + dim : int + Dimension index. + + Returns + ------- + int + Number of chunks along the dimension. + """ + + @abstractmethod + def get_chunk_grid_shape(self, array_shape: tuple[int, ...]) -> tuple[int, ...]: + """ + Get the shape of the chunk grid (number of chunks along each dimension). + + Parameters + ---------- + array_shape : tuple[int, ...] + Shape of the full array. + + Returns + ------- + tuple[int, ...] + Number of chunks along each dimension. + """ def _guess_chunks( @@ -153,58 +278,6 @@ def normalize_chunks(chunks: Any, shape: tuple[int, ...], typesize: int) -> tupl return tuple(int(c) for c in chunks) -@dataclass(frozen=True) -class ChunkGrid(Metadata): - @classmethod - def from_dict(cls, data: dict[str, JSON] | ChunkGrid | NamedConfig[str, Any]) -> ChunkGrid: - if isinstance(data, ChunkGrid): - return data - - name_parsed, _ = parse_named_configuration(data) - if name_parsed == "regular": - return RegularChunkGrid._from_dict(data) - raise ValueError(f"Unknown chunk grid. Got {name_parsed}.") - - @abstractmethod - def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: - pass - - @abstractmethod - def get_nchunks(self, array_shape: tuple[int, ...]) -> int: - pass - - -@dataclass(frozen=True) -class RegularChunkGrid(ChunkGrid): - chunk_shape: tuple[int, ...] - - def __init__(self, *, chunk_shape: ShapeLike) -> None: - chunk_shape_parsed = parse_shapelike(chunk_shape) - - object.__setattr__(self, "chunk_shape", chunk_shape_parsed) - - @classmethod - def _from_dict(cls, data: dict[str, JSON] | NamedConfig[str, Any]) -> Self: - _, configuration_parsed = parse_named_configuration(data, "regular") - - return cls(**configuration_parsed) # type: ignore[arg-type] - - def to_dict(self) -> dict[str, JSON]: - return {"name": "regular", "configuration": {"chunk_shape": tuple(self.chunk_shape)}} - - def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: - return itertools.product( - *(range(ceildiv(s, c)) for s, c in zip(array_shape, self.chunk_shape, strict=False)) - ) - - def get_nchunks(self, array_shape: tuple[int, ...]) -> int: - return reduce( - operator.mul, - itertools.starmap(ceildiv, zip(array_shape, self.chunk_shape, strict=True)), - 1, - ) - - def _guess_num_chunks_per_axis_shard( chunk_shape: tuple[int, ...], item_size: int, max_bytes: int, array_shape: tuple[int, ...] ) -> int: diff --git a/src/zarr/core/chunk_grids/regular.py b/src/zarr/core/chunk_grids/regular.py new file mode 100644 index 0000000000..fc3534e6dc --- /dev/null +++ b/src/zarr/core/chunk_grids/regular.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import itertools +import operator +from dataclasses import dataclass +from functools import reduce +from typing import TYPE_CHECKING, Any + +import numpy as np +import numpy.typing as npt + +from zarr.core.chunk_grids.common import ChunkGrid +from zarr.core.common import ( + JSON, + NamedConfig, + ShapeLike, + ceildiv, + parse_named_configuration, + parse_shapelike, +) + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Self + + +@dataclass(frozen=True) +class RegularChunkGrid(ChunkGrid): + chunk_shape: tuple[int, ...] + + def __init__(self, *, chunk_shape: ShapeLike) -> None: + chunk_shape_parsed = parse_shapelike(chunk_shape) + + object.__setattr__(self, "chunk_shape", chunk_shape_parsed) + + @classmethod + def from_dict(cls, data: dict[str, JSON] | NamedConfig[str, Any]) -> Self: + _, configuration_parsed = parse_named_configuration(data, "regular") + + return cls(**configuration_parsed) # type: ignore[arg-type] + + def to_dict(self) -> dict[str, JSON]: + return {"name": "regular", "configuration": {"chunk_shape": tuple(self.chunk_shape)}} + + def update_shape(self, new_shape: tuple[int, ...]) -> Self: + return self + + def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: + return itertools.product( + *(range(ceildiv(s, c)) for s, c in zip(array_shape, self.chunk_shape, strict=False)) + ) + + def get_nchunks(self, array_shape: tuple[int, ...]) -> int: + return reduce( + operator.mul, + itertools.starmap(ceildiv, zip(array_shape, self.chunk_shape, strict=True)), + 1, + ) + + def get_chunk_shape( + self, array_shape: tuple[int, ...], chunk_coord: tuple[int, ...] + ) -> tuple[int, ...]: + return tuple( + int(min(self.chunk_shape[i], array_shape[i] - chunk_coord[i] * self.chunk_shape[i])) + for i in range(len(array_shape)) + ) + + def get_chunk_start( + self, array_shape: tuple[int, ...], chunk_coord: tuple[int, ...] + ) -> tuple[int, ...]: + return tuple( + coord * size for coord, size in zip(chunk_coord, self.chunk_shape, strict=False) + ) + + def array_index_to_chunk_coord( + self, array_shape: tuple[int, ...], array_index: tuple[int, ...] + ) -> tuple[int, ...]: + return tuple( + 0 if size == 0 else idx // size + for idx, size in zip(array_index, self.chunk_shape, strict=False) + ) + + def array_indices_to_chunk_dim( + self, array_shape: tuple[int, ...], dim: int, indices: npt.NDArray[np.intp] + ) -> npt.NDArray[np.intp]: + chunk_size = self.chunk_shape[dim] + if chunk_size == 0: + return np.zeros_like(indices) + return indices // chunk_size + + def chunks_per_dim(self, array_shape: tuple[int, ...], dim: int) -> int: + return ceildiv(array_shape[dim], self.chunk_shape[dim]) + + def get_chunk_grid_shape(self, array_shape: tuple[int, ...]) -> tuple[int, ...]: + return tuple( + ceildiv(array_len, chunk_len) + for array_len, chunk_len in zip(array_shape, self.chunk_shape, strict=False) + ) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 5ce155bd9a..f763877edc 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -24,7 +24,7 @@ from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec from zarr.core.array_spec import ArrayConfig, ArraySpec -from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid +from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid, parse_chunk_grid from zarr.core.chunk_key_encodings import ( ChunkKeyEncoding, ChunkKeyEncodingLike, @@ -229,7 +229,7 @@ def __init__( """ shape_parsed = parse_shapelike(shape) - chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid) + chunk_grid_parsed = parse_chunk_grid(chunk_grid) chunk_key_encoding_parsed = parse_chunk_key_encoding(chunk_key_encoding) dimension_names_parsed = parse_dimension_names(dimension_names) # Note: relying on a type method is numpy-specific diff --git a/src/zarr/registry.py b/src/zarr/registry.py index d0850a1387..07a8b5c3dc 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -22,17 +22,20 @@ ) from zarr.abc.numcodec import Numcodec from zarr.core.buffer import Buffer, NDBuffer + from zarr.core.chunk_grids.common import ChunkGrid from zarr.core.chunk_key_encodings import ChunkKeyEncoding from zarr.core.common import JSON __all__ = [ "Registry", "get_buffer_class", + "get_chunk_grid_class", "get_chunk_key_encoding_class", "get_codec_class", "get_ndbuffer_class", "get_pipeline_class", "register_buffer", + "register_chunk_grid", "register_chunk_key_encoding", "register_codec", "register_ndbuffer", @@ -63,6 +66,7 @@ def register(self, cls: type[T], qualname: str | None = None) -> None: __pipeline_registry: Registry[CodecPipeline] = Registry() __buffer_registry: Registry[Buffer] = Registry() __ndbuffer_registry: Registry[NDBuffer] = Registry() +__chunk_grid_registry: Registry[ChunkGrid] = Registry() __chunk_key_encoding_registry: Registry[ChunkKeyEncoding] = Registry() """ @@ -103,6 +107,11 @@ def _collect_entrypoints() -> list[Registry[Any]]: data_type_registry._lazy_load_list.extend(entry_points.select(group="zarr.data_type")) data_type_registry._lazy_load_list.extend(entry_points.select(group="zarr", name="data_type")) + __chunk_grid_registry.lazy_load_list.extend(entry_points.select(group="zarr.chunk_grid")) + __chunk_grid_registry.lazy_load_list.extend( + entry_points.select(group="zarr", name="chunk_grid") + ) + __chunk_key_encoding_registry.lazy_load_list.extend( entry_points.select(group="zarr.chunk_key_encoding") ) @@ -125,6 +134,7 @@ def _collect_entrypoints() -> list[Registry[Any]]: __pipeline_registry, __buffer_registry, __ndbuffer_registry, + __chunk_grid_registry, __chunk_key_encoding_registry, ] @@ -156,6 +166,20 @@ def register_buffer(cls: type[Buffer], qualname: str | None = None) -> None: __buffer_registry.register(cls, qualname) +def register_chunk_grid(key: str, cls: type[ChunkGrid]) -> None: + __chunk_grid_registry.register(cls, key) + + +def get_chunk_grid_class(key: str) -> type[ChunkGrid]: + __chunk_grid_registry.lazy_load(use_entrypoint_name=True) + if key not in __chunk_grid_registry: + raise KeyError( + f"Chunk grid '{key}' not found in registered chunk grids: " + f"{list(__chunk_grid_registry)}." + ) + return __chunk_grid_registry[key] + + def register_chunk_key_encoding(key: str, cls: type) -> None: __chunk_key_encoding_registry.register(cls, key)