Browse Source

Merge ef4c2c8d27 into 6e69d62bfe

pull/13657/merge
Kinuax 2 days ago
committed by GitHub
parent
commit
a49e1e1e2d
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 21
      fastapi/_compat.py
  2. 30
      fastapi/applications.py
  3. 24
      fastapi/dependencies/utils.py
  4. 67
      fastapi/routing.py
  5. 10
      fastapi/utils.py
  6. 171
      tests/test_validation_error_fields.py

21
fastapi/_compat.py

@ -90,6 +90,8 @@ if PYDANTIC_V2:
field_info: FieldInfo
name: str
mode: Literal["validation", "serialization"] = "validation"
include_error_input: bool = True
include_error_url: bool = False
@property
def alias(self) -> str:
@ -132,7 +134,11 @@ if PYDANTIC_V2:
)
except ValidationError as exc:
return None, _regenerate_error_with_loc(
errors=exc.errors(include_url=False), loc_prefix=loc
errors=exc.errors(
include_input=self.include_error_input,
include_url=self.include_error_url,
),
loc_prefix=loc,
)
def serialize(
@ -272,11 +278,16 @@ if PYDANTIC_V2:
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
def get_missing_field_error(
loc: Tuple[str, ...],
include_error_input: bool = True,
include_error_url: bool = False,
) -> Dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors(include_url=False)[0]
error["input"] = None
).errors(include_input=include_error_input, include_url=include_error_url)[0]
if include_error_input:
error["input"] = None
return error # type: ignore[return-value]
def create_body_model(
@ -514,7 +525,7 @@ else:
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: # type: ignore[misc]
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
new_error = ValidationError([missing_field_error], RequestErrorModel)
return new_error.errors()[0] # type: ignore[return-value]

30
fastapi/applications.py

@ -752,6 +752,26 @@ class FastAPI(Starlette):
"""
),
] = True,
include_error_input: Annotated[
bool,
Doc(
"""
To include (or not) the field `input` in the validation error of all *path operations*.
This does not affect the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = True,
include_error_url: Annotated[
bool,
Doc(
"""
To include (or not) the field `url` in the validation error of all *path operations*.
This does not affect the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = False,
swagger_ui_parameters: Annotated[
Optional[Dict[str, Any]],
Doc(
@ -941,6 +961,8 @@ class FastAPI(Starlette):
callbacks=callbacks,
deprecated=deprecated,
include_in_schema=include_in_schema,
include_error_input=include_error_input,
include_error_url=include_error_url,
responses=responses,
generate_unique_id_function=generate_unique_id_function,
)
@ -1076,6 +1098,8 @@ class FastAPI(Starlette):
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
include_error_input: bool = True,
include_error_url: bool = False,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(
JSONResponse
),
@ -1106,6 +1130,8 @@ class FastAPI(Starlette):
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
include_error_input=include_error_input,
include_error_url=include_error_url,
response_class=response_class,
name=name,
openapi_extra=openapi_extra,
@ -1134,6 +1160,8 @@ class FastAPI(Starlette):
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
include_error_input: bool = True,
include_error_url: bool = False,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
@ -1163,6 +1191,8 @@ class FastAPI(Starlette):
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
include_error_input=include_error_input,
include_error_url=include_error_url,
response_class=response_class,
name=name,
openapi_extra=openapi_extra,

24
fastapi/dependencies/utils.py

@ -269,6 +269,8 @@ def get_dependant(
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
include_error_input: bool = True,
include_error_url: bool = False,
) -> Dependant:
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
@ -287,6 +289,8 @@ def get_dependant(
annotation=param.annotation,
value=param.default,
is_path_param=is_path_param,
include_error_input=include_error_input,
include_error_url=include_error_url,
)
if param_details.depends is not None:
sub_dependant = get_param_sub_dependant(
@ -351,6 +355,8 @@ def analyze_param(
annotation: Any,
value: Any,
is_path_param: bool,
include_error_input: bool,
include_error_url: bool,
) -> ParamDetails:
field_info = None
depends = None
@ -492,6 +498,8 @@ def analyze_param(
alias=alias,
required=field_info.default in (RequiredParam, Undefined),
field_info=field_info,
include_error_input=include_error_input,
include_error_url=include_error_url,
)
if is_path_param:
assert is_scalar_field(field=field), (
@ -700,7 +708,13 @@ def _validate_value_with_model_field(
) -> Tuple[Any, List[Any]]:
if value is None:
if field.required:
return None, [get_missing_field_error(loc=loc)]
if PYDANTIC_V2:
error = get_missing_field_error(
loc, field.include_error_input, field.include_error_url
)
else:
error = get_missing_field_error(loc)
return None, [error]
else:
return deepcopy(field.default), []
v_, errors_ = field.validate(value, values, loc=loc)
@ -936,7 +950,13 @@ async def request_body_to_args(
value = body_to_process.get(field.alias)
# If the received body is a list, not a dict
except AttributeError:
errors.append(get_missing_field_error(loc))
if PYDANTIC_V2:
error = get_missing_field_error(
loc, field.include_error_input, field.include_error_url
)
else:
error = get_missing_field_error(loc)
errors.append(error)
continue
v_, errors_ = _validate_value_with_model_field(
field=field, value=value, values=values, loc=loc

67
fastapi/routing.py

@ -451,6 +451,8 @@ class APIRoute(routing.Route):
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
include_error_input: bool = True,
include_error_url: bool = False,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(
JSONResponse
),
@ -481,6 +483,8 @@ class APIRoute(routing.Route):
self.response_model_exclude_defaults = response_model_exclude_defaults
self.response_model_exclude_none = response_model_exclude_none
self.include_in_schema = include_in_schema
self.include_error_input = include_error_input
self.include_error_url = include_error_url
self.response_class = response_class
self.dependency_overrides_provider = dependency_overrides_provider
self.callbacks = callbacks
@ -513,6 +517,8 @@ class APIRoute(routing.Route):
name=response_name,
type_=self.response_model,
mode="serialization",
include_error_input=self.include_error_input,
include_error_url=self.include_error_url,
)
# Create a clone of the field, so that a Pydantic submodel is not returned
# as is just because it's an instance of a subclass of a more limited class
@ -543,7 +549,11 @@ class APIRoute(routing.Route):
)
response_name = f"Response_{additional_status_code}_{self.unique_id}"
response_field = create_model_field(
name=response_name, type_=model, mode="serialization"
name=response_name,
type_=model,
mode="serialization",
include_error_input=self.include_error_input,
include_error_url=self.include_error_url,
)
response_fields[additional_status_code] = response_field
if response_fields:
@ -552,7 +562,12 @@ class APIRoute(routing.Route):
self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable"
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
self.dependant = get_dependant(
path=self.path_format,
call=self.endpoint,
include_error_input=self.include_error_input,
include_error_url=self.include_error_url,
)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
@ -819,6 +834,26 @@ class APIRouter(routing.Router):
"""
),
] = True,
include_error_input: Annotated[
bool,
Doc(
"""
To include (or not) the field `input` in the validation error of all *path operations*.
This does not affect the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = True,
include_error_url: Annotated[
bool,
Doc(
"""
To include (or not) the field `url` in the validation error of all *path operations*.
This does not affect the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = False,
generate_unique_id_function: Annotated[
Callable[[APIRoute], str],
Doc(
@ -853,6 +888,8 @@ class APIRouter(routing.Router):
self.dependencies = list(dependencies or [])
self.deprecated = deprecated
self.include_in_schema = include_in_schema
self.include_error_input = include_error_input
self.include_error_url = include_error_url
self.responses = responses or {}
self.callbacks = callbacks or []
self.dependency_overrides_provider = dependency_overrides_provider
@ -902,6 +939,8 @@ class APIRouter(routing.Router):
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
include_error_input: bool = True,
include_error_url: bool = False,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(
JSONResponse
),
@ -952,6 +991,8 @@ class APIRouter(routing.Router):
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema and self.include_in_schema,
include_error_input=include_error_input,
include_error_url=include_error_url,
response_class=current_response_class,
name=name,
dependency_overrides_provider=self.dependency_overrides_provider,
@ -983,6 +1024,8 @@ class APIRouter(routing.Router):
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
include_error_input: bool = True,
include_error_url: bool = False,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
@ -1013,6 +1056,8 @@ class APIRouter(routing.Router):
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
include_error_input=include_error_input,
include_error_url=include_error_url,
response_class=response_class,
name=name,
callbacks=callbacks,
@ -1323,6 +1368,8 @@ class APIRouter(routing.Router):
include_in_schema=route.include_in_schema
and self.include_in_schema
and include_in_schema,
include_error_input=route.include_error_input,
include_error_url=route.include_error_url,
response_class=use_response_class,
name=route.name,
route_class_override=type(route),
@ -1734,6 +1781,8 @@ class APIRouter(routing.Router):
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
include_error_input=self.include_error_input,
include_error_url=self.include_error_url,
response_class=response_class,
name=name,
callbacks=callbacks,
@ -2116,6 +2165,8 @@ class APIRouter(routing.Router):
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
include_error_input=self.include_error_input,
include_error_url=self.include_error_url,
response_class=response_class,
name=name,
callbacks=callbacks,
@ -2498,6 +2549,8 @@ class APIRouter(routing.Router):
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
include_error_input=self.include_error_input,
include_error_url=self.include_error_url,
response_class=response_class,
name=name,
callbacks=callbacks,
@ -2875,6 +2928,8 @@ class APIRouter(routing.Router):
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
include_error_input=self.include_error_input,
include_error_url=self.include_error_url,
response_class=response_class,
name=name,
callbacks=callbacks,
@ -3252,6 +3307,8 @@ class APIRouter(routing.Router):
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
include_error_input=self.include_error_input,
include_error_url=self.include_error_url,
response_class=response_class,
name=name,
callbacks=callbacks,
@ -3634,6 +3691,8 @@ class APIRouter(routing.Router):
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
include_error_input=self.include_error_input,
include_error_url=self.include_error_url,
response_class=response_class,
name=name,
callbacks=callbacks,
@ -4016,6 +4075,8 @@ class APIRouter(routing.Router):
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
include_error_input=self.include_error_input,
include_error_url=self.include_error_url,
response_class=response_class,
name=name,
callbacks=callbacks,
@ -4398,6 +4459,8 @@ class APIRouter(routing.Router):
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
include_error_input=self.include_error_input,
include_error_url=self.include_error_url,
response_class=response_class,
name=name,
callbacks=callbacks,

10
fastapi/utils.py

@ -70,6 +70,8 @@ def create_model_field(
field_info: Optional[FieldInfo] = None,
alias: Optional[str] = None,
mode: Literal["validation", "serialization"] = "validation",
include_error_input: bool = True,
include_error_url: bool = False,
) -> ModelField:
class_validators = class_validators or {}
if PYDANTIC_V2:
@ -80,7 +82,13 @@ def create_model_field(
field_info = field_info or FieldInfo()
kwargs = {"name": name, "field_info": field_info}
if PYDANTIC_V2:
kwargs.update({"mode": mode})
kwargs.update(
{
"mode": mode,
"include_error_input": include_error_input,
"include_error_url": include_error_url,
}
)
else:
kwargs.update(
{

171
tests/test_validation_error_fields.py

@ -0,0 +1,171 @@
import pytest
from fastapi import APIRouter, FastAPI
from fastapi.testclient import TestClient
from pydantic import BaseModel
from .utils import needs_pydanticv1, needs_pydanticv2
@needs_pydanticv2
@pytest.mark.parametrize(
"include_error_input,include_error_url",
[(False, False), (False, True), (True, False), (True, True)],
)
def test_input_and_url_fields_with_pydanticv2(include_error_input, include_error_url):
app = FastAPI(
include_error_input=include_error_input, include_error_url=include_error_url
)
@app.get("/get1/{path_param}")
def get1(path_param: int): ...
@app.get("/get2/")
def get2(query_param: int): ...
class Body1(BaseModel): ...
class Body2(BaseModel): ...
@app.post("/post1/")
def post1(body1: Body1, body2: Body2): ...
router = APIRouter(
include_error_input=include_error_input, include_error_url=include_error_url
)
@router.get("/get3/{path_param}")
def get3(path_param: int): ...
@router.get("/get4/")
def get4(query_param: int): ...
@router.post("/post2/")
def post2(body1: Body1, body2: Body2): ...
app.include_router(router)
client = TestClient(app)
with client:
invalid = "not-an-integer"
for path in ["get1", "get3"]:
response = client.get(f"/{path}/{invalid}")
assert response.status_code == 422, response.text
error = response.json()["detail"][0]
if include_error_input:
assert error["input"] == invalid
else:
assert "input" not in error
if include_error_url:
assert "url" in error
else:
assert "url" not in error
for path in ["get2", "get4"]:
response = client.get(f"/{path}/")
assert response.status_code == 422, response.text
error = response.json()["detail"][0]
if include_error_input:
assert error["type"] == "missing"
assert error["input"] is None
else:
assert "input" not in error
if include_error_url:
assert "url" in error
else:
assert "url" not in error
response = client.get(f"/{path}/?query_param={invalid}")
assert response.status_code == 422, response.text
error = response.json()["detail"][0]
if include_error_input:
assert error["input"] == invalid
else:
assert "input" not in error
if include_error_url:
assert "url" in error
else:
assert "url" not in error
for path in ["post1", "post2"]:
response = client.post(f"/{path}/", json=["not-a-dict"])
assert response.status_code == 422
error = response.json()["detail"][0]
if include_error_input:
assert error["type"] == "missing"
assert error["input"] is None
else:
assert "input" not in error
if include_error_url:
assert "url" in error
else:
assert "url" not in error
# TODO: remove when deprecating Pydantic v1
@needs_pydanticv1
@pytest.mark.parametrize(
"include_error_input,include_error_url",
[(False, False), (False, True), (True, False), (True, True)],
)
def test_input_and_url_fields_with_pydanticv1(include_error_input, include_error_url):
app = FastAPI(
include_error_input=include_error_input, include_error_url=include_error_url
)
@app.get("/get1/{path_param}")
def get1(path_param: int): ...
@app.get("/get2/")
def get2(query_param: int): ...
class Body1(BaseModel): ...
class Body2(BaseModel): ...
@app.post("/post1/")
def post1(body1: Body1, body2: Body2): ...
router = APIRouter(
include_error_input=include_error_input, include_error_url=include_error_url
)
@router.get("/get3/{path_param}")
def get3(path_param: int): ...
@router.get("/get4/")
def get4(query_param: int): ...
@router.post("/post2/")
def post2(body1: Body1, body2: Body2): ...
app.include_router(router)
client = TestClient(app)
with client:
invalid = "not-an-integer"
for path in ["get1", "get3"]:
response = client.get(f"/{path}/{invalid}")
assert response.status_code == 422, response.text
error = response.json()["detail"][0]
assert "input" not in error
assert "url" not in error
for path in ["get2", "get4"]:
response = client.get(f"/{path}/")
assert response.status_code == 422, response.text
error = response.json()["detail"][0]
assert "input" not in error
assert "url" not in error
response = client.get(f"/{path}/?query_param={invalid}")
assert response.status_code == 422, response.text
error = response.json()["detail"][0]
assert "input" not in error
assert "url" not in error
for path in ["post1", "post2"]:
response = client.post(f"/{path}/", json=["not-a-dict"])
assert response.status_code == 422
error = response.json()["detail"][0]
assert "input" not in error
assert "url" not in error
Loading…
Cancel
Save