|
|
@ -49,6 +49,7 @@ from fastapi.background import BackgroundTasks |
|
|
|
from fastapi.concurrency import ( |
|
|
|
asynccontextmanager, |
|
|
|
contextmanager_in_threadpool, |
|
|
|
run_in_threadpool, |
|
|
|
) |
|
|
|
from fastapi.dependencies.models import Dependant, SecurityRequirement |
|
|
|
from fastapi.logger import logger |
|
|
@ -58,7 +59,6 @@ from fastapi.security.open_id_connect_url import OpenIdConnect |
|
|
|
from fastapi.utils import create_model_field, get_path_param_names |
|
|
|
from pydantic.fields import FieldInfo |
|
|
|
from starlette.background import BackgroundTasks as StarletteBackgroundTasks |
|
|
|
from starlette.concurrency import run_in_threadpool |
|
|
|
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile |
|
|
|
from starlette.requests import HTTPConnection, Request |
|
|
|
from starlette.responses import Response |
|
|
@ -149,6 +149,7 @@ def get_sub_dependant( |
|
|
|
name=name, |
|
|
|
security_scopes=security_scopes, |
|
|
|
use_cache=depends.use_cache, |
|
|
|
limiter=depends.limiter, |
|
|
|
) |
|
|
|
if security_requirement: |
|
|
|
sub_dependant.security_requirements.append(security_requirement) |
|
|
@ -176,6 +177,7 @@ def get_flat_dependant( |
|
|
|
body_params=dependant.body_params.copy(), |
|
|
|
security_requirements=dependant.security_requirements.copy(), |
|
|
|
use_cache=dependant.use_cache, |
|
|
|
limiter=dependant.limiter, |
|
|
|
path=dependant.path, |
|
|
|
) |
|
|
|
for sub_dependant in dependant.dependencies: |
|
|
@ -244,6 +246,7 @@ def get_dependant( |
|
|
|
name: Optional[str] = None, |
|
|
|
security_scopes: Optional[List[str]] = None, |
|
|
|
use_cache: bool = True, |
|
|
|
limiter: Optional[anyio.CapacityLimiter] = None, |
|
|
|
) -> Dependant: |
|
|
|
path_param_names = get_path_param_names(path) |
|
|
|
endpoint_signature = get_typed_signature(call) |
|
|
@ -254,6 +257,7 @@ def get_dependant( |
|
|
|
path=path, |
|
|
|
security_scopes=security_scopes, |
|
|
|
use_cache=use_cache, |
|
|
|
limiter=limiter, |
|
|
|
) |
|
|
|
for param_name, param in signature_params.items(): |
|
|
|
is_path_param = param_name in path_param_names |
|
|
@ -529,10 +533,12 @@ def is_gen_callable(call: Callable[..., Any]) -> bool: |
|
|
|
|
|
|
|
|
|
|
|
async def solve_generator( |
|
|
|
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] |
|
|
|
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any], |
|
|
|
limiter: Optional[anyio.CapacityLimiter] = None, |
|
|
|
) -> Any: |
|
|
|
if is_gen_callable(call): |
|
|
|
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) |
|
|
|
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values), |
|
|
|
limiter=limiter) |
|
|
|
elif is_async_gen_callable(call): |
|
|
|
cm = asynccontextmanager(call)(**sub_values) |
|
|
|
return await stack.enter_async_context(cm) |
|
|
@ -587,6 +593,7 @@ async def solve_dependencies( |
|
|
|
call=call, |
|
|
|
name=sub_dependant.name, |
|
|
|
security_scopes=sub_dependant.security_scopes, |
|
|
|
limiter=sub_dependant.limiter, |
|
|
|
) |
|
|
|
|
|
|
|
solved_result = await solve_dependencies( |
|
|
@ -608,12 +615,14 @@ async def solve_dependencies( |
|
|
|
solved = dependency_cache[sub_dependant.cache_key] |
|
|
|
elif is_gen_callable(call) or is_async_gen_callable(call): |
|
|
|
solved = await solve_generator( |
|
|
|
call=call, stack=async_exit_stack, sub_values=solved_result.values |
|
|
|
call=call, stack=async_exit_stack, sub_values=solved_result.values, |
|
|
|
limiter=sub_dependant.limiter, |
|
|
|
) |
|
|
|
elif is_coroutine_callable(call): |
|
|
|
solved = await call(**solved_result.values) |
|
|
|
else: |
|
|
|
solved = await run_in_threadpool(call, **solved_result.values) |
|
|
|
solved = await run_in_threadpool(call, _limiter=sub_dependant.limiter, |
|
|
|
**solved_result.values) |
|
|
|
if sub_dependant.name is not None: |
|
|
|
values[sub_dependant.name] = solved |
|
|
|
if sub_dependant.cache_key not in dependency_cache: |
|
|
|