Browse Source

♻️ Refactor and simplify internal data from `solve_dependencies()` using dataclasses (#12100)

pull/12103/head
Sebastián Ramírez 7 months ago
committed by GitHub
parent
commit
5b7fa3900e
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 45
      fastapi/dependencies/utils.py
  2. 33
      fastapi/routing.py

45
fastapi/dependencies/utils.py

@ -529,6 +529,15 @@ async def solve_generator(
return await stack.enter_async_context(cm)
@dataclass
class SolvedDependency:
values: Dict[str, Any]
errors: List[Any]
background_tasks: Optional[StarletteBackgroundTasks]
response: Response
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
async def solve_dependencies(
*,
request: Union[Request, WebSocket],
@ -539,13 +548,7 @@ async def solve_dependencies(
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
async_exit_stack: AsyncExitStack,
) -> Tuple[
Dict[str, Any],
List[Any],
Optional[StarletteBackgroundTasks],
Response,
Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
]:
) -> SolvedDependency:
values: Dict[str, Any] = {}
errors: List[Any] = []
if response is None:
@ -587,27 +590,21 @@ async def solve_dependencies(
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack,
)
(
sub_values,
sub_errors,
background_tasks,
_, # the subdependency returns the same response we have
sub_dependency_cache,
) = solved_result
dependency_cache.update(sub_dependency_cache)
if sub_errors:
errors.extend(sub_errors)
background_tasks = solved_result.background_tasks
dependency_cache.update(solved_result.dependency_cache)
if solved_result.errors:
errors.extend(solved_result.errors)
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call):
solved = await solve_generator(
call=call, stack=async_exit_stack, sub_values=sub_values
call=call, stack=async_exit_stack, sub_values=solved_result.values
)
elif is_coroutine_callable(call):
solved = await call(**sub_values)
solved = await call(**solved_result.values)
else:
solved = await run_in_threadpool(call, **sub_values)
solved = await run_in_threadpool(call, **solved_result.values)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:
@ -654,7 +651,13 @@ async def solve_dependencies(
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes
)
return values, errors, background_tasks, response, dependency_cache
return SolvedDependency(
values=values,
errors=errors,
background_tasks=background_tasks,
response=response,
dependency_cache=dependency_cache,
)
def request_params_to_args(

33
fastapi/routing.py

@ -292,26 +292,34 @@ def get_request_handler(
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
)
values, errors, background_tasks, sub_response, _ = solved_result
errors = solved_result.errors
if not errors:
raw_response = await run_endpoint_function(
dependant=dependant, values=values, is_coroutine=is_coroutine
dependant=dependant,
values=solved_result.values,
is_coroutine=is_coroutine,
)
if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = background_tasks
raw_response.background = solved_result.background_tasks
response = raw_response
else:
response_args: Dict[str, Any] = {"background": background_tasks}
response_args: Dict[str, Any] = {
"background": solved_result.background_tasks
}
# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
current_status_code = (
status_code if status_code else sub_response.status_code
status_code
if status_code
else solved_result.response.status_code
)
if current_status_code is not None:
response_args["status_code"] = current_status_code
if sub_response.status_code:
response_args["status_code"] = sub_response.status_code
if solved_result.response.status_code:
response_args["status_code"] = (
solved_result.response.status_code
)
content = await serialize_response(
field=response_field,
response_content=raw_response,
@ -326,7 +334,7 @@ def get_request_handler(
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
response.body = b""
response.headers.raw.extend(sub_response.headers.raw)
response.headers.raw.extend(solved_result.response.headers.raw)
if errors:
validation_error = RequestValidationError(
_normalize_errors(errors), body=body
@ -360,11 +368,12 @@ def get_websocket_app(
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
)
values, errors, _, _2, _3 = solved_result
if errors:
raise WebSocketRequestValidationError(_normalize_errors(errors))
if solved_result.errors:
raise WebSocketRequestValidationError(
_normalize_errors(solved_result.errors)
)
assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**values)
await dependant.call(**solved_result.values)
return app

Loading…
Cancel
Save