diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 85162c2f74..d3129b6cde 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -415,7 +415,7 @@ async def _decode_single( indexer = BasicIndexer( tuple(slice(0, s) for s in shard_shape), shape=shard_shape, - chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), + chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape, array_shape=shard_shape), ) # setup output array @@ -461,7 +461,7 @@ async def _decode_partial_single( indexer = get_indexer( selection, shape=shard_shape, - chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), + chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape, array_shape=shard_shape), ) # setup output array @@ -536,7 +536,7 @@ async def _encode_single( BasicIndexer( tuple(slice(0, s) for s in shard_shape), shape=shard_shape, - chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), + chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape, array_shape=shard_shape), ) ) @@ -585,7 +585,9 @@ async def _encode_partial_single( indexer = list( get_indexer( - selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape) + selection, + shape=shard_shape, + chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape, array_shape=shard_shape), ) ) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 564d0e915a..525b4d1465 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -774,7 +774,7 @@ def _create_metadata_v3( else: fill_value_parsed = fill_value - chunk_grid_parsed = RegularChunkGrid(chunk_shape=chunk_shape) + chunk_grid_parsed = RegularChunkGrid(chunk_shape=chunk_shape, array_shape=shape) return ArrayV3Metadata( shape=shape, data_type=dtype, @@ -4694,7 +4694,9 @@ async def init_array( sharding_codec.validate( shape=chunk_shape_parsed, dtype=zdtype, - chunk_grid=RegularChunkGrid(chunk_shape=shard_shape_parsed), + chunk_grid=RegularChunkGrid( + chunk_shape=shard_shape_parsed, array_shape=chunk_shape_parsed + ), ) codecs_out = (sharding_codec,) chunks_out = shard_shape_parsed @@ -5995,8 +5997,8 @@ async def _resize( if delete_outside_chunks and not only_growing: # Remove all chunks outside of the new shape - old_chunk_coords = set(array.metadata.chunk_grid.all_chunk_coords(array.metadata.shape)) - new_chunk_coords = set(array.metadata.chunk_grid.all_chunk_coords(new_shape)) + old_chunk_coords = set(array.metadata.chunk_grid.all_chunk_coords()) + new_chunk_coords = set(new_metadata.chunk_grid.all_chunk_coords()) async def _delete_key(key: str) -> None: await (array.store_path / key).delete() diff --git a/src/zarr/core/chunk_grids/__init__.py b/src/zarr/core/chunk_grids/__init__.py new file mode 100644 index 0000000000..febda89169 --- /dev/null +++ b/src/zarr/core/chunk_grids/__init__.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from typing import Any + +from zarr.core.chunk_grids.common import ( + ChunkGrid, + _auto_partition, + _guess_chunks, + _guess_num_chunks_per_axis_shard, + normalize_chunks, +) +from zarr.core.chunk_grids.regular import RegularChunkGrid +from zarr.core.common import JSON, NamedConfig, parse_named_configuration +from zarr.registry import get_chunk_grid_class, register_chunk_grid + +register_chunk_grid("regular", RegularChunkGrid) + + +def parse_chunk_grid( + data: dict[str, JSON] | ChunkGrid | NamedConfig[str, Any], + *, + array_shape: tuple[int, ...], +) -> ChunkGrid: + """Parse a chunk grid from a dictionary, returning existing ChunkGrid instances as-is. + + Uses the chunk grid registry to look up the appropriate class by name. + + Parameters + ---------- + data : dict[str, JSON] | ChunkGrid | NamedConfig[str, Any] + Either a ChunkGrid instance (returned as-is) or a dictionary with + 'name' and 'configuration' keys. + array_shape : tuple[int, ...] + The shape of the array this chunk grid is bound to. + + Returns + ------- + ChunkGrid + + Raises + ------ + ValueError + If the chunk grid name is not found in the registry. + """ + if isinstance(data, ChunkGrid): + return data + + name_parsed, _ = parse_named_configuration(data) + try: + chunk_grid_cls = get_chunk_grid_class(name_parsed) + except KeyError as e: + raise ValueError(f"Unknown chunk grid. Got {name_parsed}.") from e + return chunk_grid_cls.from_dict(data, array_shape=array_shape) # type: ignore[arg-type, call-arg] + + +__all__ = [ + "ChunkGrid", + "RegularChunkGrid", + "_auto_partition", + "_guess_chunks", + "_guess_num_chunks_per_axis_shard", + "normalize_chunks", + "parse_chunk_grid", +] diff --git a/src/zarr/core/chunk_grids.py b/src/zarr/core/chunk_grids/common.py similarity index 80% rename from src/zarr/core/chunk_grids.py rename to src/zarr/core/chunk_grids/common.py index 2c7945fa64..12c93e910f 100644 --- a/src/zarr/core/chunk_grids.py +++ b/src/zarr/core/chunk_grids/common.py @@ -1,34 +1,77 @@ from __future__ import annotations -import itertools import math import numbers -import operator import warnings -from abc import abstractmethod -from dataclasses import dataclass -from functools import reduce -from typing import TYPE_CHECKING, Any, Literal +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Literal, Self import numpy as np +import numpy.typing as npt import zarr from zarr.abc.metadata import Metadata -from zarr.core.common import ( - JSON, - NamedConfig, - ShapeLike, - ceildiv, - parse_named_configuration, - parse_shapelike, -) from zarr.errors import ZarrUserWarning if TYPE_CHECKING: from collections.abc import Iterator - from typing import Self from zarr.core.array import ShardsLike + from zarr.core.common import JSON + + +class ChunkGrid(ABC, Metadata): + @property + @abstractmethod + def array_shape(self) -> tuple[int, ...]: + """The shape of the array this chunk grid is bound to.""" + ... + + @abstractmethod + def to_dict(self) -> dict[str, JSON]: ... + + @abstractmethod + def update_shape(self, new_shape: tuple[int, ...]) -> Self: + """Return a new ChunkGrid with the given array_shape.""" + ... + + @abstractmethod + def all_chunk_coords(self) -> Iterator[tuple[int, ...]]: ... + + @abstractmethod + def get_nchunks(self) -> int: ... + + @abstractmethod + def get_chunk_shape(self, chunk_coord: tuple[int, ...]) -> tuple[int, ...]: + """Get the shape of a specific chunk.""" + ... + + @abstractmethod + def get_chunk_start(self, chunk_coord: tuple[int, ...]) -> tuple[int, ...]: + """Get the starting position of a chunk in the array.""" + ... + + @abstractmethod + def array_index_to_chunk_coord(self, array_index: tuple[int, ...]) -> tuple[int, ...]: + """Map an array index to the chunk coordinates that contain it.""" + ... + + @abstractmethod + def array_indices_to_chunk_dim( + self, dim: int, indices: npt.NDArray[np.intp] + ) -> npt.NDArray[np.intp]: + """Map an array of indices along one dimension to chunk coordinates (vectorized).""" + ... + + @abstractmethod + def chunks_per_dim(self, dim: int) -> int: + """Get the number of chunks along a specific dimension.""" + ... + + @abstractmethod + def get_chunk_grid_shape(self) -> tuple[int, ...]: + """Get the shape of the chunk grid (number of chunks along each dimension).""" + ... def _guess_chunks( @@ -153,58 +196,6 @@ def normalize_chunks(chunks: Any, shape: tuple[int, ...], typesize: int) -> tupl return tuple(int(c) for c in chunks) -@dataclass(frozen=True) -class ChunkGrid(Metadata): - @classmethod - def from_dict(cls, data: dict[str, JSON] | ChunkGrid | NamedConfig[str, Any]) -> ChunkGrid: - if isinstance(data, ChunkGrid): - return data - - name_parsed, _ = parse_named_configuration(data) - if name_parsed == "regular": - return RegularChunkGrid._from_dict(data) - raise ValueError(f"Unknown chunk grid. Got {name_parsed}.") - - @abstractmethod - def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: - pass - - @abstractmethod - def get_nchunks(self, array_shape: tuple[int, ...]) -> int: - pass - - -@dataclass(frozen=True) -class RegularChunkGrid(ChunkGrid): - chunk_shape: tuple[int, ...] - - def __init__(self, *, chunk_shape: ShapeLike) -> None: - chunk_shape_parsed = parse_shapelike(chunk_shape) - - object.__setattr__(self, "chunk_shape", chunk_shape_parsed) - - @classmethod - def _from_dict(cls, data: dict[str, JSON] | NamedConfig[str, Any]) -> Self: - _, configuration_parsed = parse_named_configuration(data, "regular") - - return cls(**configuration_parsed) # type: ignore[arg-type] - - def to_dict(self) -> dict[str, JSON]: - return {"name": "regular", "configuration": {"chunk_shape": tuple(self.chunk_shape)}} - - def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: - return itertools.product( - *(range(ceildiv(s, c)) for s, c in zip(array_shape, self.chunk_shape, strict=False)) - ) - - def get_nchunks(self, array_shape: tuple[int, ...]) -> int: - return reduce( - operator.mul, - itertools.starmap(ceildiv, zip(array_shape, self.chunk_shape, strict=True)), - 1, - ) - - def _guess_num_chunks_per_axis_shard( chunk_shape: tuple[int, ...], item_size: int, max_bytes: int, array_shape: tuple[int, ...] ) -> int: diff --git a/src/zarr/core/chunk_grids/regular.py b/src/zarr/core/chunk_grids/regular.py new file mode 100644 index 0000000000..380cb86ed5 --- /dev/null +++ b/src/zarr/core/chunk_grids/regular.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import itertools +import operator +from dataclasses import dataclass +from functools import reduce +from typing import TYPE_CHECKING, Any, Self + +import numpy as np +import numpy.typing as npt + +from zarr.core.chunk_grids.common import ChunkGrid +from zarr.core.common import ( + JSON, + NamedConfig, + ShapeLike, + ceildiv, + parse_named_configuration, + parse_shapelike, +) + +if TYPE_CHECKING: + from collections.abc import Iterator + + +@dataclass(frozen=True) +class RegularChunkGrid(ChunkGrid): + _array_shape: tuple[int, ...] + chunk_shape: tuple[int, ...] + + def __init__(self, *, chunk_shape: ShapeLike, array_shape: ShapeLike) -> None: + chunk_shape_parsed = parse_shapelike(chunk_shape) + array_shape_parsed = parse_shapelike(array_shape) + + object.__setattr__(self, "chunk_shape", chunk_shape_parsed) + object.__setattr__(self, "_array_shape", array_shape_parsed) + + @property + def array_shape(self) -> tuple[int, ...]: + return self._array_shape + + @classmethod + def from_dict( # type: ignore[override] + cls, data: dict[str, JSON] | NamedConfig[str, Any], *, array_shape: ShapeLike + ) -> Self: + _, configuration_parsed = parse_named_configuration(data, "regular") + + return cls(**configuration_parsed, array_shape=array_shape) # type: ignore[arg-type] + + def to_dict(self) -> dict[str, JSON]: + return {"name": "regular", "configuration": {"chunk_shape": tuple(self.chunk_shape)}} + + def update_shape(self, new_shape: tuple[int, ...]) -> Self: + return type(self)(chunk_shape=self.chunk_shape, array_shape=new_shape) + + def all_chunk_coords(self) -> Iterator[tuple[int, ...]]: + return itertools.product( + *( + range(ceildiv(s, c)) + for s, c in zip(self._array_shape, self.chunk_shape, strict=False) + ) + ) + + def get_nchunks(self) -> int: + return reduce( + operator.mul, + itertools.starmap(ceildiv, zip(self._array_shape, self.chunk_shape, strict=True)), + 1, + ) + + def get_chunk_shape(self, chunk_coord: tuple[int, ...]) -> tuple[int, ...]: + """ + Get the shape of a specific chunk. + + For RegularChunkGrid, all chunks have the same shape except possibly + the last chunk in each dimension. + """ + return tuple( + int( + min( + self.chunk_shape[i], + self._array_shape[i] - chunk_coord[i] * self.chunk_shape[i], + ) + ) + for i in range(len(self._array_shape)) + ) + + def get_chunk_start(self, chunk_coord: tuple[int, ...]) -> tuple[int, ...]: + """ + Get the starting position of a chunk in the array. + + For RegularChunkGrid, this is simply chunk_coord * chunk_shape. + """ + return tuple( + coord * size for coord, size in zip(chunk_coord, self.chunk_shape, strict=False) + ) + + def array_index_to_chunk_coord(self, array_index: tuple[int, ...]) -> tuple[int, ...]: + """ + Map an array index to chunk coordinates. + + For RegularChunkGrid, this is simply array_index // chunk_shape. + """ + return tuple( + 0 if size == 0 else idx // size + for idx, size in zip(array_index, self.chunk_shape, strict=False) + ) + + def array_indices_to_chunk_dim( + self, dim: int, indices: npt.NDArray[np.intp] + ) -> npt.NDArray[np.intp]: + """ + Vectorized mapping of array indices to chunk coordinates along one dimension. + + For RegularChunkGrid, this is simply indices // chunk_size. + """ + chunk_size = self.chunk_shape[dim] + if chunk_size == 0: + return np.zeros_like(indices) + return indices // chunk_size + + def chunks_per_dim(self, dim: int) -> int: + """ + Get the number of chunks along a specific dimension. + + For RegularChunkGrid, this is ceildiv(array_shape[dim], chunk_shape[dim]). + """ + return ceildiv(self._array_shape[dim], self.chunk_shape[dim]) + + def get_chunk_grid_shape(self) -> tuple[int, ...]: + """ + Get the shape of the chunk grid (number of chunks along each dimension). + + For RegularChunkGrid, this is computed using ceildiv for each dimension. + """ + return tuple( + ceildiv(array_len, chunk_len) + for array_len, chunk_len in zip(self._array_shape, self.chunk_shape, strict=False) + ) diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index 73fd53087d..d6e38d55ac 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -1,7 +1,6 @@ from __future__ import annotations import itertools -import math import numbers import operator from collections.abc import Iterator, Sequence @@ -332,15 +331,6 @@ def is_pure_orthogonal_indexing(selection: Selection, ndim: int) -> TypeGuard[Or ) -def get_chunk_shape(chunk_grid: ChunkGrid) -> tuple[int, ...]: - from zarr.core.chunk_grids import RegularChunkGrid - - assert isinstance(chunk_grid, RegularChunkGrid), ( - "Only regular chunk grid is supported, currently." - ) - return chunk_grid.chunk_shape - - def normalize_integer_selection(dim_sel: int, dim_len: int) -> int: # normalize type to int dim_sel = int(dim_sel) @@ -380,35 +370,70 @@ class ChunkDimProjection(NamedTuple): class IntDimIndexer: dim_sel: int dim_len: int - dim_chunk_len: int + dim: int + array_shape: tuple[int, ...] + chunk_grid: ChunkGrid nitems: int = 1 - def __init__(self, dim_sel: int, dim_len: int, dim_chunk_len: int) -> None: + def __init__( + self, + dim_sel: int, + dim_len: int, + dim: int, + array_shape: tuple[int, ...], + chunk_grid: ChunkGrid, + ) -> None: object.__setattr__(self, "dim_sel", normalize_integer_selection(dim_sel, dim_len)) object.__setattr__(self, "dim_len", dim_len) - object.__setattr__(self, "dim_chunk_len", dim_chunk_len) + object.__setattr__(self, "dim", dim) + object.__setattr__(self, "array_shape", array_shape) + object.__setattr__(self, "chunk_grid", chunk_grid) def __iter__(self) -> Iterator[ChunkDimProjection]: - dim_chunk_ix = self.dim_sel // self.dim_chunk_len - dim_offset = dim_chunk_ix * self.dim_chunk_len + # Create a full array index with zeros except at this dimension + full_index = tuple( + self.dim_sel if i == self.dim else 0 for i in range(len(self.array_shape)) + ) + + # Use chunk grid to find which chunk contains this index + chunk_coords = self.chunk_grid.array_index_to_chunk_coord(full_index) + dim_chunk_ix = chunk_coords[self.dim] + + # Get the starting position of this chunk + chunk_start = self.chunk_grid.get_chunk_start(chunk_coords) + dim_offset = chunk_start[self.dim] + + # Calculate selection within the chunk dim_chunk_sel = self.dim_sel - dim_offset dim_out_sel = None - is_complete_chunk = self.dim_chunk_len == 1 + + # Check if this is a complete chunk (single element in this dimension) + chunk_shape = self.chunk_grid.get_chunk_shape(chunk_coords) + is_complete_chunk = chunk_shape[self.dim] == 1 + yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel, is_complete_chunk) @dataclass(frozen=True) class SliceDimIndexer: dim_len: int - dim_chunk_len: int + dim: int + array_shape: tuple[int, ...] + chunk_grid: ChunkGrid nitems: int - nchunks: int start: int stop: int step: int - def __init__(self, dim_sel: slice, dim_len: int, dim_chunk_len: int) -> None: + def __init__( + self, + dim_sel: slice, + dim_len: int, + dim: int, + array_shape: tuple[int, ...], + chunk_grid: ChunkGrid, + ) -> None: # normalize start, stop, step = dim_sel.indices(dim_len) if step < 1: @@ -419,23 +444,46 @@ def __init__(self, dim_sel: slice, dim_len: int, dim_chunk_len: int) -> None: object.__setattr__(self, "step", step) object.__setattr__(self, "dim_len", dim_len) - object.__setattr__(self, "dim_chunk_len", dim_chunk_len) + object.__setattr__(self, "dim", dim) + object.__setattr__(self, "array_shape", array_shape) + object.__setattr__(self, "chunk_grid", chunk_grid) object.__setattr__(self, "nitems", max(0, ceildiv((stop - start), step))) - object.__setattr__(self, "nchunks", ceildiv(dim_len, dim_chunk_len)) def __iter__(self) -> Iterator[ChunkDimProjection]: - # figure out the range of chunks we need to visit - dim_chunk_ix_from = 0 if self.start == 0 else self.start // self.dim_chunk_len - dim_chunk_ix_to = ceildiv(self.stop, self.dim_chunk_len) + # Get number of chunks along this dimension + nchunks = self.chunk_grid.chunks_per_dim(self.dim) - # iterate over chunks in range - for dim_chunk_ix in range(dim_chunk_ix_from, dim_chunk_ix_to): - # compute offsets for chunk within overall array - dim_offset = dim_chunk_ix * self.dim_chunk_len - dim_limit = min(self.dim_len, (dim_chunk_ix + 1) * self.dim_chunk_len) + # Find the range of chunks we need to visit + # Start: find chunk containing self.start + if self.start == 0: + dim_chunk_ix_from = 0 + else: + start_index = tuple( + self.start if i == self.dim else 0 for i in range(len(self.array_shape)) + ) + dim_chunk_ix_from = self.chunk_grid.array_index_to_chunk_coord(start_index)[self.dim] - # determine chunk length, accounting for trailing chunk - dim_chunk_len = dim_limit - dim_offset + # End: find chunk containing self.stop-1 (last index we need) + if self.stop == 0: + dim_chunk_ix_to = 0 + else: + end_index = tuple( + self.stop - 1 if i == self.dim else 0 for i in range(len(self.array_shape)) + ) + dim_chunk_ix_to = self.chunk_grid.array_index_to_chunk_coord(end_index)[self.dim] + 1 + + # Iterate over chunks in range + for dim_chunk_ix in range(dim_chunk_ix_from, min(dim_chunk_ix_to, nchunks)): + # Get chunk boundaries from chunk grid + chunk_coords = tuple( + dim_chunk_ix if i == self.dim else 0 for i in range(len(self.array_shape)) + ) + chunk_start = self.chunk_grid.get_chunk_start(chunk_coords) + chunk_shape = self.chunk_grid.get_chunk_shape(chunk_coords) + + dim_offset = chunk_start[self.dim] + dim_chunk_len = chunk_shape[self.dim] + dim_limit = dim_offset + dim_chunk_len if self.start < dim_offset: # selection starts before current chunk @@ -588,21 +636,18 @@ def __init__( shape: tuple[int, ...], chunk_grid: ChunkGrid, ) -> None: - chunk_shape = get_chunk_shape(chunk_grid) # handle ellipsis selection_normalized = replace_ellipsis(selection, shape) # setup per-dimension indexers dim_indexers: list[IntDimIndexer | SliceDimIndexer] = [] - for dim_sel, dim_len, dim_chunk_len in zip( - selection_normalized, shape, chunk_shape, strict=True - ): + for dim, (dim_sel, dim_len) in enumerate(zip(selection_normalized, shape, strict=True)): dim_indexer: IntDimIndexer | SliceDimIndexer if is_integer(dim_sel): - dim_indexer = IntDimIndexer(dim_sel, dim_len, dim_chunk_len) + dim_indexer = IntDimIndexer(dim_sel, dim_len, dim, shape, chunk_grid) elif is_slice(dim_sel): - dim_indexer = SliceDimIndexer(dim_sel, dim_len, dim_chunk_len) + dim_indexer = SliceDimIndexer(dim_sel, dim_len, dim, shape, chunk_grid) else: raise IndexError( @@ -635,15 +680,23 @@ def __iter__(self) -> Iterator[ChunkProjection]: class BoolArrayDimIndexer: dim_sel: npt.NDArray[np.bool_] dim_len: int - dim_chunk_len: int - nchunks: int + dim: int + array_shape: tuple[int, ...] + chunk_grid: ChunkGrid chunk_nitems: npt.NDArray[Any] chunk_nitems_cumsum: npt.NDArray[Any] nitems: int dim_chunk_ixs: npt.NDArray[np.intp] - def __init__(self, dim_sel: npt.NDArray[np.bool_], dim_len: int, dim_chunk_len: int) -> None: + def __init__( + self, + dim_sel: npt.NDArray[np.bool_], + dim_len: int, + dim: int, + array_shape: tuple[int, ...], + chunk_grid: ChunkGrid, + ) -> None: # check number of dimensions if not is_bool_array(dim_sel, 1): raise IndexError("Boolean arrays in an orthogonal selection must be 1-dimensional only") @@ -655,22 +708,32 @@ def __init__(self, dim_sel: npt.NDArray[np.bool_], dim_len: int, dim_chunk_len: ) # precompute number of selected items for each chunk - nchunks = ceildiv(dim_len, dim_chunk_len) + nchunks = chunk_grid.chunks_per_dim(dim) chunk_nitems = np.zeros(nchunks, dtype="i8") + for dim_chunk_ix in range(nchunks): - dim_offset = dim_chunk_ix * dim_chunk_len + # Get chunk boundaries from chunk grid + chunk_coords = tuple(dim_chunk_ix if i == dim else 0 for i in range(len(array_shape))) + chunk_start = chunk_grid.get_chunk_start(chunk_coords) + chunk_shape = chunk_grid.get_chunk_shape(chunk_coords) + + dim_offset = chunk_start[dim] + dim_chunk_len = chunk_shape[dim] + chunk_nitems[dim_chunk_ix] = np.count_nonzero( dim_sel[dim_offset : dim_offset + dim_chunk_len] ) + chunk_nitems_cumsum = np.cumsum(chunk_nitems) - nitems = chunk_nitems_cumsum[-1] + nitems = int(chunk_nitems_cumsum[-1]) if len(chunk_nitems_cumsum) > 0 else 0 dim_chunk_ixs = np.nonzero(chunk_nitems)[0] # store attributes object.__setattr__(self, "dim_sel", dim_sel) object.__setattr__(self, "dim_len", dim_len) - object.__setattr__(self, "dim_chunk_len", dim_chunk_len) - object.__setattr__(self, "nchunks", nchunks) + object.__setattr__(self, "dim", dim) + object.__setattr__(self, "array_shape", array_shape) + object.__setattr__(self, "chunk_grid", chunk_grid) object.__setattr__(self, "chunk_nitems", chunk_nitems) object.__setattr__(self, "chunk_nitems_cumsum", chunk_nitems_cumsum) object.__setattr__(self, "nitems", nitems) @@ -679,13 +742,22 @@ def __init__(self, dim_sel: npt.NDArray[np.bool_], dim_len: int, dim_chunk_len: def __iter__(self) -> Iterator[ChunkDimProjection]: # iterate over chunks with at least one item for dim_chunk_ix in self.dim_chunk_ixs: + # Get chunk boundaries from chunk grid + chunk_coords = tuple( + int(dim_chunk_ix) if i == self.dim else 0 for i in range(len(self.array_shape)) + ) + chunk_start = self.chunk_grid.get_chunk_start(chunk_coords) + chunk_shape = self.chunk_grid.get_chunk_shape(chunk_coords) + + dim_offset = chunk_start[self.dim] + dim_chunk_len = chunk_shape[self.dim] + # find region in chunk - dim_offset = dim_chunk_ix * self.dim_chunk_len - dim_chunk_sel = self.dim_sel[dim_offset : dim_offset + self.dim_chunk_len] + dim_chunk_sel = self.dim_sel[dim_offset : dim_offset + dim_chunk_len] # pad out if final chunk - if dim_chunk_sel.shape[0] < self.dim_chunk_len: - tmp = np.zeros(self.dim_chunk_len, dtype=bool) + if dim_chunk_sel.shape[0] < dim_chunk_len: + tmp = np.zeros(dim_chunk_len, dtype=bool) tmp[: dim_chunk_sel.shape[0]] = dim_chunk_sel dim_chunk_sel = tmp @@ -693,12 +765,14 @@ def __iter__(self) -> Iterator[ChunkDimProjection]: if dim_chunk_ix == 0: start = 0 else: - start = self.chunk_nitems_cumsum[dim_chunk_ix - 1] - stop = self.chunk_nitems_cumsum[dim_chunk_ix] + start = int(self.chunk_nitems_cumsum[dim_chunk_ix - 1]) + stop = int(self.chunk_nitems_cumsum[dim_chunk_ix]) dim_out_sel = slice(start, stop) is_complete_chunk = False # TODO - yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel, is_complete_chunk) + yield ChunkDimProjection( + int(dim_chunk_ix), dim_chunk_sel, dim_out_sel, is_complete_chunk + ) class Order(Enum): @@ -744,7 +818,9 @@ class IntArrayDimIndexer: """Integer array selection against a single dimension.""" dim_len: int - dim_chunk_len: int + dim: int + array_shape: tuple[int, ...] + chunk_grid: ChunkGrid nchunks: int nitems: int order: Order @@ -758,7 +834,9 @@ def __init__( self, dim_sel: npt.NDArray[np.intp], dim_len: int, - dim_chunk_len: int, + dim: int, + array_shape: tuple[int, ...], + chunk_grid: ChunkGrid, wraparound: bool = True, boundscheck: bool = True, order: Order = Order.UNKNOWN, @@ -769,7 +847,7 @@ def __init__( raise IndexError("integer arrays in an orthogonal selection must be 1-dimensional only") nitems = len(dim_sel) - nchunks = ceildiv(dim_len, dim_chunk_len) + nchunks = chunk_grid.chunks_per_dim(dim) # handle wraparound if wraparound: @@ -780,9 +858,7 @@ def __init__( boundscheck_indices(dim_sel, dim_len) # determine which chunk is needed for each selection item - # note: for dense integer selections, the division operation here is the - # bottleneck - dim_sel_chunk = dim_sel // dim_chunk_len + dim_sel_chunk = chunk_grid.array_indices_to_chunk_dim(dim, dim_sel) # determine order of indices if order == Order.UNKNOWN: @@ -811,7 +887,9 @@ def __init__( # store attributes object.__setattr__(self, "dim_len", dim_len) - object.__setattr__(self, "dim_chunk_len", dim_chunk_len) + object.__setattr__(self, "dim", dim) + object.__setattr__(self, "array_shape", array_shape) + object.__setattr__(self, "chunk_grid", chunk_grid) object.__setattr__(self, "nchunks", nchunks) object.__setattr__(self, "nitems", nitems) object.__setattr__(self, "order", order) @@ -835,8 +913,12 @@ def __iter__(self) -> Iterator[ChunkDimProjection]: else: dim_out_sel = self.dim_out_sel[start:stop] - # find region in chunk - dim_offset = dim_chunk_ix * self.dim_chunk_len + # find region in chunk - use chunk grid to get chunk boundaries + chunk_coords = tuple( + int(dim_chunk_ix) if i == self.dim else 0 for i in range(len(self.array_shape)) + ) + chunk_start = self.chunk_grid.get_chunk_start(chunk_coords) + dim_offset = chunk_start[self.dim] dim_chunk_sel = self.dim_sel[start:stop] - dim_offset is_complete_chunk = False # TODO yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel, is_complete_chunk) @@ -897,13 +979,12 @@ def oindex_set(a: npt.NDArray[Any], selection: Selection, value: Any) -> None: class OrthogonalIndexer(Indexer): dim_indexers: list[IntDimIndexer | SliceDimIndexer | IntArrayDimIndexer | BoolArrayDimIndexer] shape: tuple[int, ...] - chunk_shape: tuple[int, ...] + chunk_grid: ChunkGrid + array_shape: tuple[int, ...] is_advanced: bool drop_axes: tuple[int, ...] def __init__(self, selection: Selection, shape: tuple[int, ...], chunk_grid: ChunkGrid) -> None: - chunk_shape = get_chunk_shape(chunk_grid) - # handle ellipsis selection = replace_ellipsis(selection, shape) @@ -914,19 +995,19 @@ def __init__(self, selection: Selection, shape: tuple[int, ...], chunk_grid: Chu dim_indexers: list[ IntDimIndexer | SliceDimIndexer | IntArrayDimIndexer | BoolArrayDimIndexer ] = [] - for dim_sel, dim_len, dim_chunk_len in zip(selection, shape, chunk_shape, strict=True): + for dim, (dim_sel, dim_len) in enumerate(zip(selection, shape, strict=True)): dim_indexer: IntDimIndexer | SliceDimIndexer | IntArrayDimIndexer | BoolArrayDimIndexer if is_integer(dim_sel): - dim_indexer = IntDimIndexer(dim_sel, dim_len, dim_chunk_len) + dim_indexer = IntDimIndexer(dim_sel, dim_len, dim, shape, chunk_grid) elif isinstance(dim_sel, slice): - dim_indexer = SliceDimIndexer(dim_sel, dim_len, dim_chunk_len) + dim_indexer = SliceDimIndexer(dim_sel, dim_len, dim, shape, chunk_grid) elif is_integer_array(dim_sel): - dim_indexer = IntArrayDimIndexer(dim_sel, dim_len, dim_chunk_len) + dim_indexer = IntArrayDimIndexer(dim_sel, dim_len, dim, shape, chunk_grid) elif is_bool_array(dim_sel): - dim_indexer = BoolArrayDimIndexer(dim_sel, dim_len, dim_chunk_len) + dim_indexer = BoolArrayDimIndexer(dim_sel, dim_len, dim, shape, chunk_grid) else: raise IndexError( @@ -937,7 +1018,7 @@ def __init__(self, selection: Selection, shape: tuple[int, ...], chunk_grid: Chu dim_indexers.append(dim_indexer) - shape = tuple(s.nitems for s in dim_indexers if not isinstance(s, IntDimIndexer)) + output_shape = tuple(s.nitems for s in dim_indexers if not isinstance(s, IntDimIndexer)) is_advanced = not is_basic_selection(selection) if is_advanced: drop_axes = tuple( @@ -949,8 +1030,9 @@ def __init__(self, selection: Selection, shape: tuple[int, ...], chunk_grid: Chu drop_axes = () object.__setattr__(self, "dim_indexers", dim_indexers) - object.__setattr__(self, "shape", shape) - object.__setattr__(self, "chunk_shape", chunk_shape) + object.__setattr__(self, "shape", output_shape) + object.__setattr__(self, "chunk_grid", chunk_grid) + object.__setattr__(self, "array_shape", shape) object.__setattr__(self, "is_advanced", is_advanced) object.__setattr__(self, "drop_axes", drop_axes) @@ -970,7 +1052,9 @@ def __iter__(self) -> Iterator[ChunkProjection]: # so need to work around via np.ix_. Also np.ix_ does not support a # mixture of arrays and slices or integers, so need to convert slices # and integers into ranges. - chunk_selection = ix_(chunk_selection, self.chunk_shape) + # Query the actual chunk shape for this specific chunk + chunk_shape = self.chunk_grid.get_chunk_shape(chunk_coords) + chunk_selection = ix_(chunk_selection, chunk_shape) # special case for non-monotonic indices if not is_basic_selection(out_selection): @@ -1036,8 +1120,6 @@ class BlockIndexer(Indexer): def __init__( self, selection: BasicSelection, shape: tuple[int, ...], chunk_grid: ChunkGrid ) -> None: - chunk_shape = get_chunk_shape(chunk_grid) - # handle ellipsis selection_normalized = replace_ellipsis(selection, shape) @@ -1046,22 +1128,24 @@ def __init__( # setup per-dimension indexers dim_indexers = [] - for dim_sel, dim_len, dim_chunk_size in zip( - selection_normalized, shape, chunk_shape, strict=True - ): - dim_numchunks = int(np.ceil(dim_len / dim_chunk_size)) + for dim, (dim_sel, dim_len) in enumerate(zip(selection_normalized, shape, strict=True)): + dim_numchunks = chunk_grid.chunks_per_dim(dim) if is_integer(dim_sel): if dim_sel < 0: dim_sel = dim_numchunks + dim_sel - start = dim_sel * dim_chunk_size - stop = start + dim_chunk_size + # Use chunk grid to get the boundaries of this chunk (block) + chunk_coords = tuple(dim_sel if i == dim else 0 for i in range(len(shape))) + chunk_start_pos = chunk_grid.get_chunk_start(chunk_coords) + chunk_shape_here = chunk_grid.get_chunk_shape(chunk_coords) + start = chunk_start_pos[dim] + stop = start + chunk_shape_here[dim] slice_ = slice(start, stop) elif is_slice(dim_sel): - start = dim_sel.start if dim_sel.start is not None else 0 - stop = dim_sel.stop if dim_sel.stop is not None else dim_numchunks + start_block = dim_sel.start if dim_sel.start is not None else 0 + stop_block = dim_sel.stop if dim_sel.stop is not None else dim_numchunks if dim_sel.step not in {1, None}: raise IndexError( @@ -1071,13 +1155,26 @@ def __init__( # Can't reuse wraparound_indices because it expects a numpy array # We have integers here. - if start < 0: - start = dim_numchunks + start - if stop < 0: - stop = dim_numchunks + stop + if start_block < 0: + start_block = dim_numchunks + start_block + if stop_block < 0: + stop_block = dim_numchunks + stop_block + + # Convert block indices to array positions using chunk grid + start_chunk_coords = tuple( + start_block if i == dim else 0 for i in range(len(shape)) + ) + start_pos_tuple = chunk_grid.get_chunk_start(start_chunk_coords) + start = start_pos_tuple[dim] + + # For stop, get the end of the last chunk in the range + stop_chunk_coords = tuple( + stop_block - 1 if i == dim else 0 for i in range(len(shape)) + ) + stop_pos_tuple = chunk_grid.get_chunk_start(stop_chunk_coords) + stop_chunk_shape = chunk_grid.get_chunk_shape(stop_chunk_coords) + stop = stop_pos_tuple[dim] + stop_chunk_shape[dim] - start *= dim_chunk_size - stop *= dim_chunk_size slice_ = slice(start, stop) else: @@ -1086,17 +1183,17 @@ def __init__( f"expected integer or slice, got {type(dim_sel)!r}" ) - dim_indexer = SliceDimIndexer(slice_, dim_len, dim_chunk_size) + dim_indexer = SliceDimIndexer(slice_, dim_len, dim, shape, chunk_grid) dim_indexers.append(dim_indexer) if start >= dim_len or start < 0: msg = f"index out of bounds for dimension with length {dim_len}" raise BoundsCheckError(msg) - shape = tuple(s.nitems for s in dim_indexers) + output_shape = tuple(s.nitems for s in dim_indexers) object.__setattr__(self, "dim_indexers", dim_indexers) - object.__setattr__(self, "shape", shape) + object.__setattr__(self, "shape", output_shape) object.__setattr__(self, "drop_axes", ()) def __iter__(self) -> Iterator[ChunkProjection]: @@ -1157,19 +1254,19 @@ class CoordinateIndexer(Indexer): chunk_rixs: npt.NDArray[np.intp] chunk_mixs: tuple[npt.NDArray[np.intp], ...] shape: tuple[int, ...] - chunk_shape: tuple[int, ...] + chunk_grid: ChunkGrid + array_shape: tuple[int, ...] drop_axes: tuple[int, ...] def __init__( self, selection: CoordinateSelection, shape: tuple[int, ...], chunk_grid: ChunkGrid ) -> None: - chunk_shape = get_chunk_shape(chunk_grid) - + # Get chunk grid shape cdata_shape: tuple[int, ...] if shape == (): cdata_shape = (1,) else: - cdata_shape = tuple(math.ceil(s / c) for s, c in zip(shape, chunk_shape, strict=True)) + cdata_shape = chunk_grid.get_chunk_grid_shape() nchunks = reduce(operator.mul, cdata_shape, 1) # some initial normalization @@ -1197,23 +1294,19 @@ def __init__( # handle out of bounds boundscheck_indices(dim_sel, dim_len) - # compute chunk index for each point in the selection - chunks_multi_index = tuple( - dim_sel // dim_chunk_len - for (dim_sel, dim_chunk_len) in zip(selection_normalized, chunk_shape, strict=True) - ) - # broadcast selection - this will raise error if array dimensions don't match selection_broadcast = tuple(np.broadcast_arrays(*selection_normalized)) - chunks_multi_index_broadcast = np.broadcast_arrays(*chunks_multi_index) # remember shape of selection, because we will flatten indices for processing sel_shape = selection_broadcast[0].shape or (1,) # flatten selection selection_broadcast = tuple(dim_sel.reshape(-1) for dim_sel in selection_broadcast) + + # compute chunk index for each point in the selection using chunk grid chunks_multi_index_broadcast = tuple( - dim_chunks.reshape(-1) for dim_chunks in chunks_multi_index_broadcast + chunk_grid.array_indices_to_chunk_dim(dim, selection_broadcast[dim]) + for dim in range(len(shape)) ) # ravel chunk indices @@ -1229,7 +1322,7 @@ def __init__( else: sel_sort = None - shape = selection_broadcast[0].shape or (1,) + output_shape = selection_broadcast[0].shape or (1,) # precompute number of selected items for each chunk chunk_nitems = np.bincount(chunks_raveled_indices, minlength=nchunks) @@ -1246,8 +1339,9 @@ def __init__( object.__setattr__(self, "chunk_nitems_cumsum", chunk_nitems_cumsum) object.__setattr__(self, "chunk_rixs", chunk_rixs) object.__setattr__(self, "chunk_mixs", chunk_mixs) - object.__setattr__(self, "chunk_shape", chunk_shape) - object.__setattr__(self, "shape", shape) + object.__setattr__(self, "chunk_grid", chunk_grid) + object.__setattr__(self, "array_shape", shape) + object.__setattr__(self, "shape", output_shape) object.__setattr__(self, "drop_axes", ()) def __iter__(self) -> Iterator[ChunkProjection]: @@ -1265,13 +1359,11 @@ def __iter__(self) -> Iterator[ChunkProjection]: else: out_selection = self.sel_sort[start:stop] - chunk_offsets = tuple( - dim_chunk_ix * dim_chunk_len - for dim_chunk_ix, dim_chunk_len in zip(chunk_coords, self.chunk_shape, strict=True) - ) + # Use chunk grid to get chunk offsets (start positions) + chunk_start = self.chunk_grid.get_chunk_start(chunk_coords) chunk_selection = tuple( - dim_sel[start:stop] - dim_chunk_offset - for (dim_sel, dim_chunk_offset) in zip(self.selection, chunk_offsets, strict=True) + dim_sel[start:stop] - chunk_offset + for (dim_sel, chunk_offset) in zip(self.selection, chunk_start, strict=True) ) is_complete_chunk = False # TODO diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index 3204543426..0e2d5c6e04 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -119,7 +119,7 @@ def ndim(self) -> int: @cached_property def chunk_grid(self) -> RegularChunkGrid: - return RegularChunkGrid(chunk_shape=self.chunks) + return RegularChunkGrid(chunk_shape=self.chunks, array_shape=self.shape) @property def shards(self) -> tuple[int, ...] | None: diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 5ce155bd9a..81aa4388a8 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -24,7 +24,7 @@ from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec from zarr.core.array_spec import ArrayConfig, ArraySpec -from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid +from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid, parse_chunk_grid from zarr.core.chunk_key_encodings import ( ChunkKeyEncoding, ChunkKeyEncodingLike, @@ -229,7 +229,7 @@ def __init__( """ shape_parsed = parse_shapelike(shape) - chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid) + chunk_grid_parsed = parse_chunk_grid(chunk_grid, array_shape=shape_parsed) chunk_key_encoding_parsed = parse_chunk_key_encoding(chunk_key_encoding) dimension_names_parsed = parse_dimension_names(dimension_names) # Note: relying on a type method is numpy-specific @@ -436,7 +436,7 @@ def to_dict(self) -> dict[str, JSON]: return out_dict def update_shape(self, shape: tuple[int, ...]) -> Self: - return replace(self, shape=shape) + return replace(self, shape=shape, chunk_grid=self.chunk_grid.update_shape(shape)) def update_attributes(self, attributes: dict[str, JSON]) -> Self: return replace(self, attributes=attributes) diff --git a/src/zarr/registry.py b/src/zarr/registry.py index d0850a1387..07a8b5c3dc 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -22,17 +22,20 @@ ) from zarr.abc.numcodec import Numcodec from zarr.core.buffer import Buffer, NDBuffer + from zarr.core.chunk_grids.common import ChunkGrid from zarr.core.chunk_key_encodings import ChunkKeyEncoding from zarr.core.common import JSON __all__ = [ "Registry", "get_buffer_class", + "get_chunk_grid_class", "get_chunk_key_encoding_class", "get_codec_class", "get_ndbuffer_class", "get_pipeline_class", "register_buffer", + "register_chunk_grid", "register_chunk_key_encoding", "register_codec", "register_ndbuffer", @@ -63,6 +66,7 @@ def register(self, cls: type[T], qualname: str | None = None) -> None: __pipeline_registry: Registry[CodecPipeline] = Registry() __buffer_registry: Registry[Buffer] = Registry() __ndbuffer_registry: Registry[NDBuffer] = Registry() +__chunk_grid_registry: Registry[ChunkGrid] = Registry() __chunk_key_encoding_registry: Registry[ChunkKeyEncoding] = Registry() """ @@ -103,6 +107,11 @@ def _collect_entrypoints() -> list[Registry[Any]]: data_type_registry._lazy_load_list.extend(entry_points.select(group="zarr.data_type")) data_type_registry._lazy_load_list.extend(entry_points.select(group="zarr", name="data_type")) + __chunk_grid_registry.lazy_load_list.extend(entry_points.select(group="zarr.chunk_grid")) + __chunk_grid_registry.lazy_load_list.extend( + entry_points.select(group="zarr", name="chunk_grid") + ) + __chunk_key_encoding_registry.lazy_load_list.extend( entry_points.select(group="zarr.chunk_key_encoding") ) @@ -125,6 +134,7 @@ def _collect_entrypoints() -> list[Registry[Any]]: __pipeline_registry, __buffer_registry, __ndbuffer_registry, + __chunk_grid_registry, __chunk_key_encoding_registry, ] @@ -156,6 +166,20 @@ def register_buffer(cls: type[Buffer], qualname: str | None = None) -> None: __buffer_registry.register(cls, qualname) +def register_chunk_grid(key: str, cls: type[ChunkGrid]) -> None: + __chunk_grid_registry.register(cls, key) + + +def get_chunk_grid_class(key: str) -> type[ChunkGrid]: + __chunk_grid_registry.lazy_load(use_entrypoint_name=True) + if key not in __chunk_grid_registry: + raise KeyError( + f"Chunk grid '{key}' not found in registered chunk grids: " + f"{list(__chunk_grid_registry)}." + ) + return __chunk_grid_registry[key] + + def register_chunk_key_encoding(key: str, cls: type) -> None: __chunk_key_encoding_registry.register(cls, key) diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index 330f220b56..0cd565398b 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -160,7 +160,7 @@ def array_metadata( return ArrayV3Metadata( shape=shape, data_type=dtype, - chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), + chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape, array_shape=shape), fill_value=fill_value, attributes=draw(attributes), # type: ignore[arg-type] dimension_names=draw(dimension_names(ndim=ndim)), diff --git a/tests/conftest.py b/tests/conftest.py index 23a1e87d0a..6b21f6cec8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -379,7 +379,9 @@ def create_array_metadata( sharding_codec.validate( shape=chunk_shape_parsed, dtype=dtype_parsed, - chunk_grid=RegularChunkGrid(chunk_shape=shard_shape_parsed), + chunk_grid=RegularChunkGrid( + chunk_shape=shard_shape_parsed, array_shape=chunk_shape_parsed + ), ) codecs_out = (sharding_codec,) chunks_out = shard_shape_parsed @@ -390,7 +392,7 @@ def create_array_metadata( return ArrayV3Metadata( shape=shape_parsed, data_type=dtype_parsed, - chunk_grid=RegularChunkGrid(chunk_shape=chunks_out), + chunk_grid=RegularChunkGrid(chunk_shape=chunks_out, array_shape=shape_parsed), chunk_key_encoding=chunk_key_encoding_parsed, fill_value=fill_value, codecs=codecs_out, diff --git a/tests/test_cli/test_migrate_v3.py b/tests/test_cli/test_migrate_v3.py index 8bda31d208..563595c6d7 100644 --- a/tests/test_cli/test_migrate_v3.py +++ b/tests/test_cli/test_migrate_v3.py @@ -63,7 +63,7 @@ def test_migrate_array(local_store: LocalStore) -> None: expected_metadata = ArrayV3Metadata( shape=shape, data_type=UInt16(endianness="little"), - chunk_grid=RegularChunkGrid(chunk_shape=chunks), + chunk_grid=RegularChunkGrid(chunk_shape=chunks, array_shape=shape), chunk_key_encoding=V2ChunkKeyEncoding(separator="."), fill_value=fill_value, codecs=( diff --git a/tests/test_indexing.py b/tests/test_indexing.py index c0bf7dd270..11c0a49e7f 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -1219,8 +1219,8 @@ def test_get_block_selection_1d(store: StorePath) -> None: _test_get_block_selection(a, z, selection, expected_idx) bad_selections = block_selections_1d_bad + [ - z.metadata.chunk_grid.get_nchunks(z.shape) + 1, # out of bounds - -(z.metadata.chunk_grid.get_nchunks(z.shape) + 1), # out of bounds + z.metadata.chunk_grid.get_nchunks() + 1, # out of bounds + -(z.metadata.chunk_grid.get_nchunks() + 1), # out of bounds ] for selection_bad in bad_selections: