diff --git a/CHANGES/11681.feature.rst b/CHANGES/11681.feature.rst new file mode 100644 index 00000000000..21b0ab1f7c7 --- /dev/null +++ b/CHANGES/11681.feature.rst @@ -0,0 +1,6 @@ +Started accepting :term:`asynchronous context managers ` for cleanup contexts. +Legacy single-yield :term:`asynchronous generator` cleanup contexts continue to be +supported; async context managers are adapted internally so they are +entered at startup and exited during cleanup. + +-- by :user:`MannXo`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 7f81d3e5dd6..3a570dc90e7 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -291,6 +291,7 @@ Pahaz Blinov Panagiotis Kolokotronis Pankaj Pandey Parag Jain +Parman Mohammadalizadeh Patrick Lee Pau Freixes Paul Colomiets diff --git a/aiohttp/web_app.py b/aiohttp/web_app.py index 3a34311f845..095d5e806cd 100644 --- a/aiohttp/web_app.py +++ b/aiohttp/web_app.py @@ -11,6 +11,7 @@ MutableMapping, Sequence, ) +from contextlib import AbstractAsyncContextManager, asynccontextmanager from functools import lru_cache, partial, update_wrapper from typing import Any, TypeVar, cast, final, overload @@ -405,31 +406,34 @@ def exceptions(self) -> list[BaseException]: return cast(list[BaseException], self.args[1]) -_CleanupContextBase = FrozenList[Callable[[Application], AsyncIterator[None]]] +_CleanupContextCallable = ( + Callable[[Application], AbstractAsyncContextManager[None]] + | Callable[[Application], AsyncIterator[None]] +) -class CleanupContext(_CleanupContextBase): +class CleanupContext(FrozenList[_CleanupContextCallable]): def __init__(self) -> None: super().__init__() - self._exits: list[AsyncIterator[None]] = [] + self._exits: list[AbstractAsyncContextManager[None]] = [] async def _on_startup(self, app: Application) -> None: for cb in self: - it = cb(app).__aiter__() - await it.__anext__() - self._exits.append(it) + ctx = cb(app) + + if not isinstance(ctx, AbstractAsyncContextManager): + ctx = asynccontextmanager(cb)(app) # type: ignore[arg-type] + + await ctx.__aenter__() + self._exits.append(ctx) async def _on_cleanup(self, app: Application) -> None: errors = [] for it in reversed(self._exits): try: - await it.__anext__() - except StopAsyncIteration: - pass + await it.__aexit__(None, None, None) except (Exception, asyncio.CancelledError) as exc: errors.append(exc) - else: - errors.append(RuntimeError(f"{it!r} has more than one 'yield'")) if errors: if len(errors) == 1: raise errors[0] diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 09ec0f1f356..ebec0eef5a8 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -805,19 +805,21 @@ performance improvements. If you plan on reusing the session, a.k.a. creating :ref:`aiohttp-web-cleanup-ctx`. If possible we advise using :ref:`aiohttp-web-cleanup-ctx`, as it results in more compact code:: - app.cleanup_ctx.append(persistent_session) - persistent_session = aiohttp.web.AppKey("persistent_session", aiohttp.ClientSession) + session = aiohttp.web.AppKey("session", aiohttp.ClientSession) + @contextlib.asynccontextmanager async def persistent_session(app): app[persistent_session] = session = aiohttp.ClientSession() yield await session.close() async def my_request_handler(request): - session = request.app[persistent_session] - async with session.get("http://python.org") as resp: + sess = request.app[session] + async with sess.get("http://python.org") as resp: print(resp.status) + app.cleanup_ctx.append(persistent_session) + This approach can be successfully used to define numerous sessions given certain requirements. It benefits from having a single location where :class:`aiohttp.ClientSession` diff --git a/docs/faq.rst b/docs/faq.rst index 2166f1775a7..3f50b855588 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -290,6 +290,7 @@ database object, do it explicitly:: This can also be done from a :ref:`cleanup context`:: + @contextlib.asynccontextmanager async def db_context(app: web.Application) -> AsyncIterator[None]: async with create_db() as db: mainapp[db_key] = mainapp[subapp_key][db_key] = db diff --git a/docs/web_advanced.rst b/docs/web_advanced.rst index 81fa384d55b..171ab809724 100644 --- a/docs/web_advanced.rst +++ b/docs/web_advanced.rst @@ -857,6 +857,7 @@ knowledge about startup/cleanup pairs and their execution state. The solution is :attr:`Application.cleanup_ctx` usage:: + @contextlib.asynccontextmanager async def pg_engine(app: web.Application): app[pg_engine] = await create_async_engine( "postgresql+asyncpg://postgre:@localhost:5432/postgre" @@ -1168,6 +1169,7 @@ below:: await ws.send_str("{}: {}".format(channel, msg)) + @contextlib.asynccontextmanager async def background_tasks(app): app[redis_listener] = asyncio.create_task(listen_to_redis(app)) @@ -1207,6 +1209,7 @@ For example, running a long-lived task alongside the :class:`Application` can be done with a :ref:`aiohttp-web-cleanup-ctx` function like:: + @contextlib.asynccontextmanager async def run_other_task(_app): task = asyncio.create_task(other_long_task()) @@ -1222,6 +1225,7 @@ can be done with a :ref:`aiohttp-web-cleanup-ctx` function like:: Or a separate process can be run with something like:: + @contextlib.asynccontextmanager async def run_process(_app): proc = await asyncio.create_subprocess_exec(path) diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 01b237f1b0a..51a8dae189a 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -1532,7 +1532,8 @@ Application and Router Signal handlers should have the following signature:: - async def context(app): + @contextlib.asynccontextmanager + async def context(app: web.Application) -> AsyncIterator[None]: # do startup stuff yield # do cleanup diff --git a/examples/background_tasks.py b/examples/background_tasks.py index a4b37a02bb3..3d5c43e48de 100755 --- a/examples/background_tasks.py +++ b/examples/background_tasks.py @@ -2,7 +2,7 @@ """Example of aiohttp.web.Application.on_startup signal handler""" import asyncio from collections.abc import AsyncIterator -from contextlib import suppress +from contextlib import asynccontextmanager, suppress import valkey.asyncio as valkey @@ -44,6 +44,7 @@ async def listen_to_valkey(app: web.Application) -> None: print(f"message in {channel}: {msg}") +@asynccontextmanager async def background_tasks(app: web.Application) -> AsyncIterator[None]: app[valkey_listener] = asyncio.create_task(listen_to_valkey(app)) diff --git a/tests/test_web_app.py b/tests/test_web_app.py index 2d2d21dbc42..77022a63c67 100644 --- a/tests/test_web_app.py +++ b/tests/test_web_app.py @@ -1,6 +1,7 @@ import asyncio import sys from collections.abc import AsyncIterator, Callable, Iterator +from contextlib import asynccontextmanager from typing import NoReturn from unittest import mock @@ -401,12 +402,85 @@ async def inner(app: web.Application) -> AsyncIterator[None]: app.freeze() await app.startup() assert out == ["pre_1"] - with pytest.raises(RuntimeError) as ctx: + with pytest.raises(RuntimeError): await app.cleanup() - assert "has more than one 'yield'" in str(ctx.value) assert out == ["pre_1", "post_1"] +async def test_cleanup_ctx_with_async_generator_and_asynccontextmanager() -> None: + entered = [] + + async def gen_ctx(app: web.Application) -> AsyncIterator[None]: + entered.append("enter-gen") + try: + yield + finally: + entered.append("exit-gen") + + @asynccontextmanager + async def cm_ctx(app: web.Application) -> AsyncIterator[None]: + entered.append("enter-cm") + try: + yield + finally: + entered.append("exit-cm") + + app = web.Application() + app.cleanup_ctx.append(gen_ctx) + app.cleanup_ctx.append(cm_ctx) + app.freeze() + await app.startup() + assert "enter-gen" in entered and "enter-cm" in entered + await app.cleanup() + assert "exit-gen" in entered and "exit-cm" in entered + + +async def test_cleanup_ctx_exception_in_cm_exit() -> None: + app = web.Application() + + exc = RuntimeError("exit failed") + + @asynccontextmanager + async def failing_exit_ctx(app: web.Application) -> AsyncIterator[None]: + yield + raise exc + + app.cleanup_ctx.append(failing_exit_ctx) + app.freeze() + await app.startup() + with pytest.raises(RuntimeError) as ctx: + await app.cleanup() + assert ctx.value is exc + + +async def test_cleanup_ctx_mixed_with_exception_in_cm_exit() -> None: + app = web.Application() + out = [] + + async def working_gen(app: web.Application) -> AsyncIterator[None]: + out.append("pre_gen") + yield + out.append("post_gen") + + exc = RuntimeError("cm exit failed") + + @asynccontextmanager + async def failing_exit_cm(app: web.Application) -> AsyncIterator[None]: + out.append("pre_cm") + yield + out.append("post_cm") + raise exc + + app.cleanup_ctx.append(working_gen) + app.cleanup_ctx.append(failing_exit_cm) + app.freeze() + await app.startup() + with pytest.raises(RuntimeError) as ctx: + await app.cleanup() + assert ctx.value is exc + assert out == ["pre_gen", "pre_cm", "post_cm", "post_gen"] + + async def test_subapp_chained_config_dict_visibility( aiohttp_client: AiohttpClient, ) -> None: