|
|
@ -29,13 +29,13 @@ from fastapi.dependencies.utils import ( |
|
|
|
) |
|
|
|
from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder |
|
|
|
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError |
|
|
|
from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY |
|
|
|
from fastapi.types import DecoratedCallable |
|
|
|
from fastapi.utils import ( |
|
|
|
create_cloned_field, |
|
|
|
create_response_field, |
|
|
|
generate_unique_id, |
|
|
|
get_value_or_default, |
|
|
|
is_body_allowed_for_status_code, |
|
|
|
) |
|
|
|
from pydantic import BaseModel |
|
|
|
from pydantic.error_wrappers import ErrorWrapper, ValidationError |
|
|
@ -232,7 +232,17 @@ def get_request_handler( |
|
|
|
if raw_response.background is None: |
|
|
|
raw_response.background = background_tasks |
|
|
|
return raw_response |
|
|
|
response_data = await serialize_response( |
|
|
|
response_args: Dict[str, Any] = {"background": background_tasks} |
|
|
|
# If status_code was set, use it, otherwise use the default from the |
|
|
|
# response class, in the case of redirect it's 307 |
|
|
|
current_status_code = ( |
|
|
|
status_code if status_code else sub_response.status_code |
|
|
|
) |
|
|
|
if current_status_code is not None: |
|
|
|
response_args["status_code"] = current_status_code |
|
|
|
if sub_response.status_code: |
|
|
|
response_args["status_code"] = sub_response.status_code |
|
|
|
content = await serialize_response( |
|
|
|
field=response_field, |
|
|
|
response_content=raw_response, |
|
|
|
include=response_model_include, |
|
|
@ -243,15 +253,10 @@ def get_request_handler( |
|
|
|
exclude_none=response_model_exclude_none, |
|
|
|
is_coroutine=is_coroutine, |
|
|
|
) |
|
|
|
response_args: Dict[str, Any] = {"background": background_tasks} |
|
|
|
# If status_code was set, use it, otherwise use the default from the |
|
|
|
# response class, in the case of redirect it's 307 |
|
|
|
if status_code is not None: |
|
|
|
response_args["status_code"] = status_code |
|
|
|
response = actual_response_class(response_data, **response_args) |
|
|
|
response = actual_response_class(content, **response_args) |
|
|
|
if not is_body_allowed_for_status_code(status_code): |
|
|
|
response.body = b"" |
|
|
|
response.headers.raw.extend(sub_response.headers.raw) |
|
|
|
if sub_response.status_code: |
|
|
|
response.status_code = sub_response.status_code |
|
|
|
return response |
|
|
|
|
|
|
|
return app |
|
|
@ -377,8 +382,8 @@ class APIRoute(routing.Route): |
|
|
|
status_code = int(status_code) |
|
|
|
self.status_code = status_code |
|
|
|
if self.response_model: |
|
|
|
assert ( |
|
|
|
status_code not in STATUS_CODES_WITH_NO_BODY |
|
|
|
assert is_body_allowed_for_status_code( |
|
|
|
status_code |
|
|
|
), f"Status code {status_code} must not have a response body" |
|
|
|
response_name = "Response_" + self.unique_id |
|
|
|
self.response_field = create_response_field( |
|
|
@ -410,8 +415,8 @@ class APIRoute(routing.Route): |
|
|
|
assert isinstance(response, dict), "An additional response must be a dict" |
|
|
|
model = response.get("model") |
|
|
|
if model: |
|
|
|
assert ( |
|
|
|
additional_status_code not in STATUS_CODES_WITH_NO_BODY |
|
|
|
assert is_body_allowed_for_status_code( |
|
|
|
additional_status_code |
|
|
|
), f"Status code {additional_status_code} must not have a response body" |
|
|
|
response_name = f"Response_{additional_status_code}_{self.unique_id}" |
|
|
|
response_field = create_response_field(name=response_name, type_=model) |
|
|
|