Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/zarr/core/chunk_grids/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

from typing import Any

from zarr.core.chunk_grids.common import (
ChunkGrid,
_auto_partition,
_guess_chunks,
_guess_num_chunks_per_axis_shard,
normalize_chunks,
)
from zarr.core.chunk_grids.regular import RegularChunkGrid
from zarr.core.common import JSON, NamedConfig, parse_named_configuration
from zarr.registry import get_chunk_grid_class, register_chunk_grid

register_chunk_grid("regular", RegularChunkGrid)


def parse_chunk_grid(
data: dict[str, JSON] | ChunkGrid | NamedConfig[str, Any],
) -> ChunkGrid:
"""Parse a chunk grid from a dictionary, returning existing ChunkGrid instances as-is.

Uses the chunk grid registry to look up the appropriate class by name.

Parameters
----------
data : dict[str, JSON] | ChunkGrid | NamedConfig[str, Any]
Either a ChunkGrid instance (returned as-is) or a dictionary with
'name' and 'configuration' keys.

Returns
-------
ChunkGrid

Raises
------
ValueError
If the chunk grid name is not found in the registry.
"""
if isinstance(data, ChunkGrid):
return data

name_parsed, _ = parse_named_configuration(data)
try:
chunk_grid_cls = get_chunk_grid_class(name_parsed)
except KeyError as e:
raise ValueError(f"Unknown chunk grid. Got {name_parsed}.") from e
return chunk_grid_cls.from_dict(data) # type: ignore[arg-type]


__all__ = [
"ChunkGrid",
"RegularChunkGrid",
"_auto_partition",
"_guess_chunks",
"_guess_num_chunks_per_axis_shard",
"normalize_chunks",
"parse_chunk_grid",
]
199 changes: 136 additions & 63 deletions src/zarr/core/chunk_grids.py → src/zarr/core/chunk_grids/common.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,159 @@
from __future__ import annotations

import itertools
import math
import numbers
import operator
import warnings
from abc import abstractmethod
from dataclasses import dataclass
from functools import reduce
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
import numpy.typing as npt

import zarr
from zarr.abc.metadata import Metadata
from zarr.core.common import (
JSON,
NamedConfig,
ShapeLike,
ceildiv,
parse_named_configuration,
parse_shapelike,
)
from zarr.errors import ZarrUserWarning

if TYPE_CHECKING:
from collections.abc import Iterator
from typing import Self

from zarr.core.array import ShardsLike
from zarr.core.common import JSON


@dataclass(frozen=True)
class ChunkGrid(Metadata):
@abstractmethod
def to_dict(self) -> dict[str, JSON]: ...

@abstractmethod
def update_shape(self, new_shape: tuple[int, ...]) -> Self:
pass

@abstractmethod
def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
pass

@abstractmethod
def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
pass

@abstractmethod
def get_chunk_shape(
self, array_shape: tuple[int, ...], chunk_coord: tuple[int, ...]
) -> tuple[int, ...]:
"""
Get the shape of a specific chunk.

Parameters
----------
array_shape : tuple[int, ...]
Shape of the full array.
chunk_coord : tuple[int, ...]
Coordinates of the chunk in the chunk grid.

Returns
-------
tuple[int, ...]
Shape of the chunk at the given coordinates.
"""

@abstractmethod
def get_chunk_start(
self, array_shape: tuple[int, ...], chunk_coord: tuple[int, ...]
) -> tuple[int, ...]:
"""
Get the starting position of a chunk in the array.

Parameters
----------
array_shape : tuple[int, ...]
Shape of the full array.
chunk_coord : tuple[int, ...]
Coordinates of the chunk in the chunk grid.

Returns
-------
tuple[int, ...]
Starting position (offset) of the chunk in the array.
"""

@abstractmethod
def array_index_to_chunk_coord(
self, array_shape: tuple[int, ...], array_index: tuple[int, ...]
) -> tuple[int, ...]:
"""
Map an array index to the chunk coordinates that contain it.

Parameters
----------
array_shape : tuple[int, ...]
Shape of the full array.
array_index : tuple[int, ...]
Index in the array.

Returns
-------
tuple[int, ...]
Coordinates of the chunk containing the array index.
"""

@abstractmethod
def array_indices_to_chunk_dim(
self, array_shape: tuple[int, ...], dim: int, indices: npt.NDArray[np.intp]
) -> npt.NDArray[np.intp]:
"""
Map an array of indices along one dimension to chunk coordinates (vectorized).

Parameters
----------
array_shape : tuple[int, ...]
Shape of the full array.
dim : int
Dimension index.
indices : np.ndarray
Array of indices along the given dimension.

Returns
-------
np.ndarray
Array of chunk coordinates, same shape as indices.
"""

@abstractmethod
def chunks_per_dim(self, array_shape: tuple[int, ...], dim: int) -> int:
"""
Get the number of chunks along a specific dimension.

Parameters
----------
array_shape : tuple[int, ...]
Shape of the full array.
dim : int
Dimension index.

Returns
-------
int
Number of chunks along the dimension.
"""

@abstractmethod
def get_chunk_grid_shape(self, array_shape: tuple[int, ...]) -> tuple[int, ...]:
"""
Get the shape of the chunk grid (number of chunks along each dimension).

Parameters
----------
array_shape : tuple[int, ...]
Shape of the full array.

Returns
-------
tuple[int, ...]
Number of chunks along each dimension.
"""


def _guess_chunks(
Expand Down Expand Up @@ -153,58 +278,6 @@ def normalize_chunks(chunks: Any, shape: tuple[int, ...], typesize: int) -> tupl
return tuple(int(c) for c in chunks)


@dataclass(frozen=True)
class ChunkGrid(Metadata):
@classmethod
def from_dict(cls, data: dict[str, JSON] | ChunkGrid | NamedConfig[str, Any]) -> ChunkGrid:
if isinstance(data, ChunkGrid):
return data

name_parsed, _ = parse_named_configuration(data)
if name_parsed == "regular":
return RegularChunkGrid._from_dict(data)
raise ValueError(f"Unknown chunk grid. Got {name_parsed}.")

@abstractmethod
def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
pass

@abstractmethod
def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
pass


@dataclass(frozen=True)
class RegularChunkGrid(ChunkGrid):
chunk_shape: tuple[int, ...]

def __init__(self, *, chunk_shape: ShapeLike) -> None:
chunk_shape_parsed = parse_shapelike(chunk_shape)

object.__setattr__(self, "chunk_shape", chunk_shape_parsed)

@classmethod
def _from_dict(cls, data: dict[str, JSON] | NamedConfig[str, Any]) -> Self:
_, configuration_parsed = parse_named_configuration(data, "regular")

return cls(**configuration_parsed) # type: ignore[arg-type]

def to_dict(self) -> dict[str, JSON]:
return {"name": "regular", "configuration": {"chunk_shape": tuple(self.chunk_shape)}}

def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
return itertools.product(
*(range(ceildiv(s, c)) for s, c in zip(array_shape, self.chunk_shape, strict=False))
)

def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
return reduce(
operator.mul,
itertools.starmap(ceildiv, zip(array_shape, self.chunk_shape, strict=True)),
1,
)


def _guess_num_chunks_per_axis_shard(
chunk_shape: tuple[int, ...], item_size: int, max_bytes: int, array_shape: tuple[int, ...]
) -> int:
Expand Down
98 changes: 98 additions & 0 deletions src/zarr/core/chunk_grids/regular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import annotations

import itertools
import operator
from dataclasses import dataclass
from functools import reduce
from typing import TYPE_CHECKING, Any

import numpy as np
import numpy.typing as npt

from zarr.core.chunk_grids.common import ChunkGrid
from zarr.core.common import (
JSON,
NamedConfig,
ShapeLike,
ceildiv,
parse_named_configuration,
parse_shapelike,
)

if TYPE_CHECKING:
from collections.abc import Iterator
from typing import Self


@dataclass(frozen=True)
class RegularChunkGrid(ChunkGrid):
chunk_shape: tuple[int, ...]

def __init__(self, *, chunk_shape: ShapeLike) -> None:
chunk_shape_parsed = parse_shapelike(chunk_shape)

object.__setattr__(self, "chunk_shape", chunk_shape_parsed)

@classmethod
def from_dict(cls, data: dict[str, JSON] | NamedConfig[str, Any]) -> Self:
_, configuration_parsed = parse_named_configuration(data, "regular")

return cls(**configuration_parsed) # type: ignore[arg-type]

def to_dict(self) -> dict[str, JSON]:
return {"name": "regular", "configuration": {"chunk_shape": tuple(self.chunk_shape)}}

def update_shape(self, new_shape: tuple[int, ...]) -> Self:
return self

def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
return itertools.product(
*(range(ceildiv(s, c)) for s, c in zip(array_shape, self.chunk_shape, strict=False))
)

def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
return reduce(
operator.mul,
itertools.starmap(ceildiv, zip(array_shape, self.chunk_shape, strict=True)),
1,
)

def get_chunk_shape(
self, array_shape: tuple[int, ...], chunk_coord: tuple[int, ...]
) -> tuple[int, ...]:
return tuple(
int(min(self.chunk_shape[i], array_shape[i] - chunk_coord[i] * self.chunk_shape[i]))
for i in range(len(array_shape))
)

def get_chunk_start(
self, array_shape: tuple[int, ...], chunk_coord: tuple[int, ...]
) -> tuple[int, ...]:
return tuple(
coord * size for coord, size in zip(chunk_coord, self.chunk_shape, strict=False)
)

def array_index_to_chunk_coord(
self, array_shape: tuple[int, ...], array_index: tuple[int, ...]
) -> tuple[int, ...]:
return tuple(
0 if size == 0 else idx // size
for idx, size in zip(array_index, self.chunk_shape, strict=False)
)

def array_indices_to_chunk_dim(
self, array_shape: tuple[int, ...], dim: int, indices: npt.NDArray[np.intp]
) -> npt.NDArray[np.intp]:
chunk_size = self.chunk_shape[dim]
if chunk_size == 0:
return np.zeros_like(indices)
return indices // chunk_size

def chunks_per_dim(self, array_shape: tuple[int, ...], dim: int) -> int:
return ceildiv(array_shape[dim], self.chunk_shape[dim])

def get_chunk_grid_shape(self, array_shape: tuple[int, ...]) -> tuple[int, ...]:
return tuple(
ceildiv(array_len, chunk_len)
for array_len, chunk_len in zip(array_shape, self.chunk_shape, strict=False)
)
Loading