You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

703 lines
20 KiB

from enum import StrEnum, auto
from typing import Any, AsyncGenerator, List, Tuple, TypeVar
import pytest
from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Cookie,
Depends,
FastAPI,
File,
Form,
Header,
Path,
Query,
)
from fastapi.exceptions import FastAPIError
from fastapi.params import Security
from fastapi.security import SecurityScopes
from starlette.testclient import TestClient
from typing_extensions import Annotated, Generator, Literal, assert_never
T = TypeVar('T')
class DependencyStyle(StrEnum):
SYNC_FUNCTION = auto()
ASYNC_FUNCTION = auto()
SYNC_GENERATOR = auto()
ASYNC_GENERATOR = auto()
class DependencyFactory:
def __init__(
self,
dependency_style: DependencyStyle, *,
should_error: bool = False
):
self.activation_times = 0
self.deactivation_times = 0
self.dependency_style = dependency_style
self._should_error = should_error
def get_dependency(self):
if self.dependency_style == DependencyStyle.SYNC_FUNCTION:
return self._synchronous_function_dependency
if self.dependency_style == DependencyStyle.SYNC_GENERATOR:
return self._synchronous_generator_dependency
if self.dependency_style == DependencyStyle.ASYNC_FUNCTION:
return self._asynchronous_function_dependency
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR:
return self._asynchronous_generator_dependency
assert_never(self.dependency_style)
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
yield self.activation_times
self.deactivation_times += 1
def _synchronous_generator_dependency(self) -> Generator[T, None, None]:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
yield self.activation_times
self.deactivation_times += 1
async def _asynchronous_function_dependency(self) -> T:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
return self.activation_times
def _synchronous_function_dependency(self) -> T:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
return self.activation_times
def _expect_correct_amount_of_dependency_activations(
*,
app: FastAPI,
dependency_factory: DependencyFactory,
urls_and_responses: List[Tuple[str, Any]],
expected_activation_times: int
) -> None:
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
with TestClient(app) as client:
assert dependency_factory.activation_times == expected_activation_times
assert dependency_factory.deactivation_times == 0
for url, expected_response in urls_and_responses:
response = client.post(url)
response.raise_for_status()
assert response.json() == expected_response
assert dependency_factory.activation_times == expected_activation_times
assert dependency_factory.deactivation_times == 0
assert dependency_factory.activation_times == expected_activation_times
if dependency_factory.dependency_style not in (
DependencyStyle.SYNC_FUNCTION,
DependencyStyle.ASYNC_FUNCTION
):
assert dependency_factory.deactivation_times == expected_activation_times
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, use_cache):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
@router.post("/test")
async def endpoint(
dependency: Annotated[None, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)]
) -> None:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_router_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
if routing_style == "app":
app = FastAPI(dependencies=[depends])
@app.post("/test")
async def endpoint() -> None:
return None
else:
app = FastAPI()
router = APIRouter(dependencies=[depends])
@router.post("/test")
async def endpoint() -> None:
return None
app.include_router(router)
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", None)] * 2,
expected_activation_times=1
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"])
def test_dependency_cache_in_same_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache,
main_dependency_scope: Literal["endpoint", "lifespan"]
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def dependency(
sub_dependency1: Annotated[int, depends],
sub_dependency2: Annotated[int, depends],
) -> List[int]:
return [sub_dependency1, sub_dependency2]
@router.post("/test")
async def endpoint(
dependency: Annotated[List[int], Depends(
dependency,
use_cache=use_cache,
dependency_scope=main_dependency_scope,
)]
) -> List[int]:
return dependency
if routing_style == "router":
app.include_router(router)
if use_cache:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 1]),
("/test", [1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1
)
else:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 2]),
("/test", [1, 2]),
],
dependency_factory=dependency_factory,
expected_activation_times=2
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_same_endpoint(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
@router.post("/test1")
async def endpoint(
dependency1: Annotated[int, depends],
dependency2: Annotated[int, depends],
dependency3: Annotated[int, Depends(endpoint_dependency)]
) -> List[int]:
return [dependency1, dependency2, dependency3]
if routing_style == "router":
app.include_router(router)
if use_cache:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 1, 1]),
("/test1", [1, 1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1
)
else:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 2, 3]),
("/test1", [1, 2, 3]),
],
dependency_factory=dependency_factory,
expected_activation_times=3
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_different_endpoints(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
@router.post("/test1")
async def endpoint(
dependency1: Annotated[int, depends],
dependency2: Annotated[int, depends],
dependency3: Annotated[int, Depends(endpoint_dependency)]
) -> List[int]:
return [dependency1, dependency2, dependency3]
@router.post("/test2")
async def endpoint2(
dependency1: Annotated[int, depends],
dependency2: Annotated[int, depends],
dependency3: Annotated[int, Depends(endpoint_dependency)]
) -> List[int]:
return [dependency1, dependency2, dependency3]
if routing_style == "router":
app.include_router(router)
if use_cache:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 1, 1]),
("/test2", [1, 1, 1]),
("/test1", [1, 1, 1]),
("/test2", [1, 1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1
)
else:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 2, 3]),
("/test2", [4, 5, 3]),
("/test1", [1, 2, 3]),
("/test2", [4, 5, 3]),
],
dependency_factory=dependency_factory,
expected_activation_times=5
)
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_no_cached_dependency(
dependency_style: DependencyStyle,
routing_style,
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=False
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends],
) -> int:
return dependency
if routing_style == "router":
app.include_router(router)
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1
)
@pytest.mark.parametrize("annotation", [
Annotated[str, Path()],
Annotated[str, Body()],
Annotated[str, Query()],
Annotated[str, Header()],
SecurityScopes,
Annotated[str, Cookie()],
Annotated[str, Form()],
Annotated[str, File()],
BackgroundTasks,
])
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
annotation
):
async def dependency_func(param: annotation) -> None:
yield
app = FastAPI()
with pytest.raises(FastAPIError):
@app.post("/test")
async def endpoint(
dependency: Annotated[
None, Depends(dependency_func, dependency_scope="lifespan")]
) -> None:
return
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies(
dependency_style: DependencyStyle
):
dependency_factory = DependencyFactory(dependency_style)
async def lifespan_scoped_dependency(
param: Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)]
) -> AsyncGenerator[int, None]:
yield param
app = FastAPI()
@app.post("/test")
async def endpoint(
dependency: Annotated[int, Depends(
lifespan_scoped_dependency,
dependency_scope="lifespan"
)]
) -> int:
return dependency
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
expected_activation_times=1,
urls_and_responses=[("/test", 1)] * 2
)
@pytest.mark.parametrize("depends_class", [Depends, Security])
@pytest.mark.parametrize("route_type", [FastAPI.post, FastAPI.websocket], ids=[
"websocket", "endpoint"
])
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
depends_class,
route_type
):
async def sub_dependency() -> None:
pass
async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None:
yield
app = FastAPI()
route_decorator = route_type(app, "/test")
with pytest.raises(FastAPIError):
@route_decorator
async def endpoint(x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")]
) -> None:
return
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_dependencies_must_provide_correct_dependency_scope(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
with pytest.raises(FastAPIError):
@router.post("/test")
async def endpoint(
dependency: Annotated[None, Depends(
dependency_factory.get_dependency(),
dependency_scope="incorrect",
use_cache=use_cache,
)]
) -> None:
assert dependency == 1
return dependency
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_incorrect_dependency_scope(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
# We intentionally change the dependency scope here to bypass the
# validation at the function level.
depends.dependency_scope = "asdad"
with pytest.raises(FastAPIError):
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
with TestClient(app) as client:
dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"]
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {}
try:
with pytest.raises(FastAPIError):
client.post("/test")
finally:
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_internal_lifespan(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
with TestClient(app) as client:
internal_state = client.app_state["__fastapi__"]
del client.app_state["__fastapi__"]
try:
with pytest.raises(FastAPIError):
client.post("/test")
finally:
client.app_state["__fastapi__"] = internal_state
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: DependencyStyle, routing_style):
dependency_factory= DependencyFactory(dependency_style, should_error=True)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
with pytest.raises(ValueError) as exception_info:
with TestClient(app):
pass
assert exception_info.value.args == (1,)
# TODO: Add tests for dependency_overrides
# TODO: Add a websocket equivalent to all tests