Browse Source

Merge 967734e6b7 into 6e69d62bfe

pull/12135/merge
Merlin 3 days ago
committed by GitHub
parent
commit
bb0a9baa98
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 52
      fastapi/dependencies/utils.py
  2. 2
      fastapi/routing.py
  3. 101
      tests/test_none_passed_when_null_received.py

52
fastapi/dependencies/utils.py

@ -1,4 +1,6 @@
import inspect import inspect
import sys
import types
from contextlib import AsyncExitStack, contextmanager from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass from dataclasses import dataclass
@ -26,6 +28,7 @@ from fastapi._compat import (
ModelField, ModelField,
RequiredParam, RequiredParam,
Undefined, Undefined,
UndefinedType,
_regenerate_error_with_loc, _regenerate_error_with_loc,
copy_field_info, copy_field_info,
create_body_model, create_body_model,
@ -573,7 +576,7 @@ async def solve_dependencies(
*, *,
request: Union[Request, WebSocket], request: Union[Request, WebSocket],
dependant: Dependant, dependant: Dependant,
body: Optional[Union[Dict[str, Any], FormData]] = None, body: Optional[Union[Dict[str, Any], FormData, UndefinedType]] = Undefined,
background_tasks: Optional[StarletteBackgroundTasks] = None, background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None, response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None, dependency_overrides_provider: Optional[Any] = None,
@ -695,10 +698,36 @@ async def solve_dependencies(
) )
if PYDANTIC_V2:
if sys.hexversion >= 0x30A0000:
def _allows_none(field: ModelField) -> bool:
origin = get_origin(field.type_)
return (origin is Union or origin is types.UnionType) and type(
None
) in get_args(field.type_)
else:
def _allows_none(field: ModelField) -> bool:
origin = get_origin(field.type_)
return origin is Union and type(None) in get_args(field.type_)
else:
def _allows_none(field: ModelField) -> bool:
return field.allow_none # type: ignore
def _validate_value_with_model_field( def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]: ) -> Tuple[Any, List[Any]]:
if value is Undefined:
if field.required:
return None, [get_missing_field_error(loc=loc)]
else:
return deepcopy(field.default), []
if value is None: if value is None:
if _allows_none(field):
return value, []
if field.required: if field.required:
return None, [get_missing_field_error(loc=loc)] return None, [get_missing_field_error(loc=loc)]
else: else:
@ -717,12 +746,13 @@ def _get_multidict_value(
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
) -> Any: ) -> Any:
alias = alias or field.alias alias = alias or field.alias
value: Any
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)): if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
value = values.getlist(alias) value = values.getlist(alias)
else: else:
value = values.get(alias, None) value = values.get(alias, Undefined)
if ( if (
value is None value is Undefined
or ( or (
isinstance(field.field_info, params.Form) isinstance(field.field_info, params.Form)
and isinstance(value, str) # For type checks and isinstance(value, str) # For type checks
@ -731,7 +761,7 @@ def _get_multidict_value(
or (is_sequence_field(field) and len(value) == 0) or (is_sequence_field(field) and len(value) == 0)
): ):
if field.required: if field.required:
return return Undefined
else: else:
return deepcopy(field.default) return deepcopy(field.default)
return value return value
@ -779,7 +809,7 @@ def request_params_to_args(
else field.name.replace("_", "-") else field.name.replace("_", "-")
) )
value = _get_multidict_value(field, received_params, alias=alias) value = _get_multidict_value(field, received_params, alias=alias)
if value is not None: if value is not Undefined and value is not None:
params_to_process[field.name] = value params_to_process[field.name] = value
processed_keys.add(alias or field.alias) processed_keys.add(alias or field.alias)
processed_keys.add(field.name) processed_keys.add(field.name)
@ -894,7 +924,7 @@ async def _extract_form_body(
for sub_value in value: for sub_value in value:
tg.start_soon(process_fn, sub_value.read) tg.start_soon(process_fn, sub_value.read)
value = serialize_sequence_value(field=field, value=results) value = serialize_sequence_value(field=field, value=results)
if value is not None: if value is not Undefined and value is not None:
values[field.alias] = value values[field.alias] = value
for key, value in received_body.items(): for key, value in received_body.items():
if key not in values: if key not in values:
@ -904,7 +934,7 @@ async def _extract_form_body(
async def request_body_to_args( async def request_body_to_args(
body_fields: List[ModelField], body_fields: List[ModelField],
received_body: Optional[Union[Dict[str, Any], FormData]], received_body: Optional[Union[Dict[str, Any], FormData, UndefinedType]],
embed_body_fields: bool, embed_body_fields: bool,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
@ -930,10 +960,12 @@ async def request_body_to_args(
return {first_field.name: v_}, errors_ return {first_field.name: v_}, errors_
for field in body_fields: for field in body_fields:
loc = ("body", field.alias) loc = ("body", field.alias)
value: Optional[Any] = None value: Any = Undefined
if body_to_process is not None: if body_to_process is not None and not isinstance(
body_to_process, UndefinedType
):
try: try:
value = body_to_process.get(field.alias) value = body_to_process.get(field.alias, Undefined)
# If the received body is a list, not a dict # If the received body is a list, not a dict
except AttributeError: except AttributeError:
errors.append(get_missing_field_error(loc)) errors.append(get_missing_field_error(loc))

2
fastapi/routing.py

@ -242,7 +242,7 @@ def get_request_handler(
response: Union[Response, None] = None response: Union[Response, None] = None
async with AsyncExitStack() as file_stack: async with AsyncExitStack() as file_stack:
try: try:
body: Any = None body: Any = Undefined
if body_field: if body_field:
if is_body_form: if is_body_form:
body = await request.form() body = await request.form()

101
tests/test_none_passed_when_null_received.py

@ -0,0 +1,101 @@
import sys
from typing import Optional, Union
import pytest
from dirty_equals import IsDict
from fastapi import Body, FastAPI
from fastapi.testclient import TestClient
app = FastAPI()
DEFAULT = 1234567890
endpoints = []
if sys.hexversion >= 0x30A0000:
from typing import Annotated
@app.post("/api1")
def api1(
integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT,
) -> dict:
return {"received": integer_or_null}
endpoints.append("/api1")
if sys.hexversion >= 0x3090000:
from typing import Annotated
@app.post("/api2")
def api2(
integer_or_null: Annotated[Optional[int], Body(embed=True)] = DEFAULT,
) -> dict:
return {"received": integer_or_null}
endpoints.append("/api2")
@app.post("/api3")
def api3(
integer_or_null: Annotated[Union[int, None], Body(embed=True)] = DEFAULT,
) -> dict:
return {"received": integer_or_null}
endpoints.append("/api3")
@app.post("/api4")
def api4(integer_or_null: Optional[int] = Body(embed=True, default=DEFAULT)) -> dict:
return {"received": integer_or_null}
endpoints.append("/api4")
@app.post("/api5")
def api5(integer: int = Body(embed=True)) -> dict:
return {"received": integer}
client = TestClient(app)
@pytest.mark.parametrize("api", endpoints)
def test_apis(api):
response = client.post(api, json={"integer_or_null": 100})
assert response.status_code == 200, response.text
assert response.json() == {"received": 100}
response = client.post(api, json={"integer_or_null": None})
assert response.status_code == 200, response.text
assert response.json() == {"received": None}
def test_required_field():
response = client.post("/api5", json={"integer": 100})
assert response.status_code == 200, response.text
assert response.json() == {"received": 100}
response = client.post("/api5", json={"integer": None})
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"loc": ["body", "integer"],
"msg": "Field required",
"type": "missing",
"input": None,
}
]
}
) | IsDict(
{
"detail": [
{
"loc": ["body", "integer"],
"msg": "field required",
"type": "value_error.missing",
}
]
}
)
Loading…
Cancel
Save