diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 5ebdddaf6..ed03df88b 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -529,6 +529,15 @@ async def solve_generator( return await stack.enter_async_context(cm) +@dataclass +class SolvedDependency: + values: Dict[str, Any] + errors: List[Any] + background_tasks: Optional[StarletteBackgroundTasks] + response: Response + dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] + + async def solve_dependencies( *, request: Union[Request, WebSocket], @@ -539,13 +548,7 @@ async def solve_dependencies( dependency_overrides_provider: Optional[Any] = None, dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, async_exit_stack: AsyncExitStack, -) -> Tuple[ - Dict[str, Any], - List[Any], - Optional[StarletteBackgroundTasks], - Response, - Dict[Tuple[Callable[..., Any], Tuple[str]], Any], -]: +) -> SolvedDependency: values: Dict[str, Any] = {} errors: List[Any] = [] if response is None: @@ -587,27 +590,21 @@ async def solve_dependencies( dependency_cache=dependency_cache, async_exit_stack=async_exit_stack, ) - ( - sub_values, - sub_errors, - background_tasks, - _, # the subdependency returns the same response we have - sub_dependency_cache, - ) = solved_result - dependency_cache.update(sub_dependency_cache) - if sub_errors: - errors.extend(sub_errors) + background_tasks = solved_result.background_tasks + dependency_cache.update(solved_result.dependency_cache) + if solved_result.errors: + errors.extend(solved_result.errors) continue if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: solved = dependency_cache[sub_dependant.cache_key] elif is_gen_callable(call) or is_async_gen_callable(call): solved = await solve_generator( - call=call, stack=async_exit_stack, sub_values=sub_values + call=call, stack=async_exit_stack, sub_values=solved_result.values ) elif is_coroutine_callable(call): - solved = await call(**sub_values) + solved = await call(**solved_result.values) else: - solved = await run_in_threadpool(call, **sub_values) + solved = await run_in_threadpool(call, **solved_result.values) if sub_dependant.name is not None: values[sub_dependant.name] = solved if sub_dependant.cache_key not in dependency_cache: @@ -654,7 +651,13 @@ async def solve_dependencies( values[dependant.security_scopes_param_name] = SecurityScopes( scopes=dependant.security_scopes ) - return values, errors, background_tasks, response, dependency_cache + return SolvedDependency( + values=values, + errors=errors, + background_tasks=background_tasks, + response=response, + dependency_cache=dependency_cache, + ) def request_params_to_args( diff --git a/fastapi/routing.py b/fastapi/routing.py index 49f1b6013..c46772017 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -292,26 +292,34 @@ def get_request_handler( dependency_overrides_provider=dependency_overrides_provider, async_exit_stack=async_exit_stack, ) - values, errors, background_tasks, sub_response, _ = solved_result + errors = solved_result.errors if not errors: raw_response = await run_endpoint_function( - dependant=dependant, values=values, is_coroutine=is_coroutine + dependant=dependant, + values=solved_result.values, + is_coroutine=is_coroutine, ) if isinstance(raw_response, Response): if raw_response.background is None: - raw_response.background = background_tasks + raw_response.background = solved_result.background_tasks response = raw_response else: - response_args: Dict[str, Any] = {"background": background_tasks} + response_args: Dict[str, Any] = { + "background": solved_result.background_tasks + } # If status_code was set, use it, otherwise use the default from the # response class, in the case of redirect it's 307 current_status_code = ( - status_code if status_code else sub_response.status_code + status_code + if status_code + else solved_result.response.status_code ) if current_status_code is not None: response_args["status_code"] = current_status_code - if sub_response.status_code: - response_args["status_code"] = sub_response.status_code + if solved_result.response.status_code: + response_args["status_code"] = ( + solved_result.response.status_code + ) content = await serialize_response( field=response_field, response_content=raw_response, @@ -326,7 +334,7 @@ def get_request_handler( response = actual_response_class(content, **response_args) if not is_body_allowed_for_status_code(response.status_code): response.body = b"" - response.headers.raw.extend(sub_response.headers.raw) + response.headers.raw.extend(solved_result.response.headers.raw) if errors: validation_error = RequestValidationError( _normalize_errors(errors), body=body @@ -360,11 +368,12 @@ def get_websocket_app( dependency_overrides_provider=dependency_overrides_provider, async_exit_stack=async_exit_stack, ) - values, errors, _, _2, _3 = solved_result - if errors: - raise WebSocketRequestValidationError(_normalize_errors(errors)) + if solved_result.errors: + raise WebSocketRequestValidationError( + _normalize_errors(solved_result.errors) + ) assert dependant.call is not None, "dependant.call must be a function" - await dependant.call(**values) + await dependant.call(**solved_result.values) return app