From 5658b92b4c99f57325ba745c0b86e39dc27b6eab Mon Sep 17 00:00:00 2001 From: merlinz01 <na@notaccessible.xyz> Date: Thu, 5 Sep 2024 14:32:14 -0400 Subject: [PATCH] Pass None instead of the default value to parameters that accept it when null is given Signed-off-by: merlinz01 <na@notaccessible.xyz> --- fastapi/dependencies/utils.py | 16 ++++- tests/test_none_passed_when_null_received.py | 65 ++++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 tests/test_none_passed_when_null_received.py diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 98ce17b55..8bc0f6016 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -2,6 +2,7 @@ import inspect from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy from dataclasses import dataclass +import types from typing import ( Any, Callable, @@ -87,6 +88,8 @@ multipart_incorrect_install_error = ( "pip install python-multipart\n" ) +_unset: Any = object() + def ensure_multipart_is_installed() -> None: try: @@ -668,12 +671,21 @@ async def solve_dependencies( ) +def _accepts_none(field: ModelField) -> bool: + origin = get_origin(field.type_) + return (origin is Union or origin is types.UnionType) and type(None) in get_args( + field.type_ + ) + + def _validate_value_with_model_field( *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] ) -> Tuple[Any, List[Any]]: - if value is None: + if value is None or value is _unset: if field.required: return None, [get_missing_field_error(loc=loc)] + elif value is None and _accepts_none(field): + return value, [] else: return deepcopy(field.default), [] v_, errors_ = field.validate(value, values, loc=loc) @@ -820,7 +832,7 @@ async def request_body_to_args( return {first_field.name: v_}, errors_ for field in body_fields: loc = ("body", field.alias) - value: Optional[Any] = None + value: Any = _unset if body_to_process is not None: try: value = body_to_process.get(field.alias) diff --git a/tests/test_none_passed_when_null_received.py b/tests/test_none_passed_when_null_received.py new file mode 100644 index 000000000..b2c4e8796 --- /dev/null +++ b/tests/test_none_passed_when_null_received.py @@ -0,0 +1,65 @@ +from typing import Optional, Union, Annotated + +from fastapi import FastAPI, Body +from fastapi.testclient import TestClient + +app = FastAPI() +SENTINEL = 1234567890 + + +@app.post("/api1") +def api1(integer_or_null: Annotated[int | None, Body(embed=True)] = SENTINEL) -> dict: + return {"received": integer_or_null} + + +@app.post("/api2") +def api2( + integer_or_null: Annotated[Optional[int], Body(embed=True)] = SENTINEL +) -> dict: + return {"received": integer_or_null} + + +@app.post("/api3") +def api3( + integer_or_null: Annotated[Union[int, None], Body(embed=True)] = SENTINEL +) -> dict: + return {"received": integer_or_null} + + +client = TestClient(app) + + +def test_api1_integer(): + response = client.post("/api1", json={"integer_or_null": 100}) + assert response.status_code == 200, response.text + assert response.json() == {"received": 100} + + +def test_api1_null(): + response = client.post("/api1", json={"integer_or_null": None}) + assert response.status_code == 200, response.text + assert response.json() == {"received": None} + + +def test_api2_integer(): + response = client.post("/api2", json={"integer_or_null": 100}) + assert response.status_code == 200, response.text + assert response.json() == {"received": 100} + + +def test_api2_null(): + response = client.post("/api2", json={"integer_or_null": None}) + assert response.status_code == 200, response.text + assert response.json() == {"received": None} + + +def test_api3_integer(): + response = client.post("/api3", json={"integer_or_null": 100}) + assert response.status_code == 200, response.text + assert response.json() == {"received": 100} + + +def test_api3_null(): + response = client.post("/api3", json={"integer_or_null": None}) + assert response.status_code == 200, response.text + assert response.json() == {"received": None}