|
|
@ -33,6 +33,7 @@ from fastapi._compat import ( |
|
|
|
from fastapi.datastructures import Default, DefaultPlaceholder |
|
|
|
from fastapi.dependencies.models import Dependant |
|
|
|
from fastapi.dependencies.utils import ( |
|
|
|
SolvedDependency, |
|
|
|
_should_embed_body_fields, |
|
|
|
get_body_field, |
|
|
|
get_dependant, |
|
|
@ -240,52 +241,12 @@ def get_request_handler( |
|
|
|
async def app(request: Request) -> Response: |
|
|
|
response: Union[Response, None] = None |
|
|
|
async with AsyncExitStack() as file_stack: |
|
|
|
try: |
|
|
|
body: Any = None |
|
|
|
if body_field: |
|
|
|
if is_body_form: |
|
|
|
body = await request.form() |
|
|
|
file_stack.push_async_callback(body.close) |
|
|
|
else: |
|
|
|
body_bytes = await request.body() |
|
|
|
if body_bytes: |
|
|
|
json_body: Any = Undefined |
|
|
|
content_type_value = request.headers.get("content-type") |
|
|
|
if not content_type_value: |
|
|
|
json_body = await request.json() |
|
|
|
else: |
|
|
|
message = email.message.Message() |
|
|
|
message["content-type"] = content_type_value |
|
|
|
if message.get_content_maintype() == "application": |
|
|
|
subtype = message.get_content_subtype() |
|
|
|
if subtype == "json" or subtype.endswith("+json"): |
|
|
|
json_body = await request.json() |
|
|
|
if json_body != Undefined: |
|
|
|
body = json_body |
|
|
|
else: |
|
|
|
body = body_bytes |
|
|
|
except json.JSONDecodeError as e: |
|
|
|
validation_error = RequestValidationError( |
|
|
|
[ |
|
|
|
{ |
|
|
|
"type": "json_invalid", |
|
|
|
"loc": ("body", e.pos), |
|
|
|
"msg": "JSON decode error", |
|
|
|
"input": {}, |
|
|
|
"ctx": {"error": e.msg}, |
|
|
|
} |
|
|
|
], |
|
|
|
body=e.doc, |
|
|
|
) |
|
|
|
raise validation_error from e |
|
|
|
except HTTPException: |
|
|
|
# If a middleware raises an HTTPException, it should be raised again |
|
|
|
raise |
|
|
|
except Exception as e: |
|
|
|
http_error = HTTPException( |
|
|
|
status_code=400, detail="There was an error parsing the body" |
|
|
|
) |
|
|
|
raise http_error from e |
|
|
|
body = await parse_request_body( |
|
|
|
request, |
|
|
|
body_field=body_field, |
|
|
|
is_body_form=is_body_form, |
|
|
|
file_stack=file_stack, |
|
|
|
) |
|
|
|
errors: List[Any] = [] |
|
|
|
async with AsyncExitStack() as async_exit_stack: |
|
|
|
solved_result = await solve_dependencies( |
|
|
@ -303,42 +264,10 @@ def get_request_handler( |
|
|
|
values=solved_result.values, |
|
|
|
is_coroutine=is_coroutine, |
|
|
|
) |
|
|
|
if isinstance(raw_response, Response): |
|
|
|
if raw_response.background is None: |
|
|
|
raw_response.background = solved_result.background_tasks |
|
|
|
response = raw_response |
|
|
|
else: |
|
|
|
response_args: Dict[str, Any] = { |
|
|
|
"background": solved_result.background_tasks |
|
|
|
} |
|
|
|
# If status_code was set, use it, otherwise use the default from the |
|
|
|
# response class, in the case of redirect it's 307 |
|
|
|
current_status_code = ( |
|
|
|
status_code |
|
|
|
if status_code |
|
|
|
else solved_result.response.status_code |
|
|
|
) |
|
|
|
if current_status_code is not None: |
|
|
|
response_args["status_code"] = current_status_code |
|
|
|
if solved_result.response.status_code: |
|
|
|
response_args["status_code"] = ( |
|
|
|
solved_result.response.status_code |
|
|
|
) |
|
|
|
content = await serialize_response( |
|
|
|
field=response_field, |
|
|
|
response_content=raw_response, |
|
|
|
include=response_model_include, |
|
|
|
exclude=response_model_exclude, |
|
|
|
by_alias=response_model_by_alias, |
|
|
|
exclude_unset=response_model_exclude_unset, |
|
|
|
exclude_defaults=response_model_exclude_defaults, |
|
|
|
exclude_none=response_model_exclude_none, |
|
|
|
is_coroutine=is_coroutine, |
|
|
|
) |
|
|
|
response = actual_response_class(content, **response_args) |
|
|
|
if not is_body_allowed_for_status_code(response.status_code): |
|
|
|
response.body = b"" |
|
|
|
response.headers.raw.extend(solved_result.response.headers.raw) |
|
|
|
response = await prepare_response_object( |
|
|
|
raw_response, |
|
|
|
solved_result, |
|
|
|
) |
|
|
|
if errors: |
|
|
|
validation_error = RequestValidationError( |
|
|
|
_normalize_errors(errors), body=body |
|
|
@ -354,9 +283,103 @@ def get_request_handler( |
|
|
|
) |
|
|
|
return response |
|
|
|
|
|
|
|
async def prepare_response_object( |
|
|
|
raw_response: Any, |
|
|
|
solved_result: SolvedDependency, |
|
|
|
) -> Response: |
|
|
|
if isinstance(raw_response, Response): |
|
|
|
if raw_response.background is None: |
|
|
|
raw_response.background = solved_result.background_tasks |
|
|
|
response = raw_response |
|
|
|
else: |
|
|
|
response_args: Dict[str, Any] = { |
|
|
|
"background": solved_result.background_tasks |
|
|
|
} |
|
|
|
# If status_code was set, use it, otherwise use the default from the |
|
|
|
# response class, in the case of redirect it's 307 |
|
|
|
current_status_code = ( |
|
|
|
status_code if status_code else solved_result.response.status_code |
|
|
|
) |
|
|
|
if current_status_code is not None: |
|
|
|
response_args["status_code"] = current_status_code |
|
|
|
if solved_result.response.status_code: |
|
|
|
response_args["status_code"] = solved_result.response.status_code |
|
|
|
content = await serialize_response( |
|
|
|
field=response_field, |
|
|
|
response_content=raw_response, |
|
|
|
include=response_model_include, |
|
|
|
exclude=response_model_exclude, |
|
|
|
by_alias=response_model_by_alias, |
|
|
|
exclude_unset=response_model_exclude_unset, |
|
|
|
exclude_defaults=response_model_exclude_defaults, |
|
|
|
exclude_none=response_model_exclude_none, |
|
|
|
is_coroutine=is_coroutine, |
|
|
|
) |
|
|
|
response = actual_response_class(content, **response_args) |
|
|
|
if not is_body_allowed_for_status_code(response.status_code): |
|
|
|
response.body = b"" |
|
|
|
response.headers.raw.extend(solved_result.response.headers.raw) |
|
|
|
return response |
|
|
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
|
|
|
async def parse_request_body( |
|
|
|
request: Request, |
|
|
|
*, |
|
|
|
body_field: Optional[ModelField], |
|
|
|
is_body_form: bool | None, |
|
|
|
file_stack: AsyncExitStack, |
|
|
|
) -> Any: |
|
|
|
try: |
|
|
|
body: Any = None |
|
|
|
if body_field: |
|
|
|
if is_body_form: |
|
|
|
body = await request.form() |
|
|
|
file_stack.push_async_callback(body.close) |
|
|
|
else: |
|
|
|
body_bytes = await request.body() |
|
|
|
if body_bytes: |
|
|
|
json_body: Any = Undefined |
|
|
|
content_type_value = request.headers.get("content-type") |
|
|
|
if not content_type_value: |
|
|
|
json_body = await request.json() |
|
|
|
else: |
|
|
|
message = email.message.Message() |
|
|
|
message["content-type"] = content_type_value |
|
|
|
if message.get_content_maintype() == "application": |
|
|
|
subtype = message.get_content_subtype() |
|
|
|
if subtype == "json" or subtype.endswith("+json"): |
|
|
|
json_body = await request.json() |
|
|
|
if json_body != Undefined: |
|
|
|
body = json_body |
|
|
|
else: |
|
|
|
body = body_bytes |
|
|
|
except json.JSONDecodeError as e: |
|
|
|
validation_error = RequestValidationError( |
|
|
|
[ |
|
|
|
{ |
|
|
|
"type": "json_invalid", |
|
|
|
"loc": ("body", e.pos), |
|
|
|
"msg": "JSON decode error", |
|
|
|
"input": {}, |
|
|
|
"ctx": {"error": e.msg}, |
|
|
|
} |
|
|
|
], |
|
|
|
body=e.doc, |
|
|
|
) |
|
|
|
raise validation_error from e |
|
|
|
except HTTPException: |
|
|
|
# If a middleware raises an HTTPException, it should be raised again |
|
|
|
raise |
|
|
|
except Exception as e: |
|
|
|
http_error = HTTPException( |
|
|
|
status_code=400, detail="There was an error parsing the body" |
|
|
|
) |
|
|
|
raise http_error from e |
|
|
|
return body |
|
|
|
|
|
|
|
|
|
|
|
def get_websocket_app( |
|
|
|
dependant: Dependant, |
|
|
|
dependency_overrides_provider: Optional[Any] = None, |
|
|
|