|
|
@ -1,4 +1,5 @@ |
|
|
|
import inspect |
|
|
|
import sys |
|
|
|
from contextlib import AsyncExitStack, contextmanager |
|
|
|
from copy import copy, deepcopy |
|
|
|
from dataclasses import dataclass |
|
|
@ -17,6 +18,7 @@ from typing import ( |
|
|
|
Union, |
|
|
|
cast, |
|
|
|
) |
|
|
|
from weakref import WeakKeyDictionary |
|
|
|
|
|
|
|
import anyio |
|
|
|
from fastapi import params |
|
|
@ -47,10 +49,7 @@ from fastapi._compat import ( |
|
|
|
value_is_sequence, |
|
|
|
) |
|
|
|
from fastapi.background import BackgroundTasks |
|
|
|
from fastapi.concurrency import ( |
|
|
|
asynccontextmanager, |
|
|
|
contextmanager_in_threadpool, |
|
|
|
) |
|
|
|
from fastapi.concurrency import asynccontextmanager, contextmanager_in_threadpool |
|
|
|
from fastapi.dependencies.models import Dependant, SecurityRequirement |
|
|
|
from fastapi.logger import logger |
|
|
|
from fastapi.security.base import SecurityBase |
|
|
@ -88,6 +87,33 @@ multipart_incorrect_install_error = ( |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class CallableInfo: |
|
|
|
__slots__ = ( |
|
|
|
"typed_signature", |
|
|
|
"is_gen_callable", |
|
|
|
"is_async_gen_callable", |
|
|
|
"is_coroutine_callable", |
|
|
|
) |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
typed_signature: inspect.Signature, |
|
|
|
is_gen_callable: bool, |
|
|
|
is_async_gen_callable: bool, |
|
|
|
is_coroutine_callable: bool, |
|
|
|
) -> None: |
|
|
|
self.typed_signature = typed_signature |
|
|
|
self.is_gen_callable = is_gen_callable |
|
|
|
self.is_async_gen_callable = is_async_gen_callable |
|
|
|
self.is_coroutine_callable = is_coroutine_callable |
|
|
|
|
|
|
|
|
|
|
|
if sys.version_info < (3, 9): |
|
|
|
CallableInfoCache = WeakKeyDictionary |
|
|
|
else: |
|
|
|
CallableInfoCache = WeakKeyDictionary[Callable[..., Any], CallableInfo] |
|
|
|
|
|
|
|
|
|
|
|
def ensure_multipart_is_installed() -> None: |
|
|
|
try: |
|
|
|
from python_multipart import __version__ |
|
|
@ -121,6 +147,7 @@ def get_param_sub_dependant( |
|
|
|
depends: params.Depends, |
|
|
|
path: str, |
|
|
|
security_scopes: Optional[List[str]] = None, |
|
|
|
callable_info_cache: CallableInfoCache, |
|
|
|
) -> Dependant: |
|
|
|
assert depends.dependency |
|
|
|
return get_sub_dependant( |
|
|
@ -129,14 +156,25 @@ def get_param_sub_dependant( |
|
|
|
path=path, |
|
|
|
name=param_name, |
|
|
|
security_scopes=security_scopes, |
|
|
|
callable_info_cache=callable_info_cache, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: |
|
|
|
def get_parameterless_sub_dependant( |
|
|
|
*, |
|
|
|
depends: params.Depends, |
|
|
|
path: str, |
|
|
|
callable_info_cache: CallableInfoCache, |
|
|
|
) -> Dependant: |
|
|
|
assert callable(depends.dependency), ( |
|
|
|
"A parameter-less dependency must have a callable dependency" |
|
|
|
) |
|
|
|
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) |
|
|
|
return get_sub_dependant( |
|
|
|
depends=depends, |
|
|
|
dependency=depends.dependency, |
|
|
|
path=path, |
|
|
|
callable_info_cache=callable_info_cache, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def get_sub_dependant( |
|
|
@ -146,6 +184,7 @@ def get_sub_dependant( |
|
|
|
path: str, |
|
|
|
name: Optional[str] = None, |
|
|
|
security_scopes: Optional[List[str]] = None, |
|
|
|
callable_info_cache: CallableInfoCache, |
|
|
|
) -> Dependant: |
|
|
|
security_requirement = None |
|
|
|
security_scopes = security_scopes or [] |
|
|
@ -165,6 +204,7 @@ def get_sub_dependant( |
|
|
|
name=name, |
|
|
|
security_scopes=security_scopes, |
|
|
|
use_cache=depends.use_cache, |
|
|
|
callable_info_cache=callable_info_cache, |
|
|
|
) |
|
|
|
if security_requirement: |
|
|
|
sub_dependant.security_requirements.append(security_requirement) |
|
|
@ -240,7 +280,16 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: |
|
|
|
) |
|
|
|
for param in signature.parameters.values() |
|
|
|
] |
|
|
|
typed_signature = inspect.Signature(typed_params) |
|
|
|
return_annotation = signature.return_annotation |
|
|
|
if return_annotation is inspect.Signature.empty: |
|
|
|
return_annotation = None |
|
|
|
else: |
|
|
|
return_annotation = get_typed_annotation(return_annotation, globalns) |
|
|
|
|
|
|
|
typed_signature = inspect.Signature( |
|
|
|
typed_params, |
|
|
|
return_annotation=return_annotation, |
|
|
|
) |
|
|
|
return typed_signature |
|
|
|
|
|
|
|
|
|
|
@ -251,17 +300,6 @@ def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: |
|
|
|
return annotation |
|
|
|
|
|
|
|
|
|
|
|
def get_typed_return_annotation(call: Callable[..., Any]) -> Any: |
|
|
|
signature = inspect.signature(call) |
|
|
|
annotation = signature.return_annotation |
|
|
|
|
|
|
|
if annotation is inspect.Signature.empty: |
|
|
|
return None |
|
|
|
|
|
|
|
globalns = getattr(call, "__globals__", {}) |
|
|
|
return get_typed_annotation(annotation, globalns) |
|
|
|
|
|
|
|
|
|
|
|
def get_dependant( |
|
|
|
*, |
|
|
|
path: str, |
|
|
@ -269,10 +307,17 @@ def get_dependant( |
|
|
|
name: Optional[str] = None, |
|
|
|
security_scopes: Optional[List[str]] = None, |
|
|
|
use_cache: bool = True, |
|
|
|
callable_info_cache: Optional[CallableInfoCache] = None, |
|
|
|
) -> Dependant: |
|
|
|
path_param_names = get_path_param_names(path) |
|
|
|
endpoint_signature = get_typed_signature(call) |
|
|
|
signature_params = endpoint_signature.parameters |
|
|
|
callable_info_cache = prepare_callable_info_cache( |
|
|
|
existing_cache=callable_info_cache, |
|
|
|
) |
|
|
|
callable_info = get_cached_callable_info( |
|
|
|
call=call, |
|
|
|
callable_info_cache=callable_info_cache, |
|
|
|
) |
|
|
|
signature_params = callable_info.typed_signature.parameters |
|
|
|
dependant = Dependant( |
|
|
|
call=call, |
|
|
|
name=name, |
|
|
@ -294,6 +339,7 @@ def get_dependant( |
|
|
|
depends=param_details.depends, |
|
|
|
path=path, |
|
|
|
security_scopes=security_scopes, |
|
|
|
callable_info_cache=callable_info_cache, |
|
|
|
) |
|
|
|
dependant.dependencies.append(sub_dependant) |
|
|
|
continue |
|
|
@ -527,6 +573,42 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: |
|
|
|
dependant.cookie_params.append(field) |
|
|
|
|
|
|
|
|
|
|
|
def prepare_callable_info_cache( |
|
|
|
call: Optional[Callable[..., Any]] = None, |
|
|
|
existing_cache: Optional[CallableInfoCache] = None, |
|
|
|
) -> CallableInfoCache: |
|
|
|
if existing_cache is None: |
|
|
|
existing_cache = WeakKeyDictionary() |
|
|
|
if call is not None: |
|
|
|
get_cached_callable_info(call, existing_cache) |
|
|
|
return existing_cache |
|
|
|
|
|
|
|
|
|
|
|
def get_cached_callable_info( |
|
|
|
call: Callable[..., Any], |
|
|
|
callable_info_cache: CallableInfoCache, |
|
|
|
) -> CallableInfo: |
|
|
|
try: |
|
|
|
callable_info = callable_info_cache.get(call) |
|
|
|
except TypeError: |
|
|
|
# cannot create weakref, don't add to cache |
|
|
|
return inspect_callable(call) |
|
|
|
|
|
|
|
if callable_info is None: |
|
|
|
callable_info = inspect_callable(call) |
|
|
|
callable_info_cache.setdefault(call, callable_info) |
|
|
|
return callable_info |
|
|
|
|
|
|
|
|
|
|
|
def inspect_callable(call: Callable[..., Any]) -> CallableInfo: |
|
|
|
return CallableInfo( |
|
|
|
typed_signature=get_typed_signature(call), |
|
|
|
is_gen_callable=is_gen_callable(call), |
|
|
|
is_async_gen_callable=is_async_gen_callable(call), |
|
|
|
is_coroutine_callable=is_coroutine_callable(call), |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def is_coroutine_callable(call: Callable[..., Any]) -> bool: |
|
|
|
if inspect.isroutine(call): |
|
|
|
return inspect.iscoroutinefunction(call) |
|
|
@ -551,11 +633,15 @@ def is_gen_callable(call: Callable[..., Any]) -> bool: |
|
|
|
|
|
|
|
|
|
|
|
async def solve_generator( |
|
|
|
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] |
|
|
|
*, |
|
|
|
call: Callable[..., Any], |
|
|
|
callable_info: CallableInfo, |
|
|
|
stack: AsyncExitStack, |
|
|
|
sub_values: Dict[str, Any], |
|
|
|
) -> Any: |
|
|
|
if is_gen_callable(call): |
|
|
|
if callable_info.is_gen_callable: |
|
|
|
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) |
|
|
|
elif is_async_gen_callable(call): |
|
|
|
elif callable_info.is_async_gen_callable: |
|
|
|
cm = asynccontextmanager(call)(**sub_values) |
|
|
|
return await stack.enter_async_context(cm) |
|
|
|
|
|
|
@ -578,11 +664,15 @@ async def solve_dependencies( |
|
|
|
response: Optional[Response] = None, |
|
|
|
dependency_overrides_provider: Optional[Any] = None, |
|
|
|
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, |
|
|
|
callable_info_cache: Optional[CallableInfoCache] = None, |
|
|
|
async_exit_stack: AsyncExitStack, |
|
|
|
embed_body_fields: bool, |
|
|
|
) -> SolvedDependency: |
|
|
|
values: Dict[str, Any] = {} |
|
|
|
errors: List[Any] = [] |
|
|
|
callable_info_cache = prepare_callable_info_cache( |
|
|
|
existing_cache=callable_info_cache, |
|
|
|
) |
|
|
|
if response is None: |
|
|
|
response = Response() |
|
|
|
del response.headers["content-length"] |
|
|
@ -610,6 +700,8 @@ async def solve_dependencies( |
|
|
|
call=call, |
|
|
|
name=sub_dependant.name, |
|
|
|
security_scopes=sub_dependant.security_scopes, |
|
|
|
use_cache=sub_dependant.use_cache, |
|
|
|
callable_info_cache=callable_info_cache, |
|
|
|
) |
|
|
|
|
|
|
|
solved_result = await solve_dependencies( |
|
|
@ -620,6 +712,7 @@ async def solve_dependencies( |
|
|
|
response=response, |
|
|
|
dependency_overrides_provider=dependency_overrides_provider, |
|
|
|
dependency_cache=dependency_cache, |
|
|
|
callable_info_cache=callable_info_cache, |
|
|
|
async_exit_stack=async_exit_stack, |
|
|
|
embed_body_fields=embed_body_fields, |
|
|
|
) |
|
|
@ -630,14 +723,22 @@ async def solve_dependencies( |
|
|
|
continue |
|
|
|
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: |
|
|
|
solved = dependency_cache[sub_dependant.cache_key] |
|
|
|
elif is_gen_callable(call) or is_async_gen_callable(call): |
|
|
|
solved = await solve_generator( |
|
|
|
call=call, stack=async_exit_stack, sub_values=solved_result.values |
|
|
|
) |
|
|
|
elif is_coroutine_callable(call): |
|
|
|
solved = await call(**solved_result.values) |
|
|
|
else: |
|
|
|
solved = await run_in_threadpool(call, **solved_result.values) |
|
|
|
callable_info = get_cached_callable_info( |
|
|
|
call=call, |
|
|
|
callable_info_cache=callable_info_cache, |
|
|
|
) |
|
|
|
if callable_info.is_gen_callable or callable_info.is_async_gen_callable: |
|
|
|
solved = await solve_generator( |
|
|
|
call=call, |
|
|
|
callable_info=callable_info, |
|
|
|
stack=async_exit_stack, |
|
|
|
sub_values=solved_result.values, |
|
|
|
) |
|
|
|
elif callable_info.is_coroutine_callable: |
|
|
|
solved = await call(**solved_result.values) |
|
|
|
else: |
|
|
|
solved = await run_in_threadpool(call, **solved_result.values) |
|
|
|
if sub_dependant.name is not None: |
|
|
|
values[sub_dependant.name] = solved |
|
|
|
if sub_dependant.cache_key not in dependency_cache: |
|
|
|