diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 84dfa4d03..ae0bd1409 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,4 +1,6 @@ import inspect +import sys +import types from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy from dataclasses import dataclass @@ -26,6 +28,7 @@ from fastapi._compat import ( ModelField, RequiredParam, Undefined, + UndefinedType, _regenerate_error_with_loc, copy_field_info, create_body_model, @@ -573,7 +576,7 @@ async def solve_dependencies( *, request: Union[Request, WebSocket], dependant: Dependant, - body: Optional[Union[Dict[str, Any], FormData]] = None, + body: Optional[Union[Dict[str, Any], FormData, UndefinedType]] = Undefined, background_tasks: Optional[StarletteBackgroundTasks] = None, response: Optional[Response] = None, dependency_overrides_provider: Optional[Any] = None, @@ -695,10 +698,36 @@ async def solve_dependencies( ) +if PYDANTIC_V2: + if sys.hexversion >= 0x30A0000: + + def _allows_none(field: ModelField) -> bool: + origin = get_origin(field.type_) + return (origin is Union or origin is types.UnionType) and type( + None + ) in get_args(field.type_) + else: + + def _allows_none(field: ModelField) -> bool: + origin = get_origin(field.type_) + return origin is Union and type(None) in get_args(field.type_) +else: + + def _allows_none(field: ModelField) -> bool: + return field.allow_none # type: ignore + + def _validate_value_with_model_field( *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] ) -> Tuple[Any, List[Any]]: + if value is Undefined: + if field.required: + return None, [get_missing_field_error(loc=loc)] + else: + return deepcopy(field.default), [] if value is None: + if _allows_none(field): + return value, [] if field.required: return None, [get_missing_field_error(loc=loc)] else: @@ -717,12 +746,13 @@ def _get_multidict_value( field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None ) -> Any: alias = alias or field.alias + value: Any if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)): value = values.getlist(alias) else: - value = values.get(alias, None) + value = values.get(alias, Undefined) if ( - value is None + value is Undefined or ( isinstance(field.field_info, params.Form) and isinstance(value, str) # For type checks @@ -731,7 +761,7 @@ def _get_multidict_value( or (is_sequence_field(field) and len(value) == 0) ): if field.required: - return + return Undefined else: return deepcopy(field.default) return value @@ -779,7 +809,7 @@ def request_params_to_args( else field.name.replace("_", "-") ) value = _get_multidict_value(field, received_params, alias=alias) - if value is not None: + if value is not Undefined and value is not None: params_to_process[field.name] = value processed_keys.add(alias or field.alias) processed_keys.add(field.name) @@ -873,7 +903,7 @@ async def _extract_form_body( for sub_value in value: tg.start_soon(process_fn, sub_value.read) value = serialize_sequence_value(field=field, value=results) - if value is not None: + if value is not Undefined and value is not None: values[field.alias] = value for key, value in received_body.items(): if key not in values: @@ -883,7 +913,7 @@ async def _extract_form_body( async def request_body_to_args( body_fields: List[ModelField], - received_body: Optional[Union[Dict[str, Any], FormData]], + received_body: Optional[Union[Dict[str, Any], FormData, UndefinedType]], embed_body_fields: bool, ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: values: Dict[str, Any] = {} @@ -909,10 +939,12 @@ async def request_body_to_args( return {first_field.name: v_}, errors_ for field in body_fields: loc = ("body", field.alias) - value: Optional[Any] = None - if body_to_process is not None: + value: Any = Undefined + if body_to_process is not None and not isinstance( + body_to_process, UndefinedType + ): try: - value = body_to_process.get(field.alias) + value = body_to_process.get(field.alias, Undefined) # If the received body is a list, not a dict except AttributeError: errors.append(get_missing_field_error(loc)) diff --git a/fastapi/routing.py b/fastapi/routing.py index 457481e32..80d22cc3b 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -241,7 +241,7 @@ def get_request_handler( response: Union[Response, None] = None async with AsyncExitStack() as file_stack: try: - body: Any = None + body: Any = Undefined if body_field: if is_body_form: body = await request.form() diff --git a/tests/test_none_passed_when_null_received.py b/tests/test_none_passed_when_null_received.py new file mode 100644 index 000000000..3da094c6a --- /dev/null +++ b/tests/test_none_passed_when_null_received.py @@ -0,0 +1,101 @@ +import sys +from typing import Optional, Union + +import pytest +from dirty_equals import IsDict +from fastapi import Body, FastAPI +from fastapi.testclient import TestClient + +app = FastAPI() +DEFAULT = 1234567890 + +endpoints = [] + +if sys.hexversion >= 0x30A0000: + from typing import Annotated + + @app.post("/api1") + def api1( + integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT, + ) -> dict: + return {"received": integer_or_null} + + endpoints.append("/api1") + + +if sys.hexversion >= 0x3090000: + from typing import Annotated + + @app.post("/api2") + def api2( + integer_or_null: Annotated[Optional[int], Body(embed=True)] = DEFAULT, + ) -> dict: + return {"received": integer_or_null} + + endpoints.append("/api2") + + @app.post("/api3") + def api3( + integer_or_null: Annotated[Union[int, None], Body(embed=True)] = DEFAULT, + ) -> dict: + return {"received": integer_or_null} + + endpoints.append("/api3") + + +@app.post("/api4") +def api4(integer_or_null: Optional[int] = Body(embed=True, default=DEFAULT)) -> dict: + return {"received": integer_or_null} + + +endpoints.append("/api4") + + +@app.post("/api5") +def api5(integer: int = Body(embed=True)) -> dict: + return {"received": integer} + + +client = TestClient(app) + + +@pytest.mark.parametrize("api", endpoints) +def test_apis(api): + response = client.post(api, json={"integer_or_null": 100}) + assert response.status_code == 200, response.text + assert response.json() == {"received": 100} + + response = client.post(api, json={"integer_or_null": None}) + assert response.status_code == 200, response.text + assert response.json() == {"received": None} + + +def test_required_field(): + response = client.post("/api5", json={"integer": 100}) + assert response.status_code == 200, response.text + assert response.json() == {"received": 100} + + response = client.post("/api5", json={"integer": None}) + assert response.status_code == 422, response.text + assert response.json() == IsDict( + { + "detail": [ + { + "loc": ["body", "integer"], + "msg": "Field required", + "type": "missing", + "input": None, + } + ] + } + ) | IsDict( + { + "detail": [ + { + "loc": ["body", "integer"], + "msg": "field required", + "type": "value_error.missing", + } + ] + } + )