Maxim Martynov 2 days ago
committed by GitHub
parent
commit
f63d5e80ff
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 153
      fastapi/dependencies/utils.py
  2. 45
      fastapi/routing.py

153
fastapi/dependencies/utils.py

@ -1,4 +1,5 @@
import inspect import inspect
import sys
from contextlib import AsyncExitStack, contextmanager from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass from dataclasses import dataclass
@ -17,6 +18,7 @@ from typing import (
Union, Union,
cast, cast,
) )
from weakref import WeakKeyDictionary
import anyio import anyio
from fastapi import params from fastapi import params
@ -47,10 +49,7 @@ from fastapi._compat import (
value_is_sequence, value_is_sequence,
) )
from fastapi.background import BackgroundTasks from fastapi.background import BackgroundTasks
from fastapi.concurrency import ( from fastapi.concurrency import asynccontextmanager, contextmanager_in_threadpool
asynccontextmanager,
contextmanager_in_threadpool,
)
from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.logger import logger from fastapi.logger import logger
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
@ -88,6 +87,33 @@ multipart_incorrect_install_error = (
) )
class CallableInfo:
__slots__ = (
"typed_signature",
"is_gen_callable",
"is_async_gen_callable",
"is_coroutine_callable",
)
def __init__(
self,
typed_signature: inspect.Signature,
is_gen_callable: bool,
is_async_gen_callable: bool,
is_coroutine_callable: bool,
) -> None:
self.typed_signature = typed_signature
self.is_gen_callable = is_gen_callable
self.is_async_gen_callable = is_async_gen_callable
self.is_coroutine_callable = is_coroutine_callable
if sys.version_info < (3, 9):
CallableInfoCache = WeakKeyDictionary
else:
CallableInfoCache = WeakKeyDictionary[Callable[..., Any], CallableInfo]
def ensure_multipart_is_installed() -> None: def ensure_multipart_is_installed() -> None:
try: try:
from python_multipart import __version__ from python_multipart import __version__
@ -121,6 +147,7 @@ def get_param_sub_dependant(
depends: params.Depends, depends: params.Depends,
path: str, path: str,
security_scopes: Optional[List[str]] = None, security_scopes: Optional[List[str]] = None,
callable_info_cache: CallableInfoCache,
) -> Dependant: ) -> Dependant:
assert depends.dependency assert depends.dependency
return get_sub_dependant( return get_sub_dependant(
@ -129,14 +156,25 @@ def get_param_sub_dependant(
path=path, path=path,
name=param_name, name=param_name,
security_scopes=security_scopes, security_scopes=security_scopes,
callable_info_cache=callable_info_cache,
) )
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: def get_parameterless_sub_dependant(
*,
depends: params.Depends,
path: str,
callable_info_cache: CallableInfoCache,
) -> Dependant:
assert callable(depends.dependency), ( assert callable(depends.dependency), (
"A parameter-less dependency must have a callable dependency" "A parameter-less dependency must have a callable dependency"
) )
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) return get_sub_dependant(
depends=depends,
dependency=depends.dependency,
path=path,
callable_info_cache=callable_info_cache,
)
def get_sub_dependant( def get_sub_dependant(
@ -146,6 +184,7 @@ def get_sub_dependant(
path: str, path: str,
name: Optional[str] = None, name: Optional[str] = None,
security_scopes: Optional[List[str]] = None, security_scopes: Optional[List[str]] = None,
callable_info_cache: CallableInfoCache,
) -> Dependant: ) -> Dependant:
security_requirement = None security_requirement = None
security_scopes = security_scopes or [] security_scopes = security_scopes or []
@ -165,6 +204,7 @@ def get_sub_dependant(
name=name, name=name,
security_scopes=security_scopes, security_scopes=security_scopes,
use_cache=depends.use_cache, use_cache=depends.use_cache,
callable_info_cache=callable_info_cache,
) )
if security_requirement: if security_requirement:
sub_dependant.security_requirements.append(security_requirement) sub_dependant.security_requirements.append(security_requirement)
@ -240,7 +280,16 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
) )
for param in signature.parameters.values() for param in signature.parameters.values()
] ]
typed_signature = inspect.Signature(typed_params) return_annotation = signature.return_annotation
if return_annotation is inspect.Signature.empty:
return_annotation = None
else:
return_annotation = get_typed_annotation(return_annotation, globalns)
typed_signature = inspect.Signature(
typed_params,
return_annotation=return_annotation,
)
return typed_signature return typed_signature
@ -251,17 +300,6 @@ def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
return annotation return annotation
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
signature = inspect.signature(call)
annotation = signature.return_annotation
if annotation is inspect.Signature.empty:
return None
globalns = getattr(call, "__globals__", {})
return get_typed_annotation(annotation, globalns)
def get_dependant( def get_dependant(
*, *,
path: str, path: str,
@ -269,10 +307,17 @@ def get_dependant(
name: Optional[str] = None, name: Optional[str] = None,
security_scopes: Optional[List[str]] = None, security_scopes: Optional[List[str]] = None,
use_cache: bool = True, use_cache: bool = True,
callable_info_cache: Optional[CallableInfoCache] = None,
) -> Dependant: ) -> Dependant:
path_param_names = get_path_param_names(path) path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call) callable_info_cache = prepare_callable_info_cache(
signature_params = endpoint_signature.parameters existing_cache=callable_info_cache,
)
callable_info = get_cached_callable_info(
call=call,
callable_info_cache=callable_info_cache,
)
signature_params = callable_info.typed_signature.parameters
dependant = Dependant( dependant = Dependant(
call=call, call=call,
name=name, name=name,
@ -294,6 +339,7 @@ def get_dependant(
depends=param_details.depends, depends=param_details.depends,
path=path, path=path,
security_scopes=security_scopes, security_scopes=security_scopes,
callable_info_cache=callable_info_cache,
) )
dependant.dependencies.append(sub_dependant) dependant.dependencies.append(sub_dependant)
continue continue
@ -527,6 +573,42 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
dependant.cookie_params.append(field) dependant.cookie_params.append(field)
def prepare_callable_info_cache(
call: Optional[Callable[..., Any]] = None,
existing_cache: Optional[CallableInfoCache] = None,
) -> CallableInfoCache:
if existing_cache is None:
existing_cache = WeakKeyDictionary()
if call is not None:
get_cached_callable_info(call, existing_cache)
return existing_cache
def get_cached_callable_info(
call: Callable[..., Any],
callable_info_cache: CallableInfoCache,
) -> CallableInfo:
try:
callable_info = callable_info_cache.get(call)
except TypeError:
# cannot create weakref, don't add to cache
return inspect_callable(call)
if callable_info is None:
callable_info = inspect_callable(call)
callable_info_cache.setdefault(call, callable_info)
return callable_info
def inspect_callable(call: Callable[..., Any]) -> CallableInfo:
return CallableInfo(
typed_signature=get_typed_signature(call),
is_gen_callable=is_gen_callable(call),
is_async_gen_callable=is_async_gen_callable(call),
is_coroutine_callable=is_coroutine_callable(call),
)
def is_coroutine_callable(call: Callable[..., Any]) -> bool: def is_coroutine_callable(call: Callable[..., Any]) -> bool:
if inspect.isroutine(call): if inspect.isroutine(call):
return inspect.iscoroutinefunction(call) return inspect.iscoroutinefunction(call)
@ -551,11 +633,15 @@ def is_gen_callable(call: Callable[..., Any]) -> bool:
async def solve_generator( async def solve_generator(
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] *,
call: Callable[..., Any],
callable_info: CallableInfo,
stack: AsyncExitStack,
sub_values: Dict[str, Any],
) -> Any: ) -> Any:
if is_gen_callable(call): if callable_info.is_gen_callable:
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
elif is_async_gen_callable(call): elif callable_info.is_async_gen_callable:
cm = asynccontextmanager(call)(**sub_values) cm = asynccontextmanager(call)(**sub_values)
return await stack.enter_async_context(cm) return await stack.enter_async_context(cm)
@ -578,11 +664,15 @@ async def solve_dependencies(
response: Optional[Response] = None, response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None, dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
callable_info_cache: Optional[CallableInfoCache] = None,
async_exit_stack: AsyncExitStack, async_exit_stack: AsyncExitStack,
embed_body_fields: bool, embed_body_fields: bool,
) -> SolvedDependency: ) -> SolvedDependency:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
errors: List[Any] = [] errors: List[Any] = []
callable_info_cache = prepare_callable_info_cache(
existing_cache=callable_info_cache,
)
if response is None: if response is None:
response = Response() response = Response()
del response.headers["content-length"] del response.headers["content-length"]
@ -610,6 +700,8 @@ async def solve_dependencies(
call=call, call=call,
name=sub_dependant.name, name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes, security_scopes=sub_dependant.security_scopes,
use_cache=sub_dependant.use_cache,
callable_info_cache=callable_info_cache,
) )
solved_result = await solve_dependencies( solved_result = await solve_dependencies(
@ -620,6 +712,7 @@ async def solve_dependencies(
response=response, response=response,
dependency_overrides_provider=dependency_overrides_provider, dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache, dependency_cache=dependency_cache,
callable_info_cache=callable_info_cache,
async_exit_stack=async_exit_stack, async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields, embed_body_fields=embed_body_fields,
) )
@ -630,11 +723,19 @@ async def solve_dependencies(
continue continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key] solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call): else:
callable_info = get_cached_callable_info(
call=call,
callable_info_cache=callable_info_cache,
)
if callable_info.is_gen_callable or callable_info.is_async_gen_callable:
solved = await solve_generator( solved = await solve_generator(
call=call, stack=async_exit_stack, sub_values=solved_result.values call=call,
callable_info=callable_info,
stack=async_exit_stack,
sub_values=solved_result.values,
) )
elif is_coroutine_callable(call): elif callable_info.is_coroutine_callable:
solved = await call(**solved_result.values) solved = await call(**solved_result.values)
else: else:
solved = await run_in_threadpool(call, **solved_result.values) solved = await run_in_threadpool(call, **solved_result.values)

45
fastapi/routing.py

@ -34,12 +34,14 @@ from fastapi._compat import (
from fastapi.datastructures import Default, DefaultPlaceholder from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import ( from fastapi.dependencies.utils import (
CallableInfoCache,
_should_embed_body_fields, _should_embed_body_fields,
get_body_field, get_body_field,
get_cached_callable_info,
get_dependant, get_dependant,
get_flat_dependant, get_flat_dependant,
get_parameterless_sub_dependant, get_parameterless_sub_dependant,
get_typed_return_annotation, prepare_callable_info_cache,
solve_dependencies, solve_dependencies,
) )
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
@ -229,6 +231,7 @@ def get_request_handler(
response_model_exclude_none: bool = False, response_model_exclude_none: bool = False,
dependency_overrides_provider: Optional[Any] = None, dependency_overrides_provider: Optional[Any] = None,
embed_body_fields: bool = False, embed_body_fields: bool = False,
callable_info_cache: Optional[CallableInfoCache] = None,
) -> Callable[[Request], Coroutine[Any, Any, Response]]: ) -> Callable[[Request], Coroutine[Any, Any, Response]]:
assert dependant.call is not None, "dependant.call must be a function" assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = asyncio.iscoroutinefunction(dependant.call) is_coroutine = asyncio.iscoroutinefunction(dependant.call)
@ -238,6 +241,12 @@ def get_request_handler(
else: else:
actual_response_class = response_class actual_response_class = response_class
# callables are not changing between requests, so we can cache them here
callable_info_cache = prepare_callable_info_cache(
call=dependant.call,
existing_cache=callable_info_cache,
)
async def app(request: Request) -> Response: async def app(request: Request) -> Response:
response: Union[Response, None] = None response: Union[Response, None] = None
async with AsyncExitStack() as file_stack: async with AsyncExitStack() as file_stack:
@ -294,6 +303,7 @@ def get_request_handler(
dependant=dependant, dependant=dependant,
body=body, body=body,
dependency_overrides_provider=dependency_overrides_provider, dependency_overrides_provider=dependency_overrides_provider,
callable_info_cache=callable_info_cache,
async_exit_stack=async_exit_stack, async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields, embed_body_fields=embed_body_fields,
) )
@ -362,7 +372,14 @@ def get_websocket_app(
dependant: Dependant, dependant: Dependant,
dependency_overrides_provider: Optional[Any] = None, dependency_overrides_provider: Optional[Any] = None,
embed_body_fields: bool = False, embed_body_fields: bool = False,
callable_info_cache: Optional[CallableInfoCache] = None,
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
# callables are not changing between requests, so we can cache them here
callable_info_cache = prepare_callable_info_cache(
call=dependant.call,
existing_cache=callable_info_cache,
)
async def app(websocket: WebSocket) -> None: async def app(websocket: WebSocket) -> None:
async with AsyncExitStack() as async_exit_stack: async with AsyncExitStack() as async_exit_stack:
# TODO: remove this scope later, after a few releases # TODO: remove this scope later, after a few releases
@ -373,6 +390,7 @@ def get_websocket_app(
request=websocket, request=websocket,
dependant=dependant, dependant=dependant,
dependency_overrides_provider=dependency_overrides_provider, dependency_overrides_provider=dependency_overrides_provider,
callable_info_cache=callable_info_cache,
async_exit_stack=async_exit_stack, async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields, embed_body_fields=embed_body_fields,
) )
@ -398,14 +416,23 @@ class APIWebSocketRoute(routing.WebSocketRoute):
) -> None: ) -> None:
self.path = path self.path = path
self.endpoint = endpoint self.endpoint = endpoint
self.callable_info_cache = prepare_callable_info_cache(call=endpoint)
self.name = get_name(endpoint) if name is None else name self.name = get_name(endpoint) if name is None else name
self.dependencies = list(dependencies or []) self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_dependant(path=self.path_format, call=self.endpoint) self.dependant = get_dependant(
path=self.path_format,
call=self.endpoint,
callable_info_cache=self.callable_info_cache,
)
for depends in self.dependencies[::-1]: for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert( self.dependant.dependencies.insert(
0, 0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format), get_parameterless_sub_dependant(
depends=depends,
path=self.path_format,
callable_info_cache=self.callable_info_cache,
),
) )
self._flat_dependant = get_flat_dependant(self.dependant) self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields( self._embed_body_fields = _should_embed_body_fields(
@ -416,6 +443,7 @@ class APIWebSocketRoute(routing.WebSocketRoute):
dependant=self.dependant, dependant=self.dependant,
dependency_overrides_provider=dependency_overrides_provider, dependency_overrides_provider=dependency_overrides_provider,
embed_body_fields=self._embed_body_fields, embed_body_fields=self._embed_body_fields,
callable_info_cache=self.callable_info_cache,
) )
) )
@ -463,8 +491,10 @@ class APIRoute(routing.Route):
) -> None: ) -> None:
self.path = path self.path = path
self.endpoint = endpoint self.endpoint = endpoint
self.callable_info_cache = prepare_callable_info_cache(call=endpoint)
if isinstance(response_model, DefaultPlaceholder): if isinstance(response_model, DefaultPlaceholder):
return_annotation = get_typed_return_annotation(endpoint) callable_info = get_cached_callable_info(endpoint, self.callable_info_cache)
return_annotation = callable_info.typed_signature.return_annotation
if lenient_issubclass(return_annotation, Response): if lenient_issubclass(return_annotation, Response):
response_model = None response_model = None
else: else:
@ -556,7 +586,11 @@ class APIRoute(routing.Route):
for depends in self.dependencies[::-1]: for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert( self.dependant.dependencies.insert(
0, 0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format), get_parameterless_sub_dependant(
depends=depends,
path=self.path_format,
callable_info_cache=self.callable_info_cache,
),
) )
self._flat_dependant = get_flat_dependant(self.dependant) self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields( self._embed_body_fields = _should_embed_body_fields(
@ -584,6 +618,7 @@ class APIRoute(routing.Route):
response_model_exclude_none=self.response_model_exclude_none, response_model_exclude_none=self.response_model_exclude_none,
dependency_overrides_provider=self.dependency_overrides_provider, dependency_overrides_provider=self.dependency_overrides_provider,
embed_body_fields=self._embed_body_fields, embed_body_fields=self._embed_body_fields,
callable_info_cache=self.callable_info_cache,
) )
def matches(self, scope: Scope) -> Tuple[Match, Scope]: def matches(self, scope: Scope) -> Tuple[Match, Scope]:

Loading…
Cancel
Save