From c68d8487c1b3b07ed20bd8cbb180df225afb7f98 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Fri, 13 Mar 2026 11:38:44 +0200 Subject: [PATCH] feat(typing): allow new_1d to take ndarrays --- doc/conf.py | 1 + pytools/obj_array.py | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/doc/conf.py b/doc/conf.py index 676da1b9..456cb16d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -45,6 +45,7 @@ "np.dtype": "class:numpy.dtype", "np.ndarray": "class:numpy.ndarray", "np.floating": "class:numpy.floating", + "np.generic[Any]": "class:numpy.generic", # pytools typing "ObjectArray1D": "obj:pytools.obj_array.ObjectArray1D", "ReadableBuffer": "data:pytools.ReadableBuffer", diff --git a/pytools/obj_array.py b/pytools/obj_array.py index 727223ae..9e3afc8e 100644 --- a/pytools/obj_array.py +++ b/pytools/obj_array.py @@ -7,6 +7,7 @@ .. autoclass:: T_co .. autoclass:: ResultT .. autoclass:: ShapeT +.. autoclass:: NumpyTypeT .. autoclass:: ObjectArray .. autoclass:: ObjectArray0D @@ -99,6 +100,7 @@ ResultT = TypeVar("ResultT") ShapeT = TypeVar("ShapeT", bound=tuple[int, ...]) +NumpyTypeT = TypeVar("NumpyTypeT", bound="np.generic[Any]") class _ObjectArrayMetaclass(type): @@ -340,7 +342,37 @@ def from_numpy( return cast("ObjectArray[ShapeT, T_co]", cast("object", ary)) -def new_1d(res_list: Sequence[T_co]) -> ObjectArray1D[T_co]: +@overload +def new_1d( # pyright: ignore[reportOverlappingOverload] + res_list: np.ndarray[tuple[int], np.dtype[NumpyTypeT]] + ) -> ObjectArray1D[NumpyTypeT]: ... + +@overload +def new_1d( + res_list: np.ndarray[tuple[int, int], np.dtype[NumpyTypeT]] + ) -> ObjectArray1D[np.ndarray[tuple[int], np.dtype[NumpyTypeT]]]: ... + +@overload +def new_1d( + res_list: np.ndarray[tuple[int, int, int], np.dtype[NumpyTypeT]] + ) -> ObjectArray1D[np.ndarray[tuple[int, int], np.dtype[NumpyTypeT]]]: ... + +@overload +def new_1d( + res_list: np.ndarray[tuple[int, ...], np.dtype[NumpyTypeT]] + ) -> ObjectArray1D[np.ndarray[tuple[int, ...], np.dtype[NumpyTypeT]]]: ... + +@overload +def new_1d(res_list: Sequence[T_co]) -> ObjectArray1D[T_co]: ... + + +def new_1d( + res_list: ( + Sequence[T_co] + | np.ndarray[tuple[int, ...], np.dtype[NumpyTypeT]] + ) + ) -> (ObjectArray1D[T_co] + | ObjectArray1D[np.ndarray[tuple[int, ...], np.dtype[NumpyTypeT]]]): """Create a one-dimensional object array from *res_list*. This differs from ``numpy.array(res_list, dtype=object)`` by whether it tries to determine its shape by descending