pythonasyncioapiasyncfastapiframeworkjsonjson-schemaopenapiopenapi3pydanticpython-typespython3redocreststarletteswaggerswagger-uiuvicornweb
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
703 lines
20 KiB
703 lines
20 KiB
from enum import StrEnum, auto
|
|
from typing import Any, AsyncGenerator, List, Tuple, TypeVar
|
|
|
|
import pytest
|
|
from fastapi import (
|
|
APIRouter,
|
|
BackgroundTasks,
|
|
Body,
|
|
Cookie,
|
|
Depends,
|
|
FastAPI,
|
|
File,
|
|
Form,
|
|
Header,
|
|
Path,
|
|
Query,
|
|
)
|
|
from fastapi.exceptions import FastAPIError
|
|
from fastapi.params import Security
|
|
from fastapi.security import SecurityScopes
|
|
from starlette.testclient import TestClient
|
|
from typing_extensions import Annotated, Generator, Literal, assert_never
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
class DependencyStyle(StrEnum):
|
|
SYNC_FUNCTION = auto()
|
|
ASYNC_FUNCTION = auto()
|
|
SYNC_GENERATOR = auto()
|
|
ASYNC_GENERATOR = auto()
|
|
|
|
|
|
class DependencyFactory:
|
|
def __init__(
|
|
self,
|
|
dependency_style: DependencyStyle, *,
|
|
should_error: bool = False
|
|
):
|
|
self.activation_times = 0
|
|
self.deactivation_times = 0
|
|
self.dependency_style = dependency_style
|
|
self._should_error = should_error
|
|
|
|
def get_dependency(self):
|
|
if self.dependency_style == DependencyStyle.SYNC_FUNCTION:
|
|
return self._synchronous_function_dependency
|
|
|
|
if self.dependency_style == DependencyStyle.SYNC_GENERATOR:
|
|
return self._synchronous_generator_dependency
|
|
|
|
if self.dependency_style == DependencyStyle.ASYNC_FUNCTION:
|
|
return self._asynchronous_function_dependency
|
|
|
|
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR:
|
|
return self._asynchronous_generator_dependency
|
|
|
|
assert_never(self.dependency_style)
|
|
|
|
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]:
|
|
self.activation_times += 1
|
|
if self._should_error:
|
|
raise ValueError(self.activation_times)
|
|
|
|
yield self.activation_times
|
|
self.deactivation_times += 1
|
|
|
|
def _synchronous_generator_dependency(self) -> Generator[T, None, None]:
|
|
self.activation_times += 1
|
|
if self._should_error:
|
|
raise ValueError(self.activation_times)
|
|
|
|
yield self.activation_times
|
|
self.deactivation_times += 1
|
|
|
|
async def _asynchronous_function_dependency(self) -> T:
|
|
self.activation_times += 1
|
|
if self._should_error:
|
|
raise ValueError(self.activation_times)
|
|
|
|
return self.activation_times
|
|
|
|
def _synchronous_function_dependency(self) -> T:
|
|
self.activation_times += 1
|
|
if self._should_error:
|
|
raise ValueError(self.activation_times)
|
|
|
|
return self.activation_times
|
|
|
|
|
|
def _expect_correct_amount_of_dependency_activations(
|
|
*,
|
|
app: FastAPI,
|
|
dependency_factory: DependencyFactory,
|
|
urls_and_responses: List[Tuple[str, Any]],
|
|
expected_activation_times: int
|
|
) -> None:
|
|
assert dependency_factory.activation_times == 0
|
|
assert dependency_factory.deactivation_times == 0
|
|
with TestClient(app) as client:
|
|
assert dependency_factory.activation_times == expected_activation_times
|
|
assert dependency_factory.deactivation_times == 0
|
|
|
|
for url, expected_response in urls_and_responses:
|
|
response = client.post(url)
|
|
response.raise_for_status()
|
|
assert response.json() == expected_response
|
|
|
|
assert dependency_factory.activation_times == expected_activation_times
|
|
assert dependency_factory.deactivation_times == 0
|
|
|
|
assert dependency_factory.activation_times == expected_activation_times
|
|
if dependency_factory.dependency_style not in (
|
|
DependencyStyle.SYNC_FUNCTION,
|
|
DependencyStyle.ASYNC_FUNCTION
|
|
):
|
|
assert dependency_factory.deactivation_times == expected_activation_times
|
|
|
|
@pytest.mark.parametrize("use_cache", [True, False])
|
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
|
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
|
|
def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, use_cache):
|
|
dependency_factory= DependencyFactory(dependency_style)
|
|
|
|
app = FastAPI()
|
|
|
|
if routing_style == "app_endpoint":
|
|
router = app
|
|
else:
|
|
router = APIRouter()
|
|
|
|
@router.post("/test")
|
|
async def endpoint(
|
|
dependency: Annotated[None, Depends(
|
|
dependency_factory.get_dependency(),
|
|
dependency_scope="lifespan",
|
|
use_cache=use_cache,
|
|
)]
|
|
) -> None:
|
|
assert dependency == 1
|
|
return dependency
|
|
|
|
if routing_style == "router_endpoint":
|
|
app.include_router(router)
|
|
|
|
_expect_correct_amount_of_dependency_activations(
|
|
app=app,
|
|
dependency_factory=dependency_factory,
|
|
urls_and_responses=[("/test", 1)] * 2,
|
|
expected_activation_times=1
|
|
)
|
|
|
|
@pytest.mark.parametrize("use_cache", [True, False])
|
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
|
|
@pytest.mark.parametrize("routing_style", ["app", "router"])
|
|
def test_router_dependencies(
|
|
dependency_style: DependencyStyle,
|
|
routing_style,
|
|
use_cache
|
|
):
|
|
dependency_factory= DependencyFactory(dependency_style)
|
|
|
|
depends = Depends(
|
|
dependency_factory.get_dependency(),
|
|
dependency_scope="lifespan",
|
|
use_cache=use_cache
|
|
)
|
|
|
|
if routing_style == "app":
|
|
app = FastAPI(dependencies=[depends])
|
|
|
|
@app.post("/test")
|
|
async def endpoint() -> None:
|
|
return None
|
|
else:
|
|
app = FastAPI()
|
|
router = APIRouter(dependencies=[depends])
|
|
|
|
@router.post("/test")
|
|
async def endpoint() -> None:
|
|
return None
|
|
|
|
app.include_router(router)
|
|
|
|
_expect_correct_amount_of_dependency_activations(
|
|
app=app,
|
|
dependency_factory=dependency_factory,
|
|
urls_and_responses=[("/test", None)] * 2,
|
|
expected_activation_times=1
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("use_cache", [True, False])
|
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
|
|
@pytest.mark.parametrize("routing_style", ["app", "router"])
|
|
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"])
|
|
def test_dependency_cache_in_same_dependency(
|
|
dependency_style: DependencyStyle,
|
|
routing_style,
|
|
use_cache,
|
|
main_dependency_scope: Literal["endpoint", "lifespan"]
|
|
):
|
|
dependency_factory= DependencyFactory(dependency_style)
|
|
|
|
depends = Depends(
|
|
dependency_factory.get_dependency(),
|
|
dependency_scope="lifespan",
|
|
use_cache=use_cache
|
|
)
|
|
|
|
app = FastAPI()
|
|
|
|
if routing_style == "app":
|
|
router = app
|
|
|
|
else:
|
|
router = APIRouter()
|
|
|
|
async def dependency(
|
|
sub_dependency1: Annotated[int, depends],
|
|
sub_dependency2: Annotated[int, depends],
|
|
) -> List[int]:
|
|
return [sub_dependency1, sub_dependency2]
|
|
|
|
@router.post("/test")
|
|
async def endpoint(
|
|
dependency: Annotated[List[int], Depends(
|
|
dependency,
|
|
use_cache=use_cache,
|
|
dependency_scope=main_dependency_scope,
|
|
)]
|
|
) -> List[int]:
|
|
return dependency
|
|
|
|
if routing_style == "router":
|
|
app.include_router(router)
|
|
|
|
if use_cache:
|
|
_expect_correct_amount_of_dependency_activations(
|
|
app=app,
|
|
urls_and_responses=[
|
|
("/test", [1, 1]),
|
|
("/test", [1, 1]),
|
|
],
|
|
dependency_factory=dependency_factory,
|
|
expected_activation_times=1
|
|
)
|
|
else:
|
|
_expect_correct_amount_of_dependency_activations(
|
|
app=app,
|
|
urls_and_responses=[
|
|
("/test", [1, 2]),
|
|
("/test", [1, 2]),
|
|
],
|
|
dependency_factory=dependency_factory,
|
|
expected_activation_times=2
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("use_cache", [True, False])
|
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
|
|
@pytest.mark.parametrize("routing_style", ["app", "router"])
|
|
def test_dependency_cache_in_same_endpoint(
|
|
dependency_style: DependencyStyle,
|
|
routing_style,
|
|
use_cache
|
|
):
|
|
dependency_factory= DependencyFactory(dependency_style)
|
|
|
|
depends = Depends(
|
|
dependency_factory.get_dependency(),
|
|
dependency_scope="lifespan",
|
|
use_cache=use_cache
|
|
)
|
|
|
|
app = FastAPI()
|
|
|
|
if routing_style == "app":
|
|
router = app
|
|
|
|
else:
|
|
router = APIRouter()
|
|
|
|
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
|
|
return dependency3
|
|
|
|
@router.post("/test1")
|
|
async def endpoint(
|
|
dependency1: Annotated[int, depends],
|
|
dependency2: Annotated[int, depends],
|
|
dependency3: Annotated[int, Depends(endpoint_dependency)]
|
|
) -> List[int]:
|
|
return [dependency1, dependency2, dependency3]
|
|
|
|
if routing_style == "router":
|
|
app.include_router(router)
|
|
|
|
if use_cache:
|
|
_expect_correct_amount_of_dependency_activations(
|
|
app=app,
|
|
urls_and_responses=[
|
|
("/test1", [1, 1, 1]),
|
|
("/test1", [1, 1, 1]),
|
|
],
|
|
dependency_factory=dependency_factory,
|
|
expected_activation_times=1
|
|
)
|
|
else:
|
|
_expect_correct_amount_of_dependency_activations(
|
|
app=app,
|
|
urls_and_responses=[
|
|
("/test1", [1, 2, 3]),
|
|
("/test1", [1, 2, 3]),
|
|
],
|
|
dependency_factory=dependency_factory,
|
|
expected_activation_times=3
|
|
)
|
|
|
|
@pytest.mark.parametrize("use_cache", [True, False])
|
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
|
|
@pytest.mark.parametrize("routing_style", ["app", "router"])
|
|
def test_dependency_cache_in_different_endpoints(
|
|
dependency_style: DependencyStyle,
|
|
routing_style,
|
|
use_cache
|
|
):
|
|
dependency_factory= DependencyFactory(dependency_style)
|
|
|
|
depends = Depends(
|
|
dependency_factory.get_dependency(),
|
|
dependency_scope="lifespan",
|
|
use_cache=use_cache
|
|
)
|
|
|
|
app = FastAPI()
|
|
|
|
if routing_style == "app":
|
|
router = app
|
|
|
|
else:
|
|
router = APIRouter()
|
|
|
|
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
|
|
return dependency3
|
|
|
|
@router.post("/test1")
|
|
async def endpoint(
|
|
dependency1: Annotated[int, depends],
|
|
dependency2: Annotated[int, depends],
|
|
dependency3: Annotated[int, Depends(endpoint_dependency)]
|
|
) -> List[int]:
|
|
return [dependency1, dependency2, dependency3]
|
|
|
|
@router.post("/test2")
|
|
async def endpoint2(
|
|
dependency1: Annotated[int, depends],
|
|
dependency2: Annotated[int, depends],
|
|
dependency3: Annotated[int, Depends(endpoint_dependency)]
|
|
) -> List[int]:
|
|
return [dependency1, dependency2, dependency3]
|
|
|
|
if routing_style == "router":
|
|
app.include_router(router)
|
|
|
|
if use_cache:
|
|
_expect_correct_amount_of_dependency_activations(
|
|
app=app,
|
|
urls_and_responses=[
|
|
("/test1", [1, 1, 1]),
|
|
("/test2", [1, 1, 1]),
|
|
("/test1", [1, 1, 1]),
|
|
("/test2", [1, 1, 1]),
|
|
],
|
|
dependency_factory=dependency_factory,
|
|
expected_activation_times=1
|
|
)
|
|
else:
|
|
_expect_correct_amount_of_dependency_activations(
|
|
app=app,
|
|
urls_and_responses=[
|
|
("/test1", [1, 2, 3]),
|
|
("/test2", [4, 5, 3]),
|
|
("/test1", [1, 2, 3]),
|
|
("/test2", [4, 5, 3]),
|
|
],
|
|
dependency_factory=dependency_factory,
|
|
expected_activation_times=5
|
|
)
|
|
|
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
|
|
@pytest.mark.parametrize("routing_style", ["app", "router"])
|
|
def test_no_cached_dependency(
|
|
dependency_style: DependencyStyle,
|
|
routing_style,
|
|
):
|
|
dependency_factory= DependencyFactory(dependency_style)
|
|
|
|
depends = Depends(
|
|
dependency_factory.get_dependency(),
|
|
dependency_scope="lifespan",
|
|
use_cache=False
|
|
)
|
|
|
|
app = FastAPI()
|
|
|
|
if routing_style == "app":
|
|
router = app
|
|
|
|
else:
|
|
router = APIRouter()
|
|
|
|
@router.post("/test")
|
|
async def endpoint(
|
|
dependency: Annotated[int, depends],
|
|
) -> int:
|
|
return dependency
|
|
|
|
if routing_style == "router":
|
|
app.include_router(router)
|
|
|
|
_expect_correct_amount_of_dependency_activations(
|
|
app=app,
|
|
dependency_factory=dependency_factory,
|
|
urls_and_responses=[("/test", 1)] * 2,
|
|
expected_activation_times=1
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("annotation", [
|
|
Annotated[str, Path()],
|
|
Annotated[str, Body()],
|
|
Annotated[str, Query()],
|
|
Annotated[str, Header()],
|
|
SecurityScopes,
|
|
Annotated[str, Cookie()],
|
|
Annotated[str, Form()],
|
|
Annotated[str, File()],
|
|
BackgroundTasks,
|
|
])
|
|
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
|
|
annotation
|
|
):
|
|
async def dependency_func(param: annotation) -> None:
|
|
yield
|
|
|
|
app = FastAPI()
|
|
|
|
with pytest.raises(FastAPIError):
|
|
@app.post("/test")
|
|
async def endpoint(
|
|
dependency: Annotated[
|
|
None, Depends(dependency_func, dependency_scope="lifespan")]
|
|
) -> None:
|
|
return
|
|
|
|
|
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
|
|
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies(
|
|
dependency_style: DependencyStyle
|
|
):
|
|
dependency_factory = DependencyFactory(dependency_style)
|
|
|
|
async def lifespan_scoped_dependency(
|
|
param: Annotated[int, Depends(
|
|
dependency_factory.get_dependency(),
|
|
dependency_scope="lifespan"
|
|
)]
|
|
) -> AsyncGenerator[int, None]:
|
|
yield param
|
|
|
|
app = FastAPI()
|
|
|
|
@app.post("/test")
|
|
async def endpoint(
|
|
dependency: Annotated[int, Depends(
|
|
lifespan_scoped_dependency,
|
|
dependency_scope="lifespan"
|
|
)]
|
|
) -> int:
|
|
return dependency
|
|
|
|
_expect_correct_amount_of_dependency_activations(
|
|
app=app,
|
|
dependency_factory=dependency_factory,
|
|
expected_activation_times=1,
|
|
urls_and_responses=[("/test", 1)] * 2
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("depends_class", [Depends, Security])
|
|
@pytest.mark.parametrize("route_type", [FastAPI.post, FastAPI.websocket], ids=[
|
|
"websocket", "endpoint"
|
|
])
|
|
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
|
|
depends_class,
|
|
route_type
|
|
):
|
|
async def sub_dependency() -> None:
|
|
pass
|
|
|
|
async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None:
|
|
yield
|
|
|
|
app = FastAPI()
|
|
route_decorator = route_type(app, "/test")
|
|
|
|
with pytest.raises(FastAPIError):
|
|
@route_decorator
|
|
async def endpoint(x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")]
|
|
) -> None:
|
|
return
|
|
|
|
@pytest.mark.parametrize("use_cache", [True, False])
|
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
|
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
|
|
def test_dependencies_must_provide_correct_dependency_scope(
|
|
dependency_style: DependencyStyle,
|
|
routing_style,
|
|
use_cache
|
|
):
|
|
dependency_factory= DependencyFactory(dependency_style)
|
|
|
|
app = FastAPI()
|
|
|
|
if routing_style == "app_endpoint":
|
|
router = app
|
|
else:
|
|
router = APIRouter()
|
|
|
|
with pytest.raises(FastAPIError):
|
|
@router.post("/test")
|
|
async def endpoint(
|
|
dependency: Annotated[None, Depends(
|
|
dependency_factory.get_dependency(),
|
|
dependency_scope="incorrect",
|
|
use_cache=use_cache,
|
|
)]
|
|
) -> None:
|
|
assert dependency == 1
|
|
return dependency
|
|
|
|
|
|
@pytest.mark.parametrize("use_cache", [True, False])
|
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
|
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
|
|
def test_endpoints_report_incorrect_dependency_scope(
|
|
dependency_style: DependencyStyle,
|
|
routing_style,
|
|
use_cache
|
|
):
|
|
dependency_factory= DependencyFactory(dependency_style)
|
|
|
|
app = FastAPI()
|
|
|
|
if routing_style == "app_endpoint":
|
|
router = app
|
|
else:
|
|
router = APIRouter()
|
|
|
|
depends = Depends(
|
|
dependency_factory.get_dependency(),
|
|
dependency_scope="lifespan",
|
|
use_cache=use_cache,
|
|
)
|
|
# We intentionally change the dependency scope here to bypass the
|
|
# validation at the function level.
|
|
depends.dependency_scope = "asdad"
|
|
|
|
with pytest.raises(FastAPIError):
|
|
@router.post("/test")
|
|
async def endpoint(
|
|
dependency: Annotated[int, depends]
|
|
) -> int:
|
|
assert dependency == 1
|
|
return dependency
|
|
|
|
|
|
@pytest.mark.parametrize("use_cache", [True, False])
|
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
|
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
|
|
def test_endpoints_report_uninitialized_dependency(
|
|
dependency_style: DependencyStyle,
|
|
routing_style,
|
|
use_cache
|
|
):
|
|
dependency_factory= DependencyFactory(dependency_style)
|
|
|
|
app = FastAPI()
|
|
|
|
if routing_style == "app_endpoint":
|
|
router = app
|
|
else:
|
|
router = APIRouter()
|
|
|
|
depends = Depends(
|
|
dependency_factory.get_dependency(),
|
|
dependency_scope="lifespan",
|
|
use_cache=use_cache,
|
|
)
|
|
|
|
@router.post("/test")
|
|
async def endpoint(
|
|
dependency: Annotated[int, depends]
|
|
) -> int:
|
|
assert dependency == 1
|
|
return dependency
|
|
|
|
if routing_style == "router_endpoint":
|
|
app.include_router(router)
|
|
|
|
with TestClient(app) as client:
|
|
dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"]
|
|
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {}
|
|
|
|
try:
|
|
with pytest.raises(FastAPIError):
|
|
client.post("/test")
|
|
finally:
|
|
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies
|
|
|
|
|
|
@pytest.mark.parametrize("use_cache", [True, False])
|
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
|
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
|
|
def test_endpoints_report_uninitialized_internal_lifespan(
|
|
dependency_style: DependencyStyle,
|
|
routing_style,
|
|
use_cache
|
|
):
|
|
dependency_factory= DependencyFactory(dependency_style)
|
|
|
|
app = FastAPI()
|
|
|
|
if routing_style == "app_endpoint":
|
|
router = app
|
|
else:
|
|
router = APIRouter()
|
|
|
|
depends = Depends(
|
|
dependency_factory.get_dependency(),
|
|
dependency_scope="lifespan",
|
|
use_cache=use_cache,
|
|
)
|
|
|
|
@router.post("/test")
|
|
async def endpoint(
|
|
dependency: Annotated[int, depends]
|
|
) -> int:
|
|
assert dependency == 1
|
|
return dependency
|
|
|
|
if routing_style == "router_endpoint":
|
|
app.include_router(router)
|
|
|
|
with TestClient(app) as client:
|
|
internal_state = client.app_state["__fastapi__"]
|
|
del client.app_state["__fastapi__"]
|
|
|
|
try:
|
|
with pytest.raises(FastAPIError):
|
|
client.post("/test")
|
|
finally:
|
|
client.app_state["__fastapi__"] = internal_state
|
|
|
|
|
|
@pytest.mark.parametrize("use_cache", [True, False])
|
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
|
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
|
|
def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: DependencyStyle, routing_style):
|
|
dependency_factory= DependencyFactory(dependency_style, should_error=True)
|
|
depends = Depends(
|
|
dependency_factory.get_dependency(),
|
|
dependency_scope="lifespan",
|
|
use_cache=use_cache,
|
|
)
|
|
|
|
app = FastAPI()
|
|
|
|
if routing_style == "app_endpoint":
|
|
router = app
|
|
|
|
else:
|
|
router = APIRouter()
|
|
|
|
@router.post("/test")
|
|
async def endpoint(
|
|
dependency: Annotated[int, depends]
|
|
) -> int:
|
|
assert dependency == 1
|
|
return dependency
|
|
|
|
if routing_style == "router_endpoint":
|
|
app.include_router(router)
|
|
|
|
with pytest.raises(ValueError) as exception_info:
|
|
with TestClient(app):
|
|
pass
|
|
|
|
assert exception_info.value.args == (1,)
|
|
|
|
|
|
# TODO: Add tests for dependency_overrides
|
|
# TODO: Add a websocket equivalent to all tests
|
|
|