Browse Source

Pass None instead of the default value to parameters that accept it when null is given

Signed-off-by: merlinz01 <[email protected]>
pull/12135/head
merlinz01 9 months ago
parent
commit
5658b92b4c
  1. 16
      fastapi/dependencies/utils.py
  2. 65
      tests/test_none_passed_when_null_received.py

16
fastapi/dependencies/utils.py

@ -2,6 +2,7 @@ import inspect
from contextlib import AsyncExitStack, contextmanager from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass from dataclasses import dataclass
import types
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -87,6 +88,8 @@ multipart_incorrect_install_error = (
"pip install python-multipart\n" "pip install python-multipart\n"
) )
_unset: Any = object()
def ensure_multipart_is_installed() -> None: def ensure_multipart_is_installed() -> None:
try: try:
@ -668,12 +671,21 @@ async def solve_dependencies(
) )
def _accepts_none(field: ModelField) -> bool:
origin = get_origin(field.type_)
return (origin is Union or origin is types.UnionType) and type(None) in get_args(
field.type_
)
def _validate_value_with_model_field( def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]: ) -> Tuple[Any, List[Any]]:
if value is None: if value is None or value is _unset:
if field.required: if field.required:
return None, [get_missing_field_error(loc=loc)] return None, [get_missing_field_error(loc=loc)]
elif value is None and _accepts_none(field):
return value, []
else: else:
return deepcopy(field.default), [] return deepcopy(field.default), []
v_, errors_ = field.validate(value, values, loc=loc) v_, errors_ = field.validate(value, values, loc=loc)
@ -820,7 +832,7 @@ async def request_body_to_args(
return {first_field.name: v_}, errors_ return {first_field.name: v_}, errors_
for field in body_fields: for field in body_fields:
loc = ("body", field.alias) loc = ("body", field.alias)
value: Optional[Any] = None value: Any = _unset
if body_to_process is not None: if body_to_process is not None:
try: try:
value = body_to_process.get(field.alias) value = body_to_process.get(field.alias)

65
tests/test_none_passed_when_null_received.py

@ -0,0 +1,65 @@
from typing import Optional, Union, Annotated
from fastapi import FastAPI, Body
from fastapi.testclient import TestClient
app = FastAPI()
SENTINEL = 1234567890
@app.post("/api1")
def api1(integer_or_null: Annotated[int | None, Body(embed=True)] = SENTINEL) -> dict:
return {"received": integer_or_null}
@app.post("/api2")
def api2(
integer_or_null: Annotated[Optional[int], Body(embed=True)] = SENTINEL
) -> dict:
return {"received": integer_or_null}
@app.post("/api3")
def api3(
integer_or_null: Annotated[Union[int, None], Body(embed=True)] = SENTINEL
) -> dict:
return {"received": integer_or_null}
client = TestClient(app)
def test_api1_integer():
response = client.post("/api1", json={"integer_or_null": 100})
assert response.status_code == 200, response.text
assert response.json() == {"received": 100}
def test_api1_null():
response = client.post("/api1", json={"integer_or_null": None})
assert response.status_code == 200, response.text
assert response.json() == {"received": None}
def test_api2_integer():
response = client.post("/api2", json={"integer_or_null": 100})
assert response.status_code == 200, response.text
assert response.json() == {"received": 100}
def test_api2_null():
response = client.post("/api2", json={"integer_or_null": None})
assert response.status_code == 200, response.text
assert response.json() == {"received": None}
def test_api3_integer():
response = client.post("/api3", json={"integer_or_null": 100})
assert response.status_code == 200, response.text
assert response.json() == {"received": 100}
def test_api3_null():
response = client.post("/api3", json={"integer_or_null": None})
assert response.status_code == 200, response.text
assert response.json() == {"received": None}
Loading…
Cancel
Save