|
|
|
@ -1,22 +1,31 @@ |
|
|
|
import contextlib |
|
|
|
import email.message |
|
|
|
import functools |
|
|
|
import inspect |
|
|
|
import json |
|
|
|
import types |
|
|
|
from collections.abc import ( |
|
|
|
AsyncIterator, |
|
|
|
Awaitable, |
|
|
|
Collection, |
|
|
|
Coroutine, |
|
|
|
Generator, |
|
|
|
Mapping, |
|
|
|
Sequence, |
|
|
|
) |
|
|
|
from contextlib import AsyncExitStack, asynccontextmanager |
|
|
|
from contextlib import ( |
|
|
|
AbstractAsyncContextManager, |
|
|
|
AbstractContextManager, |
|
|
|
AsyncExitStack, |
|
|
|
asynccontextmanager, |
|
|
|
) |
|
|
|
from enum import Enum, IntEnum |
|
|
|
from typing import ( |
|
|
|
Annotated, |
|
|
|
Any, |
|
|
|
Callable, |
|
|
|
Optional, |
|
|
|
TypeVar, |
|
|
|
Union, |
|
|
|
) |
|
|
|
|
|
|
|
@ -143,6 +152,50 @@ def websocket_session( |
|
|
|
return app |
|
|
|
|
|
|
|
|
|
|
|
_T = TypeVar("_T") |
|
|
|
|
|
|
|
|
|
|
|
# Vendored from starlette.routing to avoid importing private symbols |
|
|
|
class _AsyncLiftContextManager(AbstractAsyncContextManager[_T]): |
|
|
|
""" |
|
|
|
Wraps a synchronous context manager to make it async. |
|
|
|
|
|
|
|
This is vendored from Starlette to avoid importing private symbols. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, cm: AbstractContextManager[_T]) -> None: |
|
|
|
self._cm = cm |
|
|
|
|
|
|
|
async def __aenter__(self) -> _T: |
|
|
|
return self._cm.__enter__() |
|
|
|
|
|
|
|
async def __aexit__( |
|
|
|
self, |
|
|
|
exc_type: Optional[type[BaseException]], |
|
|
|
exc_value: Optional[BaseException], |
|
|
|
traceback: Optional[types.TracebackType], |
|
|
|
) -> Optional[bool]: |
|
|
|
return self._cm.__exit__(exc_type, exc_value, traceback) |
|
|
|
|
|
|
|
|
|
|
|
# Vendored from starlette.routing to avoid importing private symbols |
|
|
|
def _wrap_gen_lifespan_context( |
|
|
|
lifespan_context: Callable[[Any], Generator[Any, Any, Any]], |
|
|
|
) -> Callable[[Any], AbstractAsyncContextManager[Any]]: |
|
|
|
""" |
|
|
|
Wrap a generator-based lifespan context into an async context manager. |
|
|
|
|
|
|
|
This is vendored from Starlette to avoid importing private symbols. |
|
|
|
""" |
|
|
|
cmgr = contextlib.contextmanager(lifespan_context) |
|
|
|
|
|
|
|
@functools.wraps(cmgr) |
|
|
|
def wrapper(app: Any) -> _AsyncLiftContextManager[Any]: |
|
|
|
return _AsyncLiftContextManager(cmgr(app)) |
|
|
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|
def _merge_lifespan_context( |
|
|
|
original_context: Lifespan[Any], nested_context: Lifespan[Any] |
|
|
|
) -> Lifespan[Any]: |
|
|
|
@ -160,6 +213,30 @@ def _merge_lifespan_context( |
|
|
|
return merged_lifespan # type: ignore[return-value] |
|
|
|
|
|
|
|
|
|
|
|
class _DefaultLifespan: |
|
|
|
""" |
|
|
|
Default lifespan context manager that runs on_startup and on_shutdown handlers. |
|
|
|
|
|
|
|
This is a copy of the Starlette _DefaultLifespan class that was removed |
|
|
|
in Starlette. FastAPI keeps it to maintain backward compatibility with |
|
|
|
on_startup and on_shutdown event handlers. |
|
|
|
|
|
|
|
Ref: https://github.com/Kludex/starlette/pull/3117 |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, router: "APIRouter") -> None: |
|
|
|
self._router = router |
|
|
|
|
|
|
|
async def __aenter__(self) -> None: |
|
|
|
await self._router._startup() |
|
|
|
|
|
|
|
async def __aexit__(self, *exc_info: object) -> None: |
|
|
|
await self._router._shutdown() |
|
|
|
|
|
|
|
def __call__(self: _T, app: object) -> _T: |
|
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
# Cache for endpoint context to avoid re-extracting on every request |
|
|
|
_endpoint_context_cache: dict[int, EndpointContext] = {} |
|
|
|
|
|
|
|
@ -903,13 +980,33 @@ class APIRouter(routing.Router): |
|
|
|
), |
|
|
|
] = Default(generate_unique_id), |
|
|
|
) -> None: |
|
|
|
# Handle on_startup/on_shutdown locally since Starlette removed support |
|
|
|
# Ref: https://github.com/Kludex/starlette/pull/3117 |
|
|
|
# TODO: deprecate this once the lifespan (or alternative) interface is improved |
|
|
|
self.on_startup: list[Callable[[], Any]] = ( |
|
|
|
[] if on_startup is None else list(on_startup) |
|
|
|
) |
|
|
|
self.on_shutdown: list[Callable[[], Any]] = ( |
|
|
|
[] if on_shutdown is None else list(on_shutdown) |
|
|
|
) |
|
|
|
|
|
|
|
# Determine the lifespan context to use |
|
|
|
if lifespan is None: |
|
|
|
# Use the default lifespan that runs on_startup/on_shutdown handlers |
|
|
|
lifespan_context: Lifespan[Any] = _DefaultLifespan(self) |
|
|
|
elif inspect.isasyncgenfunction(lifespan): |
|
|
|
lifespan_context = asynccontextmanager(lifespan) |
|
|
|
elif inspect.isgeneratorfunction(lifespan): |
|
|
|
lifespan_context = _wrap_gen_lifespan_context(lifespan) |
|
|
|
else: |
|
|
|
lifespan_context = lifespan |
|
|
|
self.lifespan_context = lifespan_context |
|
|
|
|
|
|
|
super().__init__( |
|
|
|
routes=routes, |
|
|
|
redirect_slashes=redirect_slashes, |
|
|
|
default=default, |
|
|
|
on_startup=on_startup, |
|
|
|
on_shutdown=on_shutdown, |
|
|
|
lifespan=lifespan, |
|
|
|
lifespan=lifespan_context, |
|
|
|
) |
|
|
|
if prefix: |
|
|
|
assert prefix.startswith("/"), "A path prefix must start with '/'" |
|
|
|
@ -4473,6 +4570,58 @@ class APIRouter(routing.Router): |
|
|
|
generate_unique_id_function=generate_unique_id_function, |
|
|
|
) |
|
|
|
|
|
|
|
# TODO: remove this once the lifespan (or alternative) interface is improved |
|
|
|
async def _startup(self) -> None: |
|
|
|
""" |
|
|
|
Run any `.on_startup` event handlers. |
|
|
|
|
|
|
|
This method is kept for backward compatibility after Starlette removed |
|
|
|
support for on_startup/on_shutdown handlers. |
|
|
|
|
|
|
|
Ref: https://github.com/Kludex/starlette/pull/3117 |
|
|
|
""" |
|
|
|
for handler in self.on_startup: |
|
|
|
if is_async_callable(handler): |
|
|
|
await handler() |
|
|
|
else: |
|
|
|
handler() |
|
|
|
|
|
|
|
# TODO: remove this once the lifespan (or alternative) interface is improved |
|
|
|
async def _shutdown(self) -> None: |
|
|
|
""" |
|
|
|
Run any `.on_shutdown` event handlers. |
|
|
|
|
|
|
|
This method is kept for backward compatibility after Starlette removed |
|
|
|
support for on_startup/on_shutdown handlers. |
|
|
|
|
|
|
|
Ref: https://github.com/Kludex/starlette/pull/3117 |
|
|
|
""" |
|
|
|
for handler in self.on_shutdown: |
|
|
|
if is_async_callable(handler): |
|
|
|
await handler() |
|
|
|
else: |
|
|
|
handler() |
|
|
|
|
|
|
|
# TODO: remove this once the lifespan (or alternative) interface is improved |
|
|
|
def add_event_handler( |
|
|
|
self, |
|
|
|
event_type: str, |
|
|
|
func: Callable[[], Any], |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Add an event handler function for startup or shutdown. |
|
|
|
|
|
|
|
This method is kept for backward compatibility after Starlette removed |
|
|
|
support for on_startup/on_shutdown handlers. |
|
|
|
|
|
|
|
Ref: https://github.com/Kludex/starlette/pull/3117 |
|
|
|
""" |
|
|
|
assert event_type in ("startup", "shutdown") |
|
|
|
if event_type == "startup": |
|
|
|
self.on_startup.append(func) |
|
|
|
else: |
|
|
|
self.on_shutdown.append(func) |
|
|
|
|
|
|
|
@deprecated( |
|
|
|
""" |
|
|
|
on_event is deprecated, use lifespan event handlers instead. |
|
|
|
|