diff --git a/asyncstdlib/asynctools.py b/asyncstdlib/asynctools.py index dd23948..8b2f9d9 100644 --- a/asyncstdlib/asynctools.py +++ b/asyncstdlib/asynctools.py @@ -1,5 +1,5 @@ -from asyncio import iscoroutinefunction from functools import wraps +from inspect import iscoroutinefunction from typing import ( Union, AsyncContextManager, diff --git a/asyncstdlib/builtins.py b/asyncstdlib/builtins.py index 717dbae..5c8f54c 100644 --- a/asyncstdlib/builtins.py +++ b/asyncstdlib/builtins.py @@ -204,27 +204,35 @@ async def _zip_inner_strict( async def map( function: Union[Callable[..., R], Callable[..., Awaitable[R]]], - *iterable: AnyIterable[Any], + iterable: AnyIterable[Any], + /, + *iterables: AnyIterable[Any], + strict: bool = False, ) -> AsyncIterator[R]: r""" An async iterator mapping an (async) function to items from (async) iterables + :raises ValueError: if the ``iterables`` are not equal length and ``strict`` is set + At each step, ``map`` collects the next item from each iterable and calls - ``function`` with all items; if ``function`` provides an awaitable, + ``function`` with these items; if ``function`` provides an awaitable, it is ``await``\ ed. The result is the next value of ``map``. Barring sync/async translation, ``map`` is equivalent to ``(await function(*args) async for args in zip(iterables))``. It is important that ``func`` receives *one* item from *each* iterable at - every step. For *n* ``iterable``, ``func`` must take *n* positional arguments. - Similar to :py:func:`~.zip`, ``map`` is exhausted as soon as its - *first* argument is exhausted. + every step. For *n* ``iterables``, ``func`` must take *n* positional arguments. + Similar to :py:func:`~.zip`, ``map`` is exhausted as soon as any of its `iterables` + is exhausted. + When called with ``strict=True``, all ``iterables`` must be of same length; + in this mode ``map`` raises :py:exc:`ValueError` if any ``iterables`` are not + exhausted with the others. The ``function`` may be a regular or async callable. - Multiple ``iterable`` may be mixed regular and async iterables. + Multiple ``iterables`` may be mixed regular and async iterables. """ function = _awaitify(function) - async with ScopedIter(zip(*iterable)) as args_iter: + async with ScopedIter(zip(iterable, *iterables, strict=strict)) as args_iter: async for args in args_iter: result = function(*args) yield await result diff --git a/asyncstdlib/builtins.pyi b/asyncstdlib/builtins.pyi index 13d22a9..c4cde97 100644 --- a/asyncstdlib/builtins.pyi +++ b/asyncstdlib/builtins.pyi @@ -82,12 +82,14 @@ def map( function: Callable[[T1], Awaitable[R]], __it1: AnyIterable[T1], /, + strict: bool = ..., ) -> AsyncIterator[R]: ... @overload def map( function: Callable[[T1], R], __it1: AnyIterable[T1], /, + strict: bool = ..., ) -> AsyncIterator[R]: ... @overload def map( @@ -95,6 +97,7 @@ def map( __it1: AnyIterable[T1], __it2: AnyIterable[T2], /, + strict: bool = ..., ) -> AsyncIterator[R]: ... @overload def map( @@ -102,6 +105,7 @@ def map( __it1: AnyIterable[T1], __it2: AnyIterable[T2], /, + strict: bool = ..., ) -> AsyncIterator[R]: ... @overload def map( @@ -110,6 +114,7 @@ def map( __it2: AnyIterable[T2], __it3: AnyIterable[T3], /, + strict: bool = ..., ) -> AsyncIterator[R]: ... @overload def map( @@ -118,6 +123,7 @@ def map( __it2: AnyIterable[T2], __it3: AnyIterable[T3], /, + strict: bool = ..., ) -> AsyncIterator[R]: ... @overload def map( @@ -127,6 +133,7 @@ def map( __it3: AnyIterable[T3], __it4: AnyIterable[T4], /, + strict: bool = ..., ) -> AsyncIterator[R]: ... @overload def map( @@ -136,6 +143,7 @@ def map( __it3: AnyIterable[T3], __it4: AnyIterable[T4], /, + strict: bool = ..., ) -> AsyncIterator[R]: ... @overload def map( @@ -146,6 +154,7 @@ def map( __it4: AnyIterable[T4], __it5: AnyIterable[T5], /, + strict: bool = ..., ) -> AsyncIterator[R]: ... @overload def map( @@ -156,6 +165,7 @@ def map( __it4: AnyIterable[T4], __it5: AnyIterable[T5], /, + strict: bool = ..., ) -> AsyncIterator[R]: ... @overload def map( @@ -167,6 +177,7 @@ def map( __it5: AnyIterable[Any], /, *iterable: AnyIterable[Any], + strict: bool = ..., ) -> AsyncIterator[R]: ... @overload def map( @@ -178,6 +189,7 @@ def map( __it5: AnyIterable[Any], /, *iterable: AnyIterable[Any], + strict: bool = ..., ) -> AsyncIterator[R]: ... @overload async def max(iterable: AnyIterable[LT], *, key: None = ...) -> LT: ... diff --git a/asyncstdlib/functools.py b/asyncstdlib/functools.py index 100c011..ad26875 100644 --- a/asyncstdlib/functools.py +++ b/asyncstdlib/functools.py @@ -1,4 +1,4 @@ -from asyncio import iscoroutinefunction +from inspect import iscoroutinefunction from typing import ( Callable, Awaitable, @@ -281,6 +281,7 @@ def decorator( async def reduce( function: Union[Callable[[T, T], T], Callable[[T, T], Awaitable[T]]], iterable: AnyIterable[T], + /, initial: T = __REDUCE_SENTINEL, # type: ignore ) -> T: """ diff --git a/asyncstdlib/functools.pyi b/asyncstdlib/functools.pyi index 4bedfdc..d222429 100644 --- a/asyncstdlib/functools.pyi +++ b/asyncstdlib/functools.pyi @@ -34,15 +34,27 @@ def cached_property( ) -> Callable[[Callable[[T], Awaitable[R]]], CachedProperty[T, R]]: ... @overload async def reduce( - function: Callable[[T1, T2], Awaitable[T1]], iterable: AnyIterable[T2], initial: T1 + function: Callable[[T1, T2], Awaitable[T1]], + iterable: AnyIterable[T2], + /, + initial: T1, ) -> T1: ... @overload async def reduce( - function: Callable[[T, T], Awaitable[T]], iterable: AnyIterable[T] + function: Callable[[T, T], Awaitable[T]], + iterable: AnyIterable[T], + /, ) -> T: ... @overload async def reduce( - function: Callable[[T1, T2], T1], iterable: AnyIterable[T2], initial: T1 + function: Callable[[T1, T2], T1], + iterable: AnyIterable[T2], + /, + initial: T1, ) -> T1: ... @overload -async def reduce(function: Callable[[T, T], T], iterable: AnyIterable[T]) -> T: ... +async def reduce( + function: Callable[[T, T], T], + iterable: AnyIterable[T], + /, +) -> T: ... diff --git a/asyncstdlib/itertools.py b/asyncstdlib/itertools.py index f6be464..2929024 100644 --- a/asyncstdlib/itertools.py +++ b/asyncstdlib/itertools.py @@ -17,6 +17,7 @@ AsyncGenerator, TYPE_CHECKING, ) + if TYPE_CHECKING: from typing_extensions import TypeAlias diff --git a/unittests/test_builtins.py b/unittests/test_builtins.py index 6231b4b..382fddd 100644 --- a/unittests/test_builtins.py +++ b/unittests/test_builtins.py @@ -1,4 +1,5 @@ import random +from typing import Any, Callable, Coroutine, TypeVar import pytest @@ -6,12 +7,16 @@ from .utility import sync, asyncify, awaitify +COR = TypeVar("COR", bound=Callable[..., Coroutine[Any, Any, Any]]) -def hide_coroutine(corofunc): - def wrapper(*args, **kwargs): + +def hide_coroutine(corofunc: COR) -> COR: + """Make a coroutine function look like a regular function returning a coroutine""" + + def wrapper(*args, **kwargs): # type: ignore return corofunc(*args, **kwargs) - return wrapper + return wrapper # type: ignore @sync @@ -94,7 +99,7 @@ async def __aiter__(self): @sync async def test_map_as(): - async def map_op(value): + async def map_op(value: int) -> int: return value * 2 assert [value async for value in a.map(map_op, range(5))] == list(range(0, 10, 2)) @@ -105,7 +110,7 @@ async def map_op(value): @sync async def test_map_sa(): - def map_op(value): + async def map_op(value: int) -> int: return value * 2 assert [value async for value in a.map(map_op, asyncify(range(5)))] == list( @@ -118,7 +123,7 @@ def map_op(value): @sync async def test_map_aa(): - async def map_op(value): + async def map_op(value: int) -> int: return value * 2 assert [value async for value in a.map(map_op, asyncify(range(5)))] == list( @@ -130,6 +135,28 @@ async def map_op(value): ] == list(range(10, 20, 4)) +@pytest.mark.parametrize( + "itrs", + [ + (range(4), range(5), range(5)), + (range(5), range(4), range(5)), + (range(5), range(5), range(4)), + ], +) +@sync +async def test_map_strict_unequal(itrs: "tuple[range, ...]"): + def triple_sum(x: int, y: int, z: int) -> int: + return x + y + z + + # no error without strict + async for _ in a.map(triple_sum, *itrs): + pass + # error with strict + with pytest.raises(ValueError): + async for _ in a.map(triple_sum, *itrs, strict=True): + pass + + @sync async def test_max_default(): assert await a.max((), default=3) == 3 @@ -142,7 +169,7 @@ async def test_max_default(): @sync async def test_max_sa(): - async def minus(x): + async def minus(x: int) -> int: return -x assert await a.max(asyncify((1, 2, 3, 4))) == 4 @@ -167,7 +194,7 @@ async def test_min_default(): @sync async def test_min_sa(): - async def minus(x): + async def minus(x: int) -> int: return -x assert await a.min(asyncify((1, 2, 3, 4))) == 1 @@ -180,7 +207,7 @@ async def minus(x): @sync async def test_filter_as(): - async def map_op(value): + async def map_op(value: int) -> bool: return value % 2 == 0 assert [value async for value in a.filter(map_op, range(5))] == list(range(0, 5, 2)) @@ -194,7 +221,7 @@ async def map_op(value): @sync async def test_filter_sa(): - def map_op(value): + def map_op(value: int) -> bool: return value % 2 == 0 assert [value async for value in a.filter(map_op, asyncify(range(5)))] == list( @@ -208,7 +235,7 @@ def map_op(value): @sync async def test_filter_aa(): - async def map_op(value): + async def map_op(value: int) -> bool: return value % 2 == 0 assert [value async for value in a.filter(map_op, asyncify(range(5)))] == list( @@ -286,7 +313,7 @@ async def test_types(): @pytest.mark.parametrize("sortable", sortables) @pytest.mark.parametrize("reverse", [True, False]) @sync -async def test_sorted_direct(sortable, reverse): +async def test_sorted_direct(sortable: "list[int] | list[float]", reverse: bool): assert await a.sorted(sortable, reverse=reverse) == sorted( sortable, reverse=reverse ) @@ -305,12 +332,12 @@ async def test_sorted_direct(sortable, reverse): async def test_sorted_stable(): values = [-i for i in range(20)] - def collision_key(x): + def collision_key(x: int) -> int: return x // 2 # test the test... assert sorted(values, key=collision_key) != [ - item for key, item in sorted([(collision_key(i), i) for i in values]) + item for _, item in sorted([(collision_key(i), i) for i in values]) ] # test the implementation assert await a.sorted(values, key=awaitify(collision_key)) == sorted(