Browse Source

feat: enhance JSON parse errors with line/column info and snippets

column positions and error snippets to JSON decode errors for
better debugging experience. Updates error location format and provides
context around the problematic JSON.
pull/14089/head
Arif Dogan 10 months ago
parent
commit
6b907a57f8
No known key found for this signature in database GPG Key ID: A2131177E83422CE
  1. 80
      fastapi/routing.py
  2. 103
      tests/test_json_error_improvements.py
  3. 14
      tests/test_tutorial/test_body/test_tutorial001.py

80
fastapi/routing.py

@ -85,7 +85,8 @@ def _prepare_response_content(
exclude_none: bool = False, exclude_none: bool = False,
) -> Any: ) -> Any:
if isinstance(res, BaseModel): if isinstance(res, BaseModel):
read_with_orm_mode = getattr(_get_model_config(res), "read_with_orm_mode", None) read_with_orm_mode = getattr(
_get_model_config(res), "read_with_orm_mode", None)
if read_with_orm_mode: if read_with_orm_mode:
# Let from_orm extract the data from this model instead of converting # Let from_orm extract the data from this model instead of converting
# it now to a dict. # it now to a dict.
@ -164,7 +165,8 @@ async def serialize_response(
exclude_none=exclude_none, exclude_none=exclude_none,
) )
if is_coroutine: if is_coroutine:
value, errors_ = field.validate(response_content, {}, loc=("response",)) value, errors_ = field.validate(
response_content, {}, loc=("response",))
else: else:
value, errors_ = await run_in_threadpool( value, errors_ = await run_in_threadpool(
field.validate, response_content, {}, loc=("response",) field.validate, response_content, {}, loc=("response",)
@ -219,7 +221,8 @@ def get_request_handler(
dependant: Dependant, dependant: Dependant,
body_field: Optional[ModelField] = None, body_field: Optional[ModelField] = None,
status_code: Optional[int] = None, status_code: Optional[int] = None,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse), response_class: Union[Type[Response],
DefaultPlaceholder] = Default(JSONResponse),
response_field: Optional[ModelField] = None, response_field: Optional[ModelField] = None,
response_model_include: Optional[IncEx] = None, response_model_include: Optional[IncEx] = None,
response_model_exclude: Optional[IncEx] = None, response_model_exclude: Optional[IncEx] = None,
@ -232,7 +235,8 @@ def get_request_handler(
) -> Callable[[Request], Coroutine[Any, Any, Response]]: ) -> Callable[[Request], Coroutine[Any, Any, Response]]:
assert dependant.call is not None, "dependant.call must be a function" assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = asyncio.iscoroutinefunction(dependant.call) is_coroutine = asyncio.iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(body_field.field_info, params.Form) is_body_form = body_field and isinstance(
body_field.field_info, params.Form)
if isinstance(response_class, DefaultPlaceholder): if isinstance(response_class, DefaultPlaceholder):
actual_response_class: Type[Response] = response_class.value actual_response_class: Type[Response] = response_class.value
else: else:
@ -251,7 +255,8 @@ def get_request_handler(
body_bytes = await request.body() body_bytes = await request.body()
if body_bytes: if body_bytes:
json_body: Any = Undefined json_body: Any = Undefined
content_type_value = request.headers.get("content-type") content_type_value = request.headers.get(
"content-type")
if not content_type_value: if not content_type_value:
json_body = await request.json() json_body = await request.json()
else: else:
@ -266,14 +271,32 @@ def get_request_handler(
else: else:
body = body_bytes body = body_bytes
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
lines_before = e.doc[: e.pos].split("\n")
line_number = len(lines_before)
column_number = len(
lines_before[-1]) + 1 if lines_before else 1
start_pos = max(0, e.pos - 40)
end_pos = min(len(e.doc), e.pos + 40)
error_snippet = e.doc[start_pos:end_pos]
if start_pos > 0:
error_snippet = "..." + error_snippet
if end_pos < len(e.doc):
error_snippet = error_snippet + "..."
validation_error = RequestValidationError( validation_error = RequestValidationError(
[ [
{ {
"type": "json_invalid", "type": "json_invalid",
"loc": ("body", e.pos), "loc": ("body", line_number, column_number),
"msg": "JSON decode error", "msg": f"JSON decode error - {e.msg} at line {line_number}, column {column_number}",
"input": {}, "input": error_snippet,
"ctx": {"error": e.msg}, "ctx": {
"error": e.msg,
"position": e.pos,
"line": line_number,
"column": column_number,
},
} }
], ],
body=e.doc, body=e.doc,
@ -336,10 +359,12 @@ def get_request_handler(
exclude_none=response_model_exclude_none, exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine, is_coroutine=is_coroutine,
) )
response = actual_response_class(content, **response_args) response = actual_response_class(
content, **response_args)
if not is_body_allowed_for_status_code(response.status_code): if not is_body_allowed_for_status_code(response.status_code):
response.body = b"" response.body = b""
response.headers.raw.extend(solved_result.response.headers.raw) response.headers.raw.extend(
solved_result.response.headers.raw)
if errors: if errors:
validation_error = RequestValidationError( validation_error = RequestValidationError(
_normalize_errors(errors), body=body _normalize_errors(errors), body=body
@ -400,12 +425,15 @@ class APIWebSocketRoute(routing.WebSocketRoute):
self.endpoint = endpoint self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name self.name = get_name(endpoint) if name is None else name
self.dependencies = list(dependencies or []) self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.path_regex, self.path_format, self.param_convertors = compile_path(
self.dependant = get_dependant(path=self.path_format, call=self.endpoint) path)
self.dependant = get_dependant(
path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]: for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert( self.dependant.dependencies.insert(
0, 0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format), get_parameterless_sub_dependant(
depends=depends, path=self.path_format),
) )
self._flat_dependant = get_flat_dependant(self.dependant) self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields( self._embed_body_fields = _should_embed_body_fields(
@ -489,7 +517,8 @@ class APIRoute(routing.Route):
self.tags = tags or [] self.tags = tags or []
self.responses = responses or {} self.responses = responses or {}
self.name = get_name(endpoint) if name is None else name self.name = get_name(endpoint) if name is None else name
self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.path_regex, self.path_format, self.param_convertors = compile_path(
path)
if methods is None: if methods is None:
methods = ["GET"] methods = ["GET"]
self.methods: Set[str] = {method.upper() for method in methods} self.methods: Set[str] = {method.upper() for method in methods}
@ -529,13 +558,15 @@ class APIRoute(routing.Route):
self.response_field = None # type: ignore self.response_field = None # type: ignore
self.secure_cloned_response_field = None self.secure_cloned_response_field = None
self.dependencies = list(dependencies or []) self.dependencies = list(dependencies or [])
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") self.description = description or inspect.cleandoc(
self.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text, # if a "form feed" character (page break) is found in the description text,
# truncate description text to the content preceding the first "form feed" # truncate description text to the content preceding the first "form feed"
self.description = self.description.split("\f")[0].strip() self.description = self.description.split("\f")[0].strip()
response_fields = {} response_fields = {}
for additional_status_code, response in self.responses.items(): for additional_status_code, response in self.responses.items():
assert isinstance(response, dict), "An additional response must be a dict" assert isinstance(
response, dict), "An additional response must be a dict"
model = response.get("model") model = response.get("model")
if model: if model:
assert is_body_allowed_for_status_code(additional_status_code), ( assert is_body_allowed_for_status_code(additional_status_code), (
@ -547,16 +578,19 @@ class APIRoute(routing.Route):
) )
response_fields[additional_status_code] = response_field response_fields[additional_status_code] = response_field
if response_fields: if response_fields:
self.response_fields: Dict[Union[int, str], ModelField] = response_fields self.response_fields: Dict[Union[int,
str], ModelField] = response_fields
else: else:
self.response_fields = {} self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable" assert callable(endpoint), "An endpoint must be a callable"
self.dependant = get_dependant(path=self.path_format, call=self.endpoint) self.dependant = get_dependant(
path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]: for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert( self.dependant.dependencies.insert(
0, 0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format), get_parameterless_sub_dependant(
depends=depends, path=self.path_format),
) )
self._flat_dependant = get_flat_dependant(self.dependant) self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields( self._embed_body_fields = _should_embed_body_fields(
@ -623,7 +657,8 @@ class APIRouter(routing.Router):
def __init__( def __init__(
self, self,
*, *,
prefix: Annotated[str, Doc("An optional path prefix for the router.")] = "", prefix: Annotated[str, Doc(
"An optional path prefix for the router.")] = "",
tags: Annotated[ tags: Annotated[
Optional[List[Union[str, Enum]]], Optional[List[Union[str, Enum]]],
Doc( Doc(
@ -1124,7 +1159,8 @@ class APIRouter(routing.Router):
self, self,
router: Annotated["APIRouter", Doc("The `APIRouter` to include.")], router: Annotated["APIRouter", Doc("The `APIRouter` to include.")],
*, *,
prefix: Annotated[str, Doc("An optional path prefix for the router.")] = "", prefix: Annotated[str, Doc(
"An optional path prefix for the router.")] = "",
tags: Annotated[ tags: Annotated[
Optional[List[Union[str, Enum]]], Optional[List[Union[str, Enum]]],
Doc( Doc(

103
tests/test_json_error_improvements.py

@ -0,0 +1,103 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import BaseModel
app = FastAPI()
class Item(BaseModel):
name: str
price: float
description: str = None
@app.post("/items/")
async def create_item(item: Item):
return item
client = TestClient(app)
def test_json_decode_error_single_line():
response = client.post(
"/items/",
content='{"name": "Test", "price": None}',
headers={"Content-Type": "application/json"},
)
assert response.status_code == 422
error = response.json()["detail"][0]
assert error["loc"] == ["body", 1, 27]
assert "line 1" in error["msg"]
assert "column 27" in error["msg"]
assert error["ctx"]["line"] == 1
assert error["ctx"]["column"] == 27
assert "None" in error["input"]
def test_json_decode_error_multiline():
invalid_json = """
{
"name": "Test",
"price": 'invalid'
}"""
response = client.post(
"/items/", content=invalid_json, headers={"Content-Type": "application/json"}
)
assert response.status_code == 422
error = response.json()["detail"][0]
assert error["loc"] == ["body", 4, 12]
assert "line 4" in error["msg"]
assert "column 12" in error["msg"]
assert error["ctx"]["line"] == 4
assert error["ctx"]["column"] == 12
assert "invalid" in error["input"]
def test_json_decode_error_shows_snippet():
long_json = '{"very_long_field_name_here": "some value", "another_field": invalid}'
response = client.post(
"/items/", content=long_json, headers={"Content-Type": "application/json"}
)
assert response.status_code == 422
error = response.json()["detail"][0]
assert "..." in error["input"]
assert "invalid" in error["input"]
assert len(error["input"]) <= 83
def test_json_decode_error_empty_body():
response = client.post(
"/items/", content="", headers={"Content-Type": "application/json"}
)
assert response.status_code == 422
error = response.json()["detail"][0]
# Empty body is handled differently, not as a JSON decode error
assert error["loc"] == ["body"]
assert error["type"] == "missing"
def test_json_decode_error_unclosed_brace():
response = client.post(
"/items/",
content='{"name": "Test"',
headers={"Content-Type": "application/json"},
)
assert response.status_code == 422
error = response.json()["detail"][0]
assert "line" in error["msg"].lower()
assert "column" in error["msg"].lower()
assert error["type"] == "json_invalid"
assert "position" in error["ctx"]

14
tests/test_tutorial/test_body/test_tutorial001.py

@ -60,7 +60,8 @@ def test_post_with_str_float_description(client: TestClient):
def test_post_with_str_float_description_tax(client: TestClient): def test_post_with_str_float_description_tax(client: TestClient):
response = client.post( response = client.post(
"/items/", "/items/",
json={"name": "Foo", "price": "50.5", "description": "Some Foo", "tax": 0.3}, json={"name": "Foo", "price": "50.5",
"description": "Some Foo", "tax": 0.3},
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == { assert response.json() == {
@ -206,11 +207,14 @@ def test_post_broken_body(client: TestClient):
"detail": [ "detail": [
{ {
"type": "json_invalid", "type": "json_invalid",
"loc": ["body", 1], "loc": ["body", 1, 2],
"msg": "JSON decode error", "msg": "JSON decode error - Expecting property name enclosed in double quotes at line 1, column 2",
"input": {}, "input": "{some broken json}",
"ctx": { "ctx": {
"error": "Expecting property name enclosed in double quotes" "error": "Expecting property name enclosed in double quotes",
"position": 1,
"line": 1,
"column": 2,
}, },
} }
] ]

Loading…
Cancel
Save