From 5658b92b4c99f57325ba745c0b86e39dc27b6eab Mon Sep 17 00:00:00 2001
From: merlinz01 <na@notaccessible.xyz>
Date: Thu, 5 Sep 2024 14:32:14 -0400
Subject: [PATCH] Pass None instead of the default value to parameters that
 accept it when null is given

Signed-off-by: merlinz01 <na@notaccessible.xyz>
---
 fastapi/dependencies/utils.py                | 16 ++++-
 tests/test_none_passed_when_null_received.py | 65 ++++++++++++++++++++
 2 files changed, 79 insertions(+), 2 deletions(-)
 create mode 100644 tests/test_none_passed_when_null_received.py

diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py
index 98ce17b55..8bc0f6016 100644
--- a/fastapi/dependencies/utils.py
+++ b/fastapi/dependencies/utils.py
@@ -2,6 +2,7 @@ import inspect
 from contextlib import AsyncExitStack, contextmanager
 from copy import copy, deepcopy
 from dataclasses import dataclass
+import types
 from typing import (
     Any,
     Callable,
@@ -87,6 +88,8 @@ multipart_incorrect_install_error = (
     "pip install python-multipart\n"
 )
 
+_unset: Any = object()
+
 
 def ensure_multipart_is_installed() -> None:
     try:
@@ -668,12 +671,21 @@ async def solve_dependencies(
     )
 
 
+def _accepts_none(field: ModelField) -> bool:
+    origin = get_origin(field.type_)
+    return (origin is Union or origin is types.UnionType) and type(None) in get_args(
+        field.type_
+    )
+
+
 def _validate_value_with_model_field(
     *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
 ) -> Tuple[Any, List[Any]]:
-    if value is None:
+    if value is None or value is _unset:
         if field.required:
             return None, [get_missing_field_error(loc=loc)]
+        elif value is None and _accepts_none(field):
+            return value, []
         else:
             return deepcopy(field.default), []
     v_, errors_ = field.validate(value, values, loc=loc)
@@ -820,7 +832,7 @@ async def request_body_to_args(
         return {first_field.name: v_}, errors_
     for field in body_fields:
         loc = ("body", field.alias)
-        value: Optional[Any] = None
+        value: Any = _unset
         if body_to_process is not None:
             try:
                 value = body_to_process.get(field.alias)
diff --git a/tests/test_none_passed_when_null_received.py b/tests/test_none_passed_when_null_received.py
new file mode 100644
index 000000000..b2c4e8796
--- /dev/null
+++ b/tests/test_none_passed_when_null_received.py
@@ -0,0 +1,65 @@
+from typing import Optional, Union, Annotated
+
+from fastapi import FastAPI, Body
+from fastapi.testclient import TestClient
+
+app = FastAPI()
+SENTINEL = 1234567890
+
+
+@app.post("/api1")
+def api1(integer_or_null: Annotated[int | None, Body(embed=True)] = SENTINEL) -> dict:
+    return {"received": integer_or_null}
+
+
+@app.post("/api2")
+def api2(
+    integer_or_null: Annotated[Optional[int], Body(embed=True)] = SENTINEL
+) -> dict:
+    return {"received": integer_or_null}
+
+
+@app.post("/api3")
+def api3(
+    integer_or_null: Annotated[Union[int, None], Body(embed=True)] = SENTINEL
+) -> dict:
+    return {"received": integer_or_null}
+
+
+client = TestClient(app)
+
+
+def test_api1_integer():
+    response = client.post("/api1", json={"integer_or_null": 100})
+    assert response.status_code == 200, response.text
+    assert response.json() == {"received": 100}
+
+
+def test_api1_null():
+    response = client.post("/api1", json={"integer_or_null": None})
+    assert response.status_code == 200, response.text
+    assert response.json() == {"received": None}
+
+
+def test_api2_integer():
+    response = client.post("/api2", json={"integer_or_null": 100})
+    assert response.status_code == 200, response.text
+    assert response.json() == {"received": 100}
+
+
+def test_api2_null():
+    response = client.post("/api2", json={"integer_or_null": None})
+    assert response.status_code == 200, response.text
+    assert response.json() == {"received": None}
+
+
+def test_api3_integer():
+    response = client.post("/api3", json={"integer_or_null": 100})
+    assert response.status_code == 200, response.text
+    assert response.json() == {"received": 100}
+
+
+def test_api3_null():
+    response = client.post("/api3", json={"integer_or_null": None})
+    assert response.status_code == 200, response.text
+    assert response.json() == {"received": None}