diff --git a/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md new file mode 100644 index 000000000..ba46da330 --- /dev/null +++ b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md @@ -0,0 +1,111 @@ +# Lifespan Scoped Dependencies + +## Intro + +So far we've used dependencies which are "endpoint scoped". Meaning, they are +called again and again for every incoming request to the endpoint. However, +this is not always ideal: + +* Sometimes dependencies have a large setup/teardown time. Running it for every request will result in bad performance. +* Sometimes dependencies need to have their values shared throughout the lifespan +of the application between multiple requests. + + +An example of this would be a connection to a database. Databases are typically +less efficient when working with lots of connections and would prefer that +clients would create a single connection for their operations. + +For such cases can be solved by using "lifespan scoped dependencies". + + +## What is a lifespan scoped dependency? +Lifespan scoped dependencies work similarly to the (endpoint scoped) +dependencies we've worked with so far. However, unlike endpoint scoped +dependencies, lifespan scoped dependencies are called once and only +once in the application's lifespan: + +* During the application startup process, all lifespan scoped dependencies will +be called. +* Their returned value will be shared across all requests to the application. +* During the application's shutdown process, all lifespan scoped dependencies +will be gracefully teared down. + + +## Create a lifespan scoped dependency + +You may declare a dependency as a lifespan scoped dependency by passing +`dependency_scope="lifespan"` to the `Depends` function: + +{* ../../docs_src/dependencies/tutorial013a_an_py39.py *} + +/// tip + +In the example above we saved the annotation to a separate variable, and then +reused it in our endpoints. This is not a requirement, we could also declare +the exact same annotation in both endpoints. However, it is recommended that you +do save the annotation to a variable so you won't accidentally forget to pass +`dependency_scope="lifespan"` to some of the endpoints (Causing the endpoint +to create a new database connection for every request). + +/// + +In this example, the `get_database_connection` dependency will be executed once, +during the application's startup. **FastAPI** will internally save the resulting +connection object, and whenever the `read_users` and `read_items` endpoints are +called, they will be using the previously saved connection. Once the application +shuts down, **FastAPI** will make sure to gracefully close the connection object. + +## The `use_cache` argument + +The `use_cache` argument works similarly to the way it worked with endpoint +scoped dependencies. Meaning as **FastAPI** gathers lifespan scoped dependencies, it +will cache dependencies it already encountered before. However, you can disable +this behavior by passing `use_cache=False` to `Depends`: + +{* ../../docs_src/dependencies/tutorial013b_an_py39.py *} + +In this example, the `read_users` and `read_groups` endpoints are using +`use_cache=False` whereas the `read_items` and `read_item` are using +`use_cache=True`. +That means that we'll have a total of 3 connections created +for the duration of the application's lifespan: + +* One connection will be shared across all requests for the `read_items` and `read_item` endpoints. +* A second connection will be shared across all requests for the `read_users` endpoint. +* A third and final connection will be shared across all requests for the `read_groups` endpoint. + + +## Lifespan Scoped Sub-Dependencies +Just like with endpoint scoped dependencies, lifespan scoped dependencies may +use other lifespan scoped sub-dependencies themselves: + +{* ../../docs_src/dependencies/tutorial013c_an_py39.py *} + +Endpoint scoped dependencies may use lifespan scoped sub dependencies as well: + +{* ../../docs_src/dependencies/tutorial013d_an_py39.py *} + +/// note + +You can pass `dependency_scope="endpoint"` if you wish to explicitly specify +that a dependency is endpoint scoped. It will work the same as not specifying +a dependency scope at all. + +/// + +As you can see, regardless of the scope, dependencies can use lifespan scoped +sub-dependencies. + +## Dependency Scope Conflicts +By definition, lifespan scoped dependencies are being setup in the application's +startup process, before any request is ever being made to any endpoint. +Therefore, it is not possible for a lifespan scoped dependency to use any +parameters that require the scope of an endpoint. + +That includes but not limited to: + +* Parts of the request (like `Body`, `Query` and `Path`) +* The request/response objects themselves (like `Request`, `Response` and `WebSocket`) +* Endpoint scoped sub-dependencies. + +Defining a dependency with such parameters will raise an `InvalidDependencyScope` error. diff --git a/docs/en/mkdocs.yml b/docs/en/mkdocs.yml index 8a5ea13e0..e4dcf7429 100644 --- a/docs/en/mkdocs.yml +++ b/docs/en/mkdocs.yml @@ -141,6 +141,7 @@ nav: - tutorial/dependencies/dependencies-in-path-operation-decorators.md - tutorial/dependencies/global-dependencies.md - tutorial/dependencies/dependencies-with-yield.md + - tutorial/dependencies/lifespan-scoped-dependencies.md - Security: - tutorial/security/index.md - tutorial/security/first-steps.md diff --git a/docs_src/dependencies/tutorial013a.py b/docs_src/dependencies/tutorial013a.py new file mode 100644 index 000000000..83687af24 --- /dev/null +++ b/docs_src/dependencies/tutorial013a.py @@ -0,0 +1,44 @@ +from typing import List + +from fastapi import Depends, FastAPI +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> List[dict]: + pass + + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") + + +@app.get("/users/") +async def read_users( + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, +): + return await database_connection.get_records("users") + + +@app.get("/items/") +async def read_items( + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, +): + return await database_connection.get_records("items") diff --git a/docs_src/dependencies/tutorial013a_an_py39.py b/docs_src/dependencies/tutorial013a_an_py39.py new file mode 100644 index 000000000..62f10a6e1 --- /dev/null +++ b/docs_src/dependencies/tutorial013a_an_py39.py @@ -0,0 +1,42 @@ +from typing import Annotated + +from fastapi import Depends, FastAPI +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> list[dict]: + pass + + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Annotated[ + MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan") +] + + +@app.get("/users/") +async def read_users(database_connection: GlobalDatabaseConnection): + return await database_connection.get_records("users") + + +@app.get("/items/") +async def read_items(database_connection: GlobalDatabaseConnection): + return await database_connection.get_records("items") diff --git a/docs_src/dependencies/tutorial013b.py b/docs_src/dependencies/tutorial013b.py new file mode 100644 index 000000000..3123b64f5 --- /dev/null +++ b/docs_src/dependencies/tutorial013b.py @@ -0,0 +1,65 @@ +from typing import List + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> List[dict]: + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") +DedicatedDatabaseConnection = Depends( + get_database_connection, dependency_scope="lifespan", use_cache=False +) + + +@app.get("/groups/") +async def read_groups( + database_connection: MyDatabaseConnection = DedicatedDatabaseConnection, +): + return await database_connection.get_records("groups") + + +@app.get("/users/") +async def read_users( + database_connection: MyDatabaseConnection = DedicatedDatabaseConnection, +): + return await database_connection.get_records("users") + + +@app.get("/items/") +async def read_items( + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, +): + return await database_connection.get_records("items") + + +@app.get("/items/{item_id}") +async def read_item( + item_id: str = Path(), + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, +): + return await database_connection.get_record("items", item_id) diff --git a/docs_src/dependencies/tutorial013b_an_py39.py b/docs_src/dependencies/tutorial013b_an_py39.py new file mode 100644 index 000000000..cc7205f40 --- /dev/null +++ b/docs_src/dependencies/tutorial013b_an_py39.py @@ -0,0 +1,61 @@ +from typing import Annotated + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> list[dict]: + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Annotated[ + MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan") +] +DedicatedDatabaseConnection = Annotated[ + MyDatabaseConnection, + Depends(get_database_connection, dependency_scope="lifespan", use_cache=False), +] + + +@app.get("/groups/") +async def read_groups(database_connection: DedicatedDatabaseConnection): + return await database_connection.get_records("groups") + + +@app.get("/users/") +async def read_users(database_connection: DedicatedDatabaseConnection): + return await database_connection.get_records("users") + + +@app.get("/items/") +async def read_items(database_connection: GlobalDatabaseConnection): + return await database_connection.get_records("items") + + +@app.get("/items/{item_id}") +async def read_item( + database_connection: GlobalDatabaseConnection, item_id: Annotated[str, Path()] +): + return await database_connection.get_record("items", item_id) diff --git a/docs_src/dependencies/tutorial013c.py b/docs_src/dependencies/tutorial013c.py new file mode 100644 index 000000000..c8814adc0 --- /dev/null +++ b/docs_src/dependencies/tutorial013c.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +@dataclass +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + connection_string: str + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + + +app = FastAPI() + + +async def get_configuration() -> dict: + return { + "database_url": "sqlite:///database.db", + } + + +GlobalConfiguration = Depends(get_configuration, dependency_scope="lifespan") + + +async def get_database_connection(configuration: dict = GlobalConfiguration): + async with MyDatabaseConnection(configuration["database_url"]) as connection: + yield connection + + +GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") + + +@app.get("/users/{user_id}") +async def read_user( + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, + user_id: str = Path(), +): + return await database_connection.get_record("users", user_id) diff --git a/docs_src/dependencies/tutorial013c_an_py39.py b/docs_src/dependencies/tutorial013c_an_py39.py new file mode 100644 index 000000000..a64e72b8a --- /dev/null +++ b/docs_src/dependencies/tutorial013c_an_py39.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from typing import Annotated + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +@dataclass +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + connection_string: str + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + + +app = FastAPI() + + +async def get_configuration() -> dict: + return { + "database_url": "sqlite:///database.db", + } + + +GlobalConfiguration = Annotated[ + dict, Depends(get_configuration, dependency_scope="lifespan") +] + + +async def get_database_connection(configuration: GlobalConfiguration): + async with MyDatabaseConnection(configuration["database_url"]) as connection: + yield connection + + +GlobalDatabaseConnection = Annotated[ + get_database_connection, + Depends(get_database_connection, dependency_scope="lifespan"), +] + + +@app.get("/users/{user_id}") +async def read_user( + database_connection: GlobalDatabaseConnection, user_id: Annotated[str, Path()] +): + return await database_connection.get_record("users", user_id) diff --git a/docs_src/dependencies/tutorial013d.py b/docs_src/dependencies/tutorial013d.py new file mode 100644 index 000000000..01e2831d7 --- /dev/null +++ b/docs_src/dependencies/tutorial013d.py @@ -0,0 +1,40 @@ +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") + + +async def get_user_record( + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, + user_id: str = Path(), +) -> dict: + return await database_connection.get_record("users", user_id) + + +@app.get("/users/{user_id}") +async def read_user(user_record: dict = Depends(get_user_record)): + return user_record diff --git a/docs_src/dependencies/tutorial013d_an_py39.py b/docs_src/dependencies/tutorial013d_an_py39.py new file mode 100644 index 000000000..fa6b0831b --- /dev/null +++ b/docs_src/dependencies/tutorial013d_an_py39.py @@ -0,0 +1,43 @@ +from typing import Annotated + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Annotated[ + MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan") +] + + +async def get_user_record( + database_connection: GlobalDatabaseConnection, user_id: Annotated[str, Path()] +) -> dict: + return await database_connection.get_record("users", user_id) + + +@app.get("/users/{user_id}") +async def read_user(user_record: Annotated[dict, Depends(get_user_record)]): + return user_record diff --git a/fastapi/applications.py b/fastapi/applications.py index 05c7bd2be..d24d52b26 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -1,6 +1,8 @@ +from contextlib import AsyncExitStack, asynccontextmanager from enum import Enum from typing import ( Any, + AsyncGenerator, Awaitable, Callable, Coroutine, @@ -15,12 +17,14 @@ from typing import ( from fastapi import routing from fastapi.datastructures import Default, DefaultPlaceholder +from fastapi.dependencies.utils import is_coroutine_callable from fastapi.exception_handlers import ( http_exception_handler, request_validation_exception_handler, websocket_request_validation_exception_handler, ) from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError +from fastapi.lifespan import resolve_lifespan_dependants from fastapi.logger import logger from fastapi.openapi.docs import ( get_redoc_html, @@ -29,9 +33,11 @@ from fastapi.openapi.docs import ( ) from fastapi.openapi.utils import get_openapi from fastapi.params import Depends +from fastapi.routing import merge_lifespan_context from fastapi.types import DecoratedCallable, IncEx from fastapi.utils import generate_unique_id from starlette.applications import Starlette +from starlette.concurrency import run_in_threadpool from starlette.datastructures import State from starlette.exceptions import HTTPException from starlette.middleware import Middleware @@ -929,12 +935,26 @@ class FastAPI(Starlette): """ ), ] = {} + if lifespan is None: + lifespan = FastAPI._internal_lifespan + else: + lifespan = merge_lifespan_context(FastAPI._internal_lifespan, lifespan) + + # Since we always use a lifespan, starlette will no longer run event + # handlers which are defined in the scope of the application. + # We therefore need to call them ourselves. + if on_startup is None: + on_startup = [] + + if on_shutdown is None: + on_shutdown = [] + self._on_startup = list(on_startup) + self._on_shutdown = list(on_shutdown) + self.router: routing.APIRouter = routing.APIRouter( routes=routes, redirect_slashes=redirect_slashes, dependency_overrides_provider=self, - on_startup=on_startup, - on_shutdown=on_shutdown, lifespan=lifespan, default_response_class=default_response_class, dependencies=dependencies, @@ -963,6 +983,30 @@ class FastAPI(Starlette): self.middleware_stack: Union[ASGIApp, None] = None self.setup() + @asynccontextmanager + async def _internal_lifespan(self) -> AsyncGenerator[Dict[str, Any], None]: + async with AsyncExitStack() as exit_stack: + lifespan_scoped_dependencies = await resolve_lifespan_dependants( + app=self, async_exit_stack=exit_stack + ) + try: + for handler in self._on_startup: + if is_coroutine_callable(handler): + await handler() + else: + await run_in_threadpool(handler) + yield { + "__fastapi__": { + "lifespan_scoped_dependencies": lifespan_scoped_dependencies + } + } + finally: + for handler in self._on_shutdown: + if is_coroutine_callable(handler): + await handler() + else: + await run_in_threadpool(handler) + def openapi(self) -> Dict[str, Any]: """ Generate the OpenAPI schema of the application. This is called by FastAPI @@ -4492,7 +4536,15 @@ class FastAPI(Starlette): Read more about it in the [FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/#alternative-events-deprecated). """ - return self.router.on_event(event_type) + + def decorator(func: DecoratedCallable) -> DecoratedCallable: + if event_type == "startup": + self._on_startup.append(func) + else: + self._on_shutdown.append(func) + return func + + return decorator def middleware( self, diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 418c11725..1f4192a51 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -1,8 +1,9 @@ from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast from fastapi._compat import ModelField from fastapi.security.base import SecurityBase +from typing_extensions import TypeAlias @dataclass @@ -11,17 +12,53 @@ class SecurityRequirement: scopes: Optional[Sequence[str]] = None +LifespanDependantCacheKey: TypeAlias = Union[ + Tuple[Callable[..., Any], Union[str, int]], Callable[..., Any] +] + + +@dataclass +class LifespanDependant: + call: Callable[..., Any] + caller: Callable[..., Any] + dependencies: List["LifespanDependant"] = field(default_factory=list) + name: Optional[str] = None + use_cache: bool = True + index: Optional[int] = None + cache_key: LifespanDependantCacheKey = field(init=False) + + def __post_init__(self) -> None: + if self.use_cache: + self.cache_key = self.call + elif self.name is not None: + self.cache_key = (self.caller, self.name) + else: + assert self.index is not None, ( + "Lifespan dependency must have an associated name or index." + ) + self.cache_key = (self.caller, self.index) + + +EndpointDependantCacheKey: TypeAlias = Tuple[ + Optional[Callable[..., Any]], Tuple[str, ...] +] + + @dataclass -class Dependant: +class EndpointDependant: + endpoint_dependencies: List["EndpointDependant"] = field(default_factory=list) + lifespan_dependencies: List[LifespanDependant] = field(default_factory=list) + name: Optional[str] = None + call: Optional[Callable[..., Any]] = None + use_cache: bool = True + index: Optional[int] = None + cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) path_params: List[ModelField] = field(default_factory=list) query_params: List[ModelField] = field(default_factory=list) header_params: List[ModelField] = field(default_factory=list) cookie_params: List[ModelField] = field(default_factory=list) body_params: List[ModelField] = field(default_factory=list) - dependencies: List["Dependant"] = field(default_factory=list) security_requirements: List[SecurityRequirement] = field(default_factory=list) - name: Optional[str] = None - call: Optional[Callable[..., Any]] = None request_param_name: Optional[str] = None websocket_param_name: Optional[str] = None http_connection_param_name: Optional[str] = None @@ -29,9 +66,26 @@ class Dependant: background_tasks_param_name: Optional[str] = None security_scopes_param_name: Optional[str] = None security_scopes: Optional[List[str]] = None - use_cache: bool = True path: Optional[str] = None - cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) def __post_init__(self) -> None: self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) + + # Kept for backwards compatibility + @property + def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]: + lifespan_dependencies = cast( + List[Union[EndpointDependant, LifespanDependant]], + self.lifespan_dependencies, + ) + endpoint_dependencies = cast( + List[Union[EndpointDependant, LifespanDependant]], + self.endpoint_dependencies, + ) + + return tuple(lifespan_dependencies + endpoint_dependencies) + + +# Kept for backwards compatibility +Dependant = EndpointDependant +CacheKey: TypeAlias = Union[EndpointDependantCacheKey, LifespanDependantCacheKey] diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 081b63a8b..51411b91e 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -51,7 +51,19 @@ from fastapi.concurrency import ( asynccontextmanager, contextmanager_in_threadpool, ) -from fastapi.dependencies.models import Dependant, SecurityRequirement +from fastapi.dependencies.models import ( + CacheKey, + EndpointDependant, + EndpointDependantCacheKey, + LifespanDependant, + LifespanDependantCacheKey, + SecurityRequirement, +) +from fastapi.exceptions import ( + DependencyScopeConflict, + InvalidDependencyScope, + UninitializedLifespanDependency, +) from fastapi.logger import logger from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes @@ -120,8 +132,9 @@ def get_param_sub_dependant( param_name: str, depends: params.Depends, path: str, + caller: Callable[..., Any], security_scopes: Optional[List[str]] = None, -) -> Dependant: +) -> Union[EndpointDependant, LifespanDependant]: assert depends.dependency return get_sub_dependant( depends=depends, @@ -129,14 +142,23 @@ def get_param_sub_dependant( path=path, name=param_name, security_scopes=security_scopes, + caller=caller, ) -def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: +def get_parameterless_sub_dependant( + *, depends: params.Depends, path: str, caller: Callable[..., Any], index: int +) -> Union[EndpointDependant, LifespanDependant]: assert callable(depends.dependency), ( "A parameter-less dependency must have a callable dependency" ) - return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) + return get_sub_dependant( + depends=depends, + dependency=depends.dependency, + path=path, + caller=caller, + index=index, + ) def get_sub_dependant( @@ -144,57 +166,72 @@ def get_sub_dependant( depends: params.Depends, dependency: Callable[..., Any], path: str, + caller: Callable[..., Any], name: Optional[str] = None, security_scopes: Optional[List[str]] = None, -) -> Dependant: - security_requirement = None - security_scopes = security_scopes or [] - if isinstance(depends, params.Security): - dependency_scopes = depends.scopes - security_scopes.extend(dependency_scopes) - if isinstance(dependency, SecurityBase): - use_scopes: List[str] = [] - if isinstance(dependency, (OAuth2, OpenIdConnect)): - use_scopes = security_scopes - security_requirement = SecurityRequirement( - security_scheme=dependency, scopes=use_scopes - ) - sub_dependant = get_dependant( - path=path, - call=dependency, - name=name, - security_scopes=security_scopes, - use_cache=depends.use_cache, - ) - if security_requirement: - sub_dependant.security_requirements.append(security_requirement) - return sub_dependant - - -CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] + index: Optional[int] = None, +) -> Union[EndpointDependant, LifespanDependant]: + if depends.dependency_scope == "lifespan": + return get_lifespan_dependant( + caller=caller, + call=dependency, + name=name, + use_cache=depends.use_cache, + index=index, + ) + elif depends.dependency_scope == "endpoint": + security_requirement = None + security_scopes = security_scopes or [] + if isinstance(depends, params.Security): + dependency_scopes = depends.scopes + security_scopes.extend(dependency_scopes) + if isinstance(dependency, SecurityBase): + use_scopes: List[str] = [] + if isinstance(dependency, (OAuth2, OpenIdConnect)): + use_scopes = security_scopes + security_requirement = SecurityRequirement( + security_scheme=dependency, scopes=use_scopes + ) + sub_dependant = get_endpoint_dependant( + path=path, + call=dependency, + name=name, + security_scopes=security_scopes, + use_cache=depends.use_cache, + index=index, + ) + if security_requirement: + sub_dependant.security_requirements.append(security_requirement) + return sub_dependant + else: + raise InvalidDependencyScope( + f'Dependency "{name}" of {caller} has an invalid ' + f'scope: "{depends.dependency_scope}"' + ) def get_flat_dependant( - dependant: Dependant, + dependant: EndpointDependant, *, skip_repeats: bool = False, visited: Optional[List[CacheKey]] = None, -) -> Dependant: +) -> EndpointDependant: if visited is None: visited = [] visited.append(dependant.cache_key) - flat_dependant = Dependant( + flat_dependant = EndpointDependant( path_params=dependant.path_params.copy(), query_params=dependant.query_params.copy(), header_params=dependant.header_params.copy(), cookie_params=dependant.cookie_params.copy(), body_params=dependant.body_params.copy(), security_requirements=dependant.security_requirements.copy(), + lifespan_dependencies=dependant.lifespan_dependencies.copy(), use_cache=dependant.use_cache, path=dependant.path, ) - for sub_dependant in dependant.dependencies: + for sub_dependant in dependant.endpoint_dependencies: if skip_repeats and sub_dependant.cache_key in visited: continue flat_sub = get_flat_dependant( @@ -206,6 +243,7 @@ def get_flat_dependant( flat_dependant.cookie_params.extend(flat_sub.cookie_params) flat_dependant.body_params.extend(flat_sub.body_params) flat_dependant.security_requirements.extend(flat_sub.security_requirements) + flat_dependant.lifespan_dependencies.extend(flat_sub.lifespan_dependencies) return flat_dependant @@ -219,7 +257,7 @@ def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]: return fields -def get_flat_params(dependant: Dependant) -> List[ModelField]: +def get_flat_params(dependant: EndpointDependant) -> List[ModelField]: flat_dependant = get_flat_dependant(dependant, skip_repeats=True) path_params = _get_flat_fields_from_params(flat_dependant.path_params) query_params = _get_flat_fields_from_params(flat_dependant.query_params) @@ -262,23 +300,75 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any: return get_typed_annotation(annotation, globalns) -def get_dependant( +def get_lifespan_dependant( + *, + caller: Callable[..., Any], + call: Callable[..., Any], + name: Optional[str] = None, + use_cache: bool = True, + index: Optional[int] = None, +) -> LifespanDependant: + dependency_signature = get_typed_signature(call) + signature_params = dependency_signature.parameters + dependant = LifespanDependant( + call=call, name=name, use_cache=use_cache, caller=caller, index=index + ) + for param_name, param in signature_params.items(): + param_details = analyze_param( + param_name=param_name, + annotation=param.annotation, + value=param.default, + is_path_param=False, + ) + if param_details.depends is None: + raise DependencyScopeConflict( + f'Lifespan scoped dependency "{dependant.name}" was defined ' + f'with an invalid argument: "{param_name}" which is ' + f'"endpoint" scoped. Lifespan scoped dependencies may only ' + f"use lifespan scoped sub-dependencies." + ) + + if param_details.depends.dependency_scope != "lifespan": + raise DependencyScopeConflict( + f"Lifespan scoped dependency {dependant.name} was defined with the " + f'sub-dependency "{param_name}" which is ' + f'"{param_details.depends.dependency_scope}" scoped. ' + f"Lifespan scoped dependencies may only use lifespan scoped " + f"sub-dependencies." + ) + + assert param_details.depends.dependency is not None + + sub_dependant = get_lifespan_dependant( + name=param_name, + call=param_details.depends.dependency, + use_cache=param_details.depends.use_cache, + caller=call, + ) + dependant.dependencies.append(sub_dependant) + + return dependant + + +def get_endpoint_dependant( *, path: str, call: Callable[..., Any], name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, -) -> Dependant: + index: Optional[int] = None, +) -> EndpointDependant: path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters - dependant = Dependant( + dependant = EndpointDependant( call=call, name=name, path=path, security_scopes=security_scopes, use_cache=use_cache, + index=index, ) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names @@ -294,8 +384,13 @@ def get_dependant( depends=param_details.depends, path=path, security_scopes=security_scopes, + caller=call, ) - dependant.dependencies.append(sub_dependant) + if isinstance(sub_dependant, EndpointDependant): + dependant.endpoint_dependencies.append(sub_dependant) + else: + assert isinstance(sub_dependant, LifespanDependant) + dependant.lifespan_dependencies.append(sub_dependant) continue if add_non_field_param_to_dependency( param_name=param_name, @@ -314,8 +409,12 @@ def get_dependant( return dependant +# Kept for backwards compatibility +get_dependant = get_endpoint_dependant + + def add_non_field_param_to_dependency( - *, param_name: str, type_annotation: Any, dependant: Dependant + *, param_name: str, type_annotation: Any, dependant: EndpointDependant ) -> Optional[bool]: if lenient_issubclass(type_annotation, Request): dependant.request_param_name = param_name @@ -511,7 +610,7 @@ def analyze_param( return ParamDetails(type_annotation=type_annotation, depends=depends, field=field) -def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: +def add_param_to_fields(*, field: ModelField, dependant: EndpointDependant) -> None: field_info = field.field_info field_info_in = getattr(field_info, "in_", None) if field_info_in == params.ParamTypes.path: @@ -560,36 +659,132 @@ async def solve_generator( return await stack.enter_async_context(cm) +@dataclass +class SolvedLifespanDependant: + value: Any + dependency_cache: Dict[LifespanDependantCacheKey, Any] + + +async def solve_lifespan_dependant( + *, + dependant: LifespanDependant, + dependency_overrides_provider: Optional[Any] = None, + dependency_cache: Optional[ + Dict[LifespanDependantCacheKey, Callable[..., Any]] + ] = None, + async_exit_stack: AsyncExitStack, +) -> SolvedLifespanDependant: + dependency_cache = dependency_cache or {} + if dependant.use_cache and dependant.cache_key in dependency_cache: + return SolvedLifespanDependant( + value=dependency_cache[dependant.cache_key], + dependency_cache=dependency_cache, + ) + + call = dependant.call + dependant_to_solve = dependant + if ( + dependency_overrides_provider + and dependency_overrides_provider.dependency_overrides + ): + call = getattr(dependency_overrides_provider, "dependency_overrides", {}).get( + dependant.call, dependant.call + ) + dependant_to_solve = get_lifespan_dependant( + caller=dependant.caller, + call=call, + name=dependant.name, + use_cache=dependant.use_cache, + index=dependant.index, + ) + + dependency_arguments: Dict[str, Any] = {} + for sub_dependant in dependant_to_solve.dependencies: + assert sub_dependant.name, ( + "Lifespan scoped dependencies should not be able to have " + "subdependencies with no name" + ) + solved_sub_dependant = await solve_lifespan_dependant( + dependant=sub_dependant, + dependency_overrides_provider=dependency_overrides_provider, + dependency_cache=dependency_cache, + async_exit_stack=async_exit_stack, + ) + dependency_cache.update(solved_sub_dependant.dependency_cache) + dependency_arguments[sub_dependant.name] = solved_sub_dependant.value + + if is_gen_callable(call) or is_async_gen_callable(call): + value = await solve_generator( + call=call, stack=async_exit_stack, sub_values=dependency_arguments + ) + elif is_coroutine_callable(call): + value = await call(**dependency_arguments) + else: + value = await run_in_threadpool(call, **dependency_arguments) + + if dependant.cache_key not in dependency_cache: + dependency_cache[dependant.cache_key] = value + + return SolvedLifespanDependant( + value=value, + dependency_cache=dependency_cache, + ) + + @dataclass class SolvedDependency: values: Dict[str, Any] errors: List[Any] background_tasks: Optional[StarletteBackgroundTasks] response: Response - dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] + dependency_cache: Dict[EndpointDependantCacheKey, Any] async def solve_dependencies( *, request: Union[Request, WebSocket], - dependant: Dependant, + dependant: EndpointDependant, body: Optional[Union[Dict[str, Any], FormData]] = None, background_tasks: Optional[StarletteBackgroundTasks] = None, response: Optional[Response] = None, dependency_overrides_provider: Optional[Any] = None, - dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, + dependency_cache: Optional[Dict[EndpointDependantCacheKey, Any]] = None, async_exit_stack: AsyncExitStack, embed_body_fields: bool, ) -> SolvedDependency: values: Dict[str, Any] = {} errors: List[Any] = [] + + for lifespan_sub_dependant in dependant.lifespan_dependencies: + if lifespan_sub_dependant.name is None: + continue + + try: + lifespan_scoped_dependencies = request.state.__fastapi__[ + "lifespan_scoped_dependencies" + ] + except (AttributeError, KeyError) as e: + raise UninitializedLifespanDependency( + "FastAPI's internal lifespan was not initialized correctly." + ) from e + + try: + value = lifespan_scoped_dependencies[lifespan_sub_dependant.cache_key] + except KeyError as e: + raise UninitializedLifespanDependency( + f'Dependency "{lifespan_sub_dependant.name}" of ' + f"`{dependant.call}` was not initialized correctly." + ) from e + + values[lifespan_sub_dependant.name] = value + if response is None: response = Response() del response.headers["content-length"] response.status_code = None # type: ignore + dependency_cache = dependency_cache or {} - sub_dependant: Dependant - for sub_dependant in dependant.dependencies: + for sub_dependant in dependant.endpoint_dependencies: sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) sub_dependant.cache_key = cast( Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key @@ -605,7 +800,7 @@ async def solve_dependencies( dependency_overrides_provider, "dependency_overrides", {} ).get(original_call, original_call) use_path: str = sub_dependant.path # type: ignore - use_sub_dependant = get_dependant( + use_sub_dependant = get_endpoint_dependant( path=use_path, call=call, name=sub_dependant.name, @@ -949,7 +1144,7 @@ async def request_body_to_args( def get_body_field( - *, flat_dependant: Dependant, name: str, embed_body_fields: bool + *, flat_dependant: EndpointDependant, name: str, embed_body_fields: bool ) -> Optional[ModelField]: """ Get a ModelField representing the request body for a path operation, combining diff --git a/fastapi/exceptions.py b/fastapi/exceptions.py index 44d4ada86..95fd60477 100644 --- a/fastapi/exceptions.py +++ b/fastapi/exceptions.py @@ -146,6 +146,22 @@ class FastAPIError(RuntimeError): """ +class DependencyError(FastAPIError): + pass + + +class InvalidDependencyScope(DependencyError): + pass + + +class DependencyScopeConflict(DependencyError): + pass + + +class UninitializedLifespanDependency(DependencyError): + pass + + class ValidationException(Exception): def __init__(self, errors: Sequence[Any]) -> None: self._errors = errors diff --git a/fastapi/lifespan.py b/fastapi/lifespan.py new file mode 100644 index 000000000..7d53fc00a --- /dev/null +++ b/fastapi/lifespan.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Any, Callable, Dict, List + +from fastapi.dependencies.models import LifespanDependant, LifespanDependantCacheKey +from fastapi.dependencies.utils import solve_lifespan_dependant +from fastapi.routing import APIRoute, APIWebSocketRoute + +if TYPE_CHECKING: # pragma: nocover + from fastapi import FastAPI + + +def _get_lifespan_dependants(app: FastAPI) -> List[LifespanDependant]: + lifespan_dependants_cache: Dict[LifespanDependantCacheKey, LifespanDependant] = {} + for route in app.router.routes: + if not isinstance(route, (APIWebSocketRoute, APIRoute)): + continue + + for sub_dependant in route.lifespan_dependencies: + if sub_dependant.cache_key in lifespan_dependants_cache: + continue + + lifespan_dependants_cache[sub_dependant.cache_key] = sub_dependant + + return list(lifespan_dependants_cache.values()) + + +async def resolve_lifespan_dependants( + *, app: FastAPI, async_exit_stack: AsyncExitStack +) -> Dict[LifespanDependantCacheKey, Callable[..., Any]]: + lifespan_dependants = _get_lifespan_dependants(app) + dependency_cache: Dict[LifespanDependantCacheKey, Callable[..., Any]] = {} + for lifespan_dependant in lifespan_dependants: + solved_dependency = await solve_lifespan_dependant( + dependant=lifespan_dependant, + dependency_overrides_provider=app, + dependency_cache=dependency_cache, + async_exit_stack=async_exit_stack, + ) + + dependency_cache.update(solved_dependency.dependency_cache) + + return dependency_cache diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 808646cc2..07bfa0f4b 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -15,7 +15,7 @@ from fastapi._compat import ( lenient_issubclass, ) from fastapi.datastructures import DefaultPlaceholder -from fastapi.dependencies.models import Dependant +from fastapi.dependencies.models import EndpointDependant from fastapi.dependencies.utils import ( _get_flat_fields_from_params, get_flat_dependant, @@ -76,7 +76,7 @@ status_code_ranges: Dict[str, str] = { def get_openapi_security_definitions( - flat_dependant: Dependant, + flat_dependant: EndpointDependant, ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: security_definitions = {} operation_security = [] @@ -94,7 +94,7 @@ def get_openapi_security_definitions( def _get_openapi_operation_parameters( *, - dependant: Dependant, + dependant: EndpointDependant, schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap, field_mapping: Dict[ diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index b3621626c..4112e90e2 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -1,8 +1,11 @@ +from __future__ import annotations + from typing import Any, Callable, Dict, List, Optional, Sequence, Union from fastapi import params from fastapi._compat import Undefined from fastapi.openapi.models import Example +from fastapi.params import DependencyScope from typing_extensions import Annotated, Doc, deprecated _Unset: Any = Undefined @@ -2244,6 +2247,33 @@ def Depends( # noqa: N802 """ ), ] = True, + dependency_scope: Annotated[ + DependencyScope, + Doc( + """ + The scope in which the dependency value should be evaluated. Can be + either `"endpoint"` or `"lifespan"`. + + If `dependency_scope` is set to "endpoint" (the default), the + dependency will be setup and teardown for every request. + + If `dependency_scope` is set to `"lifespan"` the dependency would + be setup at the start of the entire application's lifespan. The + evaluated dependency would be then reused across all endpoints. + The dependency would be teared down as a part of the application's + shutdown process. + + Note that dependencies defined with the `"endpoint"` scope may use + sub-dependencies defined with the `"lifespan"` scope, but not the + other way around; + Dependencies defined with the `"lifespan"` scope may not use + sub-dependencies with `"endpoint"` scope, nor can they use + other "endpoint scoped" arguments such as "Path", "Body", "Query", + or any other annotation which does not make sense in a scope of an + application's entire lifespan. + """ + ), + ] = "endpoint", ) -> Any: """ Declare a FastAPI dependency. @@ -2274,7 +2304,9 @@ def Depends( # noqa: N802 return commons ``` """ - return params.Depends(dependency=dependency, use_cache=use_cache) + return params.Depends( + dependency=dependency, use_cache=use_cache, dependency_scope=dependency_scope + ) def Security( # noqa: N802 diff --git a/fastapi/params.py b/fastapi/params.py index 8f5601dd3..17af800b7 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union from fastapi.openapi.models import Example from pydantic.fields import FieldInfo -from typing_extensions import Annotated, deprecated +from typing_extensions import Annotated, Literal, TypeAlias, deprecated from ._compat import ( PYDANTIC_V2, @@ -13,6 +13,7 @@ from ._compat import ( ) _Unset: Any = Undefined +DependencyScope: TypeAlias = Literal["endpoint", "lifespan"] class ParamTypes(Enum): @@ -763,15 +764,25 @@ class File(Form): class Depends: def __init__( - self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True + self, + dependency: Optional[Callable[..., Any]] = None, + *, + use_cache: bool = True, + dependency_scope: DependencyScope = "endpoint", ): self.dependency = dependency self.use_cache = use_cache + self.dependency_scope = dependency_scope def __repr__(self) -> str: attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) cache = "" if self.use_cache else ", use_cache=False" - return f"{self.__class__.__name__}({attr}{cache})" + if self.dependency_scope == "endpoint": + dependency_scope = "" + else: + dependency_scope = f', dependency_scope="{self.dependency_scope}"' + + return f"{self.__class__.__name__}({attr}{cache}{dependency_scope})" class Security(Depends): @@ -782,5 +793,7 @@ class Security(Depends): scopes: Optional[Sequence[str]] = None, use_cache: bool = True, ): - super().__init__(dependency=dependency, use_cache=use_cache) + super().__init__( + dependency=dependency, use_cache=use_cache, dependency_scope="endpoint" + ) self.scopes = scopes or [] diff --git a/fastapi/routing.py b/fastapi/routing.py index 54c75a027..aaceacc1b 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -32,11 +32,11 @@ from fastapi._compat import ( lenient_issubclass, ) from fastapi.datastructures import Default, DefaultPlaceholder -from fastapi.dependencies.models import Dependant +from fastapi.dependencies.models import EndpointDependant, LifespanDependant from fastapi.dependencies.utils import ( _should_embed_body_fields, get_body_field, - get_dependant, + get_endpoint_dependant, get_flat_dependant, get_parameterless_sub_dependant, get_typed_return_annotation, @@ -124,7 +124,7 @@ def _prepare_response_content( return res -def _merge_lifespan_context( +def merge_lifespan_context( original_context: Lifespan[Any], nested_context: Lifespan[Any] ) -> Lifespan[Any]: @asynccontextmanager @@ -203,7 +203,7 @@ async def serialize_response( async def run_endpoint_function( - *, dependant: Dependant, values: Dict[str, Any], is_coroutine: bool + *, dependant: EndpointDependant, values: Dict[str, Any], is_coroutine: bool ) -> Any: # Only called by get_request_handler. Has been split into its own function to # facilitate profiling endpoints, since inner functions are harder to profile. @@ -216,7 +216,7 @@ async def run_endpoint_function( def get_request_handler( - dependant: Dependant, + dependant: EndpointDependant, body_field: Optional[ModelField] = None, status_code: Optional[int] = None, response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse), @@ -359,7 +359,7 @@ def get_request_handler( def get_websocket_app( - dependant: Dependant, + dependant: EndpointDependant, dependency_overrides_provider: Optional[Any] = None, embed_body_fields: bool = False, ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: @@ -401,12 +401,20 @@ class APIWebSocketRoute(routing.WebSocketRoute): self.name = get_name(endpoint) if name is None else name self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) - self.dependant = get_dependant(path=self.path_format, call=self.endpoint) - for depends in self.dependencies[::-1]: - self.dependant.dependencies.insert( - 0, - get_parameterless_sub_dependant(depends=depends, path=self.path_format), + self.dependant = get_endpoint_dependant( + path=self.path_format, call=self.endpoint + ) + for i, depends in list(enumerate(self.dependencies))[::-1]: + sub_dependant = get_parameterless_sub_dependant( + depends=depends, path=self.path_format, caller=self.__call__, index=i ) + if isinstance(sub_dependant, EndpointDependant): + assert isinstance(sub_dependant, EndpointDependant) + self.dependant.endpoint_dependencies.insert(0, sub_dependant) + else: + assert isinstance(sub_dependant, LifespanDependant) + self.dependant.lifespan_dependencies.insert(0, sub_dependant) + self._flat_dependant = get_flat_dependant(self.dependant) self._embed_body_fields = _should_embed_body_fields( self._flat_dependant.body_params @@ -425,6 +433,10 @@ class APIWebSocketRoute(routing.WebSocketRoute): child_scope["route"] = self return match, child_scope + @property + def lifespan_dependencies(self) -> List[LifespanDependant]: + return self._flat_dependant.lifespan_dependencies + class APIRoute(routing.Route): def __init__( @@ -552,12 +564,19 @@ class APIRoute(routing.Route): self.response_fields = {} assert callable(endpoint), "An endpoint must be a callable" - self.dependant = get_dependant(path=self.path_format, call=self.endpoint) - for depends in self.dependencies[::-1]: - self.dependant.dependencies.insert( - 0, - get_parameterless_sub_dependant(depends=depends, path=self.path_format), + self.dependant = get_endpoint_dependant( + path=self.path_format, call=self.endpoint + ) + for i, depends in list(enumerate(self.dependencies))[::-1]: + sub_dependant = get_parameterless_sub_dependant( + depends=depends, path=self.path_format, caller=self.__call__, index=i ) + if isinstance(sub_dependant, EndpointDependant): + self.dependant.endpoint_dependencies.insert(0, sub_dependant) + else: + assert isinstance(sub_dependant, LifespanDependant) + self.dependant.lifespan_dependencies.insert(0, sub_dependant) + self._flat_dependant = get_flat_dependant(self.dependant) self._embed_body_fields = _should_embed_body_fields( self._flat_dependant.body_params @@ -592,6 +611,10 @@ class APIRoute(routing.Route): child_scope["route"] = self return match, child_scope + @property + def lifespan_dependencies(self) -> List[LifespanDependant]: + return self._flat_dependant.lifespan_dependencies + class APIRouter(routing.Router): """ @@ -1359,7 +1382,7 @@ class APIRouter(routing.Router): self.add_event_handler("startup", handler) for handler in router.on_shutdown: self.add_event_handler("shutdown", handler) - self.lifespan_context = _merge_lifespan_context( + self.lifespan_context = merge_lifespan_context( self.lifespan_context, router.lifespan_context, ) diff --git a/tests/test_lifespan_scoped_dependencies/__init__.py b/tests/test_lifespan_scoped_dependencies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py new file mode 100644 index 000000000..eb61ed248 --- /dev/null +++ b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py @@ -0,0 +1,624 @@ +from typing import Any, AsyncGenerator, List, Tuple + +import pytest +from fastapi import ( + APIRouter, + BackgroundTasks, + Body, + Cookie, + Depends, + FastAPI, + File, + Form, + Header, + Path, + Query, + Request, + WebSocket, +) +from fastapi.exceptions import DependencyScopeConflict +from fastapi.params import Security +from fastapi.security import SecurityScopes +from fastapi.testclient import TestClient +from typing_extensions import Annotated, Literal + +from tests.test_lifespan_scoped_dependencies.testing_utilities import ( + DependencyFactory, + DependencyStyle, + IntentionallyBadDependency, + create_endpoint_0_annotations, + create_endpoint_1_annotation, + create_endpoint_3_annotations, + use_endpoint, + use_websocket, +) + + +def expect_correct_amount_of_dependency_activations( + *, + app: FastAPI, + dependency_factory: DependencyFactory, + override_dependency_factory: DependencyFactory, + urls_and_responses: List[Tuple[str, Any]], + expected_activation_times: int, + is_websocket: bool, +) -> None: + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + assert override_dependency_factory.activation_times == 0 + assert override_dependency_factory.deactivation_times == 0 + + with TestClient(app) as client: + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + assert override_dependency_factory.activation_times == expected_activation_times + assert override_dependency_factory.deactivation_times == 0 + + for url, expected_response in urls_and_responses: + if is_websocket: + response = use_websocket(client, url) + else: + response = use_endpoint(client, url) + + assert response == expected_response + + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + assert ( + override_dependency_factory.activation_times + == expected_activation_times + ) + assert override_dependency_factory.deactivation_times == 0 + + assert dependency_factory.activation_times == 0 + assert override_dependency_factory.activation_times == expected_activation_times + if dependency_factory.dependency_style not in ( + DependencyStyle.SYNC_FUNCTION, + DependencyStyle.ASYNC_FUNCTION, + ): + assert dependency_factory.deactivation_times == 0 + assert ( + override_dependency_factory.deactivation_times == expected_activation_times + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoint_dependencies( + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + None, + Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ), + ], + expected_value=11, + ) + if routing_style == "router_endpoint": + app.include_router(router) + + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + urls_and_responses=[("/test", 11)] * 2, + expected_activation_times=1, + is_websocket=is_websocket, + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("dependency_duplication", [1, 2]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_router_dependencies( + dependency_style: DependencyStyle, + routing_style, + use_cache, + dependency_duplication, + is_websocket, +): + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + if routing_style == "app": + app = FastAPI(dependencies=[depends] * dependency_duplication) + + create_endpoint_0_annotations( + router=app, path="/test", is_websocket=is_websocket + ) + else: + app = FastAPI() + router = APIRouter(dependencies=[depends] * dependency_duplication) + + create_endpoint_0_annotations( + router=router, path="/test", is_websocket=is_websocket + ) + + app.include_router(router) + + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + urls_and_responses=[("/test", None)] * 2, + expected_activation_times=1 if use_cache else dependency_duplication, + is_websocket=is_websocket, + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) +def test_dependency_cache_in_same_dependency( + dependency_style: DependencyStyle, + routing_style, + use_cache, + main_dependency_scope: Literal["endpoint", "lifespan"], + is_websocket, +): + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def dependency( + sub_dependency1: Annotated[int, depends], + sub_dependency2: Annotated[int, depends], + ) -> List[int]: + return [sub_dependency1, sub_dependency2] + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + List[int], + Depends( + dependency, + use_cache=use_cache, + dependency_scope=main_dependency_scope, + ), + ], + ) + + if routing_style == "router": + app.include_router(router) + + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [11, 11]), + ("/test", [11, 11]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket, + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [11, 12]), + ("/test", [11, 12]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=2, + is_websocket=is_websocket, + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_same_endpoint( + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + create_endpoint_3_annotations( + router=router, + path="/test1", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)], + ) + + if routing_style == "router": + app.include_router(router) + + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [11, 11, 11]), + ("/test1", [11, 11, 11]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket, + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [11, 12, 13]), + ("/test1", [11, 12, 13]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=3, + is_websocket=is_websocket, + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_different_endpoints( + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + create_endpoint_3_annotations( + router=router, + path="/test1", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)], + ) + + create_endpoint_3_annotations( + router=router, + path="/test2", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)], + ) + + if routing_style == "router": + app.include_router(router) + + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [11, 11, 11]), + ("/test2", [11, 11, 11]), + ("/test1", [11, 11, 11]), + ("/test2", [11, 11, 11]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket, + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [11, 12, 13]), + ("/test2", [14, 15, 13]), + ("/test1", [11, 12, 13]), + ("/test2", [14, 15, 13]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=5, + is_websocket=is_websocket, + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_no_cached_dependency( + dependency_style: DependencyStyle, routing_style, is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=False, + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + ) + + if routing_style == "router": + app.include_router(router) + + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + urls_and_responses=[("/test", 11)] * 2, + expected_activation_times=1, + is_websocket=is_websocket, + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize( + "annotation", + [ + Annotated[str, Path()], + Annotated[str, Body()], + Annotated[str, Query()], + Annotated[str, Header()], + SecurityScopes, + Annotated[str, Cookie()], + Annotated[str, Form()], + Annotated[str, File()], + BackgroundTasks, + Request, + WebSocket, + ], +) +def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( + annotation, is_websocket +): + async def dependency_func() -> None: + yield # pragma: nocover + + async def override_dependency_func(param: annotation) -> None: + yield # pragma: nocover + + app = FastAPI() + app.dependency_overrides[dependency_func] = override_dependency_func + + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + None, Depends(dependency_func, dependency_scope="lifespan") + ], + ) + + with pytest.raises(DependencyScopeConflict): + with TestClient(app): + pass + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies( + dependency_style: DependencyStyle, is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) + + async def lifespan_scoped_dependency( + param: Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], + ) -> AsyncGenerator[int, None]: + yield param + + app = FastAPI() + + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + int, Depends(lifespan_scoped_dependency, dependency_scope="lifespan") + ], + ) + + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=1, + urls_and_responses=[("/test", 11)] * 2, + is_websocket=is_websocket, + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("depends_class", [Depends, Security]) +def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( + depends_class, is_websocket +): + async def sub_dependency() -> None: + pass # pragma: nocover + + async def dependency_func() -> None: + yield # pragma: nocover + + async def override_dependency_func( + param: Annotated[None, depends_class(sub_dependency)], + ) -> None: + yield # pragma: nocover + + app = FastAPI() + + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + None, Depends(dependency_func, dependency_scope="lifespan") + ], + ) + + app.dependency_overrides[dependency_func] = override_dependency_func + + with pytest.raises(DependencyScopeConflict): + with TestClient(app): + pass + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_bad_override_lifespan_scoped_dependencies( + use_cache, dependency_style: DependencyStyle, routing_style, is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, should_error=True) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + ) + + if routing_style == "router_endpoint": + app.include_router(router) + + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) + + with pytest.raises(IntentionallyBadDependency) as exception_info: + with TestClient(app): + pass + + assert exception_info.value.args == (1,) diff --git a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py new file mode 100644 index 000000000..0770c07fb --- /dev/null +++ b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py @@ -0,0 +1,920 @@ +import warnings +from contextlib import asynccontextmanager +from time import sleep +from typing import Any, AsyncGenerator, Dict, List, Tuple + +import pytest +from fastapi import ( + APIRouter, + BackgroundTasks, + Body, + Cookie, + Depends, + FastAPI, + File, + Form, + Header, + Path, + Query, + Request, + WebSocket, +) +from fastapi.dependencies.utils import get_endpoint_dependant +from fastapi.exceptions import ( + DependencyScopeConflict, + InvalidDependencyScope, + UninitializedLifespanDependency, +) +from fastapi.params import Security +from fastapi.security import SecurityScopes +from fastapi.testclient import TestClient +from typing_extensions import Annotated, Literal + +from tests.test_lifespan_scoped_dependencies.testing_utilities import ( + DependencyFactory, + DependencyStyle, + IntentionallyBadDependency, + create_endpoint_0_annotations, + create_endpoint_1_annotation, + create_endpoint_2_annotations, + create_endpoint_3_annotations, + use_endpoint, + use_websocket, +) + + +def expect_correct_amount_of_dependency_activations( + *, + app: FastAPI, + dependency_factory: DependencyFactory, + urls_and_responses: List[Tuple[str, Any]], + expected_activation_times: int, + is_websocket: bool, +) -> None: + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + with TestClient(app) as client: + assert dependency_factory.activation_times == expected_activation_times + assert dependency_factory.deactivation_times == 0 + + for url, expected_response in urls_and_responses: + if is_websocket: + assert use_websocket(client, url) == expected_response + else: + assert use_endpoint(client, url) == expected_response + + assert dependency_factory.activation_times == expected_activation_times + assert dependency_factory.deactivation_times == 0 + + assert dependency_factory.activation_times == expected_activation_times + if dependency_factory.dependency_style not in ( + DependencyStyle.SYNC_FUNCTION, + DependencyStyle.ASYNC_FUNCTION, + ): + assert dependency_factory.deactivation_times == expected_activation_times + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize( + "use_cache", [True, False], ids=["With Cache", "Without Cache"] +) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoint_dependencies( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket: bool, +): + dependency_factory = DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + None, + Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ), + ], + expected_value=1, + ) + + if routing_style == "router_endpoint": + app.include_router(router) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + urls_and_responses=[("/test", 1)] * 2, + expected_activation_times=1, + is_websocket=is_websocket, + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("dependency_duplication", [1, 2]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_router_dependencies( + dependency_style: DependencyStyle, + routing_style, + use_cache, + dependency_duplication, + is_websocket: bool, +): + dependency_factory = DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + if routing_style == "app": + app = FastAPI(dependencies=[depends] * dependency_duplication) + + create_endpoint_0_annotations( + router=app, path="/test", is_websocket=is_websocket + ) + else: + app = FastAPI() + router = APIRouter(dependencies=[depends] * dependency_duplication) + + create_endpoint_0_annotations( + router=router, path="/test", is_websocket=is_websocket + ) + + app.include_router(router) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + urls_and_responses=[("/test", None)] * 2, + expected_activation_times=1 if use_cache else dependency_duplication, + is_websocket=is_websocket, + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) +def test_dependency_cache_in_same_dependency( + dependency_style: DependencyStyle, + routing_style, + use_cache, + main_dependency_scope: Literal["endpoint", "lifespan"], + is_websocket: bool, +): + dependency_factory = DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def dependency( + sub_dependency1: Annotated[int, depends], + sub_dependency2: Annotated[int, depends], + ) -> List[int]: + return [sub_dependency1, sub_dependency2] + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + List[int], + Depends( + dependency, + use_cache=use_cache, + dependency_scope=main_dependency_scope, + ), + ], + ) + + if routing_style == "router": + app.include_router(router) + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 1]), + ("/test", [1, 1]), + ], + dependency_factory=dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket, + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 2]), + ("/test", [1, 2]), + ], + dependency_factory=dependency_factory, + expected_activation_times=2, + is_websocket=is_websocket, + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_same_endpoint( + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + create_endpoint_3_annotations( + router=router, + path="/test", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)], + ) + + if routing_style == "router": + app.include_router(router) + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 1, 1]), + ("/test", [1, 1, 1]), + ], + dependency_factory=dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket, + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 2, 3]), + ("/test", [1, 2, 3]), + ], + dependency_factory=dependency_factory, + expected_activation_times=3, + is_websocket=is_websocket, + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_different_endpoints( + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + create_endpoint_3_annotations( + router=router, + path="/test1", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)], + ) + + create_endpoint_3_annotations( + router=router, + path="/test2", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)], + ) + + if routing_style == "router": + app.include_router(router) + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [1, 1, 1]), + ("/test2", [1, 1, 1]), + ("/test1", [1, 1, 1]), + ("/test2", [1, 1, 1]), + ], + dependency_factory=dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket, + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [1, 2, 3]), + ("/test2", [4, 5, 3]), + ("/test1", [1, 2, 3]), + ("/test2", [4, 5, 3]), + ], + dependency_factory=dependency_factory, + expected_activation_times=5, + is_websocket=is_websocket, + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_no_cached_dependency( + dependency_style: DependencyStyle, + routing_style, + is_websocket, +): + dependency_factory = DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=False, + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + expected_value=1, + ) + + if routing_style == "router": + app.include_router(router) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + urls_and_responses=[("/test", 1)] * 2, + expected_activation_times=1, + is_websocket=is_websocket, + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize( + "annotation", + [ + Annotated[str, Path()], + Annotated[str, Body()], + Annotated[str, Query()], + Annotated[str, Header()], + SecurityScopes, + Annotated[str, Cookie()], + Annotated[str, Form()], + Annotated[str, File()], + BackgroundTasks, + Request, + WebSocket, + ], +) +def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( + annotation, is_websocket +): + async def dependency_func(param: annotation) -> None: + yield # pragma: nocover + + app = FastAPI() + + with pytest.raises(DependencyScopeConflict): + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + None, Depends(dependency_func, dependency_scope="lifespan") + ], + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( + dependency_style: DependencyStyle, is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + + async def lifespan_scoped_dependency( + param: Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], + ) -> AsyncGenerator[int, None]: + yield param + + app = FastAPI() + + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, Depends(lifespan_scoped_dependency)], + expected_value=1, + ) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + expected_activation_times=1, + urls_and_responses=[("/test", 1)] * 2, + is_websocket=is_websocket, + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize( + ["dependency_style", "supports_teardown"], + [ + (DependencyStyle.SYNC_FUNCTION, False), + (DependencyStyle.ASYNC_FUNCTION, False), + (DependencyStyle.SYNC_GENERATOR, True), + (DependencyStyle.ASYNC_GENERATOR, True), + ], +) +def test_the_same_dependency_can_work_in_different_scopes( + dependency_style: DependencyStyle, supports_teardown, is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + app = FastAPI() + + create_endpoint_2_annotations( + router=app, + path="/test", + is_websocket=is_websocket, + annotation1=Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="endpoint"), + ], + annotation2=Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], + ) + if is_websocket: + get_response = use_websocket + else: + get_response = use_endpoint + + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + with TestClient(app) as client: + assert dependency_factory.activation_times == 1 + assert dependency_factory.deactivation_times == 0 + + assert get_response(client, "/test") == [2, 1] + assert dependency_factory.activation_times == 2 + if supports_teardown: + if is_websocket: + # Websockets teardown might take some time after the test client + # has disconnected + sleep(0.1) + assert dependency_factory.deactivation_times == 1 + else: + assert dependency_factory.deactivation_times == 0 + + assert get_response(client, "/test") == [3, 1] + assert dependency_factory.activation_times == 3 + if supports_teardown: + if is_websocket: + # Websockets teardown might take some time after the test client + # has disconnected + sleep(0.1) + assert dependency_factory.deactivation_times == 2 + else: + assert dependency_factory.deactivation_times == 0 + + assert dependency_factory.activation_times == 3 + if supports_teardown: + assert dependency_factory.deactivation_times == 3 + else: + assert dependency_factory.deactivation_times == 0 + + +@pytest.mark.parametrize( + "lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"] +) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( + dependency_style: DependencyStyle, + is_websocket, + lifespan_style: Literal["lifespan_function", "lifespan_events"], +): + lifespan_started = False + lifespan_ended = False + if lifespan_style == "lifespan_generator": + + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, int], None]: + nonlocal lifespan_started + nonlocal lifespan_ended + lifespan_started = True + yield + lifespan_ended = True + + app = FastAPI(lifespan=lifespan) + elif lifespan_style == "events_decorator": + app = FastAPI() + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + + @app.on_event("startup") + async def startup() -> None: + nonlocal lifespan_started + lifespan_started = True + + @app.on_event("shutdown") + async def shutdown() -> None: + nonlocal lifespan_ended + lifespan_ended = True + else: + assert lifespan_style == "events_constructor" + + async def startup() -> None: + nonlocal lifespan_started + lifespan_started = True + + async def shutdown() -> None: + nonlocal lifespan_ended + lifespan_ended = True + + app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) + + dependency_factory = DependencyFactory(dependency_style) + + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], + expected_value=1, + ) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + expected_activation_times=1, + urls_and_responses=[("/test", 1)] * 2, + is_websocket=is_websocket, + ) + assert lifespan_started and lifespan_ended + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("depends_class", [Depends, Security]) +def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( + depends_class, is_websocket +): + async def sub_dependency() -> None: + pass # pragma: nocover + + async def dependency_func( + param: Annotated[None, depends_class(sub_dependency)], + ) -> None: + pass # pragma: nocover + + app = FastAPI() + + with pytest.raises(DependencyScopeConflict): + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + None, Depends(dependency_func, dependency_scope="lifespan") + ], + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_dependencies_must_provide_correct_dependency_scope( + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + with pytest.raises( + InvalidDependencyScope, + match=r'Dependency "value" of .* has an invalid scope: ' r'"incorrect"', + ): + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + None, + Depends( + dependency_factory.get_dependency(), + dependency_scope="incorrect", + use_cache=use_cache, + ), + ], + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoints_report_incorrect_dependency_scope( + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + # We intentionally change the dependency scope here to bypass the + # validation at the function level. + depends.dependency_scope = "asdad" + + with pytest.raises(InvalidDependencyScope): + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_endpoints_report_incorrect_dependency_scope_at_router_scope( + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket +): + dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) + + depends = Depends(dependency_factory.get_dependency(), dependency_scope="lifespan") + + # We intentionally change the dependency scope here to bypass the + # validation at the function level. + depends.dependency_scope = "asdad" + + if routing_style == "app": + app = FastAPI(dependencies=[depends]) + router = app + else: + router = APIRouter(dependencies=[depends]) + + with pytest.raises(InvalidDependencyScope): + create_endpoint_0_annotations( + router=router, + path="/test", + is_websocket=is_websocket, + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoints_report_uninitialized_dependency( + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + expected_value=1, + ) + + if routing_style == "router_endpoint": + app.include_router(router) + + with TestClient(app) as client: + dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] + client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {} + + try: + with pytest.raises(UninitializedLifespanDependency): + if is_websocket: + with client.websocket_connect("/test"): + pass # pragma: nocover + else: + client.post("/test") + finally: + client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = ( + dependencies + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoints_report_uninitialized_internal_lifespan( + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + expected_value=1, + ) + + if routing_style == "router_endpoint": + app.include_router(router) + + with TestClient(app) as client: + internal_state = client.app_state["__fastapi__"] + del client.app_state["__fastapi__"] + + try: + with pytest.raises(UninitializedLifespanDependency): + if is_websocket: + with client.websocket_connect("/test"): + pass # pragma: nocover + else: + client.post("/test") + finally: + client.app_state["__fastapi__"] = internal_state + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_bad_lifespan_scoped_dependencies( + use_cache, dependency_style: DependencyStyle, routing_style, is_websocket +): + dependency_factory = DependencyFactory(dependency_style, should_error=True) + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + expected_value=1, + ) + + if routing_style == "router_endpoint": + app.include_router(router) + + with pytest.raises(IntentionallyBadDependency) as exception_info: + with TestClient(app): + pass + + assert exception_info.value.args == (1,) + + +def test_endpoint_dependant_backwards_compatibility(): + dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) + + def endpoint( + dependency1: Annotated[int, Depends(dependency_factory.get_dependency())], + dependency2: Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], + ): + pass # pragma: nocover + + dependant = get_endpoint_dependant( + path="/test", + call=endpoint, + name="endpoint", + ) + + assert dependant.dependencies == tuple( + dependant.lifespan_dependencies + dependant.endpoint_dependencies + ) diff --git a/tests/test_lifespan_scoped_dependencies/testing_utilities.py b/tests/test_lifespan_scoped_dependencies/testing_utilities.py new file mode 100644 index 000000000..88e0925bc --- /dev/null +++ b/tests/test_lifespan_scoped_dependencies/testing_utilities.py @@ -0,0 +1,193 @@ +from enum import Enum +from typing import Any, AsyncGenerator, Generator, List, TypeVar, Union + +from fastapi import APIRouter, FastAPI, WebSocket +from fastapi.testclient import TestClient +from typing_extensions import assert_never + +T = TypeVar("T") + + +class DependencyStyle(str, Enum): + SYNC_FUNCTION = "sync_function" + ASYNC_FUNCTION = "async_function" + SYNC_GENERATOR = "sync_generator" + ASYNC_GENERATOR = "async_generator" + + +class IntentionallyBadDependency(Exception): + pass + + +class DependencyFactory: + def __init__( + self, + dependency_style: DependencyStyle, + *, + should_error: bool = False, + value_offset: int = 0, + ): + self.activation_times = 0 + self.deactivation_times = 0 + self.dependency_style = dependency_style + self._should_error = should_error + self._value_offset = value_offset + + def get_dependency(self): + if self.dependency_style == DependencyStyle.SYNC_FUNCTION: + return self._synchronous_function_dependency + + if self.dependency_style == DependencyStyle.SYNC_GENERATOR: + return self._synchronous_generator_dependency + + if self.dependency_style == DependencyStyle.ASYNC_FUNCTION: + return self._asynchronous_function_dependency + + if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: + return self._asynchronous_generator_dependency + + assert_never(self.dependency_style) # pragma: nocover + + async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: + self.activation_times += 1 + if self._should_error: + raise IntentionallyBadDependency(self.activation_times) + + yield self.activation_times + self._value_offset + self.deactivation_times += 1 + + def _synchronous_generator_dependency(self) -> Generator[T, None, None]: + self.activation_times += 1 + if self._should_error: + raise IntentionallyBadDependency(self.activation_times) + + yield self.activation_times + self._value_offset + self.deactivation_times += 1 + + async def _asynchronous_function_dependency(self) -> T: + self.activation_times += 1 + if self._should_error: + raise IntentionallyBadDependency(self.activation_times) + + return self.activation_times + self._value_offset + + def _synchronous_function_dependency(self) -> T: + self.activation_times += 1 + if self._should_error: + raise IntentionallyBadDependency(self.activation_times) + + return self.activation_times + self._value_offset + + +def use_endpoint(client: TestClient, url: str) -> Any: + response = client.post(url) + response.raise_for_status() + return response.json() + + +def use_websocket(client: TestClient, url: str) -> Any: + with client.websocket_connect(url) as connection: + return connection.receive_json() + + +def create_endpoint_0_annotations( + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, +) -> None: + if is_websocket: + + @router.websocket(path) + async def endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.send_json(None) + else: + + @router.post(path) + async def endpoint() -> None: + return None + + +def create_endpoint_1_annotation( + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, + annotation: Any, + expected_value: Any = None, +) -> None: + if is_websocket: + + @router.websocket(path) + async def endpoint(websocket: WebSocket, value: annotation) -> None: + if expected_value is not None: + assert value == expected_value + + await websocket.accept() + await websocket.send_json(value) + else: + + @router.post(path) + async def endpoint(value: annotation) -> Any: + if expected_value is not None: + assert value == expected_value + + return value + + +def create_endpoint_2_annotations( + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, + annotation1: Any, + annotation2: Any, +) -> None: + if is_websocket: + + @router.websocket(path) + async def endpoint( + websocket: WebSocket, + value1: annotation1, + value2: annotation2, + ) -> None: + await websocket.accept() + await websocket.send_json([value1, value2]) + else: + + @router.post(path) + async def endpoint( + value1: annotation1, + value2: annotation2, + ) -> List[Any]: + return [value1, value2] + + +def create_endpoint_3_annotations( + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, + annotation1: Any, + annotation2: Any, + annotation3: Any, +) -> None: + if is_websocket: + + @router.websocket(path) + async def endpoint( + websocket: WebSocket, + value1: annotation1, + value2: annotation2, + value3: annotation3, + ) -> None: + await websocket.accept() + await websocket.send_json([value1, value2, value3]) + else: + + @router.post(path) + async def endpoint( + value1: annotation1, value2: annotation2, value3: annotation3 + ) -> List[Any]: + return [value1, value2, value3] diff --git a/tests/test_params_repr.py b/tests/test_params_repr.py index bfc7bed09..8921026b2 100644 --- a/tests/test_params_repr.py +++ b/tests/test_params_repr.py @@ -1,5 +1,6 @@ from typing import Any, List +import pytest from dirty_equals import IsOneOf from fastapi.params import Body, Cookie, Depends, Header, Param, Path, Query @@ -143,10 +144,30 @@ def test_body_repr_list(): assert repr(Body([])) == "Body([])" -def test_depends_repr(): - assert repr(Depends()) == "Depends(NoneType)" - assert repr(Depends(get_user)) == "Depends(get_user)" - assert repr(Depends(use_cache=False)) == "Depends(NoneType, use_cache=False)" - assert ( - repr(Depends(get_user, use_cache=False)) == "Depends(get_user, use_cache=False)" - ) +@pytest.mark.parametrize( + ["depends", "expected_repr"], + [ + [Depends(), "Depends(NoneType)"], + [Depends(get_user), "Depends(get_user)"], + [Depends(use_cache=False), "Depends(NoneType, use_cache=False)"], + [Depends(get_user, use_cache=False), "Depends(get_user, use_cache=False)"], + [ + Depends(dependency_scope="lifespan"), + 'Depends(NoneType, dependency_scope="lifespan")', + ], + [ + Depends(get_user, dependency_scope="lifespan"), + 'Depends(get_user, dependency_scope="lifespan")', + ], + [ + Depends(use_cache=False, dependency_scope="lifespan"), + 'Depends(NoneType, use_cache=False, dependency_scope="lifespan")', + ], + [ + Depends(get_user, use_cache=False, dependency_scope="lifespan"), + 'Depends(get_user, use_cache=False, dependency_scope="lifespan")', + ], + ], +) +def test_depends_repr(depends, expected_repr): + assert repr(depends) == expected_repr diff --git a/tests/test_router_events.py b/tests/test_router_events.py index dd7ff3314..2f110e684 100644 --- a/tests/test_router_events.py +++ b/tests/test_router_events.py @@ -199,6 +199,7 @@ def test_router_nested_lifespan_state_overriding_by_parent() -> None: "app_specific": True, "router_specific": True, "overridden": "app", + "__fastapi__": {"lifespan_scoped_dependencies": {}}, } @@ -216,7 +217,7 @@ def test_merged_no_return_lifespans_return_none() -> None: app.include_router(router) with TestClient(app) as client: - assert not client.app_state + assert client.app_state == {"__fastapi__": {"lifespan_scoped_dependencies": {}}} def test_merged_mixed_state_lifespans() -> None: @@ -239,4 +240,7 @@ def test_merged_mixed_state_lifespans() -> None: app.include_router(router) with TestClient(app) as client: - assert client.app_state == {"router": True} + assert client.app_state == { + "router": True, + "__fastapi__": {"lifespan_scoped_dependencies": {}}, + } diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013a.py b/tests/test_tutorial/test_dependencies/test_tutorial013a.py new file mode 100644 index 000000000..7b5d823f9 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013a.py @@ -0,0 +1,65 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013a import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_records_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_records(self, table_name: str) -> List[dict]: + self.get_records_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return [] + + +@pytest.fixture +def database_connection_mock(monkeypatch) -> MockDatabaseConnection: + mock = MockDatabaseConnection() + + monkeypatch.setattr(MyDatabaseConnection, "__new__", lambda *args, **kwargs: mock) + + return mock + + +def test_dependency_usage(database_connection_mock): + assert database_connection_mock.enter_count == 0 + assert database_connection_mock.exit_count == 0 + with TestClient(app) as test_client: + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + + response = test_client.get("/users") + assert response.status_code == 200 + assert response.json() == [] + + assert database_connection_mock.get_records_count == 1 + + response = test_client.get("/items") + assert response.status_code == 200 + assert response.json() == [] + + assert database_connection_mock.get_records_count == 2 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py new file mode 100644 index 000000000..90775b6b0 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py @@ -0,0 +1,70 @@ +import sys +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +if sys.version_info >= (3, 9): + from docs_src.dependencies.tutorial013a_an_py39 import MyDatabaseConnection, app + +from ...utils import needs_py39 + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_records_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_records(self, table_name: str) -> List[dict]: + self.get_records_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return [] + + +@pytest.fixture +def database_connection_mock(monkeypatch) -> MockDatabaseConnection: + mock = MockDatabaseConnection() + + monkeypatch.setattr(MyDatabaseConnection, "__new__", lambda *args, **kwargs: mock) + + return mock + + +@needs_py39 +def test_dependency_usage(database_connection_mock): + assert database_connection_mock.enter_count == 0 + assert database_connection_mock.exit_count == 0 + with TestClient(app) as test_client: + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + + response = test_client.get("/users") + assert response.status_code == 200 + assert response.json() == [] + + assert database_connection_mock.get_records_count == 1 + + response = test_client.get("/items") + assert response.status_code == 200 + assert response.json() == [] + + assert database_connection_mock.get_records_count == 2 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013b.py b/tests/test_tutorial/test_dependencies/test_tutorial013b.py new file mode 100644 index 000000000..a7a092cf0 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013b.py @@ -0,0 +1,130 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013b import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_records_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_records(self, table_name: str) -> List[dict]: + self.get_records_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return [] + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_record(self, table_name, record_id) + return { + "table_name": table_name, + "record_id": record_id, + } + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + + def _get_new_connection_mock(*args, **kwargs): + mock = MockDatabaseConnection() + connections.append(mock) + + return mock + + monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) + return connections + + +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 3 + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 0 + assert connection.get_records_count == 0 + assert connection.get_record_count == 0 + + response = test_client.get("/users") + assert response.status_code == 200 + assert response.json() == [] + + users_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 1: + users_connection = connection + break + + assert users_connection is not None, ( + "No connection was found for users endpoint" + ) + + response = test_client.get("/groups") + assert response.status_code == 200 + assert response.json() == [] + + groups_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 1 and connection is not users_connection: + groups_connection = connection + break + + assert groups_connection is not None, ( + "No connection was found for groups endpoint" + ) + assert groups_connection.get_records_count == 1 + + items_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 0: + items_connection = connection + break + + assert items_connection is not None, ( + "No connection was found for items endpoint" + ) + + response = test_client.get("/items") + assert response.status_code == 200 + assert response.json() == [] + + assert items_connection.get_records_count == 1 + assert items_connection.get_record_count == 0 + + response = test_client.get("/items/asd") + assert response.status_code == 200 + assert response.json() == { + "table_name": "items", + "record_id": "asd", + } + + assert items_connection.get_records_count == 1 + assert items_connection.get_record_count == 1 + + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 0 + + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py new file mode 100644 index 000000000..e782f729f --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py @@ -0,0 +1,135 @@ +import sys +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +if sys.version_info >= (3, 9): + from docs_src.dependencies.tutorial013b_an_py39 import MyDatabaseConnection, app + +from ...utils import needs_py39 + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_records_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_records(self, table_name: str) -> List[dict]: + self.get_records_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return [] + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_record(self, table_name, record_id) + return { + "table_name": table_name, + "record_id": record_id, + } + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + + def _get_new_connection_mock(*args, **kwargs): + mock = MockDatabaseConnection() + connections.append(mock) + + return mock + + monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) + return connections + + +@needs_py39 +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 3 + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 0 + assert connection.get_records_count == 0 + assert connection.get_record_count == 0 + + response = test_client.get("/users") + assert response.status_code == 200 + assert response.json() == [] + + users_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 1: + users_connection = connection + break + + assert users_connection is not None, ( + "No connection was found for users endpoint" + ) + + response = test_client.get("/groups") + assert response.status_code == 200 + assert response.json() == [] + + groups_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 1 and connection is not users_connection: + groups_connection = connection + break + + assert groups_connection is not None, ( + "No connection was found for groups endpoint" + ) + assert groups_connection.get_records_count == 1 + + items_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 0: + items_connection = connection + break + + assert items_connection is not None, ( + "No connection was found for items endpoint" + ) + + response = test_client.get("/items") + assert response.status_code == 200 + assert response.json() == [] + + assert items_connection.get_records_count == 1 + assert items_connection.get_record_count == 0 + + response = test_client.get("/items/asd") + assert response.status_code == 200 + assert response.json() == { + "table_name": "items", + "record_id": "asd", + } + + assert items_connection.get_records_count == 1 + assert items_connection.get_record_count == 1 + + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 0 + + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013c.py b/tests/test_tutorial/test_dependencies/test_tutorial013c.py new file mode 100644 index 000000000..800eaade7 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013c.py @@ -0,0 +1,78 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013c import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self, url: str): + self.url = url + self.enter_count = 0 + self.exit_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_record(self, table_name, record_id) + return { + "table_name": table_name, + "record_id": record_id, + } + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + + def _get_new_connection_mock(cls, url): + mock = MockDatabaseConnection(url) + connections.append(mock) + + return mock + + monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) + return connections + + +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 1 + [database_connection_mock] = database_connection_mocks + + assert database_connection_mock.url == "sqlite:///database.db" + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 0 + + response = test_client.get("/users/user") + assert response.status_code == 200 + assert response.json() == { + "table_name": "users", + "record_id": "user", + } + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 1 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 + assert database_connection_mock.get_record_count == 1 + + assert len(database_connection_mocks) == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py new file mode 100644 index 000000000..80ac67f42 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py @@ -0,0 +1,83 @@ +import sys +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +if sys.version_info >= (3, 9): + from docs_src.dependencies.tutorial013c_an_py39 import MyDatabaseConnection, app + +from ...utils import needs_py39 + + +class MockDatabaseConnection: + def __init__(self, url: str): + self.url = url + self.enter_count = 0 + self.exit_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_record(self, table_name, record_id) + return { + "table_name": table_name, + "record_id": record_id, + } + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + + def _get_new_connection_mock(cls, url): + mock = MockDatabaseConnection(url) + connections.append(mock) + + return mock + + monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) + return connections + + +@needs_py39 +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 1 + [database_connection_mock] = database_connection_mocks + + assert database_connection_mock.url == "sqlite:///database.db" + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 0 + + response = test_client.get("/users/user") + assert response.status_code == 200 + assert response.json() == { + "table_name": "users", + "record_id": "user", + } + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 1 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 + assert database_connection_mock.get_record_count == 1 + + assert len(database_connection_mocks) == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013d.py b/tests/test_tutorial/test_dependencies/test_tutorial013d.py new file mode 100644 index 000000000..eb01a7232 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013d.py @@ -0,0 +1,76 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013d import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_record(self, table_name, record_id) + return { + "table_name": table_name, + "record_id": record_id, + } + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + + def _get_new_connection_mock(*args, **kwargs): + mock = MockDatabaseConnection() + connections.append(mock) + + return mock + + monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) + return connections + + +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 1 + [database_connection_mock] = database_connection_mocks + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 0 + + response = test_client.get("/users/user") + assert response.status_code == 200 + assert response.json() == { + "table_name": "users", + "record_id": "user", + } + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 1 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 + assert database_connection_mock.get_record_count == 1 + + assert len(database_connection_mocks) == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py new file mode 100644 index 000000000..8563325da --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py @@ -0,0 +1,81 @@ +import sys +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +if sys.version_info >= (3, 9): + from docs_src.dependencies.tutorial013d_an_py39 import MyDatabaseConnection, app + +from ...utils import needs_py39 + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_record(self, table_name, record_id) + return { + "table_name": table_name, + "record_id": record_id, + } + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + + def _get_new_connection_mock(*args, **kwargs): + mock = MockDatabaseConnection() + connections.append(mock) + + return mock + + monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) + return connections + + +@needs_py39 +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 1 + [database_connection_mock] = database_connection_mocks + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 0 + + response = test_client.get("/users/user") + assert response.status_code == 200 + assert response.json() == { + "table_name": "users", + "record_id": "user", + } + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 1 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 + assert database_connection_mock.get_record_count == 1 + + assert len(database_connection_mocks) == 1