Browse Source

Applied ruff linting

pull/12529/head
Nir Schulman 9 months ago
parent
commit
d512b03bd3
  1. 18
      fastapi/dependencies/models.py
  2. 68
      fastapi/dependencies/utils.py
  3. 18
      fastapi/routing.py
  4. 280
      tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py
  5. 334
      tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py
  6. 97
      tests/test_lifespan_scoped_dependencies/testing_utilities.py

18
fastapi/dependencies/models.py

@ -12,7 +12,10 @@ class SecurityRequirement:
scopes: Optional[Sequence[str]] = None scopes: Optional[Sequence[str]] = None
LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], Union[str, int]], Callable[..., Any]] LifespanDependantCacheKey: TypeAlias = Union[
Tuple[Callable[..., Any], Union[str, int]], Callable[..., Any]
]
@dataclass @dataclass
class LifespanDependant: class LifespanDependant:
@ -30,9 +33,9 @@ class LifespanDependant:
elif self.name is not None: elif self.name is not None:
self.cache_key = (self.caller, self.name) self.cache_key = (self.caller, self.name)
else: else:
assert self.index is not None, ( assert (
"Lifespan dependency must have an associated name or index." self.index is not None
) ), "Lifespan dependency must have an associated name or index."
self.cache_key = (self.caller, self.index) self.cache_key = (self.caller, self.index)
@ -49,8 +52,7 @@ class EndpointDependant:
call: Optional[Callable[..., Any]] = None call: Optional[Callable[..., Any]] = None
use_cache: bool = True use_cache: bool = True
index: Optional[int] = None index: Optional[int] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field( cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
init=False)
path_params: List[ModelField] = field(default_factory=list) path_params: List[ModelField] = field(default_factory=list)
query_params: List[ModelField] = field(default_factory=list) query_params: List[ModelField] = field(default_factory=list)
header_params: List[ModelField] = field(default_factory=list) header_params: List[ModelField] = field(default_factory=list)
@ -74,11 +76,11 @@ class EndpointDependant:
def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]: def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]:
lifespan_dependencies = cast( lifespan_dependencies = cast(
List[Union[EndpointDependant, LifespanDependant]], List[Union[EndpointDependant, LifespanDependant]],
self.lifespan_dependencies self.lifespan_dependencies,
) )
endpoint_dependencies = cast( endpoint_dependencies = cast(
List[Union[EndpointDependant, LifespanDependant]], List[Union[EndpointDependant, LifespanDependant]],
self.endpoint_dependencies self.endpoint_dependencies,
) )
return tuple(lifespan_dependencies + endpoint_dependencies) return tuple(lifespan_dependencies + endpoint_dependencies)

68
fastapi/dependencies/utils.py

@ -147,11 +147,7 @@ def get_param_sub_dependant(
def get_parameterless_sub_dependant( def get_parameterless_sub_dependant(
*, *, depends: params.Depends, path: str, caller: Callable[..., Any], index: int
depends: params.Depends,
path: str,
caller: Callable[..., Any],
index: int
) -> Union[EndpointDependant, LifespanDependant]: ) -> Union[EndpointDependant, LifespanDependant]:
assert callable( assert callable(
depends.dependency depends.dependency
@ -161,7 +157,7 @@ def get_parameterless_sub_dependant(
dependency=depends.dependency, dependency=depends.dependency,
path=path, path=path,
caller=caller, caller=caller,
index=index index=index,
) )
@ -181,7 +177,7 @@ def get_sub_dependant(
call=dependency, call=dependency,
name=name, name=name,
use_cache=depends.use_cache, use_cache=depends.use_cache,
index=index index=index,
) )
elif depends.dependency_scope == "endpoint": elif depends.dependency_scope == "endpoint":
security_requirement = None security_requirement = None
@ -202,15 +198,15 @@ def get_sub_dependant(
name=name, name=name,
security_scopes=security_scopes, security_scopes=security_scopes,
use_cache=depends.use_cache, use_cache=depends.use_cache,
index=index index=index,
) )
if security_requirement: if security_requirement:
sub_dependant.security_requirements.append(security_requirement) sub_dependant.security_requirements.append(security_requirement)
return sub_dependant return sub_dependant
else: else:
raise InvalidDependencyScope( raise InvalidDependencyScope(
f"Dependency \"{name}\" of {caller} has an invalid " f'Dependency "{name}" of {caller} has an invalid '
f"scope: \"{depends.dependency_scope}\"" f'scope: "{depends.dependency_scope}"'
) )
@ -233,7 +229,7 @@ def get_flat_dependant(
security_requirements=dependant.security_requirements.copy(), security_requirements=dependant.security_requirements.copy(),
lifespan_dependencies=dependant.lifespan_dependencies.copy(), lifespan_dependencies=dependant.lifespan_dependencies.copy(),
use_cache=dependant.use_cache, use_cache=dependant.use_cache,
path=dependant.path path=dependant.path,
) )
for sub_dependant in dependant.endpoint_dependencies: for sub_dependant in dependant.endpoint_dependencies:
if skip_repeats and sub_dependant.cache_key in visited: if skip_repeats and sub_dependant.cache_key in visited:
@ -310,16 +306,12 @@ def get_lifespan_dependant(
call: Callable[..., Any], call: Callable[..., Any],
name: Optional[str] = None, name: Optional[str] = None,
use_cache: bool = True, use_cache: bool = True,
index: Optional[int] = None index: Optional[int] = None,
) -> LifespanDependant: ) -> LifespanDependant:
dependency_signature = get_typed_signature(call) dependency_signature = get_typed_signature(call)
signature_params = dependency_signature.parameters signature_params = dependency_signature.parameters
dependant = LifespanDependant( dependant = LifespanDependant(
call=call, call=call, name=name, use_cache=use_cache, caller=caller, index=index
name=name,
use_cache=use_cache,
caller=caller,
index=index
) )
for param_name, param in signature_params.items(): for param_name, param in signature_params.items():
param_details = analyze_param( param_details = analyze_param(
@ -330,16 +322,17 @@ def get_lifespan_dependant(
) )
if param_details.depends is None: if param_details.depends is None:
raise DependencyScopeConflict( raise DependencyScopeConflict(
f"Lifespan scoped dependency \"{dependant.name}\" was defined " f'Lifespan scoped dependency "{dependant.name}" was defined '
f"with an invalid argument: \"{param_name}\" which is " f'with an invalid argument: "{param_name}" which is '
f"\"endpoint\" scoped. Lifespan scoped dependencies may only " f'"endpoint" scoped. Lifespan scoped dependencies may only '
f"use lifespan scoped sub-dependencies.") f"use lifespan scoped sub-dependencies."
)
if param_details.depends.dependency_scope != "lifespan": if param_details.depends.dependency_scope != "lifespan":
raise DependencyScopeConflict( raise DependencyScopeConflict(
f"Lifespan scoped dependency {dependant.name} was defined with the " f"Lifespan scoped dependency {dependant.name} was defined with the "
f"sub-dependency \"{param_name}\" which is " f'sub-dependency "{param_name}" which is '
f"\"{param_details.depends.dependency_scope}\" scoped. " f'"{param_details.depends.dependency_scope}" scoped. '
f"Lifespan scoped dependencies may only use lifespan scoped " f"Lifespan scoped dependencies may only use lifespan scoped "
f"sub-dependencies." f"sub-dependencies."
) )
@ -350,7 +343,7 @@ def get_lifespan_dependant(
name=param_name, name=param_name,
call=param_details.depends.dependency, call=param_details.depends.dependency,
use_cache=param_details.depends.use_cache, use_cache=param_details.depends.use_cache,
caller=call caller=call,
) )
dependant.dependencies.append(sub_dependant) dependant.dependencies.append(sub_dependant)
@ -364,7 +357,7 @@ def get_endpoint_dependant(
name: Optional[str] = None, name: Optional[str] = None,
security_scopes: Optional[List[str]] = None, security_scopes: Optional[List[str]] = None,
use_cache: bool = True, use_cache: bool = True,
index: Optional[int] = None index: Optional[int] = None,
) -> EndpointDependant: ) -> EndpointDependant:
path_param_names = get_path_param_names(path) path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call) endpoint_signature = get_typed_signature(call)
@ -375,7 +368,7 @@ def get_endpoint_dependant(
path=path, path=path,
security_scopes=security_scopes, security_scopes=security_scopes,
use_cache=use_cache, use_cache=use_cache,
index=index index=index,
) )
for param_name, param in signature_params.items(): for param_name, param in signature_params.items():
is_path_param = param_name in path_param_names is_path_param = param_name in path_param_names
@ -692,14 +685,12 @@ async def solve_lifespan_dependant(
call = dependant.call call = dependant.call
dependant_to_solve = dependant dependant_to_solve = dependant
if ( if (
dependency_overrides_provider dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides and dependency_overrides_provider.dependency_overrides
): ):
call = getattr( call = getattr(dependency_overrides_provider, "dependency_overrides", {}).get(
dependency_overrides_provider, dependant.call, dependant.call
"dependency_overrides", )
{}
).get(dependant.call, dependant.call)
dependant_to_solve = get_lifespan_dependant( dependant_to_solve = get_lifespan_dependant(
caller=dependant.caller, caller=dependant.caller,
call=call, call=call,
@ -725,9 +716,7 @@ async def solve_lifespan_dependant(
if is_gen_callable(call) or is_async_gen_callable(call): if is_gen_callable(call) or is_async_gen_callable(call):
value = await solve_generator( value = await solve_generator(
call=call, call=call, stack=async_exit_stack, sub_values=dependency_arguments
stack=async_exit_stack,
sub_values=dependency_arguments
) )
elif is_coroutine_callable(call): elif is_coroutine_callable(call):
value = await call(**dependency_arguments) value = await call(**dependency_arguments)
@ -773,7 +762,8 @@ async def solve_dependencies(
try: try:
lifespan_scoped_dependencies = request.state.__fastapi__[ lifespan_scoped_dependencies = request.state.__fastapi__[
"lifespan_scoped_dependencies"] "lifespan_scoped_dependencies"
]
except (AttributeError, KeyError) as e: except (AttributeError, KeyError) as e:
raise UninitializedLifespanDependency( raise UninitializedLifespanDependency(
"FastAPI's internal lifespan was not initialized correctly." "FastAPI's internal lifespan was not initialized correctly."
@ -783,8 +773,8 @@ async def solve_dependencies(
value = lifespan_scoped_dependencies[lifespan_sub_dependant.cache_key] value = lifespan_scoped_dependencies[lifespan_sub_dependant.cache_key]
except KeyError as e: except KeyError as e:
raise UninitializedLifespanDependency( raise UninitializedLifespanDependency(
f"Dependency \"{lifespan_sub_dependant.name}\" of " f'Dependency "{lifespan_sub_dependant.name}" of '
f"`{dependant.call}` was not initialized correctly." f"`{dependant.call}` was not initialized correctly."
) from e ) from e
values[lifespan_sub_dependant.name] = value values[lifespan_sub_dependant.name] = value

18
fastapi/routing.py

@ -400,13 +400,12 @@ class APIWebSocketRoute(routing.WebSocketRoute):
self.name = get_name(endpoint) if name is None else name self.name = get_name(endpoint) if name is None else name
self.dependencies = list(dependencies or []) self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) self.dependant = get_endpoint_dependant(
path=self.path_format, call=self.endpoint
)
for i, depends in list(enumerate(self.dependencies))[::-1]: for i, depends in list(enumerate(self.dependencies))[::-1]:
sub_dependant = get_parameterless_sub_dependant( sub_dependant = get_parameterless_sub_dependant(
depends=depends, depends=depends, path=self.path_format, caller=self.__call__, index=i
path=self.path_format,
caller=self.__call__,
index=i
) )
if depends.dependency_scope == "endpoint": if depends.dependency_scope == "endpoint":
assert isinstance(sub_dependant, EndpointDependant) assert isinstance(sub_dependant, EndpointDependant)
@ -566,13 +565,12 @@ class APIRoute(routing.Route):
self.response_fields = {} self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable" assert callable(endpoint), "An endpoint must be a callable"
self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) self.dependant = get_endpoint_dependant(
path=self.path_format, call=self.endpoint
)
for i, depends in list(enumerate(self.dependencies))[::-1]: for i, depends in list(enumerate(self.dependencies))[::-1]:
sub_dependant = get_parameterless_sub_dependant( sub_dependant = get_parameterless_sub_dependant(
depends=depends, depends=depends, path=self.path_format, caller=self.__call__, index=i
path=self.path_format,
caller=self.__call__,
index=i
) )
if depends.dependency_scope == "endpoint": if depends.dependency_scope == "endpoint":
assert isinstance(sub_dependant, EndpointDependant) assert isinstance(sub_dependant, EndpointDependant)

280
tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py

@ -33,13 +33,13 @@ from tests.test_lifespan_scoped_dependencies.testing_utilities import (
def expect_correct_amount_of_dependency_activations( def expect_correct_amount_of_dependency_activations(
*, *,
app: FastAPI, app: FastAPI,
dependency_factory: DependencyFactory, dependency_factory: DependencyFactory,
override_dependency_factory: DependencyFactory, override_dependency_factory: DependencyFactory,
urls_and_responses: List[Tuple[str, Any]], urls_and_responses: List[Tuple[str, Any]],
expected_activation_times: int, expected_activation_times: int,
is_websocket: bool is_websocket: bool,
) -> None: ) -> None:
assert dependency_factory.activation_times == 0 assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0 assert dependency_factory.deactivation_times == 0
@ -62,17 +62,22 @@ def expect_correct_amount_of_dependency_activations(
assert dependency_factory.activation_times == 0 assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0 assert dependency_factory.deactivation_times == 0
assert override_dependency_factory.activation_times == expected_activation_times assert (
override_dependency_factory.activation_times
== expected_activation_times
)
assert override_dependency_factory.deactivation_times == 0 assert override_dependency_factory.deactivation_times == 0
assert dependency_factory.activation_times == 0 assert dependency_factory.activation_times == 0
assert override_dependency_factory.activation_times == expected_activation_times assert override_dependency_factory.activation_times == expected_activation_times
if dependency_factory.dependency_style not in ( if dependency_factory.dependency_style not in (
DependencyStyle.SYNC_FUNCTION, DependencyStyle.SYNC_FUNCTION,
DependencyStyle.ASYNC_FUNCTION DependencyStyle.ASYNC_FUNCTION,
): ):
assert dependency_factory.deactivation_times == 0 assert dependency_factory.deactivation_times == 0
assert override_dependency_factory.deactivation_times == expected_activation_times assert (
override_dependency_factory.deactivation_times == expected_activation_times
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@ -80,16 +85,10 @@ def expect_correct_amount_of_dependency_activations(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoint_dependencies( def test_endpoint_dependencies(
dependency_style: DependencyStyle, dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
routing_style,
use_cache,
is_websocket
): ):
dependency_factory = DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory( override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
dependency_style,
value_offset=10
)
app = FastAPI() app = FastAPI()
@ -108,14 +107,16 @@ def test_endpoint_dependencies(
dependency_factory.get_dependency(), dependency_factory.get_dependency(),
dependency_scope="lifespan", dependency_scope="lifespan",
use_cache=use_cache, use_cache=use_cache,
) ),
], ],
expected_value=11 expected_value=11,
) )
if routing_style == "router_endpoint": if routing_style == "router_endpoint":
app.include_router(router) app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations( expect_correct_amount_of_dependency_activations(
app=app, app=app,
@ -123,7 +124,7 @@ def test_endpoint_dependencies(
override_dependency_factory=override_dependency_factory, override_dependency_factory=override_dependency_factory,
urls_and_responses=[("/test", 11)] * 2, urls_and_responses=[("/test", 11)] * 2,
expected_activation_times=1, expected_activation_times=1,
is_websocket=is_websocket is_websocket=is_websocket,
) )
@ -133,45 +134,40 @@ def test_endpoint_dependencies(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("routing_style", ["app", "router"])
def test_router_dependencies( def test_router_dependencies(
dependency_style: DependencyStyle, dependency_style: DependencyStyle,
routing_style, routing_style,
use_cache, use_cache,
dependency_duplication, dependency_duplication,
is_websocket is_websocket,
): ):
dependency_factory= DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory( override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
dependency_style,
value_offset=10
)
depends = Depends( depends = Depends(
dependency_factory.get_dependency(), dependency_factory.get_dependency(),
dependency_scope="lifespan", dependency_scope="lifespan",
use_cache=use_cache use_cache=use_cache,
) )
if routing_style == "app": if routing_style == "app":
app = FastAPI(dependencies=[depends] * dependency_duplication) app = FastAPI(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations( create_endpoint_0_annotations(
router=app, router=app, path="/test", is_websocket=is_websocket
path="/test",
is_websocket=is_websocket
) )
else: else:
app = FastAPI() app = FastAPI()
router = APIRouter(dependencies=[depends] * dependency_duplication) router = APIRouter(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations( create_endpoint_0_annotations(
router=router, router=router, path="/test", is_websocket=is_websocket
path="/test",
is_websocket=is_websocket
) )
app.include_router(router) app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations( expect_correct_amount_of_dependency_activations(
app=app, app=app,
@ -179,31 +175,29 @@ def test_router_dependencies(
override_dependency_factory=override_dependency_factory, override_dependency_factory=override_dependency_factory,
urls_and_responses=[("/test", None)] * 2, urls_and_responses=[("/test", None)] * 2,
expected_activation_times=1 if use_cache else dependency_duplication, expected_activation_times=1 if use_cache else dependency_duplication,
is_websocket=is_websocket is_websocket=is_websocket,
) )
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("routing_style", ["app", "router"])
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) @pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"])
def test_dependency_cache_in_same_dependency( def test_dependency_cache_in_same_dependency(
dependency_style: DependencyStyle, dependency_style: DependencyStyle,
routing_style, routing_style,
use_cache, use_cache,
main_dependency_scope: Literal["endpoint", "lifespan"], main_dependency_scope: Literal["endpoint", "lifespan"],
is_websocket is_websocket,
): ):
dependency_factory= DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory( override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
dependency_style,
value_offset=10
)
depends = Depends( depends = Depends(
dependency_factory.get_dependency(), dependency_factory.get_dependency(),
dependency_scope="lifespan", dependency_scope="lifespan",
use_cache=use_cache use_cache=use_cache,
) )
app = FastAPI() app = FastAPI()
@ -215,8 +209,8 @@ def test_dependency_cache_in_same_dependency(
router = APIRouter() router = APIRouter()
async def dependency( async def dependency(
sub_dependency1: Annotated[int, depends], sub_dependency1: Annotated[int, depends],
sub_dependency2: Annotated[int, depends], sub_dependency2: Annotated[int, depends],
) -> List[int]: ) -> List[int]:
return [sub_dependency1, sub_dependency2] return [sub_dependency1, sub_dependency2]
@ -224,19 +218,22 @@ def test_dependency_cache_in_same_dependency(
router=router, router=router,
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[List[int], Depends( annotation=Annotated[
dependency, List[int],
use_cache=use_cache, Depends(
dependency_scope=main_dependency_scope, dependency,
)] use_cache=use_cache,
dependency_scope=main_dependency_scope,
),
],
) )
if routing_style == "router": if routing_style == "router":
app.include_router(router) app.include_router(router)
app.dependency_overrides[ app.dependency_overrides[dependency_factory.get_dependency()] = (
dependency_factory.get_dependency() override_dependency_factory.get_dependency()
] = override_dependency_factory.get_dependency() )
if use_cache: if use_cache:
expect_correct_amount_of_dependency_activations( expect_correct_amount_of_dependency_activations(
@ -248,7 +245,7 @@ def test_dependency_cache_in_same_dependency(
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory, override_dependency_factory=override_dependency_factory,
expected_activation_times=1, expected_activation_times=1,
is_websocket=is_websocket is_websocket=is_websocket,
) )
else: else:
expect_correct_amount_of_dependency_activations( expect_correct_amount_of_dependency_activations(
@ -260,7 +257,7 @@ def test_dependency_cache_in_same_dependency(
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory, override_dependency_factory=override_dependency_factory,
expected_activation_times=2, expected_activation_times=2,
is_websocket=is_websocket is_websocket=is_websocket,
) )
@ -269,21 +266,15 @@ def test_dependency_cache_in_same_dependency(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_same_endpoint( def test_dependency_cache_in_same_endpoint(
dependency_style: DependencyStyle, dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
routing_style,
use_cache,
is_websocket
): ):
dependency_factory= DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory( override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
dependency_style,
value_offset=10
)
depends = Depends( depends = Depends(
dependency_factory.get_dependency(), dependency_factory.get_dependency(),
dependency_scope="lifespan", dependency_scope="lifespan",
use_cache=use_cache use_cache=use_cache,
) )
app = FastAPI() app = FastAPI()
@ -309,9 +300,9 @@ def test_dependency_cache_in_same_endpoint(
if routing_style == "router": if routing_style == "router":
app.include_router(router) app.include_router(router)
app.dependency_overrides[ app.dependency_overrides[dependency_factory.get_dependency()] = (
dependency_factory.get_dependency() override_dependency_factory.get_dependency()
] = override_dependency_factory.get_dependency() )
if use_cache: if use_cache:
expect_correct_amount_of_dependency_activations( expect_correct_amount_of_dependency_activations(
@ -323,7 +314,7 @@ def test_dependency_cache_in_same_endpoint(
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory, override_dependency_factory=override_dependency_factory,
expected_activation_times=1, expected_activation_times=1,
is_websocket=is_websocket is_websocket=is_websocket,
) )
else: else:
expect_correct_amount_of_dependency_activations( expect_correct_amount_of_dependency_activations(
@ -335,29 +326,24 @@ def test_dependency_cache_in_same_endpoint(
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory, override_dependency_factory=override_dependency_factory,
expected_activation_times=3, expected_activation_times=3,
is_websocket=is_websocket is_websocket=is_websocket,
) )
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_different_endpoints( def test_dependency_cache_in_different_endpoints(
dependency_style: DependencyStyle, dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
routing_style,
use_cache,
is_websocket
): ):
dependency_factory= DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory( override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
dependency_style,
value_offset=10
)
depends = Depends( depends = Depends(
dependency_factory.get_dependency(), dependency_factory.get_dependency(),
dependency_scope="lifespan", dependency_scope="lifespan",
use_cache=use_cache use_cache=use_cache,
) )
app = FastAPI() app = FastAPI()
@ -392,8 +378,9 @@ def test_dependency_cache_in_different_endpoints(
if routing_style == "router": if routing_style == "router":
app.include_router(router) app.include_router(router)
app.dependency_overrides[ app.dependency_overrides[dependency_factory.get_dependency()] = (
dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() override_dependency_factory.get_dependency()
)
if use_cache: if use_cache:
expect_correct_amount_of_dependency_activations( expect_correct_amount_of_dependency_activations(
@ -407,7 +394,7 @@ def test_dependency_cache_in_different_endpoints(
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory, override_dependency_factory=override_dependency_factory,
expected_activation_times=1, expected_activation_times=1,
is_websocket=is_websocket is_websocket=is_websocket,
) )
else: else:
expect_correct_amount_of_dependency_activations( expect_correct_amount_of_dependency_activations(
@ -421,27 +408,23 @@ def test_dependency_cache_in_different_endpoints(
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory, override_dependency_factory=override_dependency_factory,
expected_activation_times=5, expected_activation_times=5,
is_websocket=is_websocket is_websocket=is_websocket,
) )
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("routing_style", ["app", "router"])
def test_no_cached_dependency( def test_no_cached_dependency(
dependency_style: DependencyStyle, dependency_style: DependencyStyle, routing_style, is_websocket
routing_style,
is_websocket
): ):
dependency_factory= DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory( override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
dependency_style,
value_offset=10
)
depends = Depends( depends = Depends(
dependency_factory.get_dependency(), dependency_factory.get_dependency(),
dependency_scope="lifespan", dependency_scope="lifespan",
use_cache=False use_cache=False,
) )
app = FastAPI() app = FastAPI()
@ -462,8 +445,9 @@ def test_no_cached_dependency(
if routing_style == "router": if routing_style == "router":
app.include_router(router) app.include_router(router)
app.dependency_overrides[ app.dependency_overrides[dependency_factory.get_dependency()] = (
dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations( expect_correct_amount_of_dependency_activations(
app=app, app=app,
@ -471,24 +455,27 @@ def test_no_cached_dependency(
override_dependency_factory=override_dependency_factory, override_dependency_factory=override_dependency_factory,
urls_and_responses=[("/test", 11)] * 2, urls_and_responses=[("/test", 11)] * 2,
expected_activation_times=1, expected_activation_times=1,
is_websocket=is_websocket is_websocket=is_websocket,
) )
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("annotation", [ @pytest.mark.parametrize(
Annotated[str, Path()], "annotation",
Annotated[str, Body()], [
Annotated[str, Query()], Annotated[str, Path()],
Annotated[str, Header()], Annotated[str, Body()],
SecurityScopes, Annotated[str, Query()],
Annotated[str, Cookie()], Annotated[str, Header()],
Annotated[str, Form()], SecurityScopes,
Annotated[str, File()], Annotated[str, Cookie()],
BackgroundTasks, Annotated[str, Form()],
]) Annotated[str, File()],
BackgroundTasks,
],
)
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
annotation, annotation, is_websocket
is_websocket
): ):
async def dependency_func() -> None: async def dependency_func() -> None:
yield yield
@ -503,9 +490,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_paramete
router=app, router=app,
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[None, annotation=Annotated[
Depends(dependency_func, dependency_scope="lifespan") None, Depends(dependency_func, dependency_scope="lifespan")
] ],
) )
with pytest.raises(DependencyScopeConflict): with pytest.raises(DependencyScopeConflict):
@ -516,20 +503,16 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_paramete
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies( def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies(
dependency_style: DependencyStyle, dependency_style: DependencyStyle, is_websocket
is_websocket
): ):
dependency_factory = DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory( override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
dependency_style,
value_offset=10
)
async def lifespan_scoped_dependency( async def lifespan_scoped_dependency(
param: Annotated[int, Depends( param: Annotated[
dependency_factory.get_dependency(), int,
dependency_scope="lifespan" Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"),
)] ],
) -> AsyncGenerator[int, None]: ) -> AsyncGenerator[int, None]:
yield param yield param
@ -540,12 +523,13 @@ def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_sco
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[ annotation=Annotated[
int, int, Depends(lifespan_scoped_dependency, dependency_scope="lifespan")
Depends(lifespan_scoped_dependency, dependency_scope="lifespan")
], ],
) )
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations( expect_correct_amount_of_dependency_activations(
app=app, app=app,
@ -553,15 +537,14 @@ def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_sco
override_dependency_factory=override_dependency_factory, override_dependency_factory=override_dependency_factory,
expected_activation_times=1, expected_activation_times=1,
urls_and_responses=[("/test", 11)] * 2, urls_and_responses=[("/test", 11)] * 2,
is_websocket=is_websocket is_websocket=is_websocket,
) )
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("depends_class", [Depends, Security]) @pytest.mark.parametrize("depends_class", [Depends, Security])
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
depends_class, depends_class, is_websocket
is_websocket
): ):
async def sub_dependency() -> None: async def sub_dependency() -> None:
pass pass
@ -569,7 +552,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen
async def dependency_func() -> None: async def dependency_func() -> None:
yield yield
async def override_dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: async def override_dependency_func(
param: Annotated[None, depends_class(sub_dependency)],
) -> None:
yield yield
app = FastAPI() app = FastAPI()
@ -578,7 +563,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen
router=app, router=app,
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] annotation=Annotated[
None, Depends(dependency_func, dependency_scope="lifespan")
],
) )
app.dependency_overrides[dependency_func] = override_dependency_func app.dependency_overrides[dependency_func] = override_dependency_func
@ -593,12 +580,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_bad_override_lifespan_scoped_dependencies( def test_bad_override_lifespan_scoped_dependencies(
use_cache, use_cache, dependency_style: DependencyStyle, routing_style, is_websocket
dependency_style: DependencyStyle,
routing_style,
is_websocket
): ):
dependency_factory= DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, should_error=True) override_dependency_factory = DependencyFactory(dependency_style, should_error=True)
depends = Depends( depends = Depends(
@ -619,13 +603,15 @@ def test_bad_override_lifespan_scoped_dependencies(
router=router, router=router,
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[int, depends] annotation=Annotated[int, depends],
) )
if routing_style == "router_endpoint": if routing_style == "router_endpoint":
app.include_router(router) app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
with pytest.raises(IntentionallyBadDependency) as exception_info: with pytest.raises(IntentionallyBadDependency) as exception_info:
with TestClient(app): with TestClient(app):

334
tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py

@ -40,12 +40,12 @@ from tests.test_lifespan_scoped_dependencies.testing_utilities import (
def expect_correct_amount_of_dependency_activations( def expect_correct_amount_of_dependency_activations(
*, *,
app: FastAPI, app: FastAPI,
dependency_factory: DependencyFactory, dependency_factory: DependencyFactory,
urls_and_responses: List[Tuple[str, Any]], urls_and_responses: List[Tuple[str, Any]],
expected_activation_times: int, expected_activation_times: int,
is_websocket: bool is_websocket: bool,
) -> None: ) -> None:
assert dependency_factory.activation_times == 0 assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0 assert dependency_factory.deactivation_times == 0
@ -64,21 +64,23 @@ def expect_correct_amount_of_dependency_activations(
assert dependency_factory.activation_times == expected_activation_times assert dependency_factory.activation_times == expected_activation_times
if dependency_factory.dependency_style not in ( if dependency_factory.dependency_style not in (
DependencyStyle.SYNC_FUNCTION, DependencyStyle.SYNC_FUNCTION,
DependencyStyle.ASYNC_FUNCTION DependencyStyle.ASYNC_FUNCTION,
): ):
assert dependency_factory.deactivation_times == expected_activation_times assert dependency_factory.deactivation_times == expected_activation_times
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False], ids=["With Cache", "Without Cache"]) @pytest.mark.parametrize(
"use_cache", [True, False], ids=["With Cache", "Without Cache"]
)
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoint_dependencies( def test_endpoint_dependencies(
dependency_style: DependencyStyle, dependency_style: DependencyStyle,
routing_style, routing_style,
use_cache, use_cache,
is_websocket: bool, is_websocket: bool,
): ):
dependency_factory = DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
@ -93,12 +95,15 @@ def test_endpoint_dependencies(
router=router, router=router,
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[None, Depends( annotation=Annotated[
None,
Depends(
dependency_factory.get_dependency(), dependency_factory.get_dependency(),
dependency_scope="lifespan", dependency_scope="lifespan",
use_cache=use_cache, use_cache=use_cache,
)], ),
expected_value=1 ],
expected_value=1,
) )
if routing_style == "router_endpoint": if routing_style == "router_endpoint":
@ -109,45 +114,42 @@ def test_endpoint_dependencies(
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2, urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1, expected_activation_times=1,
is_websocket=is_websocket is_websocket=is_websocket,
) )
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_duplication", [1, 2]) @pytest.mark.parametrize("dependency_duplication", [1, 2])
@pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("routing_style", ["app", "router"])
def test_router_dependencies( def test_router_dependencies(
dependency_style: DependencyStyle, dependency_style: DependencyStyle,
routing_style, routing_style,
use_cache, use_cache,
dependency_duplication, dependency_duplication,
is_websocket: bool, is_websocket: bool,
): ):
dependency_factory= DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
depends = Depends( depends = Depends(
dependency_factory.get_dependency(), dependency_factory.get_dependency(),
dependency_scope="lifespan", dependency_scope="lifespan",
use_cache=use_cache use_cache=use_cache,
) )
if routing_style == "app": if routing_style == "app":
app = FastAPI(dependencies=[depends] * dependency_duplication) app = FastAPI(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations( create_endpoint_0_annotations(
router=app, router=app, path="/test", is_websocket=is_websocket
path="/test",
is_websocket=is_websocket
) )
else: else:
app = FastAPI() app = FastAPI()
router = APIRouter(dependencies=[depends] * dependency_duplication) router = APIRouter(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations( create_endpoint_0_annotations(
router=router, router=router, path="/test", is_websocket=is_websocket
path="/test",
is_websocket=is_websocket
) )
app.include_router(router) app.include_router(router)
@ -157,27 +159,28 @@ def test_router_dependencies(
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
urls_and_responses=[("/test", None)] * 2, urls_and_responses=[("/test", None)] * 2,
expected_activation_times=1 if use_cache else dependency_duplication, expected_activation_times=1 if use_cache else dependency_duplication,
is_websocket=is_websocket is_websocket=is_websocket,
) )
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("routing_style", ["app", "router"])
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) @pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"])
def test_dependency_cache_in_same_dependency( def test_dependency_cache_in_same_dependency(
dependency_style: DependencyStyle, dependency_style: DependencyStyle,
routing_style, routing_style,
use_cache, use_cache,
main_dependency_scope: Literal["endpoint", "lifespan"], main_dependency_scope: Literal["endpoint", "lifespan"],
is_websocket: bool, is_websocket: bool,
): ):
dependency_factory= DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
depends = Depends( depends = Depends(
dependency_factory.get_dependency(), dependency_factory.get_dependency(),
dependency_scope="lifespan", dependency_scope="lifespan",
use_cache=use_cache use_cache=use_cache,
) )
app = FastAPI() app = FastAPI()
@ -189,8 +192,8 @@ def test_dependency_cache_in_same_dependency(
router = APIRouter() router = APIRouter()
async def dependency( async def dependency(
sub_dependency1: Annotated[int, depends], sub_dependency1: Annotated[int, depends],
sub_dependency2: Annotated[int, depends], sub_dependency2: Annotated[int, depends],
) -> List[int]: ) -> List[int]:
return [sub_dependency1, sub_dependency2] return [sub_dependency1, sub_dependency2]
@ -198,11 +201,14 @@ def test_dependency_cache_in_same_dependency(
router=router, router=router,
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[List[int], Depends( annotation=Annotated[
dependency, List[int],
use_cache=use_cache, Depends(
dependency_scope=main_dependency_scope, dependency,
)] use_cache=use_cache,
dependency_scope=main_dependency_scope,
),
],
) )
if routing_style == "router": if routing_style == "router":
@ -217,7 +223,7 @@ def test_dependency_cache_in_same_dependency(
], ],
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
expected_activation_times=1, expected_activation_times=1,
is_websocket=is_websocket is_websocket=is_websocket,
) )
else: else:
expect_correct_amount_of_dependency_activations( expect_correct_amount_of_dependency_activations(
@ -228,7 +234,7 @@ def test_dependency_cache_in_same_dependency(
], ],
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
expected_activation_times=2, expected_activation_times=2,
is_websocket=is_websocket is_websocket=is_websocket,
) )
@ -237,17 +243,14 @@ def test_dependency_cache_in_same_dependency(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_same_endpoint( def test_dependency_cache_in_same_endpoint(
dependency_style: DependencyStyle, dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
routing_style,
use_cache,
is_websocket
): ):
dependency_factory= DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
depends = Depends( depends = Depends(
dependency_factory.get_dependency(), dependency_factory.get_dependency(),
dependency_scope="lifespan", dependency_scope="lifespan",
use_cache=use_cache use_cache=use_cache,
) )
app = FastAPI() app = FastAPI()
@ -267,7 +270,7 @@ def test_dependency_cache_in_same_endpoint(
is_websocket=is_websocket, is_websocket=is_websocket,
annotation1=Annotated[int, depends], annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends], annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)] annotation3=Annotated[int, Depends(endpoint_dependency)],
) )
if routing_style == "router": if routing_style == "router":
@ -282,7 +285,7 @@ def test_dependency_cache_in_same_endpoint(
], ],
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
expected_activation_times=1, expected_activation_times=1,
is_websocket=is_websocket is_websocket=is_websocket,
) )
else: else:
expect_correct_amount_of_dependency_activations( expect_correct_amount_of_dependency_activations(
@ -293,25 +296,23 @@ def test_dependency_cache_in_same_endpoint(
], ],
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
expected_activation_times=3, expected_activation_times=3,
is_websocket=is_websocket is_websocket=is_websocket,
) )
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_different_endpoints( def test_dependency_cache_in_different_endpoints(
dependency_style: DependencyStyle, dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
routing_style,
use_cache,
is_websocket
): ):
dependency_factory= DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
depends = Depends( depends = Depends(
dependency_factory.get_dependency(), dependency_factory.get_dependency(),
dependency_scope="lifespan", dependency_scope="lifespan",
use_cache=use_cache use_cache=use_cache,
) )
app = FastAPI() app = FastAPI()
@ -331,7 +332,7 @@ def test_dependency_cache_in_different_endpoints(
is_websocket=is_websocket, is_websocket=is_websocket,
annotation1=Annotated[int, depends], annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends], annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)] annotation3=Annotated[int, Depends(endpoint_dependency)],
) )
create_endpoint_3_annotations( create_endpoint_3_annotations(
@ -340,7 +341,7 @@ def test_dependency_cache_in_different_endpoints(
is_websocket=is_websocket, is_websocket=is_websocket,
annotation1=Annotated[int, depends], annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends], annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)] annotation3=Annotated[int, Depends(endpoint_dependency)],
) )
if routing_style == "router": if routing_style == "router":
@ -357,7 +358,7 @@ def test_dependency_cache_in_different_endpoints(
], ],
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
expected_activation_times=1, expected_activation_times=1,
is_websocket=is_websocket is_websocket=is_websocket,
) )
else: else:
expect_correct_amount_of_dependency_activations( expect_correct_amount_of_dependency_activations(
@ -370,23 +371,24 @@ def test_dependency_cache_in_different_endpoints(
], ],
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
expected_activation_times=5, expected_activation_times=5,
is_websocket=is_websocket is_websocket=is_websocket,
) )
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("routing_style", ["app", "router"])
def test_no_cached_dependency( def test_no_cached_dependency(
dependency_style: DependencyStyle, dependency_style: DependencyStyle,
routing_style, routing_style,
is_websocket, is_websocket,
): ):
dependency_factory= DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
depends = Depends( depends = Depends(
dependency_factory.get_dependency(), dependency_factory.get_dependency(),
dependency_scope="lifespan", dependency_scope="lifespan",
use_cache=False use_cache=False,
) )
app = FastAPI() app = FastAPI()
@ -402,7 +404,7 @@ def test_no_cached_dependency(
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[int, depends], annotation=Annotated[int, depends],
expected_value=1 expected_value=1,
) )
if routing_style == "router": if routing_style == "router":
@ -413,25 +415,27 @@ def test_no_cached_dependency(
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2, urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1, expected_activation_times=1,
is_websocket=is_websocket is_websocket=is_websocket,
) )
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("annotation", [ @pytest.mark.parametrize(
Annotated[str, Path()], "annotation",
Annotated[str, Body()], [
Annotated[str, Query()], Annotated[str, Path()],
Annotated[str, Header()], Annotated[str, Body()],
SecurityScopes, Annotated[str, Query()],
Annotated[str, Cookie()], Annotated[str, Header()],
Annotated[str, Form()], SecurityScopes,
Annotated[str, File()], Annotated[str, Cookie()],
BackgroundTasks, Annotated[str, Form()],
]) Annotated[str, File()],
BackgroundTasks,
],
)
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
annotation, annotation, is_websocket
is_websocket
): ):
async def dependency_func(param: annotation) -> None: async def dependency_func(param: annotation) -> None:
yield yield
@ -444,8 +448,7 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[ annotation=Annotated[
None, None, Depends(dependency_func, dependency_scope="lifespan")
Depends(dependency_func, dependency_scope="lifespan")
], ],
) )
@ -453,16 +456,15 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies(
dependency_style: DependencyStyle, dependency_style: DependencyStyle, is_websocket
is_websocket
): ):
dependency_factory = DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
async def lifespan_scoped_dependency( async def lifespan_scoped_dependency(
param: Annotated[int, Depends( param: Annotated[
dependency_factory.get_dependency(), int,
dependency_scope="lifespan" Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"),
)] ],
) -> AsyncGenerator[int, None]: ) -> AsyncGenerator[int, None]:
yield param yield param
@ -473,7 +475,7 @@ def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies(
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[int, Depends(lifespan_scoped_dependency)], annotation=Annotated[int, Depends(lifespan_scoped_dependency)],
expected_value=1 expected_value=1,
) )
expect_correct_amount_of_dependency_activations( expect_correct_amount_of_dependency_activations(
@ -481,24 +483,22 @@ def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies(
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
expected_activation_times=1, expected_activation_times=1,
urls_and_responses=[("/test", 1)] * 2, urls_and_responses=[("/test", 1)] * 2,
is_websocket=is_websocket is_websocket=is_websocket,
) )
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize([ @pytest.mark.parametrize(
"dependency_style", ["dependency_style", "supports_teardown"],
"supports_teardown" [
], [ (DependencyStyle.SYNC_FUNCTION, False),
(DependencyStyle.SYNC_FUNCTION, False), (DependencyStyle.ASYNC_FUNCTION, False),
(DependencyStyle.ASYNC_FUNCTION, False), (DependencyStyle.SYNC_GENERATOR, True),
(DependencyStyle.SYNC_GENERATOR, True), (DependencyStyle.ASYNC_GENERATOR, True),
(DependencyStyle.ASYNC_GENERATOR, True), ],
]) )
def test_the_same_dependency_can_work_in_different_scopes( def test_the_same_dependency_can_work_in_different_scopes(
dependency_style: DependencyStyle, dependency_style: DependencyStyle, supports_teardown, is_websocket
supports_teardown,
is_websocket
): ):
dependency_factory = DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
app = FastAPI() app = FastAPI()
@ -507,14 +507,14 @@ def test_the_same_dependency_can_work_in_different_scopes(
router=app, router=app,
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation1=Annotated[int, Depends( annotation1=Annotated[
dependency_factory.get_dependency(), int,
dependency_scope="endpoint" Depends(dependency_factory.get_dependency(), dependency_scope="endpoint"),
)], ],
annotation2=Annotated[int, Depends( annotation2=Annotated[
dependency_factory.get_dependency(), int,
dependency_scope="lifespan" Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"),
)], ],
) )
if is_websocket: if is_websocket:
get_response = use_websocket get_response = use_websocket
@ -548,17 +548,20 @@ def test_the_same_dependency_can_work_in_different_scopes(
assert dependency_factory.deactivation_times == 0 assert dependency_factory.deactivation_times == 0
@pytest.mark.parametrize("lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"]) @pytest.mark.parametrize(
"lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"]
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
dependency_style: DependencyStyle, dependency_style: DependencyStyle,
is_websocket, is_websocket,
lifespan_style: Literal["lifespan_function", "lifespan_events"] lifespan_style: Literal["lifespan_function", "lifespan_events"],
): ):
lifespan_started = False lifespan_started = False
lifespan_ended = False lifespan_ended = False
if lifespan_style == "lifespan_generator": if lifespan_style == "lifespan_generator":
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, int], None]: async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, int], None]:
nonlocal lifespan_started nonlocal lifespan_started
@ -571,6 +574,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
elif lifespan_style == "events_decorator": elif lifespan_style == "events_decorator":
app = FastAPI() app = FastAPI()
with warnings.catch_warnings(action="ignore", category=DeprecationWarning): with warnings.catch_warnings(action="ignore", category=DeprecationWarning):
@app.on_event("startup") @app.on_event("startup")
async def startup() -> None: async def startup() -> None:
nonlocal lifespan_started nonlocal lifespan_started
@ -581,6 +585,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
nonlocal lifespan_ended nonlocal lifespan_ended
lifespan_ended = True lifespan_ended = True
elif lifespan_style == "events_constructor": elif lifespan_style == "events_constructor":
async def startup() -> None: async def startup() -> None:
nonlocal lifespan_started nonlocal lifespan_started
lifespan_started = True lifespan_started = True
@ -588,6 +593,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
async def shutdown() -> None: async def shutdown() -> None:
nonlocal lifespan_ended nonlocal lifespan_ended
lifespan_ended = True lifespan_ended = True
app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) app = FastAPI(on_startup=[startup], on_shutdown=[shutdown])
else: else:
assert_never(lifespan_style) assert_never(lifespan_style)
@ -598,11 +604,11 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
router=app, router=app,
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[int, Depends( annotation=Annotated[
dependency_factory.get_dependency(), int,
dependency_scope="lifespan" Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"),
)], ],
expected_value=1 expected_value=1,
) )
expect_correct_amount_of_dependency_activations( expect_correct_amount_of_dependency_activations(
@ -610,20 +616,22 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
dependency_factory=dependency_factory, dependency_factory=dependency_factory,
expected_activation_times=1, expected_activation_times=1,
urls_and_responses=[("/test", 1)] * 2, urls_and_responses=[("/test", 1)] * 2,
is_websocket=is_websocket is_websocket=is_websocket,
) )
assert lifespan_started and lifespan_ended assert lifespan_started and lifespan_ended
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("depends_class", [Depends, Security]) @pytest.mark.parametrize("depends_class", [Depends, Security])
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
depends_class, depends_class, is_websocket
is_websocket
): ):
async def sub_dependency() -> None: async def sub_dependency() -> None:
pass pass
async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: async def dependency_func(
param: Annotated[None, depends_class(sub_dependency)],
) -> None:
yield yield
app = FastAPI() app = FastAPI()
@ -633,20 +641,20 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
router=app, router=app,
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")], annotation=Annotated[
None, Depends(dependency_func, dependency_scope="lifespan")
],
) )
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_dependencies_must_provide_correct_dependency_scope( def test_dependencies_must_provide_correct_dependency_scope(
dependency_style: DependencyStyle, dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
routing_style,
use_cache,
is_websocket
): ):
dependency_factory= DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
app = FastAPI() app = FastAPI()
@ -656,19 +664,21 @@ def test_dependencies_must_provide_correct_dependency_scope(
router = APIRouter() router = APIRouter()
with pytest.raises( with pytest.raises(
InvalidDependencyScope, InvalidDependencyScope,
match=r'Dependency "value" of .* has an invalid scope: ' match=r'Dependency "value" of .* has an invalid scope: ' r'"incorrect"',
r'"incorrect"'
): ):
create_endpoint_1_annotation( create_endpoint_1_annotation(
router=router, router=router,
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[None, Depends( annotation=Annotated[
dependency_factory.get_dependency(), None,
dependency_scope="incorrect", Depends(
use_cache=use_cache, dependency_factory.get_dependency(),
)] dependency_scope="incorrect",
use_cache=use_cache,
),
],
) )
@ -677,12 +687,9 @@ def test_dependencies_must_provide_correct_dependency_scope(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_incorrect_dependency_scope( def test_endpoints_report_incorrect_dependency_scope(
dependency_style: DependencyStyle, dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
routing_style,
use_cache,
is_websocket
): ):
dependency_factory= DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
app = FastAPI() app = FastAPI()
@ -705,7 +712,7 @@ def test_endpoints_report_incorrect_dependency_scope(
router=router, router=router,
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[int, depends] annotation=Annotated[int, depends],
) )
@ -714,12 +721,9 @@ def test_endpoints_report_incorrect_dependency_scope(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_dependency( def test_endpoints_report_uninitialized_dependency(
dependency_style: DependencyStyle, dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
routing_style,
use_cache,
is_websocket
): ):
dependency_factory= DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
app = FastAPI() app = FastAPI()
@ -739,7 +743,7 @@ def test_endpoints_report_uninitialized_dependency(
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[int, depends], annotation=Annotated[int, depends],
expected_value=1 expected_value=1,
) )
if routing_style == "router_endpoint": if routing_style == "router_endpoint":
@ -757,7 +761,9 @@ def test_endpoints_report_uninitialized_dependency(
else: else:
client.post("/test") client.post("/test")
finally: finally:
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = (
dependencies
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@ -765,10 +771,7 @@ def test_endpoints_report_uninitialized_dependency(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_internal_lifespan( def test_endpoints_report_uninitialized_internal_lifespan(
dependency_style: DependencyStyle, dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
routing_style,
use_cache,
is_websocket
): ):
dependency_factory = DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
@ -790,7 +793,7 @@ def test_endpoints_report_uninitialized_internal_lifespan(
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[int, depends], annotation=Annotated[int, depends],
expected_value=1 expected_value=1,
) )
if routing_style == "router_endpoint": if routing_style == "router_endpoint":
@ -816,12 +819,9 @@ def test_endpoints_report_uninitialized_internal_lifespan(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_bad_lifespan_scoped_dependencies( def test_bad_lifespan_scoped_dependencies(
use_cache, use_cache, dependency_style: DependencyStyle, routing_style, is_websocket
dependency_style: DependencyStyle,
routing_style,
is_websocket
): ):
dependency_factory= DependencyFactory(dependency_style, should_error=True) dependency_factory = DependencyFactory(dependency_style, should_error=True)
depends = Depends( depends = Depends(
dependency_factory.get_dependency(), dependency_factory.get_dependency(),
dependency_scope="lifespan", dependency_scope="lifespan",
@ -841,7 +841,7 @@ def test_bad_lifespan_scoped_dependencies(
path="/test", path="/test",
is_websocket=is_websocket, is_websocket=is_websocket,
annotation=Annotated[int, depends], annotation=Annotated[int, depends],
expected_value=1 expected_value=1,
) )
if routing_style == "router_endpoint": if routing_style == "router_endpoint":

97
tests/test_lifespan_scoped_dependencies/testing_utilities.py

@ -5,7 +5,7 @@ from fastapi import APIRouter, FastAPI, WebSocket
from starlette.testclient import TestClient from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect from starlette.websockets import WebSocketDisconnect
T = TypeVar('T') T = TypeVar("T")
class DependencyStyle(StrEnum): class DependencyStyle(StrEnum):
@ -21,10 +21,11 @@ class IntentionallyBadDependency(Exception):
class DependencyFactory: class DependencyFactory:
def __init__( def __init__(
self, self,
dependency_style: DependencyStyle, *, dependency_style: DependencyStyle,
should_error: bool = False, *,
value_offset: int = 0, should_error: bool = False,
value_offset: int = 0,
): ):
self.activation_times = 0 self.activation_times = 0
self.deactivation_times = 0 self.deactivation_times = 0
@ -90,12 +91,13 @@ def use_websocket(client: TestClient, url: str) -> Any:
def create_endpoint_0_annotations( def create_endpoint_0_annotations(
*, *,
router: Union[APIRouter, FastAPI], router: Union[APIRouter, FastAPI],
path: str, path: str,
is_websocket: bool, is_websocket: bool,
) -> None: ) -> None:
if is_websocket: if is_websocket:
@router.websocket(path) @router.websocket(path)
async def endpoint(websocket: WebSocket) -> None: async def endpoint(websocket: WebSocket) -> None:
await websocket.accept() await websocket.accept()
@ -104,25 +106,24 @@ def create_endpoint_0_annotations(
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass
else: else:
@router.post(path) @router.post(path)
async def endpoint() -> None: async def endpoint() -> None:
return None return None
def create_endpoint_1_annotation( def create_endpoint_1_annotation(
*, *,
router: Union[APIRouter, FastAPI], router: Union[APIRouter, FastAPI],
path: str, path: str,
is_websocket: bool, is_websocket: bool,
annotation: Any, annotation: Any,
expected_value: Any = None expected_value: Any = None,
) -> None: ) -> None:
if is_websocket: if is_websocket:
@router.websocket(path) @router.websocket(path)
async def endpoint( async def endpoint(websocket: WebSocket, value: annotation) -> None:
websocket: WebSocket,
value: annotation
) -> None:
if expected_value is not None: if expected_value is not None:
assert value == expected_value assert value == expected_value
@ -132,29 +133,30 @@ def create_endpoint_1_annotation(
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass
else: else:
@router.post(path) @router.post(path)
async def endpoint( async def endpoint(value: annotation) -> None:
value: annotation
) -> None:
if expected_value is not None: if expected_value is not None:
assert value == expected_value assert value == expected_value
return value return value
def create_endpoint_2_annotations( def create_endpoint_2_annotations(
*, *,
router: Union[APIRouter, FastAPI], router: Union[APIRouter, FastAPI],
path: str, path: str,
is_websocket: bool, is_websocket: bool,
annotation1: Any, annotation1: Any,
annotation2: Any, annotation2: Any,
) -> None: ) -> None:
if is_websocket: if is_websocket:
@router.websocket(path) @router.websocket(path)
async def endpoint( async def endpoint(
websocket: WebSocket, websocket: WebSocket,
value1: annotation1, value1: annotation1,
value2: annotation2, value2: annotation2,
) -> None: ) -> None:
await websocket.accept() await websocket.accept()
try: try:
@ -162,30 +164,32 @@ def create_endpoint_2_annotations(
except WebSocketDisconnect: except WebSocketDisconnect:
await websocket.close() await websocket.close()
else: else:
@router.post(path) @router.post(path)
async def endpoint( async def endpoint(
value1: annotation1, value1: annotation1,
value2: annotation2, value2: annotation2,
) -> list[Any]: ) -> list[Any]:
return [value1, value2] return [value1, value2]
def create_endpoint_3_annotations( def create_endpoint_3_annotations(
*, *,
router: Union[APIRouter, FastAPI], router: Union[APIRouter, FastAPI],
path: str, path: str,
is_websocket: bool, is_websocket: bool,
annotation1: Any, annotation1: Any,
annotation2: Any, annotation2: Any,
annotation3: Any annotation3: Any,
) -> None: ) -> None:
if is_websocket: if is_websocket:
@router.websocket(path) @router.websocket(path)
async def endpoint( async def endpoint(
websocket: WebSocket, websocket: WebSocket,
value1: annotation1, value1: annotation1,
value2: annotation2, value2: annotation2,
value3: annotation3 value3: annotation3,
) -> None: ) -> None:
await websocket.accept() await websocket.accept()
try: try:
@ -193,10 +197,9 @@ def create_endpoint_3_annotations(
except WebSocketDisconnect: except WebSocketDisconnect:
await websocket.close() await websocket.close()
else: else:
@router.post(path) @router.post(path)
async def endpoint( async def endpoint(
value1: annotation1, value1: annotation1, value2: annotation2, value3: annotation3
value2: annotation2,
value3: annotation3
) -> list[Any]: ) -> list[Any]:
return [value1, value2, value3] return [value1, value2, value3]

Loading…
Cancel
Save