Browse Source

️ Cache callable inspection result

pull/13974/head
Martynov Maxim 2 days ago
committed by Maxim Martynov
parent
commit
06a0b1e368
  1. 161
      fastapi/dependencies/utils.py
  2. 45
      fastapi/routing.py

161
fastapi/dependencies/utils.py

@ -1,4 +1,5 @@
import inspect
import sys
from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy
from dataclasses import dataclass
@ -17,6 +18,7 @@ from typing import (
Union,
cast,
)
from weakref import WeakKeyDictionary
import anyio
from fastapi import params
@ -47,10 +49,7 @@ from fastapi._compat import (
value_is_sequence,
)
from fastapi.background import BackgroundTasks
from fastapi.concurrency import (
asynccontextmanager,
contextmanager_in_threadpool,
)
from fastapi.concurrency import asynccontextmanager, contextmanager_in_threadpool
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
@ -88,6 +87,33 @@ multipart_incorrect_install_error = (
)
class CallableInfo:
__slots__ = (
"typed_signature",
"is_gen_callable",
"is_async_gen_callable",
"is_coroutine_callable",
)
def __init__(
self,
typed_signature: inspect.Signature,
is_gen_callable: bool,
is_async_gen_callable: bool,
is_coroutine_callable: bool,
) -> None:
self.typed_signature = typed_signature
self.is_gen_callable = is_gen_callable
self.is_async_gen_callable = is_async_gen_callable
self.is_coroutine_callable = is_coroutine_callable
if sys.version_info < (3, 9):
CallableInfoCache = WeakKeyDictionary
else:
CallableInfoCache = WeakKeyDictionary[Callable[..., Any], CallableInfo]
def ensure_multipart_is_installed() -> None:
try:
from python_multipart import __version__
@ -121,6 +147,7 @@ def get_param_sub_dependant(
depends: params.Depends,
path: str,
security_scopes: Optional[List[str]] = None,
callable_info_cache: CallableInfoCache,
) -> Dependant:
assert depends.dependency
return get_sub_dependant(
@ -129,14 +156,25 @@ def get_param_sub_dependant(
path=path,
name=param_name,
security_scopes=security_scopes,
callable_info_cache=callable_info_cache,
)
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
def get_parameterless_sub_dependant(
*,
depends: params.Depends,
path: str,
callable_info_cache: CallableInfoCache,
) -> Dependant:
assert callable(depends.dependency), (
"A parameter-less dependency must have a callable dependency"
)
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
return get_sub_dependant(
depends=depends,
dependency=depends.dependency,
path=path,
callable_info_cache=callable_info_cache,
)
def get_sub_dependant(
@ -146,6 +184,7 @@ def get_sub_dependant(
path: str,
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
callable_info_cache: CallableInfoCache,
) -> Dependant:
security_requirement = None
security_scopes = security_scopes or []
@ -165,6 +204,7 @@ def get_sub_dependant(
name=name,
security_scopes=security_scopes,
use_cache=depends.use_cache,
callable_info_cache=callable_info_cache,
)
if security_requirement:
sub_dependant.security_requirements.append(security_requirement)
@ -240,7 +280,16 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
)
for param in signature.parameters.values()
]
typed_signature = inspect.Signature(typed_params)
return_annotation = signature.return_annotation
if return_annotation is inspect.Signature.empty:
return_annotation = None
else:
return_annotation = get_typed_annotation(return_annotation, globalns)
typed_signature = inspect.Signature(
typed_params,
return_annotation=return_annotation,
)
return typed_signature
@ -251,17 +300,6 @@ def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
return annotation
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
signature = inspect.signature(call)
annotation = signature.return_annotation
if annotation is inspect.Signature.empty:
return None
globalns = getattr(call, "__globals__", {})
return get_typed_annotation(annotation, globalns)
def get_dependant(
*,
path: str,
@ -269,10 +307,17 @@ def get_dependant(
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
callable_info_cache: Optional[CallableInfoCache] = None,
) -> Dependant:
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters
callable_info_cache = prepare_callable_info_cache(
existing_cache=callable_info_cache,
)
callable_info = get_cached_callable_info(
call=call,
callable_info_cache=callable_info_cache,
)
signature_params = callable_info.typed_signature.parameters
dependant = Dependant(
call=call,
name=name,
@ -294,6 +339,7 @@ def get_dependant(
depends=param_details.depends,
path=path,
security_scopes=security_scopes,
callable_info_cache=callable_info_cache,
)
dependant.dependencies.append(sub_dependant)
continue
@ -527,6 +573,42 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
dependant.cookie_params.append(field)
def prepare_callable_info_cache(
call: Optional[Callable[..., Any]] = None,
existing_cache: Optional[CallableInfoCache] = None,
) -> CallableInfoCache:
if existing_cache is None:
existing_cache = WeakKeyDictionary()
if call is not None:
get_cached_callable_info(call, existing_cache)
return existing_cache
def get_cached_callable_info(
call: Callable[..., Any],
callable_info_cache: CallableInfoCache,
) -> CallableInfo:
try:
callable_info = callable_info_cache.get(call)
except TypeError:
# cannot create weakref, don't add to cache
return inspect_callable(call)
if callable_info is None:
callable_info = inspect_callable(call)
callable_info_cache.setdefault(call, callable_info)
return callable_info
def inspect_callable(call: Callable[..., Any]) -> CallableInfo:
return CallableInfo(
typed_signature=get_typed_signature(call),
is_gen_callable=is_gen_callable(call),
is_async_gen_callable=is_async_gen_callable(call),
is_coroutine_callable=is_coroutine_callable(call),
)
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
@ -551,11 +633,15 @@ def is_gen_callable(call: Callable[..., Any]) -> bool:
async def solve_generator(
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
*,
call: Callable[..., Any],
callable_info: CallableInfo,
stack: AsyncExitStack,
sub_values: Dict[str, Any],
) -> Any:
if is_gen_callable(call):
if callable_info.is_gen_callable:
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
elif is_async_gen_callable(call):
elif callable_info.is_async_gen_callable:
cm = asynccontextmanager(call)(**sub_values)
return await stack.enter_async_context(cm)
@ -578,11 +664,15 @@ async def solve_dependencies(
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
callable_info_cache: Optional[CallableInfoCache] = None,
async_exit_stack: AsyncExitStack,
embed_body_fields: bool,
) -> SolvedDependency:
values: Dict[str, Any] = {}
errors: List[Any] = []
callable_info_cache = prepare_callable_info_cache(
existing_cache=callable_info_cache,
)
if response is None:
response = Response()
del response.headers["content-length"]
@ -610,6 +700,8 @@ async def solve_dependencies(
call=call,
name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes,
use_cache=sub_dependant.use_cache,
callable_info_cache=callable_info_cache,
)
solved_result = await solve_dependencies(
@ -620,6 +712,7 @@ async def solve_dependencies(
response=response,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
callable_info_cache=callable_info_cache,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
)
@ -630,14 +723,22 @@ async def solve_dependencies(
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call):
solved = await solve_generator(
call=call, stack=async_exit_stack, sub_values=solved_result.values
)
elif is_coroutine_callable(call):
solved = await call(**solved_result.values)
else:
solved = await run_in_threadpool(call, **solved_result.values)
callable_info = get_cached_callable_info(
call=call,
callable_info_cache=callable_info_cache,
)
if callable_info.is_gen_callable or callable_info.is_async_gen_callable:
solved = await solve_generator(
call=call,
callable_info=callable_info,
stack=async_exit_stack,
sub_values=solved_result.values,
)
elif callable_info.is_coroutine_callable:
solved = await call(**solved_result.values)
else:
solved = await run_in_threadpool(call, **solved_result.values)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:

45
fastapi/routing.py

@ -34,12 +34,14 @@ from fastapi._compat import (
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import (
CallableInfoCache,
_should_embed_body_fields,
get_body_field,
get_cached_callable_info,
get_dependant,
get_flat_dependant,
get_parameterless_sub_dependant,
get_typed_return_annotation,
prepare_callable_info_cache,
solve_dependencies,
)
from fastapi.encoders import jsonable_encoder
@ -229,6 +231,7 @@ def get_request_handler(
response_model_exclude_none: bool = False,
dependency_overrides_provider: Optional[Any] = None,
embed_body_fields: bool = False,
callable_info_cache: Optional[CallableInfoCache] = None,
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
@ -238,6 +241,12 @@ def get_request_handler(
else:
actual_response_class = response_class
# callables are not changing between requests, so we can cache them here
callable_info_cache = prepare_callable_info_cache(
call=dependant.call,
existing_cache=callable_info_cache,
)
async def app(request: Request) -> Response:
response: Union[Response, None] = None
async with AsyncExitStack() as file_stack:
@ -294,6 +303,7 @@ def get_request_handler(
dependant=dependant,
body=body,
dependency_overrides_provider=dependency_overrides_provider,
callable_info_cache=callable_info_cache,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
)
@ -362,7 +372,14 @@ def get_websocket_app(
dependant: Dependant,
dependency_overrides_provider: Optional[Any] = None,
embed_body_fields: bool = False,
callable_info_cache: Optional[CallableInfoCache] = None,
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
# callables are not changing between requests, so we can cache them here
callable_info_cache = prepare_callable_info_cache(
call=dependant.call,
existing_cache=callable_info_cache,
)
async def app(websocket: WebSocket) -> None:
async with AsyncExitStack() as async_exit_stack:
# TODO: remove this scope later, after a few releases
@ -373,6 +390,7 @@ def get_websocket_app(
request=websocket,
dependant=dependant,
dependency_overrides_provider=dependency_overrides_provider,
callable_info_cache=callable_info_cache,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
)
@ -398,14 +416,23 @@ class APIWebSocketRoute(routing.WebSocketRoute):
) -> None:
self.path = path
self.endpoint = endpoint
self.callable_info_cache = prepare_callable_info_cache(call=endpoint)
self.name = get_name(endpoint) if name is None else name
self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
self.dependant = get_dependant(
path=self.path_format,
call=self.endpoint,
callable_info_cache=self.callable_info_cache,
)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
get_parameterless_sub_dependant(
depends=depends,
path=self.path_format,
callable_info_cache=self.callable_info_cache,
),
)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
@ -416,6 +443,7 @@ class APIWebSocketRoute(routing.WebSocketRoute):
dependant=self.dependant,
dependency_overrides_provider=dependency_overrides_provider,
embed_body_fields=self._embed_body_fields,
callable_info_cache=self.callable_info_cache,
)
)
@ -463,8 +491,10 @@ class APIRoute(routing.Route):
) -> None:
self.path = path
self.endpoint = endpoint
self.callable_info_cache = prepare_callable_info_cache(call=endpoint)
if isinstance(response_model, DefaultPlaceholder):
return_annotation = get_typed_return_annotation(endpoint)
callable_info = get_cached_callable_info(endpoint, self.callable_info_cache)
return_annotation = callable_info.typed_signature.return_annotation
if lenient_issubclass(return_annotation, Response):
response_model = None
else:
@ -556,7 +586,11 @@ class APIRoute(routing.Route):
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
get_parameterless_sub_dependant(
depends=depends,
path=self.path_format,
callable_info_cache=self.callable_info_cache,
),
)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
@ -584,6 +618,7 @@ class APIRoute(routing.Route):
response_model_exclude_none=self.response_model_exclude_none,
dependency_overrides_provider=self.dependency_overrides_provider,
embed_body_fields=self._embed_body_fields,
callable_info_cache=self.callable_info_cache,
)
def matches(self, scope: Scope) -> Tuple[Match, Scope]:

Loading…
Cancel
Save