From 06a0b1e36880b6c44cba144d8d3f0326a4c7317e Mon Sep 17 00:00:00 2001 From: Martynov Maxim Date: Wed, 6 Aug 2025 13:03:16 +0300 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Cache=20callable=20inspect?= =?UTF-8?q?ion=20result?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/dependencies/utils.py | 161 +++++++++++++++++++++++++++------- fastapi/routing.py | 45 ++++++++-- 2 files changed, 171 insertions(+), 35 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 081b63a8b..753a4fca6 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,4 +1,5 @@ import inspect +import sys from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy from dataclasses import dataclass @@ -17,6 +18,7 @@ from typing import ( Union, cast, ) +from weakref import WeakKeyDictionary import anyio from fastapi import params @@ -47,10 +49,7 @@ from fastapi._compat import ( value_is_sequence, ) from fastapi.background import BackgroundTasks -from fastapi.concurrency import ( - asynccontextmanager, - contextmanager_in_threadpool, -) +from fastapi.concurrency import asynccontextmanager, contextmanager_in_threadpool from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.logger import logger from fastapi.security.base import SecurityBase @@ -88,6 +87,33 @@ multipart_incorrect_install_error = ( ) +class CallableInfo: + __slots__ = ( + "typed_signature", + "is_gen_callable", + "is_async_gen_callable", + "is_coroutine_callable", + ) + + def __init__( + self, + typed_signature: inspect.Signature, + is_gen_callable: bool, + is_async_gen_callable: bool, + is_coroutine_callable: bool, + ) -> None: + self.typed_signature = typed_signature + self.is_gen_callable = is_gen_callable + self.is_async_gen_callable = is_async_gen_callable + self.is_coroutine_callable = is_coroutine_callable + + +if sys.version_info < (3, 9): + CallableInfoCache = WeakKeyDictionary +else: + CallableInfoCache = WeakKeyDictionary[Callable[..., Any], CallableInfo] + + def ensure_multipart_is_installed() -> None: try: from python_multipart import __version__ @@ -121,6 +147,7 @@ def get_param_sub_dependant( depends: params.Depends, path: str, security_scopes: Optional[List[str]] = None, + callable_info_cache: CallableInfoCache, ) -> Dependant: assert depends.dependency return get_sub_dependant( @@ -129,14 +156,25 @@ def get_param_sub_dependant( path=path, name=param_name, security_scopes=security_scopes, + callable_info_cache=callable_info_cache, ) -def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: +def get_parameterless_sub_dependant( + *, + depends: params.Depends, + path: str, + callable_info_cache: CallableInfoCache, +) -> Dependant: assert callable(depends.dependency), ( "A parameter-less dependency must have a callable dependency" ) - return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) + return get_sub_dependant( + depends=depends, + dependency=depends.dependency, + path=path, + callable_info_cache=callable_info_cache, + ) def get_sub_dependant( @@ -146,6 +184,7 @@ def get_sub_dependant( path: str, name: Optional[str] = None, security_scopes: Optional[List[str]] = None, + callable_info_cache: CallableInfoCache, ) -> Dependant: security_requirement = None security_scopes = security_scopes or [] @@ -165,6 +204,7 @@ def get_sub_dependant( name=name, security_scopes=security_scopes, use_cache=depends.use_cache, + callable_info_cache=callable_info_cache, ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) @@ -240,7 +280,16 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: ) for param in signature.parameters.values() ] - typed_signature = inspect.Signature(typed_params) + return_annotation = signature.return_annotation + if return_annotation is inspect.Signature.empty: + return_annotation = None + else: + return_annotation = get_typed_annotation(return_annotation, globalns) + + typed_signature = inspect.Signature( + typed_params, + return_annotation=return_annotation, + ) return typed_signature @@ -251,17 +300,6 @@ def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: return annotation -def get_typed_return_annotation(call: Callable[..., Any]) -> Any: - signature = inspect.signature(call) - annotation = signature.return_annotation - - if annotation is inspect.Signature.empty: - return None - - globalns = getattr(call, "__globals__", {}) - return get_typed_annotation(annotation, globalns) - - def get_dependant( *, path: str, @@ -269,10 +307,17 @@ def get_dependant( name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, + callable_info_cache: Optional[CallableInfoCache] = None, ) -> Dependant: path_param_names = get_path_param_names(path) - endpoint_signature = get_typed_signature(call) - signature_params = endpoint_signature.parameters + callable_info_cache = prepare_callable_info_cache( + existing_cache=callable_info_cache, + ) + callable_info = get_cached_callable_info( + call=call, + callable_info_cache=callable_info_cache, + ) + signature_params = callable_info.typed_signature.parameters dependant = Dependant( call=call, name=name, @@ -294,6 +339,7 @@ def get_dependant( depends=param_details.depends, path=path, security_scopes=security_scopes, + callable_info_cache=callable_info_cache, ) dependant.dependencies.append(sub_dependant) continue @@ -527,6 +573,42 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: dependant.cookie_params.append(field) +def prepare_callable_info_cache( + call: Optional[Callable[..., Any]] = None, + existing_cache: Optional[CallableInfoCache] = None, +) -> CallableInfoCache: + if existing_cache is None: + existing_cache = WeakKeyDictionary() + if call is not None: + get_cached_callable_info(call, existing_cache) + return existing_cache + + +def get_cached_callable_info( + call: Callable[..., Any], + callable_info_cache: CallableInfoCache, +) -> CallableInfo: + try: + callable_info = callable_info_cache.get(call) + except TypeError: + # cannot create weakref, don't add to cache + return inspect_callable(call) + + if callable_info is None: + callable_info = inspect_callable(call) + callable_info_cache.setdefault(call, callable_info) + return callable_info + + +def inspect_callable(call: Callable[..., Any]) -> CallableInfo: + return CallableInfo( + typed_signature=get_typed_signature(call), + is_gen_callable=is_gen_callable(call), + is_async_gen_callable=is_async_gen_callable(call), + is_coroutine_callable=is_coroutine_callable(call), + ) + + def is_coroutine_callable(call: Callable[..., Any]) -> bool: if inspect.isroutine(call): return inspect.iscoroutinefunction(call) @@ -551,11 +633,15 @@ def is_gen_callable(call: Callable[..., Any]) -> bool: async def solve_generator( - *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] + *, + call: Callable[..., Any], + callable_info: CallableInfo, + stack: AsyncExitStack, + sub_values: Dict[str, Any], ) -> Any: - if is_gen_callable(call): + if callable_info.is_gen_callable: cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) - elif is_async_gen_callable(call): + elif callable_info.is_async_gen_callable: cm = asynccontextmanager(call)(**sub_values) return await stack.enter_async_context(cm) @@ -578,11 +664,15 @@ async def solve_dependencies( response: Optional[Response] = None, dependency_overrides_provider: Optional[Any] = None, dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, + callable_info_cache: Optional[CallableInfoCache] = None, async_exit_stack: AsyncExitStack, embed_body_fields: bool, ) -> SolvedDependency: values: Dict[str, Any] = {} errors: List[Any] = [] + callable_info_cache = prepare_callable_info_cache( + existing_cache=callable_info_cache, + ) if response is None: response = Response() del response.headers["content-length"] @@ -610,6 +700,8 @@ async def solve_dependencies( call=call, name=sub_dependant.name, security_scopes=sub_dependant.security_scopes, + use_cache=sub_dependant.use_cache, + callable_info_cache=callable_info_cache, ) solved_result = await solve_dependencies( @@ -620,6 +712,7 @@ async def solve_dependencies( response=response, dependency_overrides_provider=dependency_overrides_provider, dependency_cache=dependency_cache, + callable_info_cache=callable_info_cache, async_exit_stack=async_exit_stack, embed_body_fields=embed_body_fields, ) @@ -630,14 +723,22 @@ async def solve_dependencies( continue if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: solved = dependency_cache[sub_dependant.cache_key] - elif is_gen_callable(call) or is_async_gen_callable(call): - solved = await solve_generator( - call=call, stack=async_exit_stack, sub_values=solved_result.values - ) - elif is_coroutine_callable(call): - solved = await call(**solved_result.values) else: - solved = await run_in_threadpool(call, **solved_result.values) + callable_info = get_cached_callable_info( + call=call, + callable_info_cache=callable_info_cache, + ) + if callable_info.is_gen_callable or callable_info.is_async_gen_callable: + solved = await solve_generator( + call=call, + callable_info=callable_info, + stack=async_exit_stack, + sub_values=solved_result.values, + ) + elif callable_info.is_coroutine_callable: + solved = await call(**solved_result.values) + else: + solved = await run_in_threadpool(call, **solved_result.values) if sub_dependant.name is not None: values[sub_dependant.name] = solved if sub_dependant.cache_key not in dependency_cache: diff --git a/fastapi/routing.py b/fastapi/routing.py index 54c75a027..b81587fd2 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -34,12 +34,14 @@ from fastapi._compat import ( from fastapi.datastructures import Default, DefaultPlaceholder from fastapi.dependencies.models import Dependant from fastapi.dependencies.utils import ( + CallableInfoCache, _should_embed_body_fields, get_body_field, + get_cached_callable_info, get_dependant, get_flat_dependant, get_parameterless_sub_dependant, - get_typed_return_annotation, + prepare_callable_info_cache, solve_dependencies, ) from fastapi.encoders import jsonable_encoder @@ -229,6 +231,7 @@ def get_request_handler( response_model_exclude_none: bool = False, dependency_overrides_provider: Optional[Any] = None, embed_body_fields: bool = False, + callable_info_cache: Optional[CallableInfoCache] = None, ) -> Callable[[Request], Coroutine[Any, Any, Response]]: assert dependant.call is not None, "dependant.call must be a function" is_coroutine = asyncio.iscoroutinefunction(dependant.call) @@ -238,6 +241,12 @@ def get_request_handler( else: actual_response_class = response_class + # callables are not changing between requests, so we can cache them here + callable_info_cache = prepare_callable_info_cache( + call=dependant.call, + existing_cache=callable_info_cache, + ) + async def app(request: Request) -> Response: response: Union[Response, None] = None async with AsyncExitStack() as file_stack: @@ -294,6 +303,7 @@ def get_request_handler( dependant=dependant, body=body, dependency_overrides_provider=dependency_overrides_provider, + callable_info_cache=callable_info_cache, async_exit_stack=async_exit_stack, embed_body_fields=embed_body_fields, ) @@ -362,7 +372,14 @@ def get_websocket_app( dependant: Dependant, dependency_overrides_provider: Optional[Any] = None, embed_body_fields: bool = False, + callable_info_cache: Optional[CallableInfoCache] = None, ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: + # callables are not changing between requests, so we can cache them here + callable_info_cache = prepare_callable_info_cache( + call=dependant.call, + existing_cache=callable_info_cache, + ) + async def app(websocket: WebSocket) -> None: async with AsyncExitStack() as async_exit_stack: # TODO: remove this scope later, after a few releases @@ -373,6 +390,7 @@ def get_websocket_app( request=websocket, dependant=dependant, dependency_overrides_provider=dependency_overrides_provider, + callable_info_cache=callable_info_cache, async_exit_stack=async_exit_stack, embed_body_fields=embed_body_fields, ) @@ -398,14 +416,23 @@ class APIWebSocketRoute(routing.WebSocketRoute): ) -> None: self.path = path self.endpoint = endpoint + self.callable_info_cache = prepare_callable_info_cache(call=endpoint) self.name = get_name(endpoint) if name is None else name self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) - self.dependant = get_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_dependant( + path=self.path_format, + call=self.endpoint, + callable_info_cache=self.callable_info_cache, + ) for depends in self.dependencies[::-1]: self.dependant.dependencies.insert( 0, - get_parameterless_sub_dependant(depends=depends, path=self.path_format), + get_parameterless_sub_dependant( + depends=depends, + path=self.path_format, + callable_info_cache=self.callable_info_cache, + ), ) self._flat_dependant = get_flat_dependant(self.dependant) self._embed_body_fields = _should_embed_body_fields( @@ -416,6 +443,7 @@ class APIWebSocketRoute(routing.WebSocketRoute): dependant=self.dependant, dependency_overrides_provider=dependency_overrides_provider, embed_body_fields=self._embed_body_fields, + callable_info_cache=self.callable_info_cache, ) ) @@ -463,8 +491,10 @@ class APIRoute(routing.Route): ) -> None: self.path = path self.endpoint = endpoint + self.callable_info_cache = prepare_callable_info_cache(call=endpoint) if isinstance(response_model, DefaultPlaceholder): - return_annotation = get_typed_return_annotation(endpoint) + callable_info = get_cached_callable_info(endpoint, self.callable_info_cache) + return_annotation = callable_info.typed_signature.return_annotation if lenient_issubclass(return_annotation, Response): response_model = None else: @@ -556,7 +586,11 @@ class APIRoute(routing.Route): for depends in self.dependencies[::-1]: self.dependant.dependencies.insert( 0, - get_parameterless_sub_dependant(depends=depends, path=self.path_format), + get_parameterless_sub_dependant( + depends=depends, + path=self.path_format, + callable_info_cache=self.callable_info_cache, + ), ) self._flat_dependant = get_flat_dependant(self.dependant) self._embed_body_fields = _should_embed_body_fields( @@ -584,6 +618,7 @@ class APIRoute(routing.Route): response_model_exclude_none=self.response_model_exclude_none, dependency_overrides_provider=self.dependency_overrides_provider, embed_body_fields=self._embed_body_fields, + callable_info_cache=self.callable_info_cache, ) def matches(self, scope: Scope) -> Tuple[Match, Scope]: