Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion asyncstdlib/asynctools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from asyncio import iscoroutinefunction
from functools import wraps
from inspect import iscoroutinefunction
from typing import (
Union,
AsyncContextManager,
Expand Down
22 changes: 15 additions & 7 deletions asyncstdlib/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,27 +204,35 @@ async def _zip_inner_strict(

async def map(
function: Union[Callable[..., R], Callable[..., Awaitable[R]]],
*iterable: AnyIterable[Any],
iterable: AnyIterable[Any],
/,
*iterables: AnyIterable[Any],
strict: bool = False,
) -> AsyncIterator[R]:
r"""
An async iterator mapping an (async) function to items from (async) iterables

:raises ValueError: if the ``iterables`` are not equal length and ``strict`` is set

At each step, ``map`` collects the next item from each iterable and calls
``function`` with all items; if ``function`` provides an awaitable,
``function`` with these items; if ``function`` provides an awaitable,
it is ``await``\ ed. The result is the next value of ``map``.
Barring sync/async translation, ``map`` is equivalent to
``(await function(*args) async for args in zip(iterables))``.

It is important that ``func`` receives *one* item from *each* iterable at
every step. For *n* ``iterable``, ``func`` must take *n* positional arguments.
Similar to :py:func:`~.zip`, ``map`` is exhausted as soon as its
*first* argument is exhausted.
every step. For *n* ``iterables``, ``func`` must take *n* positional arguments.
Similar to :py:func:`~.zip`, ``map`` is exhausted as soon as any of its `iterables`
is exhausted.
When called with ``strict=True``, all ``iterables`` must be of same length;
in this mode ``map`` raises :py:exc:`ValueError` if any ``iterables`` are not
exhausted with the others.

The ``function`` may be a regular or async callable.
Multiple ``iterable`` may be mixed regular and async iterables.
Multiple ``iterables`` may be mixed regular and async iterables.
"""
function = _awaitify(function)
async with ScopedIter(zip(*iterable)) as args_iter:
async with ScopedIter(zip(iterable, *iterables, strict=strict)) as args_iter:
async for args in args_iter:
result = function(*args)
yield await result
Expand Down
12 changes: 12 additions & 0 deletions asyncstdlib/builtins.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -82,26 +82,30 @@ def map(
function: Callable[[T1], Awaitable[R]],
__it1: AnyIterable[T1],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
function: Callable[[T1], R],
__it1: AnyIterable[T1],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
function: Callable[[T1, T2], Awaitable[R]],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
function: Callable[[T1, T2], R],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -110,6 +114,7 @@ def map(
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -118,6 +123,7 @@ def map(
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -127,6 +133,7 @@ def map(
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -136,6 +143,7 @@ def map(
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -146,6 +154,7 @@ def map(
__it4: AnyIterable[T4],
__it5: AnyIterable[T5],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -156,6 +165,7 @@ def map(
__it4: AnyIterable[T4],
__it5: AnyIterable[T5],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -167,6 +177,7 @@ def map(
__it5: AnyIterable[Any],
/,
*iterable: AnyIterable[Any],
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -178,6 +189,7 @@ def map(
__it5: AnyIterable[Any],
/,
*iterable: AnyIterable[Any],
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
async def max(iterable: AnyIterable[LT], *, key: None = ...) -> LT: ...
Expand Down
3 changes: 2 additions & 1 deletion asyncstdlib/functools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from asyncio import iscoroutinefunction
from inspect import iscoroutinefunction
from typing import (
Callable,
Awaitable,
Expand Down Expand Up @@ -281,6 +281,7 @@ def decorator(
async def reduce(
function: Union[Callable[[T, T], T], Callable[[T, T], Awaitable[T]]],
iterable: AnyIterable[T],
/,
initial: T = __REDUCE_SENTINEL, # type: ignore
) -> T:
"""
Expand Down
20 changes: 16 additions & 4 deletions asyncstdlib/functools.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,27 @@ def cached_property(
) -> Callable[[Callable[[T], Awaitable[R]]], CachedProperty[T, R]]: ...
@overload
async def reduce(
function: Callable[[T1, T2], Awaitable[T1]], iterable: AnyIterable[T2], initial: T1
function: Callable[[T1, T2], Awaitable[T1]],
iterable: AnyIterable[T2],
/,
initial: T1,
) -> T1: ...
@overload
async def reduce(
function: Callable[[T, T], Awaitable[T]], iterable: AnyIterable[T]
function: Callable[[T, T], Awaitable[T]],
iterable: AnyIterable[T],
/,
) -> T: ...
@overload
async def reduce(
function: Callable[[T1, T2], T1], iterable: AnyIterable[T2], initial: T1
function: Callable[[T1, T2], T1],
iterable: AnyIterable[T2],
/,
initial: T1,
) -> T1: ...
@overload
async def reduce(function: Callable[[T, T], T], iterable: AnyIterable[T]) -> T: ...
async def reduce(
function: Callable[[T, T], T],
iterable: AnyIterable[T],
/,
) -> T: ...
1 change: 1 addition & 0 deletions asyncstdlib/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AsyncGenerator,
TYPE_CHECKING,
)

if TYPE_CHECKING:
from typing_extensions import TypeAlias

Expand Down
55 changes: 41 additions & 14 deletions unittests/test_builtins.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import random
from typing import Any, Callable, Coroutine, TypeVar

import pytest

import asyncstdlib as a

from .utility import sync, asyncify, awaitify

COR = TypeVar("COR", bound=Callable[..., Coroutine[Any, Any, Any]])

def hide_coroutine(corofunc):
def wrapper(*args, **kwargs):

def hide_coroutine(corofunc: COR) -> COR:
"""Make a coroutine function look like a regular function returning a coroutine"""

def wrapper(*args, **kwargs): # type: ignore
return corofunc(*args, **kwargs)

return wrapper
return wrapper # type: ignore


@sync
Expand Down Expand Up @@ -94,7 +99,7 @@ async def __aiter__(self):

@sync
async def test_map_as():
async def map_op(value):
async def map_op(value: int) -> int:
return value * 2

assert [value async for value in a.map(map_op, range(5))] == list(range(0, 10, 2))
Expand All @@ -105,7 +110,7 @@ async def map_op(value):

@sync
async def test_map_sa():
def map_op(value):
async def map_op(value: int) -> int:
return value * 2

assert [value async for value in a.map(map_op, asyncify(range(5)))] == list(
Expand All @@ -118,7 +123,7 @@ def map_op(value):

@sync
async def test_map_aa():
async def map_op(value):
async def map_op(value: int) -> int:
return value * 2

assert [value async for value in a.map(map_op, asyncify(range(5)))] == list(
Expand All @@ -130,6 +135,28 @@ async def map_op(value):
] == list(range(10, 20, 4))


@pytest.mark.parametrize(
"itrs",
[
(range(4), range(5), range(5)),
(range(5), range(4), range(5)),
(range(5), range(5), range(4)),
],
)
@sync
async def test_map_strict_unequal(itrs: "tuple[range, ...]"):
def triple_sum(x: int, y: int, z: int) -> int:
return x + y + z

# no error without strict
async for _ in a.map(triple_sum, *itrs):
pass
# error with strict
with pytest.raises(ValueError):
async for _ in a.map(triple_sum, *itrs, strict=True):
pass


@sync
async def test_max_default():
assert await a.max((), default=3) == 3
Expand All @@ -142,7 +169,7 @@ async def test_max_default():

@sync
async def test_max_sa():
async def minus(x):
async def minus(x: int) -> int:
return -x

assert await a.max(asyncify((1, 2, 3, 4))) == 4
Expand All @@ -167,7 +194,7 @@ async def test_min_default():

@sync
async def test_min_sa():
async def minus(x):
async def minus(x: int) -> int:
return -x

assert await a.min(asyncify((1, 2, 3, 4))) == 1
Expand All @@ -180,7 +207,7 @@ async def minus(x):

@sync
async def test_filter_as():
async def map_op(value):
async def map_op(value: int) -> bool:
return value % 2 == 0

assert [value async for value in a.filter(map_op, range(5))] == list(range(0, 5, 2))
Expand All @@ -194,7 +221,7 @@ async def map_op(value):

@sync
async def test_filter_sa():
def map_op(value):
def map_op(value: int) -> bool:
return value % 2 == 0

assert [value async for value in a.filter(map_op, asyncify(range(5)))] == list(
Expand All @@ -208,7 +235,7 @@ def map_op(value):

@sync
async def test_filter_aa():
async def map_op(value):
async def map_op(value: int) -> bool:
return value % 2 == 0

assert [value async for value in a.filter(map_op, asyncify(range(5)))] == list(
Expand Down Expand Up @@ -286,7 +313,7 @@ async def test_types():
@pytest.mark.parametrize("sortable", sortables)
@pytest.mark.parametrize("reverse", [True, False])
@sync
async def test_sorted_direct(sortable, reverse):
async def test_sorted_direct(sortable: "list[int] | list[float]", reverse: bool):
assert await a.sorted(sortable, reverse=reverse) == sorted(
sortable, reverse=reverse
)
Expand All @@ -305,12 +332,12 @@ async def test_sorted_direct(sortable, reverse):
async def test_sorted_stable():
values = [-i for i in range(20)]

def collision_key(x):
def collision_key(x: int) -> int:
return x // 2

# test the test...
assert sorted(values, key=collision_key) != [
item for key, item in sorted([(collision_key(i), i) for i in values])
item for _, item in sorted([(collision_key(i), i) for i in values])
]
# test the implementation
assert await a.sorted(values, key=awaitify(collision_key)) == sorted(
Expand Down