Browse Source

Applied ruff linting

pull/12529/head
Nir Schulman 9 months ago
parent
commit
d512b03bd3
  1. 18
      fastapi/dependencies/models.py
  2. 68
      fastapi/dependencies/utils.py
  3. 18
      fastapi/routing.py
  4. 280
      tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py
  5. 334
      tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py
  6. 97
      tests/test_lifespan_scoped_dependencies/testing_utilities.py

18
fastapi/dependencies/models.py

@ -12,7 +12,10 @@ class SecurityRequirement:
scopes: Optional[Sequence[str]] = None
LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], Union[str, int]], Callable[..., Any]]
LifespanDependantCacheKey: TypeAlias = Union[
Tuple[Callable[..., Any], Union[str, int]], Callable[..., Any]
]
@dataclass
class LifespanDependant:
@ -30,9 +33,9 @@ class LifespanDependant:
elif self.name is not None:
self.cache_key = (self.caller, self.name)
else:
assert self.index is not None, (
"Lifespan dependency must have an associated name or index."
)
assert (
self.index is not None
), "Lifespan dependency must have an associated name or index."
self.cache_key = (self.caller, self.index)
@ -49,8 +52,7 @@ class EndpointDependant:
call: Optional[Callable[..., Any]] = None
use_cache: bool = True
index: Optional[int] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(
init=False)
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
path_params: List[ModelField] = field(default_factory=list)
query_params: List[ModelField] = field(default_factory=list)
header_params: List[ModelField] = field(default_factory=list)
@ -74,11 +76,11 @@ class EndpointDependant:
def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]:
lifespan_dependencies = cast(
List[Union[EndpointDependant, LifespanDependant]],
self.lifespan_dependencies
self.lifespan_dependencies,
)
endpoint_dependencies = cast(
List[Union[EndpointDependant, LifespanDependant]],
self.endpoint_dependencies
self.endpoint_dependencies,
)
return tuple(lifespan_dependencies + endpoint_dependencies)

68
fastapi/dependencies/utils.py

@ -147,11 +147,7 @@ def get_param_sub_dependant(
def get_parameterless_sub_dependant(
*,
depends: params.Depends,
path: str,
caller: Callable[..., Any],
index: int
*, depends: params.Depends, path: str, caller: Callable[..., Any], index: int
) -> Union[EndpointDependant, LifespanDependant]:
assert callable(
depends.dependency
@ -161,7 +157,7 @@ def get_parameterless_sub_dependant(
dependency=depends.dependency,
path=path,
caller=caller,
index=index
index=index,
)
@ -181,7 +177,7 @@ def get_sub_dependant(
call=dependency,
name=name,
use_cache=depends.use_cache,
index=index
index=index,
)
elif depends.dependency_scope == "endpoint":
security_requirement = None
@ -202,15 +198,15 @@ def get_sub_dependant(
name=name,
security_scopes=security_scopes,
use_cache=depends.use_cache,
index=index
index=index,
)
if security_requirement:
sub_dependant.security_requirements.append(security_requirement)
return sub_dependant
else:
raise InvalidDependencyScope(
f"Dependency \"{name}\" of {caller} has an invalid "
f"scope: \"{depends.dependency_scope}\""
f'Dependency "{name}" of {caller} has an invalid '
f'scope: "{depends.dependency_scope}"'
)
@ -233,7 +229,7 @@ def get_flat_dependant(
security_requirements=dependant.security_requirements.copy(),
lifespan_dependencies=dependant.lifespan_dependencies.copy(),
use_cache=dependant.use_cache,
path=dependant.path
path=dependant.path,
)
for sub_dependant in dependant.endpoint_dependencies:
if skip_repeats and sub_dependant.cache_key in visited:
@ -310,16 +306,12 @@ def get_lifespan_dependant(
call: Callable[..., Any],
name: Optional[str] = None,
use_cache: bool = True,
index: Optional[int] = None
index: Optional[int] = None,
) -> LifespanDependant:
dependency_signature = get_typed_signature(call)
signature_params = dependency_signature.parameters
dependant = LifespanDependant(
call=call,
name=name,
use_cache=use_cache,
caller=caller,
index=index
call=call, name=name, use_cache=use_cache, caller=caller, index=index
)
for param_name, param in signature_params.items():
param_details = analyze_param(
@ -330,16 +322,17 @@ def get_lifespan_dependant(
)
if param_details.depends is None:
raise DependencyScopeConflict(
f"Lifespan scoped dependency \"{dependant.name}\" was defined "
f"with an invalid argument: \"{param_name}\" which is "
f"\"endpoint\" scoped. Lifespan scoped dependencies may only "
f"use lifespan scoped sub-dependencies.")
f'Lifespan scoped dependency "{dependant.name}" was defined '
f'with an invalid argument: "{param_name}" which is '
f'"endpoint" scoped. Lifespan scoped dependencies may only '
f"use lifespan scoped sub-dependencies."
)
if param_details.depends.dependency_scope != "lifespan":
raise DependencyScopeConflict(
f"Lifespan scoped dependency {dependant.name} was defined with the "
f"sub-dependency \"{param_name}\" which is "
f"\"{param_details.depends.dependency_scope}\" scoped. "
f'sub-dependency "{param_name}" which is '
f'"{param_details.depends.dependency_scope}" scoped. '
f"Lifespan scoped dependencies may only use lifespan scoped "
f"sub-dependencies."
)
@ -350,7 +343,7 @@ def get_lifespan_dependant(
name=param_name,
call=param_details.depends.dependency,
use_cache=param_details.depends.use_cache,
caller=call
caller=call,
)
dependant.dependencies.append(sub_dependant)
@ -364,7 +357,7 @@ def get_endpoint_dependant(
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
index: Optional[int] = None
index: Optional[int] = None,
) -> EndpointDependant:
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
@ -375,7 +368,7 @@ def get_endpoint_dependant(
path=path,
security_scopes=security_scopes,
use_cache=use_cache,
index=index
index=index,
)
for param_name, param in signature_params.items():
is_path_param = param_name in path_param_names
@ -692,14 +685,12 @@ async def solve_lifespan_dependant(
call = dependant.call
dependant_to_solve = dependant
if (
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
):
call = getattr(
dependency_overrides_provider,
"dependency_overrides",
{}
).get(dependant.call, dependant.call)
call = getattr(dependency_overrides_provider, "dependency_overrides", {}).get(
dependant.call, dependant.call
)
dependant_to_solve = get_lifespan_dependant(
caller=dependant.caller,
call=call,
@ -725,9 +716,7 @@ async def solve_lifespan_dependant(
if is_gen_callable(call) or is_async_gen_callable(call):
value = await solve_generator(
call=call,
stack=async_exit_stack,
sub_values=dependency_arguments
call=call, stack=async_exit_stack, sub_values=dependency_arguments
)
elif is_coroutine_callable(call):
value = await call(**dependency_arguments)
@ -773,7 +762,8 @@ async def solve_dependencies(
try:
lifespan_scoped_dependencies = request.state.__fastapi__[
"lifespan_scoped_dependencies"]
"lifespan_scoped_dependencies"
]
except (AttributeError, KeyError) as e:
raise UninitializedLifespanDependency(
"FastAPI's internal lifespan was not initialized correctly."
@ -783,8 +773,8 @@ async def solve_dependencies(
value = lifespan_scoped_dependencies[lifespan_sub_dependant.cache_key]
except KeyError as e:
raise UninitializedLifespanDependency(
f"Dependency \"{lifespan_sub_dependant.name}\" of "
f"`{dependant.call}` was not initialized correctly."
f'Dependency "{lifespan_sub_dependant.name}" of '
f"`{dependant.call}` was not initialized correctly."
) from e
values[lifespan_sub_dependant.name] = value

18
fastapi/routing.py

@ -400,13 +400,12 @@ class APIWebSocketRoute(routing.WebSocketRoute):
self.name = get_name(endpoint) if name is None else name
self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint)
self.dependant = get_endpoint_dependant(
path=self.path_format, call=self.endpoint
)
for i, depends in list(enumerate(self.dependencies))[::-1]:
sub_dependant = get_parameterless_sub_dependant(
depends=depends,
path=self.path_format,
caller=self.__call__,
index=i
depends=depends, path=self.path_format, caller=self.__call__, index=i
)
if depends.dependency_scope == "endpoint":
assert isinstance(sub_dependant, EndpointDependant)
@ -566,13 +565,12 @@ class APIRoute(routing.Route):
self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable"
self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint)
self.dependant = get_endpoint_dependant(
path=self.path_format, call=self.endpoint
)
for i, depends in list(enumerate(self.dependencies))[::-1]:
sub_dependant = get_parameterless_sub_dependant(
depends=depends,
path=self.path_format,
caller=self.__call__,
index=i
depends=depends, path=self.path_format, caller=self.__call__, index=i
)
if depends.dependency_scope == "endpoint":
assert isinstance(sub_dependant, EndpointDependant)

280
tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py

@ -33,13 +33,13 @@ from tests.test_lifespan_scoped_dependencies.testing_utilities import (
def expect_correct_amount_of_dependency_activations(
*,
app: FastAPI,
dependency_factory: DependencyFactory,
override_dependency_factory: DependencyFactory,
urls_and_responses: List[Tuple[str, Any]],
expected_activation_times: int,
is_websocket: bool
*,
app: FastAPI,
dependency_factory: DependencyFactory,
override_dependency_factory: DependencyFactory,
urls_and_responses: List[Tuple[str, Any]],
expected_activation_times: int,
is_websocket: bool,
) -> None:
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
@ -62,17 +62,22 @@ def expect_correct_amount_of_dependency_activations(
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
assert override_dependency_factory.activation_times == expected_activation_times
assert (
override_dependency_factory.activation_times
== expected_activation_times
)
assert override_dependency_factory.deactivation_times == 0
assert dependency_factory.activation_times == 0
assert override_dependency_factory.activation_times == expected_activation_times
if dependency_factory.dependency_style not in (
DependencyStyle.SYNC_FUNCTION,
DependencyStyle.ASYNC_FUNCTION
DependencyStyle.SYNC_FUNCTION,
DependencyStyle.ASYNC_FUNCTION,
):
assert dependency_factory.deactivation_times == 0
assert override_dependency_factory.deactivation_times == expected_activation_times
assert (
override_dependency_factory.deactivation_times == expected_activation_times
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@ -80,16 +85,10 @@ def expect_correct_amount_of_dependency_activations(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoint_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(
dependency_style,
value_offset=10
)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
app = FastAPI()
@ -108,14 +107,16 @@ def test_endpoint_dependencies(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
),
],
expected_value=11
expected_value=11,
)
if routing_style == "router_endpoint":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency()
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations(
app=app,
@ -123,7 +124,7 @@ def test_endpoint_dependencies(
override_dependency_factory=override_dependency_factory,
urls_and_responses=[("/test", 11)] * 2,
expected_activation_times=1,
is_websocket=is_websocket
is_websocket=is_websocket,
)
@ -133,45 +134,40 @@ def test_endpoint_dependencies(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_router_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache,
dependency_duplication,
is_websocket
dependency_style: DependencyStyle,
routing_style,
use_cache,
dependency_duplication,
is_websocket,
):
dependency_factory= DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(
dependency_style,
value_offset=10
)
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
use_cache=use_cache,
)
if routing_style == "app":
app = FastAPI(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations(
router=app,
path="/test",
is_websocket=is_websocket
router=app, path="/test", is_websocket=is_websocket
)
else:
app = FastAPI()
router = APIRouter(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations(
router=router,
path="/test",
is_websocket=is_websocket
router=router, path="/test", is_websocket=is_websocket
)
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency()
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations(
app=app,
@ -179,31 +175,29 @@ def test_router_dependencies(
override_dependency_factory=override_dependency_factory,
urls_and_responses=[("/test", None)] * 2,
expected_activation_times=1 if use_cache else dependency_duplication,
is_websocket=is_websocket
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"])
def test_dependency_cache_in_same_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache,
main_dependency_scope: Literal["endpoint", "lifespan"],
is_websocket
dependency_style: DependencyStyle,
routing_style,
use_cache,
main_dependency_scope: Literal["endpoint", "lifespan"],
is_websocket,
):
dependency_factory= DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(
dependency_style,
value_offset=10
)
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
use_cache=use_cache,
)
app = FastAPI()
@ -215,8 +209,8 @@ def test_dependency_cache_in_same_dependency(
router = APIRouter()
async def dependency(
sub_dependency1: Annotated[int, depends],
sub_dependency2: Annotated[int, depends],
sub_dependency1: Annotated[int, depends],
sub_dependency2: Annotated[int, depends],
) -> List[int]:
return [sub_dependency1, sub_dependency2]
@ -224,19 +218,22 @@ def test_dependency_cache_in_same_dependency(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[List[int], Depends(
dependency,
use_cache=use_cache,
dependency_scope=main_dependency_scope,
)]
annotation=Annotated[
List[int],
Depends(
dependency,
use_cache=use_cache,
dependency_scope=main_dependency_scope,
),
],
)
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[
dependency_factory.get_dependency()
] = override_dependency_factory.get_dependency()
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
if use_cache:
expect_correct_amount_of_dependency_activations(
@ -248,7 +245,7 @@ def test_dependency_cache_in_same_dependency(
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket
is_websocket=is_websocket,
)
else:
expect_correct_amount_of_dependency_activations(
@ -260,7 +257,7 @@ def test_dependency_cache_in_same_dependency(
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=2,
is_websocket=is_websocket
is_websocket=is_websocket,
)
@ -269,21 +266,15 @@ def test_dependency_cache_in_same_dependency(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_same_endpoint(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(
dependency_style,
value_offset=10
)
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
use_cache=use_cache,
)
app = FastAPI()
@ -309,9 +300,9 @@ def test_dependency_cache_in_same_endpoint(
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[
dependency_factory.get_dependency()
] = override_dependency_factory.get_dependency()
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
if use_cache:
expect_correct_amount_of_dependency_activations(
@ -323,7 +314,7 @@ def test_dependency_cache_in_same_endpoint(
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket
is_websocket=is_websocket,
)
else:
expect_correct_amount_of_dependency_activations(
@ -335,29 +326,24 @@ def test_dependency_cache_in_same_endpoint(
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=3,
is_websocket=is_websocket
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_different_endpoints(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(
dependency_style,
value_offset=10
)
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
use_cache=use_cache,
)
app = FastAPI()
@ -392,8 +378,9 @@ def test_dependency_cache_in_different_endpoints(
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[
dependency_factory.get_dependency()] = override_dependency_factory.get_dependency()
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
if use_cache:
expect_correct_amount_of_dependency_activations(
@ -407,7 +394,7 @@ def test_dependency_cache_in_different_endpoints(
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket
is_websocket=is_websocket,
)
else:
expect_correct_amount_of_dependency_activations(
@ -421,27 +408,23 @@ def test_dependency_cache_in_different_endpoints(
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=5,
is_websocket=is_websocket
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_no_cached_dependency(
dependency_style: DependencyStyle,
routing_style,
is_websocket
dependency_style: DependencyStyle, routing_style, is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(
dependency_style,
value_offset=10
)
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=False
use_cache=False,
)
app = FastAPI()
@ -462,8 +445,9 @@ def test_no_cached_dependency(
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[
dependency_factory.get_dependency()] = override_dependency_factory.get_dependency()
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations(
app=app,
@ -471,24 +455,27 @@ def test_no_cached_dependency(
override_dependency_factory=override_dependency_factory,
urls_and_responses=[("/test", 11)] * 2,
expected_activation_times=1,
is_websocket=is_websocket
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("annotation", [
Annotated[str, Path()],
Annotated[str, Body()],
Annotated[str, Query()],
Annotated[str, Header()],
SecurityScopes,
Annotated[str, Cookie()],
Annotated[str, Form()],
Annotated[str, File()],
BackgroundTasks,
])
@pytest.mark.parametrize(
"annotation",
[
Annotated[str, Path()],
Annotated[str, Body()],
Annotated[str, Query()],
Annotated[str, Header()],
SecurityScopes,
Annotated[str, Cookie()],
Annotated[str, Form()],
Annotated[str, File()],
BackgroundTasks,
],
)
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
annotation,
is_websocket
annotation, is_websocket
):
async def dependency_func() -> None:
yield
@ -503,9 +490,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_paramete
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[None,
Depends(dependency_func, dependency_scope="lifespan")
]
annotation=Annotated[
None, Depends(dependency_func, dependency_scope="lifespan")
],
)
with pytest.raises(DependencyScopeConflict):
@ -516,20 +503,16 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_paramete
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies(
dependency_style: DependencyStyle,
is_websocket
dependency_style: DependencyStyle, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(
dependency_style,
value_offset=10
)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
async def lifespan_scoped_dependency(
param: Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)]
param: Annotated[
int,
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"),
],
) -> AsyncGenerator[int, None]:
yield param
@ -540,12 +523,13 @@ def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_sco
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
int,
Depends(lifespan_scoped_dependency, dependency_scope="lifespan")
int, Depends(lifespan_scoped_dependency, dependency_scope="lifespan")
],
)
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency()
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations(
app=app,
@ -553,15 +537,14 @@ def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_sco
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
urls_and_responses=[("/test", 11)] * 2,
is_websocket=is_websocket
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("depends_class", [Depends, Security])
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
depends_class,
is_websocket
depends_class, is_websocket
):
async def sub_dependency() -> None:
pass
@ -569,7 +552,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen
async def dependency_func() -> None:
yield
async def override_dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None:
async def override_dependency_func(
param: Annotated[None, depends_class(sub_dependency)],
) -> None:
yield
app = FastAPI()
@ -578,7 +563,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")]
annotation=Annotated[
None, Depends(dependency_func, dependency_scope="lifespan")
],
)
app.dependency_overrides[dependency_func] = override_dependency_func
@ -593,12 +580,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_bad_override_lifespan_scoped_dependencies(
use_cache,
dependency_style: DependencyStyle,
routing_style,
is_websocket
use_cache, dependency_style: DependencyStyle, routing_style, is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, should_error=True)
depends = Depends(
@ -619,13 +603,15 @@ def test_bad_override_lifespan_scoped_dependencies(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends]
annotation=Annotated[int, depends],
)
if routing_style == "router_endpoint":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency()
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
with pytest.raises(IntentionallyBadDependency) as exception_info:
with TestClient(app):

334
tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py

@ -40,12 +40,12 @@ from tests.test_lifespan_scoped_dependencies.testing_utilities import (
def expect_correct_amount_of_dependency_activations(
*,
app: FastAPI,
dependency_factory: DependencyFactory,
urls_and_responses: List[Tuple[str, Any]],
expected_activation_times: int,
is_websocket: bool
*,
app: FastAPI,
dependency_factory: DependencyFactory,
urls_and_responses: List[Tuple[str, Any]],
expected_activation_times: int,
is_websocket: bool,
) -> None:
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
@ -64,21 +64,23 @@ def expect_correct_amount_of_dependency_activations(
assert dependency_factory.activation_times == expected_activation_times
if dependency_factory.dependency_style not in (
DependencyStyle.SYNC_FUNCTION,
DependencyStyle.ASYNC_FUNCTION
DependencyStyle.SYNC_FUNCTION,
DependencyStyle.ASYNC_FUNCTION,
):
assert dependency_factory.deactivation_times == expected_activation_times
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False], ids=["With Cache", "Without Cache"])
@pytest.mark.parametrize(
"use_cache", [True, False], ids=["With Cache", "Without Cache"]
)
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoint_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket: bool,
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket: bool,
):
dependency_factory = DependencyFactory(dependency_style)
@ -93,12 +95,15 @@ def test_endpoint_dependencies(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[None, Depends(
annotation=Annotated[
None,
Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)],
expected_value=1
),
],
expected_value=1,
)
if routing_style == "router_endpoint":
@ -109,45 +114,42 @@ def test_endpoint_dependencies(
dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1,
is_websocket=is_websocket
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_duplication", [1, 2])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_router_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache,
dependency_duplication,
is_websocket: bool,
dependency_style: DependencyStyle,
routing_style,
use_cache,
dependency_duplication,
is_websocket: bool,
):
dependency_factory= DependencyFactory(dependency_style)
dependency_factory = DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
use_cache=use_cache,
)
if routing_style == "app":
app = FastAPI(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations(
router=app,
path="/test",
is_websocket=is_websocket
router=app, path="/test", is_websocket=is_websocket
)
else:
app = FastAPI()
router = APIRouter(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations(
router=router,
path="/test",
is_websocket=is_websocket
router=router, path="/test", is_websocket=is_websocket
)
app.include_router(router)
@ -157,27 +159,28 @@ def test_router_dependencies(
dependency_factory=dependency_factory,
urls_and_responses=[("/test", None)] * 2,
expected_activation_times=1 if use_cache else dependency_duplication,
is_websocket=is_websocket
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"])
def test_dependency_cache_in_same_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache,
main_dependency_scope: Literal["endpoint", "lifespan"],
is_websocket: bool,
dependency_style: DependencyStyle,
routing_style,
use_cache,
main_dependency_scope: Literal["endpoint", "lifespan"],
is_websocket: bool,
):
dependency_factory= DependencyFactory(dependency_style)
dependency_factory = DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
use_cache=use_cache,
)
app = FastAPI()
@ -189,8 +192,8 @@ def test_dependency_cache_in_same_dependency(
router = APIRouter()
async def dependency(
sub_dependency1: Annotated[int, depends],
sub_dependency2: Annotated[int, depends],
sub_dependency1: Annotated[int, depends],
sub_dependency2: Annotated[int, depends],
) -> List[int]:
return [sub_dependency1, sub_dependency2]
@ -198,11 +201,14 @@ def test_dependency_cache_in_same_dependency(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[List[int], Depends(
dependency,
use_cache=use_cache,
dependency_scope=main_dependency_scope,
)]
annotation=Annotated[
List[int],
Depends(
dependency,
use_cache=use_cache,
dependency_scope=main_dependency_scope,
),
],
)
if routing_style == "router":
@ -217,7 +223,7 @@ def test_dependency_cache_in_same_dependency(
],
dependency_factory=dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket
is_websocket=is_websocket,
)
else:
expect_correct_amount_of_dependency_activations(
@ -228,7 +234,7 @@ def test_dependency_cache_in_same_dependency(
],
dependency_factory=dependency_factory,
expected_activation_times=2,
is_websocket=is_websocket
is_websocket=is_websocket,
)
@ -237,17 +243,14 @@ def test_dependency_cache_in_same_dependency(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_same_endpoint(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
dependency_factory = DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
use_cache=use_cache,
)
app = FastAPI()
@ -267,7 +270,7 @@ def test_dependency_cache_in_same_endpoint(
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)]
annotation3=Annotated[int, Depends(endpoint_dependency)],
)
if routing_style == "router":
@ -282,7 +285,7 @@ def test_dependency_cache_in_same_endpoint(
],
dependency_factory=dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket
is_websocket=is_websocket,
)
else:
expect_correct_amount_of_dependency_activations(
@ -293,25 +296,23 @@ def test_dependency_cache_in_same_endpoint(
],
dependency_factory=dependency_factory,
expected_activation_times=3,
is_websocket=is_websocket
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_different_endpoints(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
dependency_factory = DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
use_cache=use_cache,
)
app = FastAPI()
@ -331,7 +332,7 @@ def test_dependency_cache_in_different_endpoints(
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)]
annotation3=Annotated[int, Depends(endpoint_dependency)],
)
create_endpoint_3_annotations(
@ -340,7 +341,7 @@ def test_dependency_cache_in_different_endpoints(
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)]
annotation3=Annotated[int, Depends(endpoint_dependency)],
)
if routing_style == "router":
@ -357,7 +358,7 @@ def test_dependency_cache_in_different_endpoints(
],
dependency_factory=dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket
is_websocket=is_websocket,
)
else:
expect_correct_amount_of_dependency_activations(
@ -370,23 +371,24 @@ def test_dependency_cache_in_different_endpoints(
],
dependency_factory=dependency_factory,
expected_activation_times=5,
is_websocket=is_websocket
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_no_cached_dependency(
dependency_style: DependencyStyle,
routing_style,
is_websocket,
dependency_style: DependencyStyle,
routing_style,
is_websocket,
):
dependency_factory= DependencyFactory(dependency_style)
dependency_factory = DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=False
use_cache=False,
)
app = FastAPI()
@ -402,7 +404,7 @@ def test_no_cached_dependency(
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
expected_value=1
expected_value=1,
)
if routing_style == "router":
@ -413,25 +415,27 @@ def test_no_cached_dependency(
dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1,
is_websocket=is_websocket
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("annotation", [
Annotated[str, Path()],
Annotated[str, Body()],
Annotated[str, Query()],
Annotated[str, Header()],
SecurityScopes,
Annotated[str, Cookie()],
Annotated[str, Form()],
Annotated[str, File()],
BackgroundTasks,
])
@pytest.mark.parametrize(
"annotation",
[
Annotated[str, Path()],
Annotated[str, Body()],
Annotated[str, Query()],
Annotated[str, Header()],
SecurityScopes,
Annotated[str, Cookie()],
Annotated[str, Form()],
Annotated[str, File()],
BackgroundTasks,
],
)
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
annotation,
is_websocket
annotation, is_websocket
):
async def dependency_func(param: annotation) -> None:
yield
@ -444,8 +448,7 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
None,
Depends(dependency_func, dependency_scope="lifespan")
None, Depends(dependency_func, dependency_scope="lifespan")
],
)
@ -453,16 +456,15 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies(
dependency_style: DependencyStyle,
is_websocket
dependency_style: DependencyStyle, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
async def lifespan_scoped_dependency(
param: Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)]
param: Annotated[
int,
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"),
],
) -> AsyncGenerator[int, None]:
yield param
@ -473,7 +475,7 @@ def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies(
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, Depends(lifespan_scoped_dependency)],
expected_value=1
expected_value=1,
)
expect_correct_amount_of_dependency_activations(
@ -481,24 +483,22 @@ def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies(
dependency_factory=dependency_factory,
expected_activation_times=1,
urls_and_responses=[("/test", 1)] * 2,
is_websocket=is_websocket
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize([
"dependency_style",
"supports_teardown"
], [
(DependencyStyle.SYNC_FUNCTION, False),
(DependencyStyle.ASYNC_FUNCTION, False),
(DependencyStyle.SYNC_GENERATOR, True),
(DependencyStyle.ASYNC_GENERATOR, True),
])
@pytest.mark.parametrize(
["dependency_style", "supports_teardown"],
[
(DependencyStyle.SYNC_FUNCTION, False),
(DependencyStyle.ASYNC_FUNCTION, False),
(DependencyStyle.SYNC_GENERATOR, True),
(DependencyStyle.ASYNC_GENERATOR, True),
],
)
def test_the_same_dependency_can_work_in_different_scopes(
dependency_style: DependencyStyle,
supports_teardown,
is_websocket
dependency_style: DependencyStyle, supports_teardown, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
app = FastAPI()
@ -507,14 +507,14 @@ def test_the_same_dependency_can_work_in_different_scopes(
router=app,
path="/test",
is_websocket=is_websocket,
annotation1=Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="endpoint"
)],
annotation2=Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)],
annotation1=Annotated[
int,
Depends(dependency_factory.get_dependency(), dependency_scope="endpoint"),
],
annotation2=Annotated[
int,
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"),
],
)
if is_websocket:
get_response = use_websocket
@ -548,17 +548,20 @@ def test_the_same_dependency_can_work_in_different_scopes(
assert dependency_factory.deactivation_times == 0
@pytest.mark.parametrize("lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"])
@pytest.mark.parametrize(
"lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"]
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
dependency_style: DependencyStyle,
is_websocket,
lifespan_style: Literal["lifespan_function", "lifespan_events"]
dependency_style: DependencyStyle,
is_websocket,
lifespan_style: Literal["lifespan_function", "lifespan_events"],
):
lifespan_started = False
lifespan_ended = False
if lifespan_style == "lifespan_generator":
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, int], None]:
nonlocal lifespan_started
@ -571,6 +574,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
elif lifespan_style == "events_decorator":
app = FastAPI()
with warnings.catch_warnings(action="ignore", category=DeprecationWarning):
@app.on_event("startup")
async def startup() -> None:
nonlocal lifespan_started
@ -581,6 +585,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
nonlocal lifespan_ended
lifespan_ended = True
elif lifespan_style == "events_constructor":
async def startup() -> None:
nonlocal lifespan_started
lifespan_started = True
@ -588,6 +593,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
async def shutdown() -> None:
nonlocal lifespan_ended
lifespan_ended = True
app = FastAPI(on_startup=[startup], on_shutdown=[shutdown])
else:
assert_never(lifespan_style)
@ -598,11 +604,11 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)],
expected_value=1
annotation=Annotated[
int,
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"),
],
expected_value=1,
)
expect_correct_amount_of_dependency_activations(
@ -610,20 +616,22 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
dependency_factory=dependency_factory,
expected_activation_times=1,
urls_and_responses=[("/test", 1)] * 2,
is_websocket=is_websocket
is_websocket=is_websocket,
)
assert lifespan_started and lifespan_ended
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("depends_class", [Depends, Security])
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
depends_class,
is_websocket
depends_class, is_websocket
):
async def sub_dependency() -> None:
pass
async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None:
async def dependency_func(
param: Annotated[None, depends_class(sub_dependency)],
) -> None:
yield
app = FastAPI()
@ -633,20 +641,20 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")],
annotation=Annotated[
None, Depends(dependency_func, dependency_scope="lifespan")
],
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_dependencies_must_provide_correct_dependency_scope(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
dependency_factory = DependencyFactory(dependency_style)
app = FastAPI()
@ -656,19 +664,21 @@ def test_dependencies_must_provide_correct_dependency_scope(
router = APIRouter()
with pytest.raises(
InvalidDependencyScope,
match=r'Dependency "value" of .* has an invalid scope: '
r'"incorrect"'
InvalidDependencyScope,
match=r'Dependency "value" of .* has an invalid scope: ' r'"incorrect"',
):
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[None, Depends(
dependency_factory.get_dependency(),
dependency_scope="incorrect",
use_cache=use_cache,
)]
annotation=Annotated[
None,
Depends(
dependency_factory.get_dependency(),
dependency_scope="incorrect",
use_cache=use_cache,
),
],
)
@ -677,12 +687,9 @@ def test_dependencies_must_provide_correct_dependency_scope(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_incorrect_dependency_scope(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
dependency_factory = DependencyFactory(dependency_style)
app = FastAPI()
@ -705,7 +712,7 @@ def test_endpoints_report_incorrect_dependency_scope(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends]
annotation=Annotated[int, depends],
)
@ -714,12 +721,9 @@ def test_endpoints_report_incorrect_dependency_scope(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
dependency_factory = DependencyFactory(dependency_style)
app = FastAPI()
@ -739,7 +743,7 @@ def test_endpoints_report_uninitialized_dependency(
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
expected_value=1
expected_value=1,
)
if routing_style == "router_endpoint":
@ -757,7 +761,9 @@ def test_endpoints_report_uninitialized_dependency(
else:
client.post("/test")
finally:
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = (
dependencies
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@ -765,10 +771,7 @@ def test_endpoints_report_uninitialized_dependency(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_internal_lifespan(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
@ -790,7 +793,7 @@ def test_endpoints_report_uninitialized_internal_lifespan(
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
expected_value=1
expected_value=1,
)
if routing_style == "router_endpoint":
@ -816,12 +819,9 @@ def test_endpoints_report_uninitialized_internal_lifespan(
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_bad_lifespan_scoped_dependencies(
use_cache,
dependency_style: DependencyStyle,
routing_style,
is_websocket
use_cache, dependency_style: DependencyStyle, routing_style, is_websocket
):
dependency_factory= DependencyFactory(dependency_style, should_error=True)
dependency_factory = DependencyFactory(dependency_style, should_error=True)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
@ -841,7 +841,7 @@ def test_bad_lifespan_scoped_dependencies(
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
expected_value=1
expected_value=1,
)
if routing_style == "router_endpoint":

97
tests/test_lifespan_scoped_dependencies/testing_utilities.py

@ -5,7 +5,7 @@ from fastapi import APIRouter, FastAPI, WebSocket
from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
T = TypeVar('T')
T = TypeVar("T")
class DependencyStyle(StrEnum):
@ -21,10 +21,11 @@ class IntentionallyBadDependency(Exception):
class DependencyFactory:
def __init__(
self,
dependency_style: DependencyStyle, *,
should_error: bool = False,
value_offset: int = 0,
self,
dependency_style: DependencyStyle,
*,
should_error: bool = False,
value_offset: int = 0,
):
self.activation_times = 0
self.deactivation_times = 0
@ -90,12 +91,13 @@ def use_websocket(client: TestClient, url: str) -> Any:
def create_endpoint_0_annotations(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(websocket: WebSocket) -> None:
await websocket.accept()
@ -104,25 +106,24 @@ def create_endpoint_0_annotations(
except WebSocketDisconnect:
pass
else:
@router.post(path)
async def endpoint() -> None:
return None
def create_endpoint_1_annotation(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
annotation: Any,
expected_value: Any = None
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
annotation: Any,
expected_value: Any = None,
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(
websocket: WebSocket,
value: annotation
) -> None:
async def endpoint(websocket: WebSocket, value: annotation) -> None:
if expected_value is not None:
assert value == expected_value
@ -132,29 +133,30 @@ def create_endpoint_1_annotation(
except WebSocketDisconnect:
pass
else:
@router.post(path)
async def endpoint(
value: annotation
) -> None:
async def endpoint(value: annotation) -> None:
if expected_value is not None:
assert value == expected_value
return value
def create_endpoint_2_annotations(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
annotation1: Any,
annotation2: Any,
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
annotation1: Any,
annotation2: Any,
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(
websocket: WebSocket,
value1: annotation1,
value2: annotation2,
websocket: WebSocket,
value1: annotation1,
value2: annotation2,
) -> None:
await websocket.accept()
try:
@ -162,30 +164,32 @@ def create_endpoint_2_annotations(
except WebSocketDisconnect:
await websocket.close()
else:
@router.post(path)
async def endpoint(
value1: annotation1,
value2: annotation2,
value1: annotation1,
value2: annotation2,
) -> list[Any]:
return [value1, value2]
def create_endpoint_3_annotations(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
annotation1: Any,
annotation2: Any,
annotation3: Any
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
annotation1: Any,
annotation2: Any,
annotation3: Any,
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(
websocket: WebSocket,
value1: annotation1,
value2: annotation2,
value3: annotation3
websocket: WebSocket,
value1: annotation1,
value2: annotation2,
value3: annotation3,
) -> None:
await websocket.accept()
try:
@ -193,10 +197,9 @@ def create_endpoint_3_annotations(
except WebSocketDisconnect:
await websocket.close()
else:
@router.post(path)
async def endpoint(
value1: annotation1,
value2: annotation2,
value3: annotation3
value1: annotation1, value2: annotation2, value3: annotation3
) -> list[Any]:
return [value1, value2, value3]

Loading…
Cancel
Save