|
|
@ -4,6 +4,8 @@ from time import sleep |
|
|
|
from typing import Any, AsyncGenerator, Dict, List, Tuple |
|
|
|
|
|
|
|
import pytest |
|
|
|
from setuptools import depends |
|
|
|
|
|
|
|
from fastapi import ( |
|
|
|
APIRouter, |
|
|
|
BackgroundTasks, |
|
|
@ -19,6 +21,7 @@ from fastapi import ( |
|
|
|
Request, |
|
|
|
WebSocket, |
|
|
|
) |
|
|
|
from fastapi.dependencies.utils import get_endpoint_dependant |
|
|
|
from fastapi.exceptions import ( |
|
|
|
DependencyScopeConflict, |
|
|
|
InvalidDependencyScope, |
|
|
@ -443,7 +446,7 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( |
|
|
|
annotation, is_websocket |
|
|
|
): |
|
|
|
async def dependency_func(param: annotation) -> None: |
|
|
|
yield |
|
|
|
yield # pragma: nocover |
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
@ -598,7 +601,8 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( |
|
|
|
async def shutdown() -> None: |
|
|
|
nonlocal lifespan_ended |
|
|
|
lifespan_ended = True |
|
|
|
elif lifespan_style == "events_constructor": |
|
|
|
else: |
|
|
|
assert lifespan_style == "events_constructor" |
|
|
|
|
|
|
|
async def startup() -> None: |
|
|
|
nonlocal lifespan_started |
|
|
@ -609,8 +613,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( |
|
|
|
lifespan_ended = True |
|
|
|
|
|
|
|
app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) |
|
|
|
else: |
|
|
|
assert_never(lifespan_style) |
|
|
|
|
|
|
|
|
|
|
|
dependency_factory = DependencyFactory(dependency_style) |
|
|
|
|
|
|
@ -641,12 +644,12 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( |
|
|
|
depends_class, is_websocket |
|
|
|
): |
|
|
|
async def sub_dependency() -> None: |
|
|
|
pass |
|
|
|
pass # pragma: nocover |
|
|
|
|
|
|
|
async def dependency_func( |
|
|
|
param: Annotated[None, depends_class(sub_dependency)], |
|
|
|
) -> None: |
|
|
|
yield |
|
|
|
pass # pragma: nocover |
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
@ -730,6 +733,39 @@ def test_endpoints_report_incorrect_dependency_scope( |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|
|
|
@pytest.mark.parametrize("use_cache", [True, False]) |
|
|
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|
|
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|
|
|
def test_endpoints_report_incorrect_dependency_scope_at_router_scope( |
|
|
|
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
|
|
|
): |
|
|
|
dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) |
|
|
|
|
|
|
|
depends = Depends( |
|
|
|
dependency_factory.get_dependency(), |
|
|
|
dependency_scope="lifespan" |
|
|
|
) |
|
|
|
|
|
|
|
# We intentionally change the dependency scope here to bypass the |
|
|
|
# validation at the function level. |
|
|
|
depends.dependency_scope = "asdad" |
|
|
|
|
|
|
|
if routing_style == "app_endpoint": |
|
|
|
app = FastAPI(dependencies=[depends]) |
|
|
|
router = app |
|
|
|
else: |
|
|
|
router = APIRouter(dependencies=[depends]) |
|
|
|
|
|
|
|
|
|
|
|
with pytest.raises(InvalidDependencyScope): |
|
|
|
create_endpoint_0_annotations( |
|
|
|
router=router, |
|
|
|
path="/test", |
|
|
|
is_websocket=is_websocket, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|
|
|
@pytest.mark.parametrize("use_cache", [True, False]) |
|
|
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|
|
@ -866,3 +902,26 @@ def test_bad_lifespan_scoped_dependencies( |
|
|
|
pass |
|
|
|
|
|
|
|
assert exception_info.value.args == (1,) |
|
|
|
|
|
|
|
def test_endpoint_dependant_backwards_compatibility(): |
|
|
|
dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) |
|
|
|
|
|
|
|
def endpoint( |
|
|
|
dependency1: Annotated[int, Depends(dependency_factory.get_dependency())], |
|
|
|
dependency2: Annotated[int, Depends( |
|
|
|
dependency_factory.get_dependency(), |
|
|
|
dependency_scope="lifespan" |
|
|
|
)], |
|
|
|
): |
|
|
|
pass # pragma: nocover |
|
|
|
|
|
|
|
dependant = get_endpoint_dependant( |
|
|
|
path="/test", |
|
|
|
call=endpoint, |
|
|
|
name="endpoint", |
|
|
|
) |
|
|
|
|
|
|
|
assert dependant.dependencies == tuple( |
|
|
|
dependant.lifespan_dependencies + |
|
|
|
dependant.endpoint_dependencies |
|
|
|
) |