diff --git a/changes/3715.misc.md b/changes/3715.misc.md new file mode 100644 index 0000000000..caf06d1c54 --- /dev/null +++ b/changes/3715.misc.md @@ -0,0 +1,11 @@ +Added several performance optimizations to chunk encoding and decoding. Low-latency stores that do not benefit from +`async` operations can now implement synchronous IO methods which will be used when available during chunk processing. +Similarly, codecs can implement a synchronous API which will be used if available during chunk processing. +These changes remove unnecessary interactions with the event loop. + +The synchronous chunk processing path optionally uses a thread pool to parallelize codec work across chunks. +The pool is skipped for single-chunk operations and for pipelines that only contain cheap codecs (e.g. endian +swap, transpose, checksum). + +Use of the thread pool can be disabled in the global configuration. The minimum number of threads +and the maximum number of threads can be set via the configuration as well. diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index d41c457b4e..492fbbc27b 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -2,7 +2,8 @@ from abc import abstractmethod from collections.abc import Mapping -from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeGuard, TypeVar, runtime_checkable from typing_extensions import ReadOnly, TypedDict @@ -19,7 +20,7 @@ from zarr.core.array_spec import ArraySpec from zarr.core.chunk_grids import ChunkGrid from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType - from zarr.core.indexing import SelectorTuple + from zarr.core.indexing import ChunkProjection, SelectorTuple from zarr.core.metadata import ArrayMetadata __all__ = [ @@ -32,6 +33,8 @@ "CodecInput", "CodecOutput", "CodecPipeline", + "PreparedWrite", + "SupportsSyncCodec", ] CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer) @@ -59,6 +62,19 @@ def _check_codecjson_v2(data: object) -> TypeGuard[CodecJSON_V2[str]]: """The widest type of JSON-like input that could specify a codec.""" +@runtime_checkable +class SupportsSyncCodec(Protocol): + """Protocol for codecs that support synchronous encode/decode.""" + + def _decode_sync( + self, chunk_data: NDBuffer | Buffer, chunk_spec: ArraySpec + ) -> NDBuffer | Buffer: ... + + def _encode_sync( + self, chunk_data: NDBuffer | Buffer, chunk_spec: ArraySpec + ) -> NDBuffer | Buffer | None: ... + + class BaseCodec(Metadata, Generic[CodecInput, CodecOutput]): """Generic base class for codecs. @@ -186,9 +202,188 @@ class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer]): """Base class for array-to-array codecs.""" +def _is_complete_selection(selection: Any, shape: tuple[int, ...]) -> bool: + """Check whether a chunk selection covers the entire chunk shape.""" + if not isinstance(selection, tuple): + selection = (selection,) + for sel, dim_len in zip(selection, shape, strict=False): + if isinstance(sel, int): + if dim_len != 1: + return False + elif isinstance(sel, slice): + start, stop, step = sel.indices(dim_len) + if not (start == 0 and stop == dim_len and step == 1): + return False + else: + return False + return True + + +@dataclass +class PreparedWrite: + """Result of prepare_write: existing encoded chunk bytes + selection info.""" + + chunk_dict: dict[tuple[int, ...], Buffer | None] + inner_codec_chain: Any # CodecChain + inner_chunk_spec: ArraySpec + indexer: list[ChunkProjection] + value_selection: SelectorTuple | None = None + # If not None, slice value with this before using inner out_selections. + # For sharding: the outer out_selection from batch_info. + # For non-sharded: None (inner out_selection IS the outer out_selection). + write_full_shard: bool = True + # True when the entire shard blob will be written from scratch (either + # because the shard doesn't exist yet or because the selection is complete). + # Used by ShardingCodec.finalize_write to decide between set vs set_range. + is_complete_shard: bool = False + # True when the outer selection covers the entire shard. When True, + # the indexer is empty and finalize_write receives the shard value + # via shard_data. The codec then encodes the full shard in one shot + # rather than iterating over individual inner chunks. + shard_data: NDBuffer | None = None + # The full shard value for complete-selection writes. Set by the pipeline + # when is_complete_shard is True, before calling finalize_write. + + class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]): """Base class for array-to-bytes codecs.""" + @property + def inner_codec_chain(self) -> Any: + """The codec chain for decoding inner chunks after deserialization. + + Returns None by default — the pipeline should use its own codec_chain. + ShardingCodec overrides to return its inner codec chain. + """ + return None + + def deserialize( + self, raw: Buffer | None, chunk_spec: ArraySpec + ) -> dict[tuple[int, ...], Buffer | None]: + """Pure compute: unpack stored bytes into per-inner-chunk buffers. + + Default implementation: single chunk at (0,). + ShardingCodec overrides to decode shard index and slice blob into per-chunk buffers. + """ + return {(0,): raw} + + def serialize( + self, chunk_dict: dict[tuple[int, ...], Buffer | None], chunk_spec: ArraySpec + ) -> Buffer | None: + """Pure compute: pack per-inner-chunk buffers into a storage blob. + + Default implementation: return the single chunk's bytes (or None if absent). + ShardingCodec overrides to concatenate chunks + build index. + Returns None if all chunks are empty (caller should delete the key). + """ + return chunk_dict.get((0,)) + + def prepare_read_sync( + self, + byte_getter: Any, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + codec_chain: Any, + aa_chain: Any, + ab_pair: Any, + bb_chain: Any, + ) -> NDBuffer | None: + """IO + full decode for the selected region. Returns decoded sub-array.""" + raw = byte_getter.get_sync(prototype=chunk_spec.prototype) + chunk_array: NDBuffer | None = codec_chain.decode_chunk( + raw, chunk_spec, aa_chain, ab_pair, bb_chain + ) + if chunk_array is not None: + return chunk_array[chunk_selection] + return None + + def prepare_write_sync( + self, + byte_setter: Any, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + out_selection: SelectorTuple, + codec_chain: Any, + ) -> PreparedWrite: + """IO + deserialize. Returns PreparedWrite for the pipeline to decode/merge/encode.""" + is_complete = _is_complete_selection(chunk_selection, chunk_spec.shape) + existing: Buffer | None = None + if not is_complete: + existing = byte_setter.get_sync(prototype=chunk_spec.prototype) + chunk_dict = self.deserialize(existing, chunk_spec) + inner_chain = self.inner_codec_chain or codec_chain + return PreparedWrite( + chunk_dict=chunk_dict, + inner_codec_chain=inner_chain, + inner_chunk_spec=chunk_spec, + indexer=[((0,), chunk_selection, out_selection, is_complete)], # type: ignore[list-item] + ) + + async def prepare_read( + self, + byte_getter: Any, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + codec_chain: Any, + aa_chain: Any, + ab_pair: Any, + bb_chain: Any, + ) -> NDBuffer | None: + """Async IO + full decode for the selected region. Returns decoded sub-array.""" + raw = await byte_getter.get(prototype=chunk_spec.prototype) + chunk_array: NDBuffer | None = codec_chain.decode_chunk( + raw, chunk_spec, aa_chain, ab_pair, bb_chain + ) + if chunk_array is not None: + return chunk_array[chunk_selection] + return None + + async def prepare_write( + self, + byte_setter: Any, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + out_selection: SelectorTuple, + codec_chain: Any, + ) -> PreparedWrite: + """Async IO + deserialize. Returns PreparedWrite for the pipeline to decode/merge/encode.""" + is_complete = _is_complete_selection(chunk_selection, chunk_spec.shape) + existing: Buffer | None = None + if not is_complete: + existing = await byte_setter.get(prototype=chunk_spec.prototype) + chunk_dict = self.deserialize(existing, chunk_spec) + inner_chain = self.inner_codec_chain or codec_chain + return PreparedWrite( + chunk_dict=chunk_dict, + inner_codec_chain=inner_chain, + inner_chunk_spec=chunk_spec, + indexer=[((0,), chunk_selection, out_selection, is_complete)], # type: ignore[list-item] + ) + + def finalize_write_sync( + self, prepared: PreparedWrite, chunk_spec: ArraySpec, byte_setter: Any + ) -> None: + """Serialize prepared chunk_dict and write to store. + + Default: serialize to a single blob and call set (or delete if all empty). + ShardingCodec overrides this for byte-range writes when inner codecs are fixed-size. + """ + blob = self.serialize(prepared.chunk_dict, chunk_spec) + if blob is None: + byte_setter.delete_sync() + else: + byte_setter.set_sync(blob) + + async def finalize_write( + self, prepared: PreparedWrite, chunk_spec: ArraySpec, byte_setter: Any + ) -> None: + """Async version of finalize_write_sync.""" + blob = self.serialize(prepared.chunk_dict, chunk_spec) + if blob is None: + await byte_setter.delete() + else: + await byte_setter.set(blob) + class BytesBytesCodec(BaseCodec[Buffer, Buffer]): """Base class for bytes-to-bytes codecs.""" @@ -459,6 +654,59 @@ async def write( """ ... + # ------------------------------------------------------------------- + # Fully synchronous read/write (opt-in) + # + # When a CodecPipeline subclass can run the entire read/write path + # (store IO + codec compute + buffer scatter) without touching the + # event loop, it overrides these methods and sets supports_sync_io + # to True. This lets Array selection methods bypass sync() entirely. + # + # The default implementations raise NotImplementedError. + # BatchedCodecPipeline overrides these when all codecs support sync. + # ------------------------------------------------------------------- + + @property + def supports_sync_io(self) -> bool: + """Whether this pipeline can run read/write entirely on the calling thread. + + True when: + - All codecs implement ``SupportsSyncCodec`` + - The pipeline's read_sync/write_sync methods are implemented + + Checked by ``Array._can_use_sync_path()`` to decide whether to bypass + the ``sync()`` event-loop bridge. + """ + return False + + def read_sync( + self, + batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + out: NDBuffer, + drop_axes: tuple[int, ...] = (), + ) -> None: + """Synchronous read: fetch bytes from store, decode, scatter into out. + + Runs entirely on the calling thread. Only available when + ``supports_sync_io`` is True. Called by ``_get_selection_sync`` in + ``array.py`` when the sync bypass is active. + """ + raise NotImplementedError + + def write_sync( + self, + batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + value: NDBuffer, + drop_axes: tuple[int, ...] = (), + ) -> None: + """Synchronous write: gather from value, encode, persist to store. + + Runs entirely on the calling thread. Only available when + ``supports_sync_io`` is True. Called by ``_set_selection_sync`` in + ``array.py`` when the sync bypass is active. + """ + raise NotImplementedError + async def _batching_helper( func: Callable[[CodecInput, ArraySpec], Awaitable[CodecOutput | None]], diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 87df89a683..575cd561d8 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -16,7 +16,12 @@ from zarr.core.buffer import Buffer, BufferPrototype -__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"] +__all__ = [ + "ByteGetter", + "ByteSetter", + "Store", + "set_or_delete", +] @dataclass @@ -466,6 +471,21 @@ async def set(self, key: str, value: Buffer) -> None: """ ... + async def set_range(self, key: str, value: Buffer, start: int) -> None: + """Write ``value`` into an existing key beginning at byte offset ``start``. + + The key must already exist and ``start + len(value)`` must not exceed + the current size of the stored value. + + Parameters + ---------- + key : str + value : Buffer + start : int + Byte offset at which to begin writing. + """ + raise NotImplementedError(f"{type(self).__name__} does not support set_range") + async def set_if_not_exists(self, key: str, value: Buffer) -> None: """ Store a key to ``value`` if the key is not already present. @@ -695,6 +715,8 @@ async def get( async def set(self, value: Buffer) -> None: ... + async def set_range(self, value: Buffer, start: int) -> None: ... + async def delete(self) -> None: ... async def set_if_not_exists(self, default: Buffer) -> None: ... diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index 6164cda957..6ad92025ac 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -386,7 +386,9 @@ async def open( is_v3_array = zarr_format == 3 and _metadata_dict.get("node_type") == "array" if is_v3_array or zarr_format == 2: return AsyncArray( - store_path=store_path, metadata=_metadata_dict, config=kwargs.get("config") + store_path=store_path, + metadata=_metadata_dict, + config=kwargs.get("config"), ) except (AssertionError, FileNotFoundError, NodeTypeValidationError): pass @@ -1279,7 +1281,10 @@ async def open_array( _warn_write_empty_chunks_kwarg() try: - return await AsyncArray.open(store_path, zarr_format=zarr_format) + return await AsyncArray.open( + store_path, + zarr_format=zarr_format, + ) except FileNotFoundError as err: if not store_path.read_only and mode in _CREATE_MODES: overwrite = _infer_overwrite(mode) diff --git a/src/zarr/codecs/blosc.py b/src/zarr/codecs/blosc.py index 5b91cfa005..fd1e3d449b 100644 --- a/src/zarr/codecs/blosc.py +++ b/src/zarr/codecs/blosc.py @@ -299,28 +299,29 @@ def _blosc_codec(self) -> Blosc: config_dict["typesize"] = self.typesize return Blosc.from_config(config_dict) + def _decode_sync(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Buffer: + return as_numpy_array_wrapper(self._blosc_codec.decode, chunk_bytes, chunk_spec.prototype) + + def _encode_sync(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Buffer | None: + # Since blosc only support host memory, we convert the input and output of the encoding + # between numpy array and buffer + return chunk_spec.prototype.buffer.from_bytes( + self._blosc_codec.encode(chunk_bytes.as_numpy_array()) + ) + async def _decode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer: - return await asyncio.to_thread( - as_numpy_array_wrapper, self._blosc_codec.decode, chunk_bytes, chunk_spec.prototype - ) + return await asyncio.to_thread(self._decode_sync, chunk_bytes, chunk_spec) async def _encode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer | None: - # Since blosc only support host memory, we convert the input and output of the encoding - # between numpy array and buffer - return await asyncio.to_thread( - lambda chunk: chunk_spec.prototype.buffer.from_bytes( - self._blosc_codec.encode(chunk.as_numpy_array()) - ), - chunk_bytes, - ) + return await asyncio.to_thread(self._encode_sync, chunk_bytes, chunk_spec) def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int: raise NotImplementedError diff --git a/src/zarr/codecs/bytes.py b/src/zarr/codecs/bytes.py index 1fbdeef497..3d62eac2bb 100644 --- a/src/zarr/codecs/bytes.py +++ b/src/zarr/codecs/bytes.py @@ -65,7 +65,7 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: ) return self - async def _decode_single( + def _decode_sync( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, @@ -88,7 +88,7 @@ async def _decode_single( ) return chunk_array - async def _encode_single( + def _encode_sync( self, chunk_array: NDBuffer, chunk_spec: ArraySpec, @@ -109,5 +109,19 @@ async def _encode_single( nd_array = nd_array.ravel().view(dtype="B") return chunk_spec.prototype.buffer.from_array_like(nd_array) + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + return self._decode_sync(chunk_bytes, chunk_spec) + + async def _encode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return self._encode_sync(chunk_array, chunk_spec) + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length diff --git a/src/zarr/codecs/crc32c_.py b/src/zarr/codecs/crc32c_.py index 9536d0d558..3cd3aef873 100644 --- a/src/zarr/codecs/crc32c_.py +++ b/src/zarr/codecs/crc32c_.py @@ -31,11 +31,7 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: def to_dict(self) -> dict[str, JSON]: return {"name": "crc32c"} - async def _decode_single( - self, - chunk_bytes: Buffer, - chunk_spec: ArraySpec, - ) -> Buffer: + def _decode_sync(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Buffer: data = chunk_bytes.as_numpy_array() crc32_bytes = data[-4:] inner_bytes = data[:-4] @@ -51,11 +47,7 @@ async def _decode_single( ) return chunk_spec.prototype.buffer.from_array_like(inner_bytes) - async def _encode_single( - self, - chunk_bytes: Buffer, - chunk_spec: ArraySpec, - ) -> Buffer | None: + def _encode_sync(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Buffer | None: data = chunk_bytes.as_numpy_array() # Calculate the checksum and "cast" it to a numpy array checksum = np.array( @@ -64,5 +56,19 @@ async def _encode_single( # Append the checksum (as bytes) to the data return chunk_spec.prototype.buffer.from_array_like(np.append(data, checksum.view("B"))) + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer: + return self._decode_sync(chunk_bytes, chunk_spec) + + async def _encode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return self._encode_sync(chunk_bytes, chunk_spec) + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length + 4 diff --git a/src/zarr/codecs/gzip.py b/src/zarr/codecs/gzip.py index 610ca9dadd..a883b0d640 100644 --- a/src/zarr/codecs/gzip.py +++ b/src/zarr/codecs/gzip.py @@ -2,6 +2,7 @@ import asyncio from dataclasses import dataclass +from functools import cached_property from typing import TYPE_CHECKING from numcodecs.gzip import GZip @@ -48,23 +49,37 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: def to_dict(self) -> dict[str, JSON]: return {"name": "gzip", "configuration": {"level": self.level}} + # Cache the numcodecs GZip instance. GzipCodec is a frozen dataclass, + # so `level` never changes after construction, making this safe. + # This matches the pattern used by ZstdCodec._zstd_codec and + # BloscCodec._blosc_codec. Without caching, a new GZip(level) was + # created on every encode/decode call. + @cached_property + def _gzip_codec(self) -> GZip: + return GZip(self.level) + + def _decode_sync(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Buffer: + # Use the cached codec instance instead of creating GZip(self.level) + # each time. The async _decode_single delegates to this method via + # asyncio.to_thread, so both paths benefit from the cache. + return as_numpy_array_wrapper(self._gzip_codec.decode, chunk_bytes, chunk_spec.prototype) + + def _encode_sync(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Buffer | None: + return as_numpy_array_wrapper(self._gzip_codec.encode, chunk_bytes, chunk_spec.prototype) + async def _decode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer: - return await asyncio.to_thread( - as_numpy_array_wrapper, GZip(self.level).decode, chunk_bytes, chunk_spec.prototype - ) + return await asyncio.to_thread(self._decode_sync, chunk_bytes, chunk_spec) async def _encode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer | None: - return await asyncio.to_thread( - as_numpy_array_wrapper, GZip(self.level).encode, chunk_bytes, chunk_spec.prototype - ) + return await asyncio.to_thread(self._encode_sync, chunk_bytes, chunk_spec) def compute_encoded_size( self, diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 85162c2f74..a2b1226a5b 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -1,7 +1,7 @@ from __future__ import annotations -from collections.abc import Iterable, Mapping, MutableMapping -from dataclasses import dataclass, replace +from collections.abc import Iterable, Mapping +from dataclasses import dataclass, field, replace from enum import Enum from functools import lru_cache from operator import itemgetter @@ -13,14 +13,10 @@ from zarr.abc.codec import ( ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, - ArrayBytesCodecPartialEncodeMixin, Codec, - CodecPipeline, ) from zarr.abc.store import ( ByteGetter, - ByteRequest, - ByteSetter, RangeByteRequest, SuffixByteRequest, ) @@ -42,30 +38,31 @@ parse_shapelike, product, ) +from zarr.core.dtype.common import HasItemSize from zarr.core.dtype.npy.int import UInt64 from zarr.core.indexing import ( BasicIndexer, + ChunkProjection, SelectorTuple, - _morton_order, _morton_order_keys, c_order_iter, get_indexer, morton_order_iter, ) from zarr.core.metadata.v3 import parse_codecs -from zarr.registry import get_ndbuffer_class, get_pipeline_class -from zarr.storage._utils import _normalize_byte_range_index +from zarr.registry import get_ndbuffer_class if TYPE_CHECKING: from collections.abc import Iterator from typing import Self + from zarr.abc.codec import PreparedWrite + from zarr.core.codec_pipeline import CodecChain from zarr.core.common import JSON from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType MAX_UINT_64 = 2**64 - 1 ShardMapping = Mapping[tuple[int, ...], Buffer | None] -ShardMutableMapping = MutableMapping[tuple[int, ...], Buffer | None] class ShardingCodecIndexLocation(Enum): @@ -81,41 +78,6 @@ def parse_index_location(data: object) -> ShardingCodecIndexLocation: return parse_enum(data, ShardingCodecIndexLocation) -@dataclass(frozen=True) -class _ShardingByteGetter(ByteGetter): - shard_dict: ShardMapping - chunk_coords: tuple[int, ...] - - async def get( - self, prototype: BufferPrototype, byte_range: ByteRequest | None = None - ) -> Buffer | None: - assert prototype == default_buffer_prototype(), ( - f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}" - ) - value = self.shard_dict.get(self.chunk_coords) - if value is None: - return None - if byte_range is None: - return value - start, stop = _normalize_byte_range_index(value, byte_range) - return value[start:stop] - - -@dataclass(frozen=True) -class _ShardingByteSetter(_ShardingByteGetter, ByteSetter): - shard_dict: ShardMutableMapping - - async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None: - assert byte_range is None, "byte_range is not supported within shards" - self.shard_dict[self.chunk_coords] = value - - async def delete(self) -> None: - del self.shard_dict[self.chunk_coords] - - async def set_if_not_exists(self, default: Buffer) -> None: - self.shard_dict.setdefault(self.chunk_coords, default) - - class _ShardIndex(NamedTuple): # dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2) offsets_and_lengths: npt.NDArray[np.uint64] @@ -228,9 +190,10 @@ class _ShardReader(ShardMapping): index: _ShardIndex @classmethod - async def from_bytes( + def from_bytes_sync( cls, buf: Buffer, codec: ShardingCodec, chunks_per_shard: tuple[int, ...] ) -> _ShardReader: + """Synchronous version of from_bytes — decodes the shard index inline.""" shard_index_size = codec._shard_index_size(chunks_per_shard) obj = cls() obj.buf = buf @@ -239,7 +202,7 @@ async def from_bytes( else: shard_index_bytes = obj.buf[-shard_index_size:] - obj.index = await codec._decode_shard_index(shard_index_bytes, chunks_per_shard) + obj.index = codec._decode_shard_index_sync(shard_index_bytes, chunks_per_shard) return obj @classmethod @@ -296,15 +259,15 @@ def to_dict_vectorized( @dataclass(frozen=True) -class ShardingCodec( - ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin -): +class ShardingCodec(ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin): """Sharding codec""" chunk_shape: tuple[int, ...] codecs: tuple[Codec, ...] index_codecs: tuple[Codec, ...] index_location: ShardingCodecIndexLocation = ShardingCodecIndexLocation.end + _codec_chain: CodecChain = field(init=False, repr=False, compare=False) + _index_codec_chain: CodecChain = field(init=False, repr=False, compare=False) def __init__( self, @@ -333,6 +296,12 @@ def __init__( object.__setattr__(self, "_get_index_chunk_spec", lru_cache()(self._get_index_chunk_spec)) object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard)) + # Cached CodecChain instances — computed once, used for all sync paths. + from zarr.core.codec_pipeline import CodecChain + + object.__setattr__(self, "_codec_chain", CodecChain.from_codecs(codecs_parsed)) + object.__setattr__(self, "_index_codec_chain", CodecChain.from_codecs(index_codecs_parsed)) + # todo: typedict return type def __getstate__(self) -> dict[str, Any]: return self.to_dict() @@ -349,15 +318,22 @@ def __setstate__(self, state: dict[str, Any]) -> None: object.__setattr__(self, "_get_index_chunk_spec", lru_cache()(self._get_index_chunk_spec)) object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard)) + from zarr.core.codec_pipeline import CodecChain + + object.__setattr__( + self, "_codec_chain", CodecChain.from_codecs(parse_codecs(config["codecs"])) + ) + object.__setattr__( + self, + "_index_codec_chain", + CodecChain.from_codecs(parse_codecs(config["index_codecs"])), + ) + @classmethod def from_dict(cls, data: dict[str, JSON]) -> Self: _, configuration_parsed = parse_named_configuration(data, "sharding_indexed") return cls(**configuration_parsed) # type: ignore[arg-type] - @property - def codec_pipeline(self) -> CodecPipeline: - return get_pipeline_class().from_codecs(self.codecs) - def to_dict(self) -> dict[str, JSON]: return { "name": "sharding_indexed", @@ -407,11 +383,35 @@ async def _decode_single( shard_bytes: Buffer, shard_spec: ArraySpec, ) -> NDBuffer: + # _decode_single is pure compute (no IO), same as _decode_sync. + return self._decode_sync(shard_bytes, shard_spec) + + def _decode_sync( + self, + shard_bytes: Buffer, + shard_spec: ArraySpec, + ) -> NDBuffer: + """Synchronous full-shard decode. + + Uses deserialize() to unpack stored bytes into per-inner-chunk buffers, + then decodes each inner chunk via CodecChain.decode_chunk (pure compute). + """ shard_shape = shard_spec.shape chunk_shape = self.chunk_shape - chunks_per_shard = self._get_chunks_per_shard(shard_spec) chunk_spec = self._get_chunk_spec(shard_spec) + chunk_dict = self.deserialize(shard_bytes, shard_spec) + + # Check if all chunks are empty + if all(v is None for v in chunk_dict.values()): + out = chunk_spec.prototype.nd_buffer.empty( + shape=shard_shape, + dtype=shard_spec.dtype.to_native_dtype(), + order=shard_spec.order, + ) + out.fill(shard_spec.fill_value) + return out + indexer = BasicIndexer( tuple(slice(0, s) for s in shard_shape), shape=shard_shape, @@ -424,35 +424,414 @@ async def _decode_single( dtype=shard_spec.dtype.to_native_dtype(), order=shard_spec.order, ) - shard_dict = await _ShardReader.from_bytes(shard_bytes, self, chunks_per_shard) - if shard_dict.index.is_all_empty(): - out.fill(shard_spec.fill_value) - return out + # Pre-resolve metadata chain once for all inner chunks. + codec_chain = self._codec_chain + aa_chain, ab_pair, bb_chain = codec_chain.resolve_metadata_chain(chunk_spec) - # decoding chunks and writing them into the output buffer - await self.codec_pipeline.read( - [ - ( - _ShardingByteGetter(shard_dict, chunk_coords), + # Decode each inner chunk directly via CodecChain (pure compute). + from zarr.core.codec_pipeline import fill_value_or_default + + fill_value = fill_value_or_default(shard_spec) + for chunk_coords, chunk_selection, out_selection, _is_complete in indexer: + chunk_bytes = chunk_dict.get(chunk_coords) + if chunk_bytes is not None: + chunk_array = codec_chain.decode_chunk( + chunk_bytes, chunk_spec, aa_chain, ab_pair, bb_chain + ) + if chunk_array is not None: + out[out_selection] = chunk_array[chunk_selection] + else: + out[out_selection] = fill_value + else: + out[out_selection] = fill_value + + return out + + def _encode_sync( + self, + shard_array: NDBuffer, + shard_spec: ArraySpec, + ) -> Buffer | None: + """Synchronous full-shard encode. + + Encodes each inner chunk via CodecChain.encode_chunk (pure compute), + then assembles the shard via serialize(). + """ + shard_shape = shard_spec.shape + chunk_shape = self.chunk_shape + chunks_per_shard = self._get_chunks_per_shard(shard_spec) + chunk_spec = self._get_chunk_spec(shard_spec) + skip_empty = not chunk_spec.config.write_empty_chunks + + if skip_empty: + from zarr.core.codec_pipeline import fill_value_or_default + + fill_value = fill_value_or_default(chunk_spec) + + # Quick check: if the entire shard equals fill value, return None. + if shard_array.all_equal(fill_value): + return None + + # Fast path: vectorized encoding for fixed-size inner codecs (no compression). + # Reorders chunks from C-order to morton order using numpy operations, + # avoiding per-chunk Python function calls entirely. + if self._inner_codecs_fixed_size: + result = self._encode_vectorized( + shard_array, shard_spec, chunks_per_shard, chunk_shape, chunk_spec, skip_empty + ) + if result is not None: + return result + # result is None means either: + # 1. All chunks are fill-value (skip_empty=True) → return None + # 2. Vectorized path not applicable → fall through to per-chunk loop + if skip_empty: + return None + + # Slow path: per-chunk encode loop. + indexer = list( + BasicIndexer( + tuple(slice(0, s) for s in shard_shape), + shape=shard_shape, + chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), + ) + ) + + shard_builder: dict[tuple[int, ...], Buffer | None] = dict.fromkeys( + morton_order_iter(chunks_per_shard) + ) + + codec_chain = self._codec_chain + + if skip_empty: + from zarr.core.codec_pipeline import fill_value_or_default + + fill_value = fill_value_or_default(chunk_spec) + + # Vectorized per-chunk fill-value check. + shard_np = shard_array.as_numpy_array() + ndim = len(shard_shape) + reshaped_dims = [] + for s_dim, c_dim in zip(chunks_per_shard, chunk_shape, strict=False): + reshaped_dims.extend([s_dim, c_dim]) + try: + shard_reshaped = shard_np.reshape(reshaped_dims) + chunk_interior_axes = tuple(range(1, 2 * ndim, 2)) + if np.isnan(fill_value) if isinstance(fill_value, float) else False: + is_fill = np.all(np.isnan(shard_reshaped), axis=chunk_interior_axes) + else: + is_fill = np.all(shard_reshaped == fill_value, axis=chunk_interior_axes) + except (ValueError, AttributeError): + is_fill = None + + for chunk_coords, _chunk_selection, out_selection, _is_complete in indexer: + chunk_array = shard_array[out_selection] + if chunk_array is not None: + if is_fill is not None: + chunk_is_fill = bool(is_fill[chunk_coords]) + else: + chunk_is_fill = chunk_array.all_equal(fill_value) + if chunk_is_fill: + shard_builder[chunk_coords] = None + else: + shard_builder[chunk_coords] = codec_chain.encode_chunk( + chunk_array, chunk_spec + ) + else: + for chunk_coords, _chunk_selection, out_selection, _is_complete in indexer: + chunk_array = shard_array[out_selection] + if chunk_array is not None: + shard_builder[chunk_coords] = codec_chain.encode_chunk(chunk_array, chunk_spec) + + return self.serialize(shard_builder, shard_spec) + + def _encode_vectorized( + self, + shard_array: NDBuffer, + shard_spec: ArraySpec, + chunks_per_shard: tuple[int, ...], + chunk_shape: tuple[int, ...], + chunk_spec: ArraySpec, + skip_empty: bool, + ) -> Buffer | None: + """Vectorized shard encoding for fixed-size inner codecs. + + Reorders chunks from C-order to morton order using numpy operations, + building the entire shard blob without per-chunk Python function calls. + Returns None if all chunks are fill-value, or if the fast path is not + applicable (caller should fall through to per-chunk loop). + """ + from zarr.core.indexing import _morton_order + + shard_np = shard_array.as_numpy_array() + ndim = len(chunks_per_shard) + total_chunks = product(chunks_per_shard) + + # Handle endianness at the shard level (BytesCodec normally does this per-chunk) + ab_codec = self._codec_chain.array_bytes_codec + if ( + isinstance(ab_codec, BytesCodec) + and shard_np.dtype.itemsize > 1 + and ab_codec.endian is not None + and ab_codec.endian != shard_array.byteorder + ): + new_dtype = shard_np.dtype.newbyteorder(ab_codec.endian.name) # type: ignore[arg-type] + shard_np = shard_np.astype(new_dtype) + + # Reshape: (shard_shape) → (cps[0], cs[0], cps[1], cs[1], ...) + reshaped_dims: list[int] = [] + for cps, cs in zip(chunks_per_shard, chunk_shape, strict=False): + reshaped_dims.extend([cps, cs]) + + shard_reshaped = shard_np.reshape(reshaped_dims) + + if skip_empty: + from zarr.core.codec_pipeline import fill_value_or_default + + fill_value = fill_value_or_default(chunk_spec) + chunk_interior_axes = tuple(range(1, 2 * ndim, 2)) + if np.isnan(fill_value) if isinstance(fill_value, float) else False: + is_fill = np.all(np.isnan(shard_reshaped), axis=chunk_interior_axes) + else: + is_fill = np.all(shard_reshaped == fill_value, axis=chunk_interior_axes) + + if np.all(is_fill): + return None + if np.any(is_fill): + # Some chunks are fill-value, some are not. + # Fall through to per-chunk loop for this mixed case. + # Return a sentinel that the caller interprets as "not applicable". + # We use a special approach: return _MIXED_FILL sentinel + return self._encode_vectorized_sparse( + shard_np, + shard_spec, + chunks_per_shard, + chunk_shape, chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, + shard_reshaped, + is_fill, ) - for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer - ], - out, + + # Transpose to (cps[0], cps[1], ..., cs[0], cs[1], ...) + chunk_grid_axes = tuple(range(0, 2 * ndim, 2)) + chunk_data_axes = tuple(range(1, 2 * ndim, 2)) + transposed = shard_reshaped.transpose(chunk_grid_axes + chunk_data_axes) + + # Reshape to (total_chunks, elements_per_chunk), then reorder to morton + elements_per_chunk = product(chunk_shape) + chunks_2d = transposed.reshape(total_chunks, elements_per_chunk) + + # Reorder from C-order to morton order + morton_coords = _morton_order(chunks_per_shard) # (total_chunks, ndim) + c_order_linear = np.ravel_multi_index( + tuple(morton_coords[:, i] for i in range(ndim)), chunks_per_shard ) + reordered = chunks_2d[c_order_linear] - return out + # Flatten to bytes + chunk_data_bytes = reordered.ravel().view(np.uint8) - async def _decode_partial_single( + # Build deterministic shard index (all chunks present, each chunk_byte_length) + index = _ShardIndex.create_empty(chunks_per_shard) + encoded_chunk_byte_length = self._inner_chunk_byte_length(chunk_spec) + for rank in range(total_chunks): + offset = rank * encoded_chunk_byte_length + chunk_coords = tuple(int(x) for x in morton_coords[rank]) + index.set_chunk_slice(chunk_coords, slice(offset, offset + encoded_chunk_byte_length)) + + index_bytes = self._encode_shard_index_sync(index) + + if self.index_location == ShardingCodecIndexLocation.start: + # Shift non-empty offsets by index size + non_empty = index.offsets_and_lengths[..., 0] != MAX_UINT_64 + index.offsets_and_lengths[non_empty, 0] += len(index_bytes) + index_bytes = self._encode_shard_index_sync(index) + shard_bytes_np = np.concatenate( + [ + np.frombuffer(index_bytes.as_buffer_like(), dtype=np.uint8), + chunk_data_bytes, + ] + ) + else: + shard_bytes_np = np.concatenate( + [ + chunk_data_bytes, + np.frombuffer(index_bytes.as_buffer_like(), dtype=np.uint8), + ] + ) + + return default_buffer_prototype().buffer.from_array_like(shard_bytes_np) + + def _encode_vectorized_sparse( self, - byte_getter: ByteGetter, + shard_np: npt.NDArray[Any], + shard_spec: ArraySpec, + chunks_per_shard: tuple[int, ...], + chunk_shape: tuple[int, ...], + chunk_spec: ArraySpec, + shard_reshaped: npt.NDArray[Any], + is_fill: npt.NDArray[np.bool_], + ) -> Buffer | None: + """Vectorized encoding when some chunks are fill-value (sparse shard). + + Builds the shard blob with only non-fill chunks present. + """ + from zarr.core.indexing import _morton_order + + ndim = len(chunks_per_shard) + total_chunks = product(chunks_per_shard) + + # Transpose to (cps[0], cps[1], ..., cs[0], cs[1], ...) + chunk_grid_axes = tuple(range(0, 2 * ndim, 2)) + chunk_data_axes = tuple(range(1, 2 * ndim, 2)) + transposed = shard_reshaped.transpose(chunk_grid_axes + chunk_data_axes) + + elements_per_chunk = product(chunk_shape) + chunks_2d = transposed.reshape(total_chunks, elements_per_chunk) + + # Reorder from C-order to morton order + morton_coords = _morton_order(chunks_per_shard) # (total_chunks, ndim) + c_order_linear = np.ravel_multi_index( + tuple(morton_coords[:, i] for i in range(ndim)), chunks_per_shard + ) + reordered = chunks_2d[c_order_linear] + + # is_fill is in C-order grid shape, flatten to C-order linear + is_fill_morton = is_fill.ravel()[c_order_linear] + + # Build index and collect non-fill chunk data + index = _ShardIndex.create_empty(chunks_per_shard) + encoded_chunk_byte_length = self._inner_chunk_byte_length(chunk_spec) + + # Select only non-fill chunks + non_fill_mask = ~is_fill_morton + non_fill_data = reordered[non_fill_mask] + + if len(non_fill_data) == 0: + return None + + # Build index: set offsets for non-fill chunks using morton coordinates + offset = 0 + for rank in range(total_chunks): + if non_fill_mask[rank]: + chunk_coords = tuple(int(x) for x in morton_coords[rank]) + index.set_chunk_slice( + chunk_coords, slice(offset, offset + encoded_chunk_byte_length) + ) + offset += encoded_chunk_byte_length + + index_bytes = self._encode_shard_index_sync(index) + + chunk_data_bytes = non_fill_data.ravel().view(np.uint8) + + if self.index_location == ShardingCodecIndexLocation.start: + non_empty = index.offsets_and_lengths[..., 0] != MAX_UINT_64 + index.offsets_and_lengths[non_empty, 0] += len(index_bytes) + index_bytes = self._encode_shard_index_sync(index) + shard_bytes_np = np.concatenate( + [ + np.frombuffer(index_bytes.as_buffer_like(), dtype=np.uint8), + chunk_data_bytes, + ] + ) + else: + shard_bytes_np = np.concatenate( + [ + chunk_data_bytes, + np.frombuffer(index_bytes.as_buffer_like(), dtype=np.uint8), + ] + ) + + return default_buffer_prototype().buffer.from_array_like(shard_bytes_np) + + def _encode_shard_dict_sync( + self, + map: ShardMapping, + chunks_per_shard: tuple[int, ...], + buffer_prototype: BufferPrototype, + ) -> Buffer | None: + """Serialize encoded chunks into a shard blob with index.""" + index = _ShardIndex.create_empty(chunks_per_shard) + + buffers = [] + + template = buffer_prototype.buffer.create_zero_length() + chunk_start = 0 + for chunk_coords in morton_order_iter(chunks_per_shard): + value = map.get(chunk_coords) + if value is None: + continue + + if len(value) == 0: + continue + + chunk_length = len(value) + buffers.append(value) + index.set_chunk_slice(chunk_coords, slice(chunk_start, chunk_start + chunk_length)) + chunk_start += chunk_length + + if len(buffers) == 0: + return None + + index_bytes = self._encode_shard_index_sync(index) + if self.index_location == ShardingCodecIndexLocation.start: + empty_chunks_mask = index.offsets_and_lengths[..., 0] == MAX_UINT_64 + index.offsets_and_lengths[~empty_chunks_mask, 0] += len(index_bytes) + index_bytes = self._encode_shard_index_sync( + index + ) # encode again with corrected offsets + buffers.insert(0, index_bytes) + else: + buffers.append(index_bytes) + + return template.combine(buffers) + + def _load_shard_index_maybe_sync( + self, byte_getter: Any, chunks_per_shard: tuple[int, ...] + ) -> _ShardIndex | None: + """Synchronous version of _load_shard_index_maybe. + + Reads the shard index bytes via byte_getter.get_sync (a sync byte-range + read from the store), then decodes the index inline. + """ + shard_index_size = self._shard_index_size(chunks_per_shard) + if self.index_location == ShardingCodecIndexLocation.start: + index_bytes = byte_getter.get_sync( + prototype=numpy_buffer_prototype(), + byte_range=RangeByteRequest(0, shard_index_size), + ) + else: + index_bytes = byte_getter.get_sync( + prototype=numpy_buffer_prototype(), + byte_range=SuffixByteRequest(shard_index_size), + ) + if index_bytes is not None: + return self._decode_shard_index_sync(index_bytes, chunks_per_shard) + return None + + def _load_full_shard_maybe_sync( + self, + byte_getter: Any, + prototype: BufferPrototype, + chunks_per_shard: tuple[int, ...], + ) -> _ShardReader | None: + """Synchronous version of _load_full_shard_maybe.""" + shard_bytes = byte_getter.get_sync(prototype=prototype) + return ( + _ShardReader.from_bytes_sync(shard_bytes, self, chunks_per_shard) + if shard_bytes + else None + ) + + def _decode_partial_sync( + self, + byte_getter: Any, selection: SelectorTuple, shard_spec: ArraySpec, ) -> NDBuffer | None: + """Synchronous partial decode: fetch shard index + requested chunks + via sync byte-range reads, then decode via CodecChain (pure compute). + """ shard_shape = shard_spec.shape chunk_shape = self.chunk_shape chunks_per_shard = self._get_chunks_per_shard(shard_spec) @@ -478,7 +857,7 @@ async def _decode_partial_single( shard_dict: ShardMapping = {} if self._is_total_shard(all_chunk_coords, chunks_per_shard): # read entire shard - shard_dict_maybe = await self._load_full_shard_maybe( + shard_dict_maybe = self._load_full_shard_maybe_sync( byte_getter=byte_getter, prototype=chunk_spec.prototype, chunks_per_shard=chunks_per_shard, @@ -488,171 +867,132 @@ async def _decode_partial_single( shard_dict = shard_dict_maybe else: # read some chunks within the shard - shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard) + shard_index = self._load_shard_index_maybe_sync(byte_getter, chunks_per_shard) if shard_index is None: return None shard_dict = {} for chunk_coords in all_chunk_coords: chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords) if chunk_byte_slice: - chunk_bytes = await byte_getter.get( + chunk_bytes = byte_getter.get_sync( prototype=chunk_spec.prototype, byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]), ) if chunk_bytes: shard_dict[chunk_coords] = chunk_bytes - # decoding chunks and writing them into the output buffer - await self.codec_pipeline.read( - [ - ( - _ShardingByteGetter(shard_dict, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, + # Decode chunks directly via CodecChain (pure compute, no inner pipeline). + codec_chain = self._codec_chain + aa_chain, ab_pair, bb_chain = codec_chain.resolve_metadata_chain(chunk_spec) + + from zarr.core.codec_pipeline import fill_value_or_default + + fill_value = fill_value_or_default(shard_spec) + for chunk_coords, chunk_selection, out_selection, _is_complete in indexed_chunks: + chunk_bytes = shard_dict.get(chunk_coords) + if chunk_bytes is not None: + chunk_array = codec_chain.decode_chunk( + chunk_bytes, chunk_spec, aa_chain, ab_pair, bb_chain ) - for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer - ], - out, - ) + if chunk_array is not None: + out[out_selection] = chunk_array[chunk_selection] + else: + out[out_selection] = fill_value + else: + out[out_selection] = fill_value if hasattr(indexer, "sel_shape"): return out.reshape(indexer.sel_shape) else: return out - async def _encode_single( + async def _decode_partial_single( self, - shard_array: NDBuffer, + byte_getter: ByteGetter, + selection: SelectorTuple, shard_spec: ArraySpec, - ) -> Buffer | None: + ) -> NDBuffer | None: shard_shape = shard_spec.shape chunk_shape = self.chunk_shape chunks_per_shard = self._get_chunks_per_shard(shard_spec) chunk_spec = self._get_chunk_spec(shard_spec) - indexer = list( - BasicIndexer( - tuple(slice(0, s) for s in shard_shape), - shape=shard_shape, - chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), - ) + indexer = get_indexer( + selection, + shape=shard_shape, + chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), ) - shard_builder = dict.fromkeys(morton_order_iter(chunks_per_shard)) - - await self.codec_pipeline.write( - [ - ( - _ShardingByteSetter(shard_builder, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, - ) - for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer - ], - shard_array, + # setup output array + out = shard_spec.prototype.nd_buffer.empty( + shape=indexer.shape, + dtype=shard_spec.dtype.to_native_dtype(), + order=shard_spec.order, ) - return await self._encode_shard_dict( - shard_builder, - chunks_per_shard=chunks_per_shard, - buffer_prototype=default_buffer_prototype(), - ) + indexed_chunks = list(indexer) + all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks} - async def _encode_partial_single( - self, - byte_setter: ByteSetter, - shard_array: NDBuffer, - selection: SelectorTuple, - shard_spec: ArraySpec, - ) -> None: - shard_shape = shard_spec.shape - chunk_shape = self.chunk_shape - chunks_per_shard = self._get_chunks_per_shard(shard_spec) - chunk_spec = self._get_chunk_spec(shard_spec) + # reading bytes of all requested chunks + shard_dict: ShardMapping = {} + if self._is_total_shard(all_chunk_coords, chunks_per_shard): + # read entire shard + shard_dict_maybe = await self._load_full_shard_maybe( + byte_getter=byte_getter, + prototype=chunk_spec.prototype, + chunks_per_shard=chunks_per_shard, + ) + if shard_dict_maybe is None: + return None + shard_dict = shard_dict_maybe + else: + # read some chunks within the shard + shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard) + if shard_index is None: + return None + shard_dict = {} + for chunk_coords in all_chunk_coords: + chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords) + if chunk_byte_slice: + chunk_bytes = await byte_getter.get( + prototype=chunk_spec.prototype, + byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]), + ) + if chunk_bytes: + shard_dict[chunk_coords] = chunk_bytes - shard_reader = await self._load_full_shard_maybe( - byte_getter=byte_setter, - prototype=chunk_spec.prototype, - chunks_per_shard=chunks_per_shard, - ) - shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) - # Use vectorized lookup for better performance - shard_dict = shard_reader.to_dict_vectorized(np.asarray(_morton_order(chunks_per_shard))) + # Decode chunks directly via CodecChain (pure compute, no inner pipeline). + codec_chain = self._codec_chain + aa_chain, ab_pair, bb_chain = codec_chain.resolve_metadata_chain(chunk_spec) - indexer = list( - get_indexer( - selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape) - ) - ) + from zarr.core.codec_pipeline import fill_value_or_default - await self.codec_pipeline.write( - [ - ( - _ShardingByteSetter(shard_dict, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, + fill_value = fill_value_or_default(shard_spec) + for chunk_coords, chunk_selection, out_selection, _is_complete in indexed_chunks: + chunk_bytes = shard_dict.get(chunk_coords) + if chunk_bytes is not None: + chunk_array = codec_chain.decode_chunk( + chunk_bytes, chunk_spec, aa_chain, ab_pair, bb_chain ) - for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer - ], - shard_array, - ) - buf = await self._encode_shard_dict( - shard_dict, - chunks_per_shard=chunks_per_shard, - buffer_prototype=default_buffer_prototype(), - ) + if chunk_array is not None: + out[out_selection] = chunk_array[chunk_selection] + else: + out[out_selection] = fill_value + else: + out[out_selection] = fill_value - if buf is None: - await byte_setter.delete() + if hasattr(indexer, "sel_shape"): + return out.reshape(indexer.sel_shape) else: - await byte_setter.set(buf) + return out - async def _encode_shard_dict( + async def _encode_single( self, - map: ShardMapping, - chunks_per_shard: tuple[int, ...], - buffer_prototype: BufferPrototype, + shard_array: NDBuffer, + shard_spec: ArraySpec, ) -> Buffer | None: - index = _ShardIndex.create_empty(chunks_per_shard) - - buffers = [] - - template = buffer_prototype.buffer.create_zero_length() - chunk_start = 0 - for chunk_coords in morton_order_iter(chunks_per_shard): - value = map.get(chunk_coords) - if value is None: - continue - - if len(value) == 0: - continue - - chunk_length = len(value) - buffers.append(value) - index.set_chunk_slice(chunk_coords, slice(chunk_start, chunk_start + chunk_length)) - chunk_start += chunk_length - - if len(buffers) == 0: - return None - - index_bytes = await self._encode_shard_index(index) - if self.index_location == ShardingCodecIndexLocation.start: - empty_chunks_mask = index.offsets_and_lengths[..., 0] == MAX_UINT_64 - index.offsets_and_lengths[~empty_chunks_mask, 0] += len(index_bytes) - index_bytes = await self._encode_shard_index( - index - ) # encode again with corrected offsets - buffers.insert(0, index_bytes) - else: - buffers.append(index_bytes) - - return template.combine(buffers) + # _encode_single is pure compute (no IO), same as _encode_sync. + return self._encode_sync(shard_array, shard_spec) def _is_total_shard( self, all_chunk_coords: set[tuple[int, ...]], chunks_per_shard: tuple[int, ...] @@ -661,48 +1001,27 @@ def _is_total_shard( chunk_coords in all_chunk_coords for chunk_coords in c_order_iter(chunks_per_shard) ) - async def _decode_shard_index( + def _decode_shard_index_sync( self, index_bytes: Buffer, chunks_per_shard: tuple[int, ...] ) -> _ShardIndex: - index_array = next( - iter( - await get_pipeline_class() - .from_codecs(self.index_codecs) - .decode( - [(index_bytes, self._get_index_chunk_spec(chunks_per_shard))], - ) - ) - ) - # This cannot be None because we have the bytes already - index_array = cast(NDBuffer, index_array) + """Decode shard index synchronously via the cached index CodecChain.""" + index_chunk_spec = self._get_index_chunk_spec(chunks_per_shard) + index_array = self._index_codec_chain.decode_chunk(index_bytes, index_chunk_spec) + assert index_array is not None return _ShardIndex(index_array.as_numpy_array()) - async def _encode_shard_index(self, index: _ShardIndex) -> Buffer: - index_bytes = next( - iter( - await get_pipeline_class() - .from_codecs(self.index_codecs) - .encode( - [ - ( - get_ndbuffer_class().from_numpy_array(index.offsets_and_lengths), - self._get_index_chunk_spec(index.chunks_per_shard), - ) - ], - ) - ) - ) + def _encode_shard_index_sync(self, index: _ShardIndex) -> Buffer: + """Encode shard index synchronously via the cached index CodecChain.""" + index_chunk_spec = self._get_index_chunk_spec(index.chunks_per_shard) + index_nd = get_ndbuffer_class().from_numpy_array(index.offsets_and_lengths) + index_bytes = self._index_codec_chain.encode_chunk(index_nd, index_chunk_spec) assert index_bytes is not None assert isinstance(index_bytes, Buffer) return index_bytes def _shard_index_size(self, chunks_per_shard: tuple[int, ...]) -> int: - return ( - get_pipeline_class() - .from_codecs(self.index_codecs) - .compute_encoded_size( - 16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard) - ) + return self._index_codec_chain.compute_encoded_size( + 16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard) ) def _get_index_chunk_spec(self, chunks_per_shard: tuple[int, ...]) -> ArraySpec: @@ -749,7 +1068,7 @@ async def _load_shard_index_maybe( prototype=numpy_buffer_prototype(), byte_range=SuffixByteRequest(shard_index_size) ) if index_bytes is not None: - return await self._decode_shard_index(index_bytes, chunks_per_shard) + return self._decode_shard_index_sync(index_bytes, chunks_per_shard) return None async def _load_shard_index( @@ -765,11 +1084,579 @@ async def _load_full_shard_maybe( shard_bytes = await byte_getter.get(prototype=prototype) return ( - await _ShardReader.from_bytes(shard_bytes, self, chunks_per_shard) + _ShardReader.from_bytes_sync(shard_bytes, self, chunks_per_shard) if shard_bytes else None ) + # ------------------------------------------------------------------- + # prepare_* overrides — composable building blocks for the pipeline + # ------------------------------------------------------------------- + + @property + def inner_codec_chain(self) -> Any: + return self._codec_chain + + def deserialize( + self, raw: Buffer | None, shard_spec: ArraySpec + ) -> dict[tuple[int, ...], Buffer | None]: + chunks_per_shard = self._get_chunks_per_shard(shard_spec) + if raw is None: + return dict.fromkeys(morton_order_iter(chunks_per_shard)) + shard_reader = _ShardReader.from_bytes_sync(raw, self, chunks_per_shard) + result: dict[tuple[int, ...], Buffer | None] = {} + for coords in morton_order_iter(chunks_per_shard): + chunk_byte_slice = shard_reader.index.get_chunk_slice(coords) + if chunk_byte_slice: + result[coords] = shard_reader.buf[chunk_byte_slice[0] : chunk_byte_slice[1]] + else: + result[coords] = None + return result + + def serialize( + self, chunk_dict: dict[tuple[int, ...], Buffer | None], shard_spec: ArraySpec + ) -> Buffer | None: + chunks_per_shard = self._get_chunks_per_shard(shard_spec) + return self._encode_shard_dict_sync( + chunk_dict, chunks_per_shard, default_buffer_prototype() + ) + + def prepare_read_sync( + self, + byte_getter: Any, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + codec_chain: Any, + aa_chain: Any, + ab_pair: Any, + bb_chain: Any, + ) -> NDBuffer | None: + return self._decode_partial_sync(byte_getter, chunk_selection, chunk_spec) + + async def prepare_read( + self, + byte_getter: Any, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + codec_chain: Any, + aa_chain: Any, + ab_pair: Any, + bb_chain: Any, + ) -> NDBuffer | None: + return await self._decode_partial_single(byte_getter, chunk_selection, chunk_spec) + + def _prepare_write_partial_fixed_sync( + self, + byte_setter: Any, + chunk_spec: ArraySpec, + chunks_per_shard: tuple[int, ...], + chunk_spec_inner: ArraySpec, + indexer: list[ChunkProjection], + ) -> dict[tuple[int, ...], Buffer | None]: + """For fixed-size partial writes: fetch only the inner chunks that need merging.""" + chunk_byte_length = self._inner_chunk_byte_length(chunk_spec_inner) + chunk_dict: dict[tuple[int, ...], Buffer | None] = {} + for coords, _, _, is_complete_inner in indexer: + if is_complete_inner: + chunk_dict[coords] = None + else: + offset = self._chunk_byte_offset(coords, chunks_per_shard, chunk_byte_length) + chunk_dict[coords] = byte_setter.get_sync( + prototype=chunk_spec_inner.prototype, + byte_range=RangeByteRequest(offset, offset + chunk_byte_length), + ) + return chunk_dict + + async def _prepare_write_partial_fixed( + self, + byte_setter: Any, + chunk_spec: ArraySpec, + chunks_per_shard: tuple[int, ...], + chunk_spec_inner: ArraySpec, + indexer: list[ChunkProjection], + ) -> dict[tuple[int, ...], Buffer | None]: + """Async version: fetch only the inner chunks that need merging.""" + chunk_byte_length = self._inner_chunk_byte_length(chunk_spec_inner) + chunk_dict: dict[tuple[int, ...], Buffer | None] = {} + for coords, _, _, is_complete_inner in indexer: + if is_complete_inner: + chunk_dict[coords] = None + else: + offset = self._chunk_byte_offset(coords, chunks_per_shard, chunk_byte_length) + chunk_dict[coords] = await byte_setter.get( + prototype=chunk_spec_inner.prototype, + byte_range=RangeByteRequest(offset, offset + chunk_byte_length), + ) + return chunk_dict + + def prepare_write_sync( + self, + byte_setter: Any, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + out_selection: SelectorTuple, + codec_chain: Any, + ) -> PreparedWrite: + from zarr.abc.codec import PreparedWrite, _is_complete_selection + + chunks_per_shard = self._get_chunks_per_shard(chunk_spec) + chunk_spec_inner = self._get_chunk_spec(chunk_spec) + + # Build inner indexer first — needed for fixed-size targeted reads. + indexer = list( + get_indexer( + chunk_selection, + shape=chunk_spec.shape, + chunk_grid=RegularChunkGrid(chunk_shape=self.chunk_shape), + ) + ) + + is_complete = _is_complete_selection(chunk_selection, chunk_spec.shape) + if is_complete: + # Complete selection: the pipeline will pass the shard value to + # finalize_write which encodes the full shard in one shot. + return PreparedWrite( + chunk_dict={}, + inner_codec_chain=self._codec_chain, + inner_chunk_spec=chunk_spec_inner, + indexer=[], + value_selection=out_selection, + write_full_shard=True, + is_complete_shard=True, + ) + elif self._inner_codecs_fixed_size: + # Fixed-size partial write: only fetch inner chunks that need merging. + # Check if shard exists first — if not, all chunks are None. + probe = byte_setter.get_sync( + prototype=chunk_spec_inner.prototype, + byte_range=RangeByteRequest(0, 1), + ) + chunk_dict: dict[tuple[int, ...], Buffer | None] + if probe is None: + write_full_shard = True + chunk_dict = {coords: None for coords, _, _, _ in indexer} + else: + write_full_shard = False + chunk_dict = self._prepare_write_partial_fixed_sync( + byte_setter, + chunk_spec, + chunks_per_shard, + chunk_spec_inner, + indexer, + ) + else: + # Variable-size: must fetch entire shard. + shard_reader = self._load_full_shard_maybe_sync( + byte_setter, chunk_spec_inner.prototype, chunks_per_shard + ) + write_full_shard = shard_reader is None + shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) + chunk_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)} + + return PreparedWrite( + chunk_dict=chunk_dict, + inner_codec_chain=self._codec_chain, + inner_chunk_spec=chunk_spec_inner, + indexer=indexer, + value_selection=out_selection, + write_full_shard=write_full_shard, + ) + + async def prepare_write( + self, + byte_setter: Any, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + out_selection: SelectorTuple, + codec_chain: Any, + ) -> PreparedWrite: + from zarr.abc.codec import PreparedWrite, _is_complete_selection + + chunks_per_shard = self._get_chunks_per_shard(chunk_spec) + chunk_spec_inner = self._get_chunk_spec(chunk_spec) + + indexer = list( + get_indexer( + chunk_selection, + shape=chunk_spec.shape, + chunk_grid=RegularChunkGrid(chunk_shape=self.chunk_shape), + ) + ) + + is_complete = _is_complete_selection(chunk_selection, chunk_spec.shape) + if is_complete: + return PreparedWrite( + chunk_dict={}, + inner_codec_chain=self._codec_chain, + inner_chunk_spec=chunk_spec_inner, + indexer=[], + value_selection=out_selection, + write_full_shard=True, + is_complete_shard=True, + ) + elif self._inner_codecs_fixed_size: + probe = await byte_setter.get( + prototype=chunk_spec_inner.prototype, + byte_range=RangeByteRequest(0, 1), + ) + chunk_dict: dict[tuple[int, ...], Buffer | None] + if probe is None: + write_full_shard = True + chunk_dict = {coords: None for coords, _, _, _ in indexer} + else: + write_full_shard = False + chunk_dict = await self._prepare_write_partial_fixed( + byte_setter, + chunk_spec, + chunks_per_shard, + chunk_spec_inner, + indexer, + ) + else: + shard_reader = await self._load_full_shard_maybe( + byte_setter, chunk_spec_inner.prototype, chunks_per_shard + ) + write_full_shard = shard_reader is None + shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) + chunk_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)} + + return PreparedWrite( + chunk_dict=chunk_dict, + inner_codec_chain=self._codec_chain, + inner_chunk_spec=chunk_spec_inner, + indexer=indexer, + value_selection=out_selection, + write_full_shard=write_full_shard, + ) + + @property + def _inner_codecs_fixed_size(self) -> bool: + return all(c.is_fixed_size for c in self._codec_chain) + + @property + def is_fixed_size(self) -> bool: # type: ignore[override] + # ShardingCodec output varies when write_empty_chunks=False causes + # fill-value sub-chunks to be omitted, so it is not fixed-size in general. + return False + + def _inner_chunk_byte_length(self, chunk_spec: ArraySpec) -> int: + """Encoded byte length of a single inner chunk (only valid when _inner_codecs_fixed_size).""" + assert isinstance(chunk_spec.dtype, HasItemSize) + raw_byte_length = product(self.chunk_shape) * chunk_spec.dtype.item_size + return int(self._codec_chain.compute_encoded_size(raw_byte_length, chunk_spec)) + + @staticmethod + @lru_cache(maxsize=16) + def _morton_rank_map(chunks_per_shard: tuple[int, ...]) -> dict[tuple[int, ...], int]: + """Return a dict mapping morton-order coords → rank (0-based). + + Cached because the same chunks_per_shard is used repeatedly for reads and writes. + """ + return {coords: rank for rank, coords in enumerate(morton_order_iter(chunks_per_shard))} + + def _chunk_byte_offset( + self, + chunk_coords: tuple[int, ...], + chunks_per_shard: tuple[int, ...], + chunk_byte_length: int, + ) -> int: + """Byte offset of an inner chunk within a dense shard blob. + + Assumes all chunks are present and laid out in morton order. + """ + rank = self._morton_rank_map(chunks_per_shard)[chunk_coords] + offset = rank * chunk_byte_length + if self.index_location == ShardingCodecIndexLocation.start: + offset += self._shard_index_size(chunks_per_shard) + return offset + + def _shard_index_byte_offset( + self, chunks_per_shard: tuple[int, ...], chunk_byte_length: int + ) -> int: + """Byte offset of the shard index within a dense shard blob.""" + n_chunks = product(chunks_per_shard) + if self.index_location == ShardingCodecIndexLocation.start: + return 0 + return n_chunks * chunk_byte_length + + def _build_dense_shard_index( + self, chunks_per_shard: tuple[int, ...], chunk_byte_length: int + ) -> _ShardIndex: + """Build a shard index for a fully-dense shard with fixed-size chunks.""" + index = _ShardIndex.create_empty(chunks_per_shard) + data_offset = ( + self._shard_index_size(chunks_per_shard) + if self.index_location == ShardingCodecIndexLocation.start + else 0 + ) + for rank, coords in enumerate(morton_order_iter(chunks_per_shard)): + chunk_start = data_offset + rank * chunk_byte_length + index.set_chunk_slice(coords, slice(chunk_start, chunk_start + chunk_byte_length)) + return index + + def _build_dense_shard_blob( + self, + chunk_dict: dict[tuple[int, ...], Buffer | None], + chunks_per_shard: tuple[int, ...], + chunk_byte_length: int, + ) -> Buffer: + """Build a dense shard blob with fixed-size chunks at deterministic offsets. + + Unlike ``_encode_shard_dict_sync`` (used by ``serialize``), this places each + chunk at ``rank * chunk_byte_length`` (plus index offset for start-indexed shards), + producing a layout compatible with ``_chunk_byte_offset`` / ``set_range``. + """ + index = self._build_dense_shard_index(chunks_per_shard, chunk_byte_length) + index_bytes = self._encode_shard_index_sync(index) + + # Allocate the full blob as a flat numpy array + n_chunks = product(chunks_per_shard) + data_size = n_chunks * chunk_byte_length + total_size = data_size + len(index_bytes) + blob_array = np.zeros(total_size, dtype=np.uint8) + + data_offset = ( + len(index_bytes) if self.index_location == ShardingCodecIndexLocation.start else 0 + ) + index_offset = 0 if self.index_location == ShardingCodecIndexLocation.start else data_size + + # Place each chunk at its deterministic offset + for rank, coords in enumerate(morton_order_iter(chunks_per_shard)): + chunk_bytes = chunk_dict.get(coords) + if chunk_bytes is not None: + start = data_offset + rank * chunk_byte_length + blob_array[start : start + len(chunk_bytes)] = chunk_bytes.as_numpy_array() + + # Place the index + blob_array[index_offset : index_offset + len(index_bytes)] = index_bytes.as_numpy_array() + + return default_buffer_prototype().buffer.from_bytes(blob_array.tobytes()) + + def finalize_write_sync(self, prepared: Any, chunk_spec: ArraySpec, byte_setter: Any) -> None: + from zarr.abc.codec import PreparedWrite + + assert isinstance(prepared, PreparedWrite) + + # Complete shard: encode the entire shard value in one shot. + if prepared.is_complete_shard: + assert prepared.shard_data is not None + shard_data = prepared.shard_data + # Expand scalar/broadcast value to shard shape. + if shard_data.shape != chunk_spec.shape: + expanded = chunk_spec.prototype.nd_buffer.create( + shape=chunk_spec.shape, + dtype=chunk_spec.dtype.to_native_dtype(), + order=chunk_spec.order, + fill_value=0, + ) + expanded[()] = shard_data + shard_data = expanded + blob = self._encode_sync(shard_data, chunk_spec) + if blob is None: + byte_setter.delete_sync() + else: + byte_setter.set_sync(blob) + return + + chunks_per_shard = self._get_chunks_per_shard(chunk_spec) + chunk_spec_inner = self._get_chunk_spec(chunk_spec) + + if not self._inner_codecs_fixed_size: + # Fall back to full serialize + set + blob = self.serialize(prepared.chunk_dict, chunk_spec) + if blob is None: + byte_setter.delete_sync() + else: + byte_setter.set_sync(blob) + return + + chunk_byte_length = self._inner_chunk_byte_length(chunk_spec_inner) + # Spec with write_empty_chunks=True — needed to encode fill-value chunks + # into actual bytes for the dense shard layout. + dense_spec = replace( + chunk_spec_inner, + config=ArrayConfig( + order=chunk_spec_inner.config.order, + write_empty_chunks=True, + ), + ) + + if prepared.write_full_shard: + # Full shard write: create a fully-dense blob and write it all at once. + # If all chunks are fill-value (None), delete the shard if it exists. + if all(v is None for v in prepared.chunk_dict.values()): + byte_setter.delete_sync() + return + # Encode fill-value chunks for any coords that are None (either + # unmodified coords or modified coords that equal the fill value). + for coords in morton_order_iter(chunks_per_shard): + if prepared.chunk_dict.get(coords) is None: + fill_chunk = chunk_spec_inner.prototype.nd_buffer.create( + shape=chunk_spec_inner.shape, + dtype=chunk_spec_inner.dtype.to_native_dtype(), + order=chunk_spec_inner.order, + fill_value=chunk_spec_inner.fill_value, + ) + prepared.chunk_dict[coords] = prepared.inner_codec_chain.encode_chunk( + fill_chunk, dense_spec + ) + blob = self._build_dense_shard_blob( + prepared.chunk_dict, chunks_per_shard, chunk_byte_length + ) + byte_setter.set_sync(blob) + return + + # Existing shard with fixed-size chunks: write only modified chunks via set_range. + # If any modified chunk became fill-value (None), fall back to full read-modify-write + # so that shard deletion works correctly. + has_fill_chunks = any( + prepared.chunk_dict.get(coords) is None for coords, _, _, _ in prepared.indexer + ) + if has_fill_chunks: + # Need full read-modify-write for correct shard deletion behavior. + shard_reader = self._load_full_shard_maybe_sync( + byte_setter, chunk_spec_inner.prototype, chunks_per_shard + ) + if shard_reader is not None: + full_dict: dict[tuple[int, ...], Buffer | None] = { + k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard) + } + # Merge modified chunks into the full dict. + for coords, _, _, _ in prepared.indexer: + full_dict[coords] = prepared.chunk_dict.get(coords) + blob = self.serialize(full_dict, chunk_spec) + if blob is None: + byte_setter.delete_sync() + else: + byte_setter.set_sync(blob) + return + + try: + for coords, _, _, _ in prepared.indexer: + chunk_bytes = prepared.chunk_dict.get(coords) + if chunk_bytes is not None: + offset = self._chunk_byte_offset(coords, chunks_per_shard, chunk_byte_length) + byte_setter.set_range_sync(chunk_bytes, offset) + + # Update the shard index (unchanged for dense layout but must be rewritten + # because the index is part of the blob). + index = self._build_dense_shard_index(chunks_per_shard, chunk_byte_length) + index_bytes = self._encode_shard_index_sync(index) + index_offset = self._shard_index_byte_offset(chunks_per_shard, chunk_byte_length) + byte_setter.set_range_sync(index_bytes, index_offset) + except NotImplementedError: + # Store doesn't support set_range — fall back to full serialize + set. + blob = self.serialize(prepared.chunk_dict, chunk_spec) + if blob is None: + byte_setter.delete_sync() + else: + byte_setter.set_sync(blob) + + async def finalize_write(self, prepared: Any, chunk_spec: ArraySpec, byte_setter: Any) -> None: + from zarr.abc.codec import PreparedWrite + + assert isinstance(prepared, PreparedWrite) + + # Complete shard: encode the entire shard value in one shot. + if prepared.is_complete_shard: + assert prepared.shard_data is not None + shard_data = prepared.shard_data + if shard_data.shape != chunk_spec.shape: + expanded = chunk_spec.prototype.nd_buffer.create( + shape=chunk_spec.shape, + dtype=chunk_spec.dtype.to_native_dtype(), + order=chunk_spec.order, + fill_value=0, + ) + expanded[()] = shard_data + shard_data = expanded + blob = self._encode_sync(shard_data, chunk_spec) + if blob is None: + await byte_setter.delete() + else: + await byte_setter.set(blob) + return + + chunks_per_shard = self._get_chunks_per_shard(chunk_spec) + chunk_spec_inner = self._get_chunk_spec(chunk_spec) + + if not self._inner_codecs_fixed_size: + blob = self.serialize(prepared.chunk_dict, chunk_spec) + if blob is None: + await byte_setter.delete() + else: + await byte_setter.set(blob) + return + + chunk_byte_length = self._inner_chunk_byte_length(chunk_spec_inner) + + if prepared.write_full_shard: + # If all chunks are fill-value (None), delete the shard if it exists. + if all(v is None for v in prepared.chunk_dict.values()): + await byte_setter.delete() + return + dense_spec = replace( + chunk_spec_inner, + config=ArrayConfig( + order=chunk_spec_inner.config.order, + write_empty_chunks=True, + ), + ) + for coords in morton_order_iter(chunks_per_shard): + if prepared.chunk_dict.get(coords) is None: + fill_chunk = chunk_spec_inner.prototype.nd_buffer.create( + shape=chunk_spec_inner.shape, + dtype=chunk_spec_inner.dtype.to_native_dtype(), + order=chunk_spec_inner.order, + fill_value=chunk_spec_inner.fill_value, + ) + prepared.chunk_dict[coords] = prepared.inner_codec_chain.encode_chunk( + fill_chunk, dense_spec + ) + blob = self._build_dense_shard_blob( + prepared.chunk_dict, chunks_per_shard, chunk_byte_length + ) + await byte_setter.set(blob) + return + + has_fill_chunks = any( + prepared.chunk_dict.get(coords) is None for coords, _, _, _ in prepared.indexer + ) + if has_fill_chunks: + shard_reader = await self._load_full_shard_maybe( + byte_setter, chunk_spec_inner.prototype, chunks_per_shard + ) + if shard_reader is not None: + full_dict: dict[tuple[int, ...], Buffer | None] = { + k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard) + } + for coords, _, _, _ in prepared.indexer: + full_dict[coords] = prepared.chunk_dict.get(coords) + blob = self.serialize(full_dict, chunk_spec) + if blob is None: + await byte_setter.delete() + else: + await byte_setter.set(blob) + return + + try: + for coords, _, _, _ in prepared.indexer: + chunk_bytes = prepared.chunk_dict.get(coords) + if chunk_bytes is not None: + offset = self._chunk_byte_offset(coords, chunks_per_shard, chunk_byte_length) + await byte_setter.set_range(chunk_bytes, offset) + + index = self._build_dense_shard_index(chunks_per_shard, chunk_byte_length) + index_bytes = self._encode_shard_index_sync(index) + index_offset = self._shard_index_byte_offset(chunks_per_shard, chunk_byte_length) + await byte_setter.set_range(index_bytes, index_offset) + except NotImplementedError: + blob = self.serialize(prepared.chunk_dict, chunk_spec) + if blob is None: + await byte_setter.delete() + else: + await byte_setter.set(blob) + def compute_encoded_size(self, input_byte_length: int, shard_spec: ArraySpec) -> int: chunks_per_shard = self._get_chunks_per_shard(shard_spec) return input_byte_length + self._shard_index_size(chunks_per_shard) diff --git a/src/zarr/codecs/transpose.py b/src/zarr/codecs/transpose.py index a8570b6e8f..8fe4c90409 100644 --- a/src/zarr/codecs/transpose.py +++ b/src/zarr/codecs/transpose.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterable -from dataclasses import dataclass, replace +from dataclasses import dataclass, field, replace from typing import TYPE_CHECKING, cast import numpy as np @@ -33,11 +33,14 @@ class TransposeCodec(ArrayArrayCodec): is_fixed_size = True order: tuple[int, ...] + _inverse_order: tuple[int, ...] = field(init=False, repr=False, compare=False) def __init__(self, *, order: Iterable[int]) -> None: order_parsed = parse_transpose_order(order) object.__setattr__(self, "order", order_parsed) + # Cache the inverse order to avoid np.argsort on every decode. + object.__setattr__(self, "_inverse_order", tuple(int(i) for i in np.argsort(order_parsed))) @classmethod def from_dict(cls, data: dict[str, JSON]) -> Self: @@ -95,20 +98,25 @@ def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: prototype=chunk_spec.prototype, ) + def _decode_sync(self, chunk_array: NDBuffer, chunk_spec: ArraySpec) -> NDBuffer: + return chunk_array.transpose(self._inverse_order) + + def _encode_sync(self, chunk_array: NDBuffer, _chunk_spec: ArraySpec) -> NDBuffer | None: + return chunk_array.transpose(self.order) + async def _decode_single( self, chunk_array: NDBuffer, chunk_spec: ArraySpec, ) -> NDBuffer: - inverse_order = np.argsort(self.order) - return chunk_array.transpose(inverse_order) + return self._decode_sync(chunk_array, chunk_spec) async def _encode_single( self, chunk_array: NDBuffer, _chunk_spec: ArraySpec, ) -> NDBuffer | None: - return chunk_array.transpose(self.order) + return self._encode_sync(chunk_array, _chunk_spec) def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length diff --git a/src/zarr/codecs/vlen_utf8.py b/src/zarr/codecs/vlen_utf8.py index fb1fb76126..16de25001c 100644 --- a/src/zarr/codecs/vlen_utf8.py +++ b/src/zarr/codecs/vlen_utf8.py @@ -40,12 +40,7 @@ def to_dict(self) -> dict[str, JSON]: def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: return self - # TODO: expand the tests for this function - async def _decode_single( - self, - chunk_bytes: Buffer, - chunk_spec: ArraySpec, - ) -> NDBuffer: + def _decode_sync(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> NDBuffer: assert isinstance(chunk_bytes, Buffer) raw_bytes = chunk_bytes.as_array_like() @@ -55,15 +50,25 @@ async def _decode_single( as_string_dtype = decoded.astype(chunk_spec.dtype.to_native_dtype(), copy=False) return chunk_spec.prototype.nd_buffer.from_numpy_array(as_string_dtype) + def _encode_sync(self, chunk_array: NDBuffer, chunk_spec: ArraySpec) -> Buffer | None: + assert isinstance(chunk_array, NDBuffer) + return chunk_spec.prototype.buffer.from_bytes( + _vlen_utf8_codec.encode(chunk_array.as_numpy_array()) + ) + + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + return self._decode_sync(chunk_bytes, chunk_spec) + async def _encode_single( self, chunk_array: NDBuffer, chunk_spec: ArraySpec, ) -> Buffer | None: - assert isinstance(chunk_array, NDBuffer) - return chunk_spec.prototype.buffer.from_bytes( - _vlen_utf8_codec.encode(chunk_array.as_numpy_array()) - ) + return self._encode_sync(chunk_array, chunk_spec) def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: # what is input_byte_length for an object dtype? @@ -86,11 +91,7 @@ def to_dict(self) -> dict[str, JSON]: def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: return self - async def _decode_single( - self, - chunk_bytes: Buffer, - chunk_spec: ArraySpec, - ) -> NDBuffer: + def _decode_sync(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> NDBuffer: assert isinstance(chunk_bytes, Buffer) raw_bytes = chunk_bytes.as_array_like() @@ -99,15 +100,25 @@ async def _decode_single( decoded = _reshape_view(decoded, chunk_spec.shape) return chunk_spec.prototype.nd_buffer.from_numpy_array(decoded) + def _encode_sync(self, chunk_array: NDBuffer, chunk_spec: ArraySpec) -> Buffer | None: + assert isinstance(chunk_array, NDBuffer) + return chunk_spec.prototype.buffer.from_bytes( + _vlen_bytes_codec.encode(chunk_array.as_numpy_array()) + ) + + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + return self._decode_sync(chunk_bytes, chunk_spec) + async def _encode_single( self, chunk_array: NDBuffer, chunk_spec: ArraySpec, ) -> Buffer | None: - assert isinstance(chunk_array, NDBuffer) - return chunk_spec.prototype.buffer.from_bytes( - _vlen_bytes_codec.encode(chunk_array.as_numpy_array()) - ) + return self._encode_sync(chunk_array, chunk_spec) def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: # what is input_byte_length for an object dtype? diff --git a/src/zarr/codecs/zstd.py b/src/zarr/codecs/zstd.py index 27cc9a7777..f8db8da0ca 100644 --- a/src/zarr/codecs/zstd.py +++ b/src/zarr/codecs/zstd.py @@ -38,7 +38,7 @@ def parse_checksum(data: JSON) -> bool: class ZstdCodec(BytesBytesCodec): """zstd codec""" - is_fixed_size = True + is_fixed_size = False level: int = 0 checksum: bool = False @@ -71,23 +71,25 @@ def _zstd_codec(self) -> Zstd: config_dict = {"level": self.level, "checksum": self.checksum} return Zstd.from_config(config_dict) + def _decode_sync(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Buffer: + return as_numpy_array_wrapper(self._zstd_codec.decode, chunk_bytes, chunk_spec.prototype) + + def _encode_sync(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Buffer | None: + return as_numpy_array_wrapper(self._zstd_codec.encode, chunk_bytes, chunk_spec.prototype) + async def _decode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer: - return await asyncio.to_thread( - as_numpy_array_wrapper, self._zstd_codec.decode, chunk_bytes, chunk_spec.prototype - ) + return await asyncio.to_thread(self._decode_sync, chunk_bytes, chunk_spec) async def _encode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer | None: - return await asyncio.to_thread( - as_numpy_array_wrapper, self._zstd_codec.encode, chunk_bytes, chunk_spec.prototype - ) + return await asyncio.to_thread(self._encode_sync, chunk_bytes, chunk_spec) def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int: raise NotImplementedError diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 564d0e915a..805867d92b 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -1973,6 +1973,29 @@ def config(self) -> ArrayConfig: """ return self.async_array.config + def _can_use_sync_path(self) -> bool: + """Check if we can bypass the event loop entirely for read/write. + + Two conditions must hold: + + 1. The codec pipeline supports fully synchronous IO (all codecs + implement ``SupportsSyncCodec``). This is True for + BatchedCodecPipeline when all codecs support sync. + + 2. The store supports synchronous operations (has a ``get_sync`` + method). MemoryStore and LocalStore provide this; remote + stores do not. + + When both hold, the selection methods below call + _get_selection_sync / _set_selection_sync directly, running the + entire read/write path on the calling thread with zero async + overhead. Otherwise, the async path with concurrent IO overlap + is used automatically. + """ + pipeline = self.async_array.codec_pipeline + store = self.async_array.store_path.store + return getattr(pipeline, "supports_sync_io", False) and hasattr(store, "get_sync") + @classmethod @deprecated("Use zarr.create_array instead.", category=ZarrDeprecationWarning) def create( @@ -3049,9 +3072,28 @@ def get_basic_selection( if prototype is None: prototype = default_buffer_prototype() + indexer = BasicIndexer(selection, self.shape, self.metadata.chunk_grid) + # Sync bypass: when the codec pipeline and store both support + # synchronous operation, skip the sync() → event loop bridge and + # run the entire read path on the calling thread. This pattern is + # repeated in all 10 get_*/set_* methods below. + if self._can_use_sync_path(): + return _get_selection_sync( + self.async_array.store_path, + self.async_array.metadata, + self.async_array.codec_pipeline, + self.async_array.config, + indexer, + out=out, + fields=fields, + prototype=prototype, + ) + # Fallback: submit the async coroutine to the background event loop + # thread via sync(). Used for remote stores or when the sync bypass + # is not active. return sync( self.async_array._get_selection( - BasicIndexer(selection, self.shape, self.metadata.chunk_grid), + indexer, out=out, fields=fields, prototype=prototype, @@ -3159,6 +3201,18 @@ def set_basic_selection( if prototype is None: prototype = default_buffer_prototype() indexer = BasicIndexer(selection, self.shape, self.metadata.chunk_grid) + if self._can_use_sync_path(): + _set_selection_sync( + self.async_array.store_path, + self.async_array.metadata, + self.async_array.codec_pipeline, + self.async_array.config, + indexer, + value, + fields=fields, + prototype=prototype, + ) + return sync(self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) def get_orthogonal_selection( @@ -3287,6 +3341,17 @@ def get_orthogonal_selection( if prototype is None: prototype = default_buffer_prototype() indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) + if self._can_use_sync_path(): + return _get_selection_sync( + self.async_array.store_path, + self.async_array.metadata, + self.async_array.codec_pipeline, + self.async_array.config, + indexer, + out=out, + fields=fields, + prototype=prototype, + ) return sync( self.async_array._get_selection( indexer=indexer, out=out, fields=fields, prototype=prototype @@ -3406,9 +3471,19 @@ def set_orthogonal_selection( if prototype is None: prototype = default_buffer_prototype() indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) - return sync( - self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype) - ) + if self._can_use_sync_path(): + _set_selection_sync( + self.async_array.store_path, + self.async_array.metadata, + self.async_array.codec_pipeline, + self.async_array.config, + indexer, + value, + fields=fields, + prototype=prototype, + ) + return + sync(self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) def get_mask_selection( self, @@ -3494,6 +3569,17 @@ def get_mask_selection( if prototype is None: prototype = default_buffer_prototype() indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) + if self._can_use_sync_path(): + return _get_selection_sync( + self.async_array.store_path, + self.async_array.metadata, + self.async_array.codec_pipeline, + self.async_array.config, + indexer, + out=out, + fields=fields, + prototype=prototype, + ) return sync( self.async_array._get_selection( indexer=indexer, out=out, fields=fields, prototype=prototype @@ -3584,6 +3670,18 @@ def set_mask_selection( if prototype is None: prototype = default_buffer_prototype() indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) + if self._can_use_sync_path(): + _set_selection_sync( + self.async_array.store_path, + self.async_array.metadata, + self.async_array.codec_pipeline, + self.async_array.config, + indexer, + value, + fields=fields, + prototype=prototype, + ) + return sync(self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) def get_coordinate_selection( @@ -3672,11 +3770,23 @@ def get_coordinate_selection( if prototype is None: prototype = default_buffer_prototype() indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) - out_array = sync( - self.async_array._get_selection( - indexer=indexer, out=out, fields=fields, prototype=prototype + if self._can_use_sync_path(): + out_array = _get_selection_sync( + self.async_array.store_path, + self.async_array.metadata, + self.async_array.codec_pipeline, + self.async_array.config, + indexer, + out=out, + fields=fields, + prototype=prototype, + ) + else: + out_array = sync( + self.async_array._get_selection( + indexer=indexer, out=out, fields=fields, prototype=prototype + ) ) - ) if hasattr(out_array, "shape"): # restore shape @@ -3786,6 +3896,18 @@ def set_coordinate_selection( f"elements with an array of {value.shape[0]} elements." ) + if self._can_use_sync_path(): + _set_selection_sync( + self.async_array.store_path, + self.async_array.metadata, + self.async_array.codec_pipeline, + self.async_array.config, + indexer, + value, + fields=fields, + prototype=prototype, + ) + return sync(self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) def get_block_selection( @@ -3887,6 +4009,17 @@ def get_block_selection( if prototype is None: prototype = default_buffer_prototype() indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid) + if self._can_use_sync_path(): + return _get_selection_sync( + self.async_array.store_path, + self.async_array.metadata, + self.async_array.codec_pipeline, + self.async_array.config, + indexer, + out=out, + fields=fields, + prototype=prototype, + ) return sync( self.async_array._get_selection( indexer=indexer, out=out, fields=fields, prototype=prototype @@ -3988,6 +4121,18 @@ def set_block_selection( if prototype is None: prototype = default_buffer_prototype() indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid) + if self._can_use_sync_path(): + _set_selection_sync( + self.async_array.store_path, + self.async_array.metadata, + self.async_array.codec_pipeline, + self.async_array.config, + indexer, + value, + fields=fields, + prototype=prototype, + ) + return sync(self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) @property @@ -5619,6 +5764,174 @@ async def _get_selection( return out_buffer.as_ndarray_like() +def _get_selection_sync( + store_path: StorePath, + metadata: ArrayMetadata, + codec_pipeline: CodecPipeline, + config: ArrayConfig, + indexer: Indexer, + *, + prototype: BufferPrototype, + out: NDBuffer | None = None, + fields: Fields | None = None, +) -> NDArrayLikeOrScalar: + """Synchronous version of _get_selection — bypasses the event loop entirely. + + This function mirrors ``_get_selection`` (the async version defined above) + exactly, with one critical difference: it calls ``codec_pipeline.read_sync()`` + instead of ``await codec_pipeline.read()``. This means the entire operation + — store IO, codec decode, buffer scatter — runs on the calling thread with + no event loop involvement. + + Called by ``Array.get_basic_selection``, ``get_orthogonal_selection``, etc. + when ``Array._can_use_sync_path()`` returns True. + + The setup logic (dtype resolution, output buffer creation, field checks) is + duplicated from the async version rather than extracted into a shared helper. + This keeps the hot path simple and avoids adding indirection. The two + versions should be kept in sync manually. + """ + # Get dtype from metadata — same logic as async _get_selection + if metadata.zarr_format == 2: + zdtype = metadata.dtype + else: + zdtype = metadata.data_type + dtype = zdtype.to_native_dtype() + + # Determine memory order + if metadata.zarr_format == 2: + order = metadata.order + else: + order = config.order + + # check fields are sensible + out_dtype = check_fields(fields, dtype) + + # setup output buffer + if out is not None: + if isinstance(out, NDBuffer): + out_buffer = out + else: + raise TypeError(f"out argument needs to be an NDBuffer. Got {type(out)!r}") + if out_buffer.shape != indexer.shape: + raise ValueError( + f"shape of out argument doesn't match. Expected {indexer.shape}, got {out.shape}" + ) + else: + out_buffer = prototype.nd_buffer.empty( + shape=indexer.shape, + dtype=out_dtype, + order=order, + ) + if product(indexer.shape) > 0: + _config = config + if metadata.zarr_format == 2: + _config = replace(_config, order=order) + + # This is the key difference from the async version: read_sync() + # runs the entire pipeline (store fetch → codec decode → scatter) + # on this thread. Each entry in the list is a (StorePath, ArraySpec, + # chunk_selection, out_selection, is_complete_chunk) tuple. + # StorePath acts as the ByteGetter — its get_sync() method is called + # by the pipeline to fetch raw chunk bytes from the store. + codec_pipeline.read_sync( + [ + ( + store_path / metadata.encode_chunk_key(chunk_coords), + metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype), + chunk_selection, + out_selection, + is_complete_chunk, + ) + for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer + ], + out_buffer, + drop_axes=indexer.drop_axes, + ) + if isinstance(indexer, BasicIndexer) and indexer.shape == (): + return out_buffer.as_scalar() + return out_buffer.as_ndarray_like() + + +def _set_selection_sync( + store_path: StorePath, + metadata: ArrayMetadata, + codec_pipeline: CodecPipeline, + config: ArrayConfig, + indexer: Indexer, + value: npt.ArrayLike, + *, + prototype: BufferPrototype, + fields: Fields | None = None, +) -> None: + """Synchronous version of _set_selection — bypasses the event loop entirely. + + Mirrors ``_set_selection`` (the async version) with the same setup logic + (dtype coercion, value shape validation, buffer wrapping) but calls + ``codec_pipeline.write_sync()`` instead of ``await codec_pipeline.write()``. + + Called by ``Array.set_basic_selection``, ``set_orthogonal_selection``, etc. + when ``Array._can_use_sync_path()`` returns True. + """ + # Get dtype from metadata + if metadata.zarr_format == 2: + zdtype = metadata.dtype + else: + zdtype = metadata.data_type + dtype = zdtype.to_native_dtype() + + # check fields are sensible + check_fields(fields, dtype) + fields = check_no_multi_fields(fields) + + # check value shape + if np.isscalar(value): + array_like = prototype.buffer.create_zero_length().as_array_like() + if isinstance(array_like, np._typing._SupportsArrayFunc): + array_like_ = cast("np._typing._SupportsArrayFunc", array_like) + value = np.asanyarray(value, dtype=dtype, like=array_like_) + else: + if not hasattr(value, "shape"): + value = np.asarray(value, dtype) + if not hasattr(value, "dtype") or value.dtype.name != dtype.name: + if hasattr(value, "astype"): + value = value.astype(dtype=dtype, order="A") + else: + value = np.array(value, dtype=dtype, order="A") + value = cast("NDArrayLike", value) + + value_buffer = prototype.nd_buffer.from_ndarray_like(value) + + # Determine memory order + if metadata.zarr_format == 2: + order = metadata.order + else: + order = config.order + + _config = config + if metadata.zarr_format == 2: + _config = replace(_config, order=order) + + # Key difference from async version: write_sync() runs the entire + # pipeline (read existing → decode → merge → encode → store write) + # on this thread. StorePath acts as ByteSetter — its set_sync() and + # delete_sync() methods persist/remove chunk bytes directly. + codec_pipeline.write_sync( + [ + ( + store_path / metadata.encode_chunk_key(chunk_coords), + metadata.get_chunk_spec(chunk_coords, _config, prototype), + chunk_selection, + out_selection, + is_complete_chunk, + ) + for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer + ], + value_buffer, + drop_axes=indexer.drop_axes, + ) + + async def _getitem( store_path: StorePath, metadata: ArrayMetadata, diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index fd557ac43e..a4ca488f99 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -1,8 +1,11 @@ from __future__ import annotations -from dataclasses import dataclass -from itertools import islice, pairwise -from typing import TYPE_CHECKING, Any, TypeVar +import os +import threading +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from itertools import pairwise +from typing import TYPE_CHECKING, Any, TypeVar, cast from warnings import warn from zarr.abc.codec import ( @@ -13,15 +16,16 @@ BytesBytesCodec, Codec, CodecPipeline, + SupportsSyncCodec, ) -from zarr.core.common import concurrent_map +from zarr.core.common import concurrent_map, product from zarr.core.config import config from zarr.core.indexing import SelectorTuple, is_scalar from zarr.errors import ZarrUserWarning from zarr.registry import register_pipeline if TYPE_CHECKING: - from collections.abc import Iterable, Iterator + from collections.abc import Callable, Iterable, Iterator from typing import Self from zarr.abc.store import ByteGetter, ByteSetter @@ -43,14 +47,6 @@ def _unzip2(iterable: Iterable[tuple[T, U]]) -> tuple[list[T], list[U]]: return (out0, out1) -def batched(iterable: Iterable[T], n: int) -> Iterable[tuple[T, ...]]: - if n < 1: - raise ValueError("n must be at least one") - it = iter(iterable) - while batch := tuple(islice(it, n)): - yield batch - - def resolve_batched(codec: Codec, chunk_specs: Iterable[ArraySpec]) -> Iterable[ArraySpec]: return [codec.resolve_metadata(chunk_spec) for chunk_spec in chunk_specs] @@ -68,70 +64,426 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any: return fill_value +# --------------------------------------------------------------------------- +# Thread pool for parallel codec compute +# --------------------------------------------------------------------------- + +# Minimum chunk size (in bytes) to consider using the thread pool. +# Below this, per-chunk codec work is too small to offset dispatch overhead. +_MIN_CHUNK_NBYTES_FOR_POOL = 100_000 # 100 KB + + +def _get_codec_worker_config() -> tuple[bool, int, int]: + """Read the ``threading.codec_workers`` config. + + Returns (enabled, min_workers, max_workers). + """ + codec_workers = config.get("threading.codec_workers") + enabled: bool = codec_workers.get("enabled", True) + min_workers: int = codec_workers.get("min", 0) + max_workers: int = max(codec_workers.get("max") or os.cpu_count() or 4, min_workers) + return enabled, min_workers, max_workers + + +def _choose_workers(n_chunks: int, chunk_nbytes: int, codecs: Iterable[Codec]) -> int: + """Decide how many thread pool workers to use (0 = don't use pool). + + Respects ``threading.codec_workers`` config: + - ``enabled``: if False, always returns 0. + - ``min``: floor for the number of workers. + - ``max``: ceiling for the number of workers (default: ``os.cpu_count()``). + + Returns 0 if already running on a pool worker thread (prevents deadlock). + """ + # Prevent nested pool usage: if we're already on a pool worker, don't + # submit more work to the same pool (classic nested-pool deadlock). + if getattr(_thread_local, "in_pool_worker", False): + return 0 + + enabled, min_workers, max_workers = _get_codec_worker_config() + if not enabled: + return 0 + + if n_chunks < 2: + return min_workers + + # Only use the pool when at least one codec does real work + # (BytesBytesCodec = compression/checksum, which releases the GIL in C) + # and the chunks are large enough to offset dispatch overhead. + if not any(isinstance(c, BytesBytesCodec) for c in codecs) and min_workers == 0: + return 0 + if chunk_nbytes < _MIN_CHUNK_NBYTES_FOR_POOL and min_workers == 0: + return 0 + + return max(min_workers, min(n_chunks, max_workers)) + + +def _get_pool() -> ThreadPoolExecutor: + """Get the module-level thread pool, creating it lazily.""" + global _pool + if _pool is None: + _, _, max_workers = _get_codec_worker_config() + _pool = ThreadPoolExecutor(max_workers=max_workers) + return _pool + + +_pool: ThreadPoolExecutor | None = None + +# Thread-local flag to prevent nested thread pool deadlock. +# When a pool worker is running codec compute, inner pipelines (e.g. sharding) +# must not submit work to the same pool. +_thread_local = threading.local() + + +def _mark_pool_worker(fn: Callable[..., T]) -> Callable[..., T]: + """Wrap *fn* so that ``_thread_local.in_pool_worker`` is ``True`` while it runs. + + Used around functions dispatched to the thread pool so that nested + ``_choose_workers`` calls (e.g. from sharding) return 0 instead of + deadlocking by submitting more work to the same pool. + """ + + def wrapper(*args: Any, **kwargs: Any) -> T: + _thread_local.in_pool_worker = True + try: + return fn(*args, **kwargs) + finally: + _thread_local.in_pool_worker = False + + return wrapper + + +# Sentinel to distinguish "delete this key" from None. +_DELETED = object() + + @dataclass(frozen=True) -class BatchedCodecPipeline(CodecPipeline): - """Default codec pipeline. +class CodecChain: + """Lightweight codec chain: array-array -> array-bytes -> bytes-bytes. - This batched codec pipeline divides the chunk batches into batches of a configurable - batch size ("mini-batch"). Fetching, decoding, encoding and storing are performed in - lock step for each mini-batch. Multiple mini-batches are processing concurrently. + Pure compute only — no IO methods, no threading, no batching. + The pipeline accesses IO methods (prepare_read, prepare_write) + via ``codec_chain.array_bytes_codec`` directly. """ array_array_codecs: tuple[ArrayArrayCodec, ...] array_bytes_codec: ArrayBytesCodec bytes_bytes_codecs: tuple[BytesBytesCodec, ...] - batch_size: int - def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: - return type(self).from_codecs(c.evolve_from_array_spec(array_spec=array_spec) for c in self) + _all_sync: bool = field(default=False, init=False, repr=False, compare=False) + + def __post_init__(self) -> None: + object.__setattr__( + self, + "_all_sync", + all(isinstance(c, SupportsSyncCodec) for c in self), + ) + + def __iter__(self) -> Iterator[Codec]: + yield from self.array_array_codecs + yield self.array_bytes_codec + yield from self.bytes_bytes_codecs @classmethod - def from_codecs(cls, codecs: Iterable[Codec], *, batch_size: int | None = None) -> Self: - array_array_codecs, array_bytes_codec, bytes_bytes_codecs = codecs_from_list(codecs) - - return cls( - array_array_codecs=array_array_codecs, - array_bytes_codec=array_bytes_codec, - bytes_bytes_codecs=bytes_bytes_codecs, - batch_size=batch_size or config.get("codec_pipeline.batch_size"), + def from_codecs(cls, codecs: Iterable[Codec]) -> CodecChain: + aa, ab, bb = codecs_from_list(list(codecs)) + return cls(array_array_codecs=aa, array_bytes_codec=ab, bytes_bytes_codecs=bb) + + def resolve_metadata_chain( + self, chunk_spec: ArraySpec + ) -> tuple[ + list[tuple[ArrayArrayCodec, ArraySpec]], + tuple[ArrayBytesCodec, ArraySpec], + list[tuple[BytesBytesCodec, ArraySpec]], + ]: + """Resolve metadata through the codec chain for a single chunk_spec.""" + aa_codecs_with_spec: list[tuple[ArrayArrayCodec, ArraySpec]] = [] + spec = chunk_spec + for aa_codec in self.array_array_codecs: + aa_codecs_with_spec.append((aa_codec, spec)) + spec = aa_codec.resolve_metadata(spec) + + ab_codec_with_spec = (self.array_bytes_codec, spec) + spec = self.array_bytes_codec.resolve_metadata(spec) + + bb_codecs_with_spec: list[tuple[BytesBytesCodec, ArraySpec]] = [] + for bb_codec in self.bytes_bytes_codecs: + bb_codecs_with_spec.append((bb_codec, spec)) + spec = bb_codec.resolve_metadata(spec) + + return (aa_codecs_with_spec, ab_codec_with_spec, bb_codecs_with_spec) + + def decode_chunk( + self, + chunk_bytes: Buffer | None, + chunk_spec: ArraySpec, + aa_chain: list[tuple[ArrayArrayCodec, ArraySpec]] | None = None, + ab_pair: tuple[ArrayBytesCodec, ArraySpec] | None = None, + bb_chain: list[tuple[BytesBytesCodec, ArraySpec]] | None = None, + ) -> NDBuffer | None: + """Decode a single chunk through the full codec chain, synchronously. + + Pure compute — no IO. Only callable when all codecs support sync. + + The optional ``aa_chain``, ``ab_pair``, ``bb_chain`` parameters allow + pre-resolved metadata to be reused across many chunks with the same spec. + If not provided, ``resolve_metadata_chain`` is called internally. + """ + if chunk_bytes is None: + return None + + if aa_chain is None or ab_pair is None or bb_chain is None: + aa_chain, ab_pair, bb_chain = self.resolve_metadata_chain(chunk_spec) + + bb_out: Any = chunk_bytes + for bb_codec, spec in reversed(bb_chain): + bb_out = cast("SupportsSyncCodec", bb_codec)._decode_sync(bb_out, spec) + + ab_codec, ab_spec = ab_pair + ab_out: Any = cast("SupportsSyncCodec", ab_codec)._decode_sync(bb_out, ab_spec) + + for aa_codec, spec in reversed(aa_chain): + ab_out = cast("SupportsSyncCodec", aa_codec)._decode_sync(ab_out, spec) + + return ab_out # type: ignore[no-any-return] + + def encode_chunk( + self, + chunk_array: NDBuffer | None, + chunk_spec: ArraySpec, + ) -> Buffer | None: + """Encode a single chunk through the full codec chain, synchronously. + + Pure compute — no IO. Only callable when all codecs support sync. + """ + if chunk_array is None: + return None + + spec = chunk_spec + aa_out: Any = chunk_array + + for aa_codec in self.array_array_codecs: + if aa_out is None: + return None + aa_out = cast("SupportsSyncCodec", aa_codec)._encode_sync(aa_out, spec) + spec = aa_codec.resolve_metadata(spec) + + if aa_out is None: + return None + bb_out: Any = cast("SupportsSyncCodec", self.array_bytes_codec)._encode_sync(aa_out, spec) + spec = self.array_bytes_codec.resolve_metadata(spec) + + for bb_codec in self.bytes_bytes_codecs: + if bb_out is None: + return None + bb_out = cast("SupportsSyncCodec", bb_codec)._encode_sync(bb_out, spec) + spec = bb_codec.resolve_metadata(spec) + + return bb_out # type: ignore[no-any-return] + + def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int: + for codec in self: + byte_length = codec.compute_encoded_size(byte_length, array_spec) + array_spec = codec.resolve_metadata(array_spec) + return byte_length + + def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: + for codec in self: + chunk_spec = codec.resolve_metadata(chunk_spec) + return chunk_spec + + +# --------------------------------------------------------------------------- +# Module-level helpers used by both BatchedCodecPipeline and ArrayBytesCodec +# --------------------------------------------------------------------------- + + +def _merge_chunk_array( + existing_chunk_array: NDBuffer | None, + value: NDBuffer, + out_selection: SelectorTuple, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + is_complete_chunk: bool, + drop_axes: tuple[int, ...], +) -> NDBuffer: + """Merge new data into an existing (or freshly-created) chunk array.""" + if ( + is_complete_chunk + and value.shape == chunk_spec.shape + # Guard that this is not a partial chunk at the end with is_complete_chunk=True + and value[out_selection].shape == chunk_spec.shape + ): + return value + if existing_chunk_array is None: + chunk_array = chunk_spec.prototype.nd_buffer.create( + shape=chunk_spec.shape, + dtype=chunk_spec.dtype.to_native_dtype(), + order=chunk_spec.order, + fill_value=fill_value_or_default(chunk_spec), + ) + else: + chunk_array = existing_chunk_array.copy() # make a writable copy + if chunk_selection == () or is_scalar( + value.as_ndarray_like(), chunk_spec.dtype.to_native_dtype() + ): + chunk_value = value + else: + chunk_value = value[out_selection] + # handle missing singleton dimensions + if drop_axes != (): + item = tuple( + None # equivalent to np.newaxis + if idx in drop_axes + else slice(None) + for idx in range(chunk_spec.ndim) + ) + chunk_value = chunk_value[item] + chunk_array[chunk_selection] = chunk_value + return chunk_array + + +def _write_chunk_compute_default( + existing_bytes: Buffer | None, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + out_selection: SelectorTuple, + is_complete_chunk: bool, + value: NDBuffer, + drop_axes: tuple[int, ...], + codec_chain: CodecChain, + aa_chain: list[tuple[ArrayArrayCodec, ArraySpec]] | None = None, + ab_pair: tuple[ArrayBytesCodec, ArraySpec] | None = None, + bb_chain: list[tuple[BytesBytesCodec, ArraySpec]] | None = None, +) -> Buffer | None | object: + """Per-chunk compute for write: decode existing -> merge -> encode. + + Returns the encoded chunk bytes, or ``_DELETED`` if the chunk should be + removed from the store. + """ + existing_array: NDBuffer | None = None + if existing_bytes is not None: + if aa_chain is None or ab_pair is None or bb_chain is None: + aa_chain, ab_pair, bb_chain = codec_chain.resolve_metadata_chain(chunk_spec) + existing_array = codec_chain.decode_chunk( + existing_bytes, chunk_spec, aa_chain, ab_pair, bb_chain ) + chunk_array: NDBuffer | None = _merge_chunk_array( + existing_array, + value, + out_selection, + chunk_spec, + chunk_selection, + is_complete_chunk, + drop_axes, + ) + + if ( + chunk_array is not None + and not chunk_spec.config.write_empty_chunks + and chunk_array.all_equal(fill_value_or_default(chunk_spec)) + ): + chunk_array = None + + if chunk_array is None: + return _DELETED + chunk_bytes = codec_chain.encode_chunk(chunk_array, chunk_spec) + if chunk_bytes is None: + return _DELETED + return chunk_bytes + + +@dataclass(frozen=True) +class BatchedCodecPipeline(CodecPipeline): + """Codec pipeline that automatically selects the optimal execution strategy. + + When all codecs support synchronous operations and the store supports + sync IO, this pipeline runs the entire read/write path on the calling + thread with zero async overhead, using a thread pool for parallel codec + compute on multi-chunk operations. + + When the store requires async IO (e.g. cloud stores), this pipeline uses + the async path with concurrent IO overlap via ``concurrent_map``. + + This automatic dispatch eliminates the need for users to choose between + pipeline implementations — the right strategy is selected based on codec + and store capabilities. + """ + + codec_chain: CodecChain + batch_size: int | None = None + + def __init__( + self, + *, + codec_chain: CodecChain | None = None, + array_array_codecs: tuple[ArrayArrayCodec, ...] | None = None, + array_bytes_codec: ArrayBytesCodec | None = None, + bytes_bytes_codecs: tuple[BytesBytesCodec, ...] | None = None, + batch_size: int | None = None, + ) -> None: + if batch_size is not None: + warn( + "The 'batch_size' parameter is deprecated and has no effect. " + "Batch size is now determined automatically.", + FutureWarning, + stacklevel=2, + ) + object.__setattr__(self, "batch_size", batch_size) + + if codec_chain is not None: + object.__setattr__(self, "codec_chain", codec_chain) + elif array_bytes_codec is not None: + object.__setattr__( + self, + "codec_chain", + CodecChain( + array_array_codecs=array_array_codecs or (), + array_bytes_codec=array_bytes_codec, + bytes_bytes_codecs=bytes_bytes_codecs or (), + ), + ) + else: + raise ValueError("Either codec_chain or array_bytes_codec must be provided.") + @property - def supports_partial_decode(self) -> bool: - """Determines whether the codec pipeline supports partial decoding. + def array_array_codecs(self) -> tuple[ArrayArrayCodec, ...]: + return self.codec_chain.array_array_codecs + + @property + def array_bytes_codec(self) -> ArrayBytesCodec: + return self.codec_chain.array_bytes_codec + + @property + def bytes_bytes_codecs(self) -> tuple[BytesBytesCodec, ...]: + return self.codec_chain.bytes_bytes_codecs + + @property + def _all_sync(self) -> bool: + return self.codec_chain._all_sync - Currently, only codec pipelines with a single ArrayBytesCodec that supports - partial decoding can support partial decoding. This limitation is due to the fact - that ArrayArrayCodecs can change the slice selection leading to non-contiguous - slices and BytesBytesCodecs can change the chunk bytes in a way that slice - selections cannot be attributed to byte ranges anymore which renders partial - decoding infeasible. + def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: + return type(self).from_codecs(c.evolve_from_array_spec(array_spec=array_spec) for c in self) + + @classmethod + def from_codecs(cls, codecs: Iterable[Codec]) -> Self: + return cls(codec_chain=CodecChain.from_codecs(codecs)) - This limitation may softened in the future.""" + @property + def supports_partial_decode(self) -> bool: return (len(self.array_array_codecs) + len(self.bytes_bytes_codecs)) == 0 and isinstance( self.array_bytes_codec, ArrayBytesCodecPartialDecodeMixin ) @property def supports_partial_encode(self) -> bool: - """Determines whether the codec pipeline supports partial encoding. - - Currently, only codec pipelines with a single ArrayBytesCodec that supports - partial encoding can support partial encoding. This limitation is due to the fact - that ArrayArrayCodecs can change the slice selection leading to non-contiguous - slices and BytesBytesCodecs can change the chunk bytes in a way that slice - selections cannot be attributed to byte ranges anymore which renders partial - encoding infeasible. - - This limitation may softened in the future.""" return (len(self.array_array_codecs) + len(self.bytes_bytes_codecs)) == 0 and isinstance( self.array_bytes_codec, ArrayBytesCodecPartialEncodeMixin ) def __iter__(self) -> Iterator[Codec]: - yield from self.array_array_codecs - yield self.array_bytes_codec - yield from self.bytes_bytes_codecs + yield from self.codec_chain def validate( self, @@ -144,10 +496,11 @@ def validate( codec.validate(shape=shape, dtype=dtype, chunk_grid=chunk_grid) def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int: - for codec in self: - byte_length = codec.compute_encoded_size(byte_length, array_spec) - array_spec = codec.resolve_metadata(array_spec) - return byte_length + return self.codec_chain.compute_encoded_size(byte_length, array_spec) + + # ------------------------------------------------------------------- + # Batched async decode/encode (layer-by-layer across all chunks) + # ------------------------------------------------------------------- def _codecs_with_resolved_metadata_batched( self, chunk_specs: Iterable[ArraySpec] @@ -203,14 +556,6 @@ async def decode_batch( return chunk_array_batch - async def decode_partial_batch( - self, - batch_info: Iterable[tuple[ByteGetter, SelectorTuple, ArraySpec]], - ) -> Iterable[NDBuffer | None]: - assert self.supports_partial_decode - assert isinstance(self.array_bytes_codec, ArrayBytesCodecPartialDecodeMixin) - return await self.array_bytes_codec.decode_partial(batch_info) - async def encode_batch( self, chunk_arrays_and_specs: Iterable[tuple[NDBuffer | None, ArraySpec]], @@ -238,13 +583,51 @@ async def encode_batch( return chunk_bytes_batch - async def encode_partial_batch( + # ------------------------------------------------------------------- + # Top-level decode / encode + # ------------------------------------------------------------------- + + async def decode( self, - batch_info: Iterable[tuple[ByteSetter, NDBuffer, SelectorTuple, ArraySpec]], - ) -> None: - assert self.supports_partial_encode - assert isinstance(self.array_bytes_codec, ArrayBytesCodecPartialEncodeMixin) - await self.array_bytes_codec.encode_partial(batch_info) + chunk_bytes_and_specs: Iterable[tuple[Buffer | None, ArraySpec]], + ) -> Iterable[NDBuffer | None]: + items = list(chunk_bytes_and_specs) + if not items: + return [] + + if self._all_sync: + # All codecs support sync -- run the full chain inline (no threading). + _, first_spec = items[0] + aa_chain, ab_pair, bb_chain = self.codec_chain.resolve_metadata_chain(first_spec) + return [ + self.codec_chain.decode_chunk(chunk_bytes, chunk_spec, aa_chain, ab_pair, bb_chain) + for chunk_bytes, chunk_spec in items + ] + + # Async fallback: layer-by-layer across all chunks. + return list(await self.decode_batch(items)) + + async def encode( + self, + chunk_arrays_and_specs: Iterable[tuple[NDBuffer | None, ArraySpec]], + ) -> Iterable[Buffer | None]: + items = list(chunk_arrays_and_specs) + if not items: + return [] + + if self._all_sync: + # All codecs support sync -- run the full chain inline (no threading). + return [ + self.codec_chain.encode_chunk(chunk_array, chunk_spec) + for chunk_array, chunk_spec in items + ] + + # Async fallback: layer-by-layer across all chunks. + return list(await self.encode_batch(items)) + + # ------------------------------------------------------------------- + # Async read / write (IO overlap via concurrent_map) + # ------------------------------------------------------------------- async def read_batch( self, @@ -252,47 +635,93 @@ async def read_batch( out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: - if self.supports_partial_decode: - chunk_array_batch = await self.decode_partial_batch( - [ - (byte_getter, chunk_selection, chunk_spec) - for byte_getter, chunk_spec, chunk_selection, *_ in batch_info - ] - ) - for chunk_array, (_, chunk_spec, _, out_selection, _) in zip( - chunk_array_batch, batch_info, strict=False - ): - if chunk_array is not None: - out[out_selection] = chunk_array + batch_info = list(batch_info) + + if self._all_sync: + _, first_spec, *_ = batch_info[0] + aa_chain, ab_pair, bb_chain = self.codec_chain.resolve_metadata_chain(first_spec) + ab_codec = self.array_bytes_codec + codec_chain = self.codec_chain + + async def _read_chunk( + byte_getter: ByteGetter, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + out_selection: SelectorTuple, + ) -> None: + result = await ab_codec.prepare_read( + byte_getter, + chunk_spec, + chunk_selection, + codec_chain, + aa_chain, + ab_pair, + bb_chain, + ) + if result is not None: + if drop_axes != (): + result = result.squeeze(axis=drop_axes) + out[out_selection] = result else: out[out_selection] = fill_value_or_default(chunk_spec) + + await concurrent_map( + [ + (byte_getter, chunk_spec, chunk_selection, out_selection) + for byte_getter, chunk_spec, chunk_selection, out_selection, _ in batch_info + ], + _read_chunk, + config.get("async.concurrency"), + ) else: + # Async fallback: fetch all → decode all (async codec API) → scatter. + # Used for codecs that don't implement _decode_sync (e.g. numcodecs). + + async def _fetch(byte_getter: ByteGetter, prototype: BufferPrototype) -> Buffer | None: + return await byte_getter.get(prototype=prototype) + chunk_bytes_batch = await concurrent_map( - [(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info], - lambda byte_getter, prototype: byte_getter.get(prototype), + [(byte_getter, chunk_spec.prototype) for byte_getter, chunk_spec, *_ in batch_info], + _fetch, config.get("async.concurrency"), ) chunk_array_batch = await self.decode_batch( - [ - (chunk_bytes, chunk_spec) - for chunk_bytes, (_, chunk_spec, *_) in zip( - chunk_bytes_batch, batch_info, strict=False - ) - ], + zip( + chunk_bytes_batch, + [chunk_spec for _, chunk_spec, *_ in batch_info], + strict=False, + ) ) - for chunk_array, (_, chunk_spec, chunk_selection, out_selection, _) in zip( - chunk_array_batch, batch_info, strict=False - ): - if chunk_array is not None: - tmp = chunk_array[chunk_selection] - if drop_axes != (): - tmp = tmp.squeeze(axis=drop_axes) - out[out_selection] = tmp - else: - out[out_selection] = fill_value_or_default(chunk_spec) + self._scatter(chunk_array_batch, batch_info, out, drop_axes) - def _merge_chunk_array( + @staticmethod + def _scatter( + chunk_array_batch: Iterable[NDBuffer | None], + batch_info: list[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + out: NDBuffer, + drop_axes: tuple[int, ...], + ) -> None: + for chunk_array, (_, chunk_spec, chunk_selection, out_selection, _) in zip( + chunk_array_batch, batch_info, strict=False + ): + if chunk_array is not None: + tmp = chunk_array[chunk_selection] + if drop_axes != (): + tmp = tmp.squeeze(axis=drop_axes) + out[out_selection] = tmp + else: + out[out_selection] = fill_value_or_default(chunk_spec) + + async def read( self, + batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + out: NDBuffer, + drop_axes: tuple[int, ...] = (), + ) -> None: + await self.read_batch(batch_info, out, drop_axes) + + @staticmethod + def _merge_chunk_array( existing_chunk_array: NDBuffer | None, value: NDBuffer, out_selection: SelectorTuple, @@ -301,39 +730,15 @@ def _merge_chunk_array( is_complete_chunk: bool, drop_axes: tuple[int, ...], ) -> NDBuffer: - if ( - is_complete_chunk - and value.shape == chunk_spec.shape - # Guard that this is not a partial chunk at the end with is_complete_chunk=True - and value[out_selection].shape == chunk_spec.shape - ): - return value - if existing_chunk_array is None: - chunk_array = chunk_spec.prototype.nd_buffer.create( - shape=chunk_spec.shape, - dtype=chunk_spec.dtype.to_native_dtype(), - order=chunk_spec.order, - fill_value=fill_value_or_default(chunk_spec), - ) - else: - chunk_array = existing_chunk_array.copy() # make a writable copy - if chunk_selection == () or is_scalar( - value.as_ndarray_like(), chunk_spec.dtype.to_native_dtype() - ): - chunk_value = value - else: - chunk_value = value[out_selection] - # handle missing singleton dimensions - if drop_axes != (): - item = tuple( - None # equivalent to np.newaxis - if idx in drop_axes - else slice(None) - for idx in range(chunk_spec.ndim) - ) - chunk_value = chunk_value[item] - chunk_array[chunk_selection] = chunk_value - return chunk_array + return _merge_chunk_array( + existing_chunk_array, + value, + out_selection, + chunk_spec, + chunk_selection, + is_complete_chunk, + drop_axes, + ) async def write_batch( self, @@ -341,95 +746,120 @@ async def write_batch( value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: - if self.supports_partial_encode: - # Pass scalar values as is - if len(value.shape) == 0: - await self.encode_partial_batch( - [ - (byte_setter, value, chunk_selection, chunk_spec) - for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info - ], - ) - else: - await self.encode_partial_batch( - [ - (byte_setter, value[out_selection], chunk_selection, chunk_spec) - for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info - ], + batch_info = list(batch_info) + + if self._all_sync: + ab_codec = self.array_bytes_codec + codec_chain = self.codec_chain + + async def _write_chunk( + byte_setter: ByteSetter, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + out_selection: SelectorTuple, + _is_complete_chunk: bool, + ) -> None: + prepared = await ab_codec.prepare_write( + byte_setter, + chunk_spec, + chunk_selection, + out_selection, + codec_chain, ) - else: - # Read existing bytes if not total slice - async def _read_key( - byte_setter: ByteSetter | None, prototype: BufferPrototype - ) -> Buffer | None: - if byte_setter is None: - return None - return await byte_setter.get(prototype=prototype) + if prepared.is_complete_shard: + if prepared.value_selection is not None and not is_scalar( + value.as_ndarray_like(), + prepared.inner_chunk_spec.dtype.to_native_dtype(), + ): + prepared.shard_data = value[prepared.value_selection] + else: + prepared.shard_data = value + await ab_codec.finalize_write(prepared, chunk_spec, byte_setter) + return + + inner_chain = prepared.inner_codec_chain + inner_spec = prepared.inner_chunk_spec + inner_aa, inner_ab, inner_bb = inner_chain.resolve_metadata_chain(inner_spec) + + if prepared.value_selection is not None and not is_scalar( + value.as_ndarray_like(), inner_spec.dtype.to_native_dtype() + ): + write_value = value[prepared.value_selection] + else: + write_value = value + + for coords, chunk_sel, out_sel, _is_complete in prepared.indexer: + existing_bytes_inner = prepared.chunk_dict.get(coords) + if existing_bytes_inner is not None: + existing_array = inner_chain.decode_chunk( + existing_bytes_inner, + inner_spec, + inner_aa, + inner_ab, + inner_bb, + ) + else: + existing_array = None + merged = _merge_chunk_array( + existing_array, + write_value, + out_sel, + inner_spec, + chunk_sel, + _is_complete, + drop_axes, + ) + if not inner_spec.config.write_empty_chunks and merged.all_equal( + fill_value_or_default(inner_spec) + ): + prepared.chunk_dict[coords] = None + else: + prepared.chunk_dict[coords] = inner_chain.encode_chunk(merged, inner_spec) - chunk_bytes_batch: Iterable[Buffer | None] - chunk_bytes_batch = await concurrent_map( + await ab_codec.finalize_write(prepared, chunk_spec, byte_setter) + + await concurrent_map( [ - ( - None if is_complete_chunk else byte_setter, - chunk_spec.prototype, - ) - for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info + (byte_setter, chunk_spec, chunk_selection, out_selection, is_complete_chunk) + for byte_setter, chunk_spec, chunk_selection, out_selection, is_complete_chunk in batch_info ], - _read_key, + _write_chunk, config.get("async.concurrency"), ) - chunk_array_decoded = await self.decode_batch( + else: + # Async fallback: phased approach for codecs without sync support. + # Phase 1: Fetch existing chunks for partial writes. + + async def _fetch_existing( + byte_setter: ByteSetter, chunk_spec: ArraySpec, is_complete_chunk: bool + ) -> Buffer | None: + if is_complete_chunk: + return None + return await byte_setter.get(prototype=chunk_spec.prototype) + + existing_bytes_list: list[Buffer | None] = await concurrent_map( [ - (chunk_bytes, chunk_spec) - for chunk_bytes, (_, chunk_spec, *_) in zip( - chunk_bytes_batch, batch_info, strict=False - ) + (byte_setter, chunk_spec, is_complete_chunk) + for byte_setter, chunk_spec, _, _, is_complete_chunk in batch_info ], + _fetch_existing, + config.get("async.concurrency"), ) - chunk_array_merged = [ - self._merge_chunk_array( - chunk_array, - value, - out_selection, - chunk_spec, - chunk_selection, - is_complete_chunk, - drop_axes, + # Phase 2: Decode → merge → encode (async codec API). + decode_items: list[tuple[Buffer | None, ArraySpec]] = [ + (existing_bytes if not is_complete_chunk else None, chunk_spec) + for existing_bytes, (_, chunk_spec, _, _, is_complete_chunk) in zip( + existing_bytes_list, batch_info, strict=False ) - for chunk_array, ( - _, - chunk_spec, - chunk_selection, - out_selection, - is_complete_chunk, - ) in zip(chunk_array_decoded, batch_info, strict=False) ] - chunk_array_batch: list[NDBuffer | None] = [] - for chunk_array, (_, chunk_spec, *_) in zip( - chunk_array_merged, batch_info, strict=False - ): - if chunk_array is None: - chunk_array_batch.append(None) # type: ignore[unreachable] - else: - if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal( - fill_value_or_default(chunk_spec) - ): - chunk_array_batch.append(None) - else: - chunk_array_batch.append(chunk_array) - - chunk_bytes_batch = await self.encode_batch( - [ - (chunk_array, chunk_spec) - for chunk_array, (_, chunk_spec, *_) in zip( - chunk_array_batch, batch_info, strict=False - ) - ], + encoded_list = await self._write_batch_compute( + decode_items, batch_info, value, drop_axes ) - async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> None: + # Phase 3: Write encoded chunks to store. + async def _write_out(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> None: if chunk_bytes is None: await byte_setter.delete() else: @@ -438,61 +868,290 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non await concurrent_map( [ (byte_setter, chunk_bytes) - for chunk_bytes, (byte_setter, *_) in zip( - chunk_bytes_batch, batch_info, strict=False + for (byte_setter, *_), chunk_bytes in zip( + batch_info, encoded_list, strict=False ) ], - _write_key, + _write_out, config.get("async.concurrency"), ) - async def decode( + async def _write_batch_compute( self, - chunk_bytes_and_specs: Iterable[tuple[Buffer | None, ArraySpec]], - ) -> Iterable[NDBuffer | None]: - output: list[NDBuffer | None] = [] - for batch_info in batched(chunk_bytes_and_specs, self.batch_size): - output.extend(await self.decode_batch(batch_info)) - return output + decode_items: list[tuple[Buffer | None, ArraySpec]], + batch_info: list[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + value: NDBuffer, + drop_axes: tuple[int, ...], + ) -> list[Buffer | None]: + chunk_array_decoded: Iterable[NDBuffer | None] = await self.decode(decode_items) - async def encode( + chunk_array_batch = self._merge_and_filter( + chunk_array_decoded, batch_info, value, drop_axes + ) + + encoded_batch: Iterable[Buffer | None] = await self.encode( + [ + (chunk_array, chunk_spec) + for chunk_array, (_, chunk_spec, *_) in zip( + chunk_array_batch, batch_info, strict=False + ) + ] + ) + return list(encoded_batch) + + def _merge_and_filter( self, - chunk_arrays_and_specs: Iterable[tuple[NDBuffer | None, ArraySpec]], - ) -> Iterable[Buffer | None]: - output: list[Buffer | None] = [] - for single_batch_info in batched(chunk_arrays_and_specs, self.batch_size): - output.extend(await self.encode_batch(single_batch_info)) - return output + chunk_array_decoded: Iterable[NDBuffer | None], + batch_info: list[tuple[Any, ArraySpec, SelectorTuple, SelectorTuple, bool]], + value: NDBuffer, + drop_axes: tuple[int, ...], + ) -> list[NDBuffer | None]: + chunk_array_merged = [ + self._merge_chunk_array( + chunk_array, + value, + out_selection, + chunk_spec, + chunk_selection, + is_complete_chunk, + drop_axes, + ) + for chunk_array, ( + _, + chunk_spec, + chunk_selection, + out_selection, + is_complete_chunk, + ) in zip(chunk_array_decoded, batch_info, strict=False) + ] + chunk_array_batch: list[NDBuffer | None] = [] + for chunk_array, (_, chunk_spec, *_) in zip(chunk_array_merged, batch_info, strict=False): + if chunk_array is None: + chunk_array_batch.append(None) # type: ignore[unreachable] + else: + if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal( + fill_value_or_default(chunk_spec) + ): + chunk_array_batch.append(None) + else: + chunk_array_batch.append(chunk_array) + return chunk_array_batch - async def read( + async def write( self, - batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + value: NDBuffer, + drop_axes: tuple[int, ...] = (), + ) -> None: + await self.write_batch(batch_info, value, drop_axes) + + # ------------------------------------------------------------------- + # Fully synchronous read / write (no event loop) + # ------------------------------------------------------------------- + + @property + def supports_sync_io(self) -> bool: + return self._all_sync + + def read_sync( + self, + batch_info: Iterable[tuple[Any, ArraySpec, SelectorTuple, SelectorTuple, bool]], out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: - await concurrent_map( - [ - (single_batch_info, out, drop_axes) - for single_batch_info in batched(batch_info, self.batch_size) - ], - self.read_batch, - config.get("async.concurrency"), + batch_info_list = list(batch_info) + if not batch_info_list: + return + + _, first_spec, *_ = batch_info_list[0] + aa_chain, ab_pair, bb_chain = self.codec_chain.resolve_metadata_chain(first_spec) + + chunk_nbytes = product(first_spec.shape) * getattr(first_spec.dtype, "item_size", 1) + n_workers = _choose_workers(len(batch_info_list), chunk_nbytes, self) + if n_workers > 0: + # Threaded: fetch all, decode in parallel, scatter. + chunk_bytes_list: list[Buffer | None] = [ + byte_getter.get_sync(prototype=chunk_spec.prototype) + for byte_getter, chunk_spec, *_ in batch_info_list + ] + pool = _get_pool() + chunk_arrays: list[NDBuffer | None] = list( + pool.map( + _mark_pool_worker(self.codec_chain.decode_chunk), + chunk_bytes_list, + [chunk_spec for _, chunk_spec, *_ in batch_info_list], + [aa_chain] * len(batch_info_list), + [ab_pair] * len(batch_info_list), + [bb_chain] * len(batch_info_list), + ) + ) + self._scatter(chunk_arrays, batch_info_list, out, drop_axes) + else: + # Non-threaded: prepare_read_sync handles IO + decode. + # ShardingCodec overrides for optimized partial IO (byte-range reads). + ab_codec = self.array_bytes_codec + for byte_getter, chunk_spec, chunk_selection, out_selection, _ in batch_info_list: + result = ab_codec.prepare_read_sync( + byte_getter, + chunk_spec, + chunk_selection, + self.codec_chain, + aa_chain, + ab_pair, + bb_chain, + ) + if result is not None: + if drop_axes != (): + result = result.squeeze(axis=drop_axes) + out[out_selection] = result + else: + out[out_selection] = fill_value_or_default(chunk_spec) + + def _write_chunk_compute( + self, + existing_bytes: Buffer | None, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + out_selection: SelectorTuple, + is_complete_chunk: bool, + value: NDBuffer, + drop_axes: tuple[int, ...], + aa_chain: list[tuple[ArrayArrayCodec, ArraySpec]] | None = None, + ab_pair: tuple[ArrayBytesCodec, ArraySpec] | None = None, + bb_chain: list[tuple[BytesBytesCodec, ArraySpec]] | None = None, + ) -> Buffer | None | object: + """Per-chunk compute for write: decode existing -> merge -> encode.""" + return _write_chunk_compute_default( + existing_bytes, + chunk_spec, + chunk_selection, + out_selection, + is_complete_chunk, + value, + drop_axes, + self.codec_chain, + aa_chain, + ab_pair, + bb_chain, ) - async def write( + def write_sync( self, - batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[tuple[Any, ArraySpec, SelectorTuple, SelectorTuple, bool]], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: - await concurrent_map( - [ - (single_batch_info, value, drop_axes) - for single_batch_info in batched(batch_info, self.batch_size) - ], - self.write_batch, - config.get("async.concurrency"), - ) + batch_info_list = list(batch_info) + if not batch_info_list: + return + + _, first_spec, *_ = batch_info_list[0] + aa_chain, ab_pair, bb_chain = self.codec_chain.resolve_metadata_chain(first_spec) + chunk_nbytes = product(first_spec.shape) * getattr(first_spec.dtype, "item_size", 1) + n_workers = _choose_workers(len(batch_info_list), chunk_nbytes, self) + if n_workers > 0: + # Threaded: fetch all, compute in parallel, write all. + existing_bytes_list: list[Buffer | None] = [ + byte_setter.get_sync(prototype=chunk_spec.prototype) + if not is_complete_chunk + else None + for byte_setter, chunk_spec, _, _, is_complete_chunk in batch_info_list + ] + pool = _get_pool() + n = len(batch_info_list) + encoded_list: list[Buffer | None | object] = list( + pool.map( + _mark_pool_worker(self._write_chunk_compute), + existing_bytes_list, + [chunk_spec for _, chunk_spec, *_ in batch_info_list], + [chunk_selection for _, _, chunk_selection, _, _ in batch_info_list], + [out_selection for _, _, _, out_selection, _ in batch_info_list], + [is_complete for _, _, _, _, is_complete in batch_info_list], + [value] * n, + [drop_axes] * n, + [aa_chain] * n, + [ab_pair] * n, + [bb_chain] * n, + ) + ) + for encoded, (byte_setter, *_) in zip(encoded_list, batch_info_list, strict=False): + if encoded is _DELETED: + byte_setter.delete_sync() + else: + byte_setter.set_sync(encoded) + else: + # Non-threaded: prepare_write_sync handles IO + deserialize. + # Pipeline does decode/merge/encode loop, then serialize + write. + ab_codec = self.array_bytes_codec + for ( + byte_setter, + chunk_spec, + chunk_selection, + out_selection, + _, + ) in batch_info_list: + prepared = ab_codec.prepare_write_sync( + byte_setter, + chunk_spec, + chunk_selection, + out_selection, + self.codec_chain, + ) + + if prepared.is_complete_shard: + # Complete shard: pass the shard value to finalize_write + # which encodes and writes in one shot, bypassing the + # per-inner-chunk loop. + if prepared.value_selection is not None and not is_scalar( + value.as_ndarray_like(), + prepared.inner_chunk_spec.dtype.to_native_dtype(), + ): + prepared.shard_data = value[prepared.value_selection] + else: + prepared.shard_data = value + ab_codec.finalize_write_sync(prepared, chunk_spec, byte_setter) + continue + + inner_chain = prepared.inner_codec_chain + inner_spec = prepared.inner_chunk_spec + inner_aa, inner_ab, inner_bb = inner_chain.resolve_metadata_chain(inner_spec) + + if prepared.value_selection is not None and not is_scalar( + value.as_ndarray_like(), inner_spec.dtype.to_native_dtype() + ): + write_value = value[prepared.value_selection] + else: + write_value = value + + for coords, chunk_sel, out_sel, _is_complete in prepared.indexer: + existing_bytes_inner = prepared.chunk_dict.get(coords) + if existing_bytes_inner is not None: + existing_array = inner_chain.decode_chunk( + existing_bytes_inner, + inner_spec, + inner_aa, + inner_ab, + inner_bb, + ) + else: + existing_array = None + merged = _merge_chunk_array( + existing_array, + write_value, + out_sel, + inner_spec, + chunk_sel, + _is_complete, + drop_axes, + ) + if not inner_spec.config.write_empty_chunks and merged.all_equal( + fill_value_or_default(inner_spec) + ): + prepared.chunk_dict[coords] = None + else: + prepared.chunk_dict[coords] = inner_chain.encode_chunk(merged, inner_spec) + + ab_codec.finalize_write_sync(prepared, chunk_spec, byte_setter) def codecs_from_list( @@ -500,11 +1159,12 @@ def codecs_from_list( ) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]: from zarr.codecs.sharding import ShardingCodec + codecs = list(codecs) array_array: tuple[ArrayArrayCodec, ...] = () array_bytes_maybe: ArrayBytesCodec | None = None bytes_bytes: tuple[BytesBytesCodec, ...] = () - if any(isinstance(codec, ShardingCodec) for codec in codecs) and len(tuple(codecs)) > 1: + if any(isinstance(codec, ShardingCodec) for codec in codecs) and len(codecs) > 1: warn( "Combining a `sharding_indexed` codec disables partial reads and " "writes, which may lead to inefficient performance.", diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index f8f8ea4f5f..f21637c495 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -99,11 +99,13 @@ def enable_gpu(self) -> ConfigSet: "target_shard_size_bytes": None, }, "async": {"concurrency": 10, "timeout": None}, - "threading": {"max_workers": None}, + "threading": { + "max_workers": None, + "codec_workers": {"enabled": True, "min": 0, "max": None}, + }, "json_indent": 2, "codec_pipeline": { "path": "zarr.core.codec_pipeline.BatchedCodecPipeline", - "batch_size": 1, }, "codecs": { "blosc": "zarr.codecs.blosc.BloscCodec", diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index 4bea04f024..384b01c822 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -228,6 +228,47 @@ async def is_empty(self) -> bool: """ return await self.store.is_empty(self.path) + # ------------------------------------------------------------------- + # Synchronous IO delegation + # + # StorePath is what gets passed to the codec pipeline as a ByteGetter / + # ByteSetter. The async path uses get() / set() / delete(); the sync + # bypass uses these sync variants instead. They simply prepend + # self.path to the key and delegate to the underlying Store's sync + # methods. + # + # Note: These methods are only available when the underlying Store + # also has get_sync / set_sync / delete_sync (e.g. MemoryStore, + # LocalStore). Callers check ``hasattr(store, 'get_sync')`` + # before invoking these. + # ------------------------------------------------------------------- + + def get_sync( + self, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + """Synchronous read — delegates to ``self.store.get_sync(self.path, ...)``.""" + if prototype is None: + prototype = default_buffer_prototype() + return self.store.get_sync(self.path, prototype=prototype, byte_range=byte_range) # type: ignore[attr-defined, no-any-return] + + def set_sync(self, value: Buffer) -> None: + """Synchronous write — delegates to ``self.store.set_sync(self.path, value)``.""" + self.store.set_sync(self.path, value) # type: ignore[attr-defined] + + async def set_range(self, value: Buffer, start: int) -> None: + """Write ``value`` at byte offset ``start`` within the existing key.""" + await self.store.set_range(self.path, value, start) + + def set_range_sync(self, value: Buffer, start: int) -> None: + """Synchronous byte-range write.""" + self.store.set_range_sync(self.path, value, start) # type: ignore[attr-defined] + + def delete_sync(self) -> None: + """Synchronous delete — delegates to ``self.store.delete_sync(self.path)``.""" + self.store.delete_sync(self.path) # type: ignore[attr-defined] + def __truediv__(self, other: str) -> StorePath: """Combine this store path with another path""" return self.__class__(self.store, _dereference_path(self.path, other)) diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 80233a112d..6002e616aa 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -85,6 +85,13 @@ def _put(path: Path, value: Buffer, exclusive: bool = False) -> int: return f.write(view) +def _put_range(path: Path, value: Buffer, start: int) -> None: + view = value.as_buffer_like() + with path.open("r+b") as f: + f.seek(start) + f.write(view) + + class LocalStore(Store): """ Store for the local file system. @@ -187,6 +194,74 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root + # ------------------------------------------------------------------- + # Synchronous store methods + # + # LocalStore's async get/set wrap the synchronous helpers _get() and + # _put() (defined at module level) in asyncio.to_thread(). These sync + # methods call _get/_put directly, removing the thread-hop overhead. + # + # The open-guard logic is inlined from _open(): create root dir if + # writable, check existence, set _is_open. We can't call the async + # _open() from a sync context, so we replicate its logic here. + # ------------------------------------------------------------------- + + def get_sync( + self, + key: str, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + if prototype is None: + prototype = default_buffer_prototype() + # Inline open guard: mirrors async _open() but without await. + if not self._is_open: + if not self.read_only: + self.root.mkdir(parents=True, exist_ok=True) + if not self.root.exists(): + raise FileNotFoundError(f"{self.root} does not exist") + self._is_open = True + assert isinstance(key, str) + path = self.root / key + try: + # Call _get() directly — the async version wraps this same + # function in asyncio.to_thread(). + return _get(path, prototype, byte_range) + except (FileNotFoundError, IsADirectoryError, NotADirectoryError): + return None + + def set_sync(self, key: str, value: Buffer) -> None: + if not self._is_open: + if not self.read_only: + self.root.mkdir(parents=True, exist_ok=True) + if not self.root.exists(): + raise FileNotFoundError(f"{self.root} does not exist") + self._is_open = True + self._check_writable() + assert isinstance(key, str) + if not isinstance(value, Buffer): + raise TypeError( + f"LocalStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." + ) + path = self.root / key + # Call _put() directly — the async version wraps this in + # asyncio.to_thread(). + _put(path, value) + + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: + self._check_writable() + path = self.root / key + _put_range(path, value, start) + + def delete_sync(self, key: str) -> None: + self._check_writable() + path = self.root / key + # Same logic as async delete(), but without await. + if path.is_dir(): + shutil.rmtree(path) + else: + path.unlink(missing_ok=True) + async def get( self, key: str, @@ -223,6 +298,14 @@ async def set(self, key: str, value: Buffer) -> None: # docstring inherited return await self._set(key, value) + async def set_range(self, key: str, value: Buffer, start: int) -> None: + # docstring inherited + if not self._is_open: + await self._open() + self._check_writable() + path = self.root / key + await asyncio.to_thread(_put_range, path, value, start) + async def set_if_not_exists(self, key: str, value: Buffer) -> None: # docstring inherited try: diff --git a/src/zarr/storage/_logging.py b/src/zarr/storage/_logging.py index 98dca6b23d..0a135ceb2a 100644 --- a/src/zarr/storage/_logging.py +++ b/src/zarr/storage/_logging.py @@ -195,6 +195,11 @@ async def set(self, key: str, value: Buffer) -> None: with self.log(key): return await self._store.set(key=key, value=value) + async def set_range(self, key: str, value: Buffer, start: int) -> None: + # docstring inherited + with self.log(key): + return await self._store.set_range(key=key, value=value, start=start) + async def set_if_not_exists(self, key: str, value: Buffer) -> None: # docstring inherited with self.log(key): diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index e6f9b7a512..5b9264fc5e 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -77,6 +77,60 @@ def __eq__(self, other: object) -> bool: and self.read_only == other.read_only ) + # ------------------------------------------------------------------- + # Synchronous store methods + # + # MemoryStore is a thin wrapper around a Python dict. The async get/set + # methods are already synchronous in substance — they just happen to be + # ``async def``. These sync variants let the codec pipeline's read_sync / + # write_sync access the dict directly without going through the event + # loop, eliminating the dominant source of overhead for in-memory arrays. + # + # The logic mirrors the async counterparts exactly, except: + # - We set _is_open = True inline instead of ``await self._open()``, + # since MemoryStore._open() is a no-op beyond setting the flag. + # ------------------------------------------------------------------- + + def get_sync( + self, + key: str, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + if prototype is None: + prototype = default_buffer_prototype() + # Inline open: MemoryStore._open() just sets _is_open = True. + if not self._is_open: + self._is_open = True + assert isinstance(key, str) + try: + # Direct dict lookup — this is what async get() does too, + # but without the event loop round-trip. + value = self._store_dict[key] + start, stop = _normalize_byte_range_index(value, byte_range) + return prototype.buffer.from_buffer(value[start:stop]) + except KeyError: + return None + + def set_sync(self, key: str, value: Buffer) -> None: + self._check_writable() + if not self._is_open: + self._is_open = True + assert isinstance(key, str) + if not isinstance(value, Buffer): + raise TypeError( + f"MemoryStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." + ) + # Direct dict assignment — no event loop overhead. + self._store_dict[key] = value + + def delete_sync(self, key: str) -> None: + self._check_writable() + try: + del self._store_dict[key] + except KeyError: + logger.debug("Key %s does not exist.", key) + async def get( self, key: str, @@ -113,7 +167,7 @@ async def exists(self, key: str) -> bool: # docstring inherited return key in self._store_dict - async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + async def set(self, key: str, value: Buffer) -> None: # docstring inherited self._check_writable() await self._ensure_open() @@ -122,13 +176,28 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None raise TypeError( f"MemoryStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." ) + self._store_dict[key] = value - if byte_range is not None: - buf = self._store_dict[key] - buf[byte_range[0] : byte_range[1]] = value - self._store_dict[key] = buf - else: - self._store_dict[key] = value + def _set_range_impl(self, key: str, value: Buffer, start: int) -> None: + buf = self._store_dict[key] + target = buf.as_numpy_array() + if not target.flags.writeable: + target = target.copy() + self._store_dict[key] = buf.__class__(target) + target[start : start + len(value)] = value.as_numpy_array() + + async def set_range(self, key: str, value: Buffer, start: int) -> None: + # docstring inherited + self._check_writable() + await self._ensure_open() + self._set_range_impl(key, value, start) + + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: + """Synchronous byte-range write.""" + self._check_writable() + if not self._is_open: + self._is_open = True + self._set_range_impl(key, value, start) async def set_if_not_exists(self, key: str, value: Buffer) -> None: # docstring inherited @@ -464,7 +533,7 @@ def from_dict(cls, store_dict: MutableMapping[str, Buffer]) -> Self: gpu_store_dict = {k: gpu.Buffer.from_buffer(v) for k, v in store_dict.items()} return cls(gpu_store_dict) - async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + async def set(self, key: str, value: Buffer) -> None: # docstring inherited self._check_writable() assert isinstance(key, str) @@ -474,4 +543,4 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None ) # Convert to gpu.Buffer gpu_value = value if isinstance(value, gpu.Buffer) else gpu.Buffer.from_buffer(value) - await super().set(key, gpu_value, byte_range=byte_range) + await super().set(key, gpu_value) diff --git a/src/zarr/storage/_wrapper.py b/src/zarr/storage/_wrapper.py index e8a2859abc..6e2f5b4536 100644 --- a/src/zarr/storage/_wrapper.py +++ b/src/zarr/storage/_wrapper.py @@ -111,6 +111,9 @@ async def exists(self, key: str) -> bool: async def set(self, key: str, value: Buffer) -> None: await self._store.set(key, value) + async def set_range(self, key: str, value: Buffer, start: int) -> None: + await self._store.set_range(key, value, start) + async def set_if_not_exists(self, key: str, value: Buffer) -> None: return await self._store.set_if_not_exists(key, value) diff --git a/src/zarr/testing/buffer.py b/src/zarr/testing/buffer.py index 6096ece2f8..93bc99ece5 100644 --- a/src/zarr/testing/buffer.py +++ b/src/zarr/testing/buffer.py @@ -67,10 +67,15 @@ class StoreExpectingTestBuffer(MemoryStore): We assume that keys containing "json" is metadata """ - async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + async def set(self, key: str, value: Buffer) -> None: if "json" not in key: assert isinstance(value, TestBuffer) - await super().set(key, value, byte_range) + await super().set(key, value) + + def set_sync(self, key: str, value: Buffer) -> None: + if "json" not in key: + assert isinstance(value, TestBuffer) + super().set_sync(key, value) async def get( self, @@ -84,3 +89,16 @@ async def get( if ret is not None: assert isinstance(ret, prototype.buffer) return ret + + def get_sync( + self, + key: str, + prototype: BufferPrototype | None = None, + byte_range: Any = None, + ) -> Buffer | None: + if "json" not in key and prototype is not None: + assert prototype.buffer is TestBuffer + ret = super().get_sync(key=key, prototype=prototype, byte_range=byte_range) + if ret is not None and prototype is not None: + assert isinstance(ret, prototype.buffer) + return ret diff --git a/tests/package_with_entrypoint/__init__.py b/tests/package_with_entrypoint/__init__.py index 7b5dfb5a1e..7394b2e5c8 100644 --- a/tests/package_with_entrypoint/__init__.py +++ b/tests/package_with_entrypoint/__init__.py @@ -40,7 +40,7 @@ def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> class TestEntrypointCodecPipeline(CodecPipeline): - def __init__(self, batch_size: int = 1) -> None: + def __init__(self) -> None: pass async def encode( diff --git a/tests/test_array.py b/tests/test_array.py index 01a82e1938..bae07a165d 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -2259,9 +2259,8 @@ def test_create_array_with_data_num_gets( data = zarr.zeros(shape, dtype="int64") zarr.create_array(store, data=data, chunks=chunk_shape, shards=shard_shape, fill_value=-1) # type: ignore[arg-type] - # one get for the metadata and one per shard. - # Note: we don't actually need one get per shard, but this is the current behavior - assert store.counter["get"] == 1 + num_shards + # one get for the metadata only — complete shard writes skip fetching existing data + assert store.counter["get"] == 1 @pytest.mark.parametrize("config", [{}, {"write_empty_chunks": True}, {"order": "C"}]) diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index d0e2d09b7c..7c008eb898 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -19,6 +19,7 @@ ) from zarr.core.buffer import NDArrayLike, default_buffer_prototype from zarr.storage import StorePath, ZipStore +from zarr.storage._logging import LoggingStore from ..conftest import ArrayRequest from .test_codecs import _AsyncArrayProxy, order_from_dim @@ -238,6 +239,105 @@ def test_sharding_partial_overwrite( assert np.array_equal(data, read_data) +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) +async def test_sharding_subchunk_writes_are_independent( + store: Store, index_location: ShardingCodecIndexLocation +) -> None: + """Writing to different inner chunks of a shard in separate operations + must preserve all previously written data. + + For uncompressed shards, partial writes should use set_range (byte-range + writes) rather than rewriting the entire shard blob. + """ + logging_store = LoggingStore(store) + # 1 shard of shape (4, 4) containing 4 inner chunks of shape (2, 2) + a = await zarr.api.asynchronous.create_array( + StorePath(logging_store), + shape=(4, 4), + chunks=(2, 2), + shards={"shape": (4, 4), "index_location": index_location}, + compressors=None, + dtype="uint16", + fill_value=0, + ) + + # Write each inner chunk separately with distinct values. + # First write creates the shard (uses set). + logging_store.counter.clear() + await _AsyncArrayProxy(a)[0:2, 0:2].set(np.full((2, 2), 1, dtype="uint16")) + assert logging_store.counter["set"] == 1, "first write should create shard via set" + assert logging_store.counter["set_range"] == 0 + + # Subsequent writes to existing shard should use set_range, not set. + logging_store.counter.clear() + await _AsyncArrayProxy(a)[0:2, 2:4].set(np.full((2, 2), 2, dtype="uint16")) + assert logging_store.counter["set"] == 0, "partial write should not use set" + assert logging_store.counter["set_range"] >= 1, "partial write should use set_range" + + logging_store.counter.clear() + await _AsyncArrayProxy(a)[2:4, 0:2].set(np.full((2, 2), 3, dtype="uint16")) + assert logging_store.counter["set"] == 0, "partial write should not use set" + assert logging_store.counter["set_range"] >= 1, "partial write should use set_range" + + logging_store.counter.clear() + await _AsyncArrayProxy(a)[2:4, 2:4].set(np.full((2, 2), 4, dtype="uint16")) + assert logging_store.counter["set"] == 0, "partial write should not use set" + assert logging_store.counter["set_range"] >= 1, "partial write should use set_range" + + # Every inner chunk must still contain its value + expected = np.array([[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]], dtype="uint16") + np.testing.assert_array_equal(await a.getitem(...), expected) + + +@pytest.mark.parametrize("outer_index_location", ["start", "end"]) +@pytest.mark.parametrize("inner_index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) +def test_nested_sharding_subchunk_writes_are_independent( + store: Store, + outer_index_location: ShardingCodecIndexLocation, + inner_index_location: ShardingCodecIndexLocation, +) -> None: + """Writing to different leaf chunks of a nested-sharded array in separate + operations must preserve all previously written data. + + Layout (1-D for clarity): + outer shard shape = 8 + inner shard shape = 4 (2 inner shards per outer shard) + chunk shape = 2 (2 chunks per inner shard) + total shape = 8 (1 outer shard) + + Four separate writes, one per leaf chunk, then read back the whole array. + """ + a = zarr.create_array( + StorePath(store), + shape=(8,), + dtype="uint16", + fill_value=0, + serializer=ShardingCodec( + chunk_shape=(4,), + codecs=[ + ShardingCodec( + chunk_shape=(2,), + index_location=inner_index_location, + ), + ], + index_location=outer_index_location, + ), + filters=None, + compressors=None, + ) + + # Write each leaf chunk independently + a[0:2] = np.full((2,), 1, dtype="uint16") + a[2:4] = np.full((2,), 2, dtype="uint16") + a[4:6] = np.full((2,), 3, dtype="uint16") + a[6:8] = np.full((2,), 4, dtype="uint16") + + expected = np.array([1, 1, 2, 2, 3, 3, 4, 4], dtype="uint16") + np.testing.assert_array_equal(a[:], expected) + + # Zip storage raises a warning about a duplicate name, which we ignore. @pytest.mark.filterwarnings("ignore:Duplicate name.*:UserWarning") @pytest.mark.parametrize( diff --git a/tests/test_config.py b/tests/test_config.py index c3102e8efe..ba74140b75 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -56,11 +56,13 @@ def test_config_defaults_set() -> None: "target_shard_size_bytes": None, }, "async": {"concurrency": 10, "timeout": None}, - "threading": {"max_workers": None}, + "threading": { + "max_workers": None, + "codec_workers": {"enabled": True, "min": 0, "max": None}, + }, "json_indent": 2, "codec_pipeline": { "path": "zarr.core.codec_pipeline.BatchedCodecPipeline", - "batch_size": 1, }, "codecs": { "blosc": "zarr.codecs.blosc.BloscCodec", @@ -103,7 +105,6 @@ def test_config_defaults_set() -> None: assert config.get("array.order") == "C" assert config.get("async.concurrency") == 10 assert config.get("async.timeout") is None - assert config.get("codec_pipeline.batch_size") == 1 assert config.get("json_indent") == 2 @@ -132,7 +133,7 @@ def test_config_codec_pipeline_class(store: Store) -> None: # has default value assert get_pipeline_class().__name__ != "" - config.set({"codec_pipeline.name": "zarr.core.codec_pipeline.BatchedCodecPipeline"}) + config.set({"codec_pipeline.path": "zarr.core.codec_pipeline.BatchedCodecPipeline"}) assert get_pipeline_class() == zarr.core.codec_pipeline.BatchedCodecPipeline _mock = Mock() @@ -146,6 +147,14 @@ async def write( ) -> None: _mock.call() + def write_sync( + self, + batch_info: Any, + value: NDBuffer, + drop_axes: tuple[int, ...] = (), + ) -> None: + _mock.call() + register_pipeline(MockCodecPipeline) config.set({"codec_pipeline.path": fully_qualified_name(MockCodecPipeline)}) @@ -191,6 +200,10 @@ async def _encode_single(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Bu _mock.call() return None + def _encode_sync(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Buffer | None: + _mock.call() + return None + register_codec("blosc", MockBloscCodec) with config.set({"codecs.blosc": fully_qualified_name(MockBloscCodec)}): assert get_codec_class("blosc") == MockBloscCodec diff --git a/tests/test_indexing.py b/tests/test_indexing.py index c0bf7dd270..193b0b98f4 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -34,6 +34,7 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator + from zarr.abc.store import ByteRequest from zarr.core.buffer import BufferPrototype from zarr.core.buffer.core import Buffer @@ -78,10 +79,25 @@ async def get( self.counter["__getitem__", key_suffix] += 1 return await super().get(key, prototype, byte_range) - async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + async def set(self, key: str, value: Buffer) -> None: key_suffix = "/".join(key.split("/")[1:]) self.counter["__setitem__", key_suffix] += 1 - return await super().set(key, value, byte_range) + return await super().set(key, value) + + def get_sync( + self, + key: str, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + key_suffix = "/".join(key.split("/")[1:]) + self.counter["__getitem__", key_suffix] += 1 + return super().get_sync(key, prototype, byte_range) + + def set_sync(self, key: str, value: Buffer) -> None: + key_suffix = "/".join(key.split("/")[1:]) + self.counter["__setitem__", key_suffix] += 1 + return super().set_sync(key, value) def test_normalize_integer_selection() -> None: diff --git a/tests/test_sync_codec_pipeline.py b/tests/test_sync_codec_pipeline.py new file mode 100644 index 0000000000..8fac3d54f6 --- /dev/null +++ b/tests/test_sync_codec_pipeline.py @@ -0,0 +1,305 @@ +"""Tests for sync codec capabilities in BatchedCodecPipeline.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest + +import zarr +from zarr.abc.codec import SupportsSyncCodec +from zarr.codecs.bytes import BytesCodec +from zarr.codecs.gzip import GzipCodec +from zarr.codecs.transpose import TransposeCodec +from zarr.codecs.zstd import ZstdCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import default_buffer_prototype +from zarr.core.codec_pipeline import BatchedCodecPipeline +from zarr.core.dtype import get_data_type_from_native_dtype +from zarr.storage import MemoryStore + + +def _make_array_spec(shape: tuple[int, ...], dtype: np.dtype[Any]) -> ArraySpec: + zdtype = get_data_type_from_native_dtype(dtype) + return ArraySpec( + shape=shape, + dtype=zdtype, + fill_value=zdtype.default_scalar(), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + + +def _make_nd_buffer(arr: np.ndarray[Any, Any]) -> zarr.core.buffer.NDBuffer: + return default_buffer_prototype().nd_buffer.from_numpy_array(arr) + + +# --------------------------------------------------------------------------- +# Unit tests: SupportsSyncCodec protocol +# --------------------------------------------------------------------------- + + +class TestSupportsSync: + def test_gzip_supports_sync(self) -> None: + assert isinstance(GzipCodec(), SupportsSyncCodec) + + def test_zstd_supports_sync(self) -> None: + assert isinstance(ZstdCodec(), SupportsSyncCodec) + + def test_bytes_supports_sync(self) -> None: + assert isinstance(BytesCodec(), SupportsSyncCodec) + + def test_transpose_supports_sync(self) -> None: + assert isinstance(TransposeCodec(order=(0, 1)), SupportsSyncCodec) + + def test_sharding_supports_sync(self) -> None: + from zarr.codecs.sharding import ShardingCodec + + assert isinstance(ShardingCodec(chunk_shape=(8,)), SupportsSyncCodec) + + +# --------------------------------------------------------------------------- +# Unit tests: individual codec sync roundtrips +# --------------------------------------------------------------------------- + + +class TestGzipCodecSync: + def test_roundtrip(self) -> None: + codec = GzipCodec(level=1) + arr = np.arange(100, dtype="float64") + spec = _make_array_spec(arr.shape, arr.dtype) + buf = default_buffer_prototype().buffer.from_array_like(arr.view("B")) + + encoded = codec._encode_sync(buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + result = np.frombuffer(decoded.as_numpy_array(), dtype="float64") + np.testing.assert_array_equal(arr, result) + + +class TestZstdCodecSync: + def test_roundtrip(self) -> None: + codec = ZstdCodec(level=1) + arr = np.arange(100, dtype="float64") + spec = _make_array_spec(arr.shape, arr.dtype) + buf = default_buffer_prototype().buffer.from_array_like(arr.view("B")) + + encoded = codec._encode_sync(buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + result = np.frombuffer(decoded.as_numpy_array(), dtype="float64") + np.testing.assert_array_equal(arr, result) + + +class TestBytesCodecSync: + def test_roundtrip(self) -> None: + codec = BytesCodec() + arr = np.arange(100, dtype="float64") + spec = _make_array_spec(arr.shape, arr.dtype) + nd_buf = _make_nd_buffer(arr) + + # Evolve from array spec (handles endianness) + codec = codec.evolve_from_array_spec(spec) + + encoded = codec._encode_sync(nd_buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) + + +class TestTransposeCodecSync: + def test_roundtrip(self) -> None: + codec = TransposeCodec(order=(1, 0)) + arr = np.arange(12, dtype="float64").reshape(3, 4) + spec = _make_array_spec(arr.shape, arr.dtype) + nd_buf = _make_nd_buffer(arr) + + encoded = codec._encode_sync(nd_buf, spec) + assert encoded is not None + resolved_spec = codec.resolve_metadata(spec) + decoded = codec._decode_sync(encoded, resolved_spec) + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) + + +# --------------------------------------------------------------------------- +# Unit tests: pipeline construction +# --------------------------------------------------------------------------- + + +class TestPipelineConstruction: + def test_from_codecs_valid(self) -> None: + pipeline = BatchedCodecPipeline.from_codecs([BytesCodec(), GzipCodec(level=1)]) + assert isinstance(pipeline, BatchedCodecPipeline) + assert len(pipeline.bytes_bytes_codecs) == 1 + assert isinstance(pipeline.array_bytes_codec, BytesCodec) + + def test_from_codecs_accepts_sharding(self) -> None: + from zarr.codecs.sharding import ShardingCodec + + pipeline = BatchedCodecPipeline.from_codecs([ShardingCodec(chunk_shape=(8,))]) + assert isinstance(pipeline, BatchedCodecPipeline) + assert pipeline._all_sync + + def test_from_codecs_rejects_missing_array_bytes(self) -> None: + with pytest.raises(ValueError, match="Required ArrayBytesCodec"): + BatchedCodecPipeline.from_codecs([GzipCodec()]) + + def test_from_codecs_with_transpose(self) -> None: + pipeline = BatchedCodecPipeline.from_codecs( + [ + TransposeCodec(order=(1, 0)), + BytesCodec(), + GzipCodec(level=1), + ] + ) + assert len(pipeline.array_array_codecs) == 1 + assert isinstance(pipeline.array_array_codecs[0], TransposeCodec) + + +# --------------------------------------------------------------------------- +# Unit tests: pipeline encode/decode roundtrip +# --------------------------------------------------------------------------- + + +class TestPipelineRoundtrip: + @pytest.mark.asyncio + async def test_encode_decode_single_chunk(self) -> None: + pipeline = BatchedCodecPipeline.from_codecs([BytesCodec(), GzipCodec(level=1)]) + arr = np.random.default_rng(42).standard_normal((32, 32)).astype("float64") + spec = _make_array_spec(arr.shape, arr.dtype) + pipeline = pipeline.evolve_from_array_spec(spec) + nd_buf = _make_nd_buffer(arr) + + encoded = await pipeline.encode([(nd_buf, spec)]) + decoded = await pipeline.decode([(next(iter(encoded)), spec)]) + result = next(iter(decoded)) + assert result is not None + np.testing.assert_array_equal(arr, result.as_numpy_array()) + + @pytest.mark.asyncio + async def test_encode_decode_multiple_chunks(self) -> None: + pipeline = BatchedCodecPipeline.from_codecs([BytesCodec(), GzipCodec(level=1)]) + rng = np.random.default_rng(42) + spec = _make_array_spec((16, 16), np.dtype("float64")) + pipeline = pipeline.evolve_from_array_spec(spec) + chunks = [rng.standard_normal((16, 16)).astype("float64") for _ in range(10)] + nd_bufs = [_make_nd_buffer(c) for c in chunks] + + encoded = list(await pipeline.encode([(buf, spec) for buf in nd_bufs])) + decoded = list(await pipeline.decode([(enc, spec) for enc in encoded])) + for original, dec in zip(chunks, decoded, strict=False): + assert dec is not None + np.testing.assert_array_equal(original, dec.as_numpy_array()) + + @pytest.mark.asyncio + async def test_encode_decode_empty_batch(self) -> None: + pipeline = BatchedCodecPipeline.from_codecs([BytesCodec(), GzipCodec(level=1)]) + encoded = await pipeline.encode([]) + assert list(encoded) == [] + decoded = await pipeline.decode([]) + assert list(decoded) == [] + + @pytest.mark.asyncio + async def test_encode_decode_none_chunk(self) -> None: + pipeline = BatchedCodecPipeline.from_codecs([BytesCodec(), GzipCodec(level=1)]) + spec = _make_array_spec((8,), np.dtype("float64")) + pipeline = pipeline.evolve_from_array_spec(spec) + + encoded = list(await pipeline.encode([(None, spec)])) + assert encoded[0] is None + + decoded = list(await pipeline.decode([(None, spec)])) + assert decoded[0] is None + + +# --------------------------------------------------------------------------- +# Integration tests: default pipeline has sync capabilities +# --------------------------------------------------------------------------- + + +class TestDefaultPipelineSync: + def test_create_array_uses_batched_pipeline(self) -> None: + store = MemoryStore() + arr = zarr.create_array( + store, + shape=(100, 100), + chunks=(32, 32), + dtype="float64", + ) + assert isinstance(arr.async_array.codec_pipeline, BatchedCodecPipeline) + + data = np.random.default_rng(42).standard_normal((100, 100)) + arr[:] = data + np.testing.assert_array_equal(arr[:], data) + + def test_open_uses_batched_pipeline(self) -> None: + store = MemoryStore() + arr = zarr.create_array( + store, + shape=(50, 50), + chunks=(25, 25), + dtype="float64", + ) + data = np.random.default_rng(42).standard_normal((50, 50)) + arr[:] = data + + arr2 = zarr.open_array(store=store) + assert isinstance(arr2.async_array.codec_pipeline, BatchedCodecPipeline) + np.testing.assert_array_equal(arr2[:], data) + + def test_from_array_uses_batched_pipeline(self) -> None: + store1 = MemoryStore() + arr1 = zarr.create_array( + store1, + shape=(20, 20), + chunks=(10, 10), + dtype="float64", + ) + data = np.random.default_rng(42).standard_normal((20, 20)) + arr1[:] = data + + store2 = MemoryStore() + arr2 = zarr.from_array(store2, data=arr1) + assert isinstance(arr2.async_array.codec_pipeline, BatchedCodecPipeline) + np.testing.assert_array_equal(arr2[:], data) + + def test_partial_write(self) -> None: + store = MemoryStore() + arr = zarr.create_array( + store, + shape=(100,), + chunks=(10,), + dtype="int32", + fill_value=0, + ) + arr[5:15] = np.arange(10, dtype="int32") + 1 + result = arr[:] + expected = np.zeros(100, dtype="int32") + expected[5:15] = np.arange(10, dtype="int32") + 1 + np.testing.assert_array_equal(result, expected) + + def test_zstd_codec(self) -> None: + store = MemoryStore() + arr = zarr.create_array( + store, + shape=(50,), + chunks=(10,), + dtype="float32", + compressors=ZstdCodec(level=3), + ) + data = np.random.default_rng(42).standard_normal(50).astype("float32") + arr[:] = data + np.testing.assert_array_equal(arr[:], data) + + def test_supports_sync_io(self) -> None: + """Default pipeline supports sync IO when all codecs are sync.""" + pipeline = BatchedCodecPipeline.from_codecs([BytesCodec(), GzipCodec(level=1)]) + assert pipeline.supports_sync_io + + def test_supports_sync_io_default(self) -> None: + """Default BatchedCodecPipeline is the sync pipeline — no config switch needed.""" + store = MemoryStore() + arr = zarr.create_array(store, shape=(10,), dtype="float64") + assert isinstance(arr.async_array.codec_pipeline, BatchedCodecPipeline) + assert arr.async_array.codec_pipeline.supports_sync_io