Browse Source

Make request handler extensible

pull/12241/head
Stanislav Zmiev 7 months ago
parent
commit
dc97105518
Failed to extract signature
  1. 187
      fastapi/routing.py

187
fastapi/routing.py

@ -33,6 +33,7 @@ from fastapi._compat import (
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import (
SolvedDependency,
_should_embed_body_fields,
get_body_field,
get_dependant,
@ -240,52 +241,12 @@ def get_request_handler(
async def app(request: Request) -> Response:
response: Union[Response, None] = None
async with AsyncExitStack() as file_stack:
try:
body: Any = None
if body_field:
if is_body_form:
body = await request.form()
file_stack.push_async_callback(body.close)
else:
body_bytes = await request.body()
if body_bytes:
json_body: Any = Undefined
content_type_value = request.headers.get("content-type")
if not content_type_value:
json_body = await request.json()
else:
message = email.message.Message()
message["content-type"] = content_type_value
if message.get_content_maintype() == "application":
subtype = message.get_content_subtype()
if subtype == "json" or subtype.endswith("+json"):
json_body = await request.json()
if json_body != Undefined:
body = json_body
else:
body = body_bytes
except json.JSONDecodeError as e:
validation_error = RequestValidationError(
[
{
"type": "json_invalid",
"loc": ("body", e.pos),
"msg": "JSON decode error",
"input": {},
"ctx": {"error": e.msg},
}
],
body=e.doc,
)
raise validation_error from e
except HTTPException:
# If a middleware raises an HTTPException, it should be raised again
raise
except Exception as e:
http_error = HTTPException(
status_code=400, detail="There was an error parsing the body"
)
raise http_error from e
body = await parse_request_body(
request,
body_field=body_field,
is_body_form=is_body_form,
file_stack=file_stack,
)
errors: List[Any] = []
async with AsyncExitStack() as async_exit_stack:
solved_result = await solve_dependencies(
@ -303,42 +264,10 @@ def get_request_handler(
values=solved_result.values,
is_coroutine=is_coroutine,
)
if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = solved_result.background_tasks
response = raw_response
else:
response_args: Dict[str, Any] = {
"background": solved_result.background_tasks
}
# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
current_status_code = (
status_code
if status_code
else solved_result.response.status_code
)
if current_status_code is not None:
response_args["status_code"] = current_status_code
if solved_result.response.status_code:
response_args["status_code"] = (
solved_result.response.status_code
)
content = await serialize_response(
field=response_field,
response_content=raw_response,
include=response_model_include,
exclude=response_model_exclude,
by_alias=response_model_by_alias,
exclude_unset=response_model_exclude_unset,
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine,
)
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
response.body = b""
response.headers.raw.extend(solved_result.response.headers.raw)
response = await prepare_response_object(
raw_response,
solved_result,
)
if errors:
validation_error = RequestValidationError(
_normalize_errors(errors), body=body
@ -354,9 +283,103 @@ def get_request_handler(
)
return response
async def prepare_response_object(
raw_response: Any,
solved_result: SolvedDependency,
) -> Response:
if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = solved_result.background_tasks
response = raw_response
else:
response_args: Dict[str, Any] = {
"background": solved_result.background_tasks
}
# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
current_status_code = (
status_code if status_code else solved_result.response.status_code
)
if current_status_code is not None:
response_args["status_code"] = current_status_code
if solved_result.response.status_code:
response_args["status_code"] = solved_result.response.status_code
content = await serialize_response(
field=response_field,
response_content=raw_response,
include=response_model_include,
exclude=response_model_exclude,
by_alias=response_model_by_alias,
exclude_unset=response_model_exclude_unset,
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine,
)
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
response.body = b""
response.headers.raw.extend(solved_result.response.headers.raw)
return response
return app
async def parse_request_body(
request: Request,
*,
body_field: Optional[ModelField],
is_body_form: bool | None,
file_stack: AsyncExitStack,
) -> Any:
try:
body: Any = None
if body_field:
if is_body_form:
body = await request.form()
file_stack.push_async_callback(body.close)
else:
body_bytes = await request.body()
if body_bytes:
json_body: Any = Undefined
content_type_value = request.headers.get("content-type")
if not content_type_value:
json_body = await request.json()
else:
message = email.message.Message()
message["content-type"] = content_type_value
if message.get_content_maintype() == "application":
subtype = message.get_content_subtype()
if subtype == "json" or subtype.endswith("+json"):
json_body = await request.json()
if json_body != Undefined:
body = json_body
else:
body = body_bytes
except json.JSONDecodeError as e:
validation_error = RequestValidationError(
[
{
"type": "json_invalid",
"loc": ("body", e.pos),
"msg": "JSON decode error",
"input": {},
"ctx": {"error": e.msg},
}
],
body=e.doc,
)
raise validation_error from e
except HTTPException:
# If a middleware raises an HTTPException, it should be raised again
raise
except Exception as e:
http_error = HTTPException(
status_code=400, detail="There was an error parsing the body"
)
raise http_error from e
return body
def get_websocket_app(
dependant: Dependant,
dependency_overrides_provider: Optional[Any] = None,

Loading…
Cancel
Save