Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ async def _decode_single(
indexer = BasicIndexer(
tuple(slice(0, s) for s in shard_shape),
shape=shard_shape,
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape, array_shape=shard_shape),
)

# setup output array
Expand Down Expand Up @@ -461,7 +461,7 @@ async def _decode_partial_single(
indexer = get_indexer(
selection,
shape=shard_shape,
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape, array_shape=shard_shape),
)

# setup output array
Expand Down Expand Up @@ -536,7 +536,7 @@ async def _encode_single(
BasicIndexer(
tuple(slice(0, s) for s in shard_shape),
shape=shard_shape,
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape, array_shape=shard_shape),
)
)

Expand Down Expand Up @@ -585,7 +585,9 @@ async def _encode_partial_single(

indexer = list(
get_indexer(
selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape)
selection,
shape=shard_shape,
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape, array_shape=shard_shape),
)
)

Expand Down
10 changes: 6 additions & 4 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def _create_metadata_v3(
else:
fill_value_parsed = fill_value

chunk_grid_parsed = RegularChunkGrid(chunk_shape=chunk_shape)
chunk_grid_parsed = RegularChunkGrid(chunk_shape=chunk_shape, array_shape=shape)
return ArrayV3Metadata(
shape=shape,
data_type=dtype,
Expand Down Expand Up @@ -4694,7 +4694,9 @@ async def init_array(
sharding_codec.validate(
shape=chunk_shape_parsed,
dtype=zdtype,
chunk_grid=RegularChunkGrid(chunk_shape=shard_shape_parsed),
chunk_grid=RegularChunkGrid(
chunk_shape=shard_shape_parsed, array_shape=chunk_shape_parsed
),
)
codecs_out = (sharding_codec,)
chunks_out = shard_shape_parsed
Expand Down Expand Up @@ -5995,8 +5997,8 @@ async def _resize(

if delete_outside_chunks and not only_growing:
# Remove all chunks outside of the new shape
old_chunk_coords = set(array.metadata.chunk_grid.all_chunk_coords(array.metadata.shape))
new_chunk_coords = set(array.metadata.chunk_grid.all_chunk_coords(new_shape))
old_chunk_coords = set(array.metadata.chunk_grid.all_chunk_coords())
new_chunk_coords = set(new_metadata.chunk_grid.all_chunk_coords())

async def _delete_key(key: str) -> None:
await (array.store_path / key).delete()
Expand Down
64 changes: 64 additions & 0 deletions src/zarr/core/chunk_grids/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

from typing import Any

from zarr.core.chunk_grids.common import (
ChunkGrid,
_auto_partition,
_guess_chunks,
_guess_num_chunks_per_axis_shard,
normalize_chunks,
)
from zarr.core.chunk_grids.regular import RegularChunkGrid
from zarr.core.common import JSON, NamedConfig, parse_named_configuration
from zarr.registry import get_chunk_grid_class, register_chunk_grid

register_chunk_grid("regular", RegularChunkGrid)


def parse_chunk_grid(
data: dict[str, JSON] | ChunkGrid | NamedConfig[str, Any],
*,
array_shape: tuple[int, ...],
) -> ChunkGrid:
"""Parse a chunk grid from a dictionary, returning existing ChunkGrid instances as-is.

Uses the chunk grid registry to look up the appropriate class by name.

Parameters
----------
data : dict[str, JSON] | ChunkGrid | NamedConfig[str, Any]
Either a ChunkGrid instance (returned as-is) or a dictionary with
'name' and 'configuration' keys.
array_shape : tuple[int, ...]
The shape of the array this chunk grid is bound to.

Returns
-------
ChunkGrid

Raises
------
ValueError
If the chunk grid name is not found in the registry.
"""
if isinstance(data, ChunkGrid):
return data

name_parsed, _ = parse_named_configuration(data)
try:
chunk_grid_cls = get_chunk_grid_class(name_parsed)
except KeyError as e:
raise ValueError(f"Unknown chunk grid. Got {name_parsed}.") from e
return chunk_grid_cls.from_dict(data, array_shape=array_shape) # type: ignore[arg-type, call-arg]


__all__ = [
"ChunkGrid",
"RegularChunkGrid",
"_auto_partition",
"_guess_chunks",
"_guess_num_chunks_per_axis_shard",
"normalize_chunks",
"parse_chunk_grid",
]
125 changes: 58 additions & 67 deletions src/zarr/core/chunk_grids.py → src/zarr/core/chunk_grids/common.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,77 @@
from __future__ import annotations

import itertools
import math
import numbers
import operator
import warnings
from abc import abstractmethod
from dataclasses import dataclass
from functools import reduce
from typing import TYPE_CHECKING, Any, Literal
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Literal, Self

import numpy as np
import numpy.typing as npt

import zarr
from zarr.abc.metadata import Metadata
from zarr.core.common import (
JSON,
NamedConfig,
ShapeLike,
ceildiv,
parse_named_configuration,
parse_shapelike,
)
from zarr.errors import ZarrUserWarning

if TYPE_CHECKING:
from collections.abc import Iterator
from typing import Self

from zarr.core.array import ShardsLike
from zarr.core.common import JSON


class ChunkGrid(ABC, Metadata):
@property
@abstractmethod
def array_shape(self) -> tuple[int, ...]:
"""The shape of the array this chunk grid is bound to."""
...

@abstractmethod
def to_dict(self) -> dict[str, JSON]: ...

@abstractmethod
def update_shape(self, new_shape: tuple[int, ...]) -> Self:
"""Return a new ChunkGrid with the given array_shape."""
...

@abstractmethod
def all_chunk_coords(self) -> Iterator[tuple[int, ...]]: ...

@abstractmethod
def get_nchunks(self) -> int: ...

@abstractmethod
def get_chunk_shape(self, chunk_coord: tuple[int, ...]) -> tuple[int, ...]:
"""Get the shape of a specific chunk."""
...

@abstractmethod
def get_chunk_start(self, chunk_coord: tuple[int, ...]) -> tuple[int, ...]:
"""Get the starting position of a chunk in the array."""
...

@abstractmethod
def array_index_to_chunk_coord(self, array_index: tuple[int, ...]) -> tuple[int, ...]:
"""Map an array index to the chunk coordinates that contain it."""
...

@abstractmethod
def array_indices_to_chunk_dim(
self, dim: int, indices: npt.NDArray[np.intp]
) -> npt.NDArray[np.intp]:
"""Map an array of indices along one dimension to chunk coordinates (vectorized)."""
...

@abstractmethod
def chunks_per_dim(self, dim: int) -> int:
"""Get the number of chunks along a specific dimension."""
...

@abstractmethod
def get_chunk_grid_shape(self) -> tuple[int, ...]:
"""Get the shape of the chunk grid (number of chunks along each dimension)."""
...


def _guess_chunks(
Expand Down Expand Up @@ -153,58 +196,6 @@ def normalize_chunks(chunks: Any, shape: tuple[int, ...], typesize: int) -> tupl
return tuple(int(c) for c in chunks)


@dataclass(frozen=True)
class ChunkGrid(Metadata):
@classmethod
def from_dict(cls, data: dict[str, JSON] | ChunkGrid | NamedConfig[str, Any]) -> ChunkGrid:
if isinstance(data, ChunkGrid):
return data

name_parsed, _ = parse_named_configuration(data)
if name_parsed == "regular":
return RegularChunkGrid._from_dict(data)
raise ValueError(f"Unknown chunk grid. Got {name_parsed}.")

@abstractmethod
def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
pass

@abstractmethod
def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
pass


@dataclass(frozen=True)
class RegularChunkGrid(ChunkGrid):
chunk_shape: tuple[int, ...]

def __init__(self, *, chunk_shape: ShapeLike) -> None:
chunk_shape_parsed = parse_shapelike(chunk_shape)

object.__setattr__(self, "chunk_shape", chunk_shape_parsed)

@classmethod
def _from_dict(cls, data: dict[str, JSON] | NamedConfig[str, Any]) -> Self:
_, configuration_parsed = parse_named_configuration(data, "regular")

return cls(**configuration_parsed) # type: ignore[arg-type]

def to_dict(self) -> dict[str, JSON]:
return {"name": "regular", "configuration": {"chunk_shape": tuple(self.chunk_shape)}}

def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
return itertools.product(
*(range(ceildiv(s, c)) for s, c in zip(array_shape, self.chunk_shape, strict=False))
)

def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
return reduce(
operator.mul,
itertools.starmap(ceildiv, zip(array_shape, self.chunk_shape, strict=True)),
1,
)


def _guess_num_chunks_per_axis_shard(
chunk_shape: tuple[int, ...], item_size: int, max_bytes: int, array_shape: tuple[int, ...]
) -> int:
Expand Down
Loading