Browse Source

Pass None instead of the default value to parameters that accept it when null is given

Signed-off-by: merlinz01 <[email protected]>
pull/12135/head
merlinz01 7 months ago
parent
commit
ebc60c12a3
  1. 27
      fastapi/dependencies/utils.py
  2. 2
      fastapi/routing.py
  3. 47
      tests/test_none_passed_when_null_received.py

27
fastapi/dependencies/utils.py

@ -88,8 +88,6 @@ multipart_incorrect_install_error = (
"pip install python-multipart\n" "pip install python-multipart\n"
) )
_unset: Any = object()
def ensure_multipart_is_installed() -> None: def ensure_multipart_is_installed() -> None:
try: try:
@ -549,7 +547,7 @@ async def solve_dependencies(
*, *,
request: Union[Request, WebSocket], request: Union[Request, WebSocket],
dependant: Dependant, dependant: Dependant,
body: Optional[Union[Dict[str, Any], FormData]] = None, body: Optional[Union[Dict[str, Any], FormData]] = Undefined,
background_tasks: Optional[StarletteBackgroundTasks] = None, background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None, response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None, dependency_overrides_provider: Optional[Any] = None,
@ -671,7 +669,7 @@ async def solve_dependencies(
) )
def _accepts_none(field: ModelField) -> bool: def _allows_none(field: ModelField) -> bool:
origin = get_origin(field.type_) origin = get_origin(field.type_)
return (origin is Union or origin is types.UnionType) and type(None) in get_args( return (origin is Union or origin is types.UnionType) and type(None) in get_args(
field.type_ field.type_
@ -681,11 +679,16 @@ def _accepts_none(field: ModelField) -> bool:
def _validate_value_with_model_field( def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]: ) -> Tuple[Any, List[Any]]:
if value is None or value is _unset: if value is Undefined:
if field.required: if field.required:
return None, [get_missing_field_error(loc=loc)] return None, [get_missing_field_error(loc=loc)]
elif value is None and _accepts_none(field): else:
return deepcopy(field.default), []
if value is None:
if _allows_none(field):
return value, [] return value, []
if field.required:
return None, [get_missing_field_error(loc=loc)]
else: else:
return deepcopy(field.default), [] return deepcopy(field.default), []
v_, errors_ = field.validate(value, values, loc=loc) v_, errors_ = field.validate(value, values, loc=loc)
@ -702,9 +705,9 @@ def _get_multidict_value(field: ModelField, values: Mapping[str, Any]) -> Any:
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)): if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
value = values.getlist(field.alias) value = values.getlist(field.alias)
else: else:
value = values.get(field.alias, None) value = values.get(field.alias, Undefined)
if ( if (
value is None value is Undefined
or ( or (
isinstance(field.field_info, params.Form) isinstance(field.field_info, params.Form)
and isinstance(value, str) # For type checks and isinstance(value, str) # For type checks
@ -713,7 +716,7 @@ def _get_multidict_value(field: ModelField, values: Mapping[str, Any]) -> Any:
or (is_sequence_field(field) and len(value) == 0) or (is_sequence_field(field) and len(value) == 0)
): ):
if field.required: if field.required:
return return Undefined
else: else:
return deepcopy(field.default) return deepcopy(field.default)
return value return value
@ -799,7 +802,7 @@ async def _extract_form_body(
for sub_value in value: for sub_value in value:
tg.start_soon(process_fn, sub_value.read) tg.start_soon(process_fn, sub_value.read)
value = serialize_sequence_value(field=field, value=results) value = serialize_sequence_value(field=field, value=results)
if value is not None: if value is not Undefined and value is not None:
values[field.name] = value values[field.name] = value
return values return values
@ -832,10 +835,10 @@ async def request_body_to_args(
return {first_field.name: v_}, errors_ return {first_field.name: v_}, errors_
for field in body_fields: for field in body_fields:
loc = ("body", field.alias) loc = ("body", field.alias)
value: Any = _unset value: Any = Undefined
if body_to_process is not None: if body_to_process is not None:
try: try:
value = body_to_process.get(field.alias) value = body_to_process.get(field.alias, Undefined)
# If the received body is a list, not a dict # If the received body is a list, not a dict
except AttributeError: except AttributeError:
errors.append(get_missing_field_error(loc)) errors.append(get_missing_field_error(loc))

2
fastapi/routing.py

@ -241,7 +241,7 @@ def get_request_handler(
response: Union[Response, None] = None response: Union[Response, None] = None
async with AsyncExitStack() as file_stack: async with AsyncExitStack() as file_stack:
try: try:
body: Any = None body: Any = Undefined
if body_field: if body_field:
if is_body_form: if is_body_form:
body = await request.form() body = await request.form()

47
tests/test_none_passed_when_null_received.py

@ -2,64 +2,43 @@ from typing import Optional, Union, Annotated
from fastapi import FastAPI, Body from fastapi import FastAPI, Body
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
import pytest
app = FastAPI() app = FastAPI()
SENTINEL = 1234567890 DEFAULT = 1234567890
@app.post("/api1") @app.post("/api1")
def api1(integer_or_null: Annotated[int | None, Body(embed=True)] = SENTINEL) -> dict: def api1(integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT) -> dict:
return {"received": integer_or_null} return {"received": integer_or_null}
@app.post("/api2") @app.post("/api2")
def api2( def api2(integer_or_null: Annotated[Optional[int], Body(embed=True)] = DEFAULT) -> dict:
integer_or_null: Annotated[Optional[int], Body(embed=True)] = SENTINEL
) -> dict:
return {"received": integer_or_null} return {"received": integer_or_null}
@app.post("/api3") @app.post("/api3")
def api3( def api3(
integer_or_null: Annotated[Union[int, None], Body(embed=True)] = SENTINEL integer_or_null: Annotated[Union[int, None], Body(embed=True)] = DEFAULT
) -> dict: ) -> dict:
return {"received": integer_or_null} return {"received": integer_or_null}
client = TestClient(app) @app.post("/api4")
def api4(integer_or_null: Optional[int] = Body(embed=True, default=DEFAULT)) -> dict:
return {"received": integer_or_null}
def test_api1_integer():
response = client.post("/api1", json={"integer_or_null": 100})
assert response.status_code == 200, response.text
assert response.json() == {"received": 100}
def test_api1_null():
response = client.post("/api1", json={"integer_or_null": None})
assert response.status_code == 200, response.text
assert response.json() == {"received": None}
def test_api2_integer():
response = client.post("/api2", json={"integer_or_null": 100})
assert response.status_code == 200, response.text
assert response.json() == {"received": 100}
def test_api2_null(): client = TestClient(app)
response = client.post("/api2", json={"integer_or_null": None})
assert response.status_code == 200, response.text
assert response.json() == {"received": None}
def test_api3_integer(): @pytest.mark.parametrize("api", ["/api1", "/api2", "/api3", "/api4"])
response = client.post("/api3", json={"integer_or_null": 100}) def test_api1_integer(api):
response = client.post(api, json={"integer_or_null": 100})
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == {"received": 100} assert response.json() == {"received": 100}
response = client.post(api, json={"integer_or_null": None})
def test_api3_null():
response = client.post("/api3", json={"integer_or_null": None})
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == {"received": None} assert response.json() == {"received": None}

Loading…
Cancel
Save