Browse Source

Merge c90d831672 into 1d434dec47

pull/12135/merge
Merlin 5 days ago
committed by GitHub
parent
commit
9afc3a88fd
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 52
      fastapi/dependencies/utils.py
  2. 2
      fastapi/routing.py
  3. 101
      tests/test_none_passed_when_null_received.py

52
fastapi/dependencies/utils.py

@ -1,4 +1,6 @@
import inspect
import sys
import types
from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy
from dataclasses import dataclass
@ -26,6 +28,7 @@ from fastapi._compat import (
ModelField,
RequiredParam,
Undefined,
UndefinedType,
_regenerate_error_with_loc,
copy_field_info,
create_body_model,
@ -573,7 +576,7 @@ async def solve_dependencies(
*,
request: Union[Request, WebSocket],
dependant: Dependant,
body: Optional[Union[Dict[str, Any], FormData]] = None,
body: Optional[Union[Dict[str, Any], FormData, UndefinedType]] = Undefined,
background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
@ -695,10 +698,36 @@ async def solve_dependencies(
)
if PYDANTIC_V2:
if sys.hexversion >= 0x30A0000:
def _allows_none(field: ModelField) -> bool:
origin = get_origin(field.type_)
return (origin is Union or origin is types.UnionType) and type(
None
) in get_args(field.type_)
else:
def _allows_none(field: ModelField) -> bool:
origin = get_origin(field.type_)
return origin is Union and type(None) in get_args(field.type_)
else:
def _allows_none(field: ModelField) -> bool:
return field.allow_none # type: ignore
def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]:
if value is Undefined:
if field.required:
return None, [get_missing_field_error(loc=loc)]
else:
return deepcopy(field.default), []
if value is None:
if _allows_none(field):
return value, []
if field.required:
return None, [get_missing_field_error(loc=loc)]
else:
@ -717,12 +746,13 @@ def _get_multidict_value(
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
) -> Any:
alias = alias or field.alias
value: Any
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
value = values.getlist(alias)
else:
value = values.get(alias, None)
value = values.get(alias, Undefined)
if (
value is None
value is Undefined
or (
isinstance(field.field_info, params.Form)
and isinstance(value, str) # For type checks
@ -731,7 +761,7 @@ def _get_multidict_value(
or (is_sequence_field(field) and len(value) == 0)
):
if field.required:
return
return Undefined
else:
return deepcopy(field.default)
return value
@ -779,7 +809,7 @@ def request_params_to_args(
else field.name.replace("_", "-")
)
value = _get_multidict_value(field, received_params, alias=alias)
if value is not None:
if value is not Undefined and value is not None:
params_to_process[field.name] = value
processed_keys.add(alias or field.alias)
processed_keys.add(field.name)
@ -873,7 +903,7 @@ async def _extract_form_body(
for sub_value in value:
tg.start_soon(process_fn, sub_value.read)
value = serialize_sequence_value(field=field, value=results)
if value is not None:
if value is not Undefined and value is not None:
values[field.alias] = value
for key, value in received_body.items():
if key not in values:
@ -883,7 +913,7 @@ async def _extract_form_body(
async def request_body_to_args(
body_fields: List[ModelField],
received_body: Optional[Union[Dict[str, Any], FormData]],
received_body: Optional[Union[Dict[str, Any], FormData, UndefinedType]],
embed_body_fields: bool,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
values: Dict[str, Any] = {}
@ -909,10 +939,12 @@ async def request_body_to_args(
return {first_field.name: v_}, errors_
for field in body_fields:
loc = ("body", field.alias)
value: Optional[Any] = None
if body_to_process is not None:
value: Any = Undefined
if body_to_process is not None and not isinstance(
body_to_process, UndefinedType
):
try:
value = body_to_process.get(field.alias)
value = body_to_process.get(field.alias, Undefined)
# If the received body is a list, not a dict
except AttributeError:
errors.append(get_missing_field_error(loc))

2
fastapi/routing.py

@ -241,7 +241,7 @@ def get_request_handler(
response: Union[Response, None] = None
async with AsyncExitStack() as file_stack:
try:
body: Any = None
body: Any = Undefined
if body_field:
if is_body_form:
body = await request.form()

101
tests/test_none_passed_when_null_received.py

@ -0,0 +1,101 @@
import sys
from typing import Optional, Union
import pytest
from dirty_equals import IsDict
from fastapi import Body, FastAPI
from fastapi.testclient import TestClient
app = FastAPI()
DEFAULT = 1234567890
endpoints = []
if sys.hexversion >= 0x30A0000:
from typing import Annotated
@app.post("/api1")
def api1(
integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT,
) -> dict:
return {"received": integer_or_null}
endpoints.append("/api1")
if sys.hexversion >= 0x3090000:
from typing import Annotated
@app.post("/api2")
def api2(
integer_or_null: Annotated[Optional[int], Body(embed=True)] = DEFAULT,
) -> dict:
return {"received": integer_or_null}
endpoints.append("/api2")
@app.post("/api3")
def api3(
integer_or_null: Annotated[Union[int, None], Body(embed=True)] = DEFAULT,
) -> dict:
return {"received": integer_or_null}
endpoints.append("/api3")
@app.post("/api4")
def api4(integer_or_null: Optional[int] = Body(embed=True, default=DEFAULT)) -> dict:
return {"received": integer_or_null}
endpoints.append("/api4")
@app.post("/api5")
def api5(integer: int = Body(embed=True)) -> dict:
return {"received": integer}
client = TestClient(app)
@pytest.mark.parametrize("api", endpoints)
def test_apis(api):
response = client.post(api, json={"integer_or_null": 100})
assert response.status_code == 200, response.text
assert response.json() == {"received": 100}
response = client.post(api, json={"integer_or_null": None})
assert response.status_code == 200, response.text
assert response.json() == {"received": None}
def test_required_field():
response = client.post("/api5", json={"integer": 100})
assert response.status_code == 200, response.text
assert response.json() == {"received": 100}
response = client.post("/api5", json={"integer": None})
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"loc": ["body", "integer"],
"msg": "Field required",
"type": "missing",
"input": None,
}
]
}
) | IsDict(
{
"detail": [
{
"loc": ["body", "integer"],
"msg": "field required",
"type": "value_error.missing",
}
]
}
)
Loading…
Cancel
Save