Browse Source

Merge b5032f39a0 into 76b324d95b

pull/12502/merge
maxclaey 2 days ago
committed by GitHub
parent
commit
24ccec04f3
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 47
      fastapi/dependencies/utils.py
  2. 47
      tests/test_forms_nullable_param.py
  3. 4
      tests/test_forms_single_model.py
  4. 2
      tests/test_tutorial/test_cookie_param_models/test_tutorial001.py
  5. 2
      tests/test_tutorial/test_cookie_param_models/test_tutorial002.py
  6. 2
      tests/test_tutorial/test_header_param_models/test_tutorial001.py
  7. 7
      tests/test_tutorial/test_header_param_models/test_tutorial002.py

47
fastapi/dependencies/utils.py

@ -87,6 +87,9 @@ multipart_incorrect_install_error = (
"pip install python-multipart\n" "pip install python-multipart\n"
) )
# Sentinel value for unspecified fields
_NOT_SPECIFIED = object()
def ensure_multipart_is_installed() -> None: def ensure_multipart_is_installed() -> None:
try: try:
@ -698,7 +701,7 @@ async def solve_dependencies(
def _validate_value_with_model_field( def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]: ) -> Tuple[Any, List[Any]]:
if value is None: if value is _NOT_SPECIFIED:
if field.required: if field.required:
return None, [get_missing_field_error(loc=loc)] return None, [get_missing_field_error(loc=loc)]
else: else:
@ -720,20 +723,24 @@ def _get_multidict_value(
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)): if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
value = values.getlist(alias) value = values.getlist(alias)
else: else:
value = values.get(alias, None) value = values.get(alias, _NOT_SPECIFIED)
if ( if (
value is None isinstance(field.field_info, params.Form)
or ( and isinstance(value, str) # For type checks
isinstance(field.field_info, params.Form) and value == ""
and isinstance(value, str) # For type checks
and value == ""
)
or (is_sequence_field(field) and len(value) == 0)
): ):
# Empty strings in a form can be a representation of None values
_, error = field.validate(None, {}, loc=())
# If None is an accepted value for this field, use that
if error is None:
value = None
if value == _NOT_SPECIFIED or (is_sequence_field(field) and len(value) == 0):
if field.required: if field.required:
return return _NOT_SPECIFIED
else: else:
return deepcopy(field.default) return deepcopy(field.default)
return value return value
@ -779,7 +786,7 @@ def request_params_to_args(
else field.name.replace("_", "-") else field.name.replace("_", "-")
) )
value = _get_multidict_value(field, received_params, alias=alias) value = _get_multidict_value(field, received_params, alias=alias)
if value is not None: if value != _NOT_SPECIFIED:
params_to_process[field.name] = value params_to_process[field.name] = value
processed_keys.add(alias or field.alias) processed_keys.add(alias or field.alias)
processed_keys.add(field.name) processed_keys.add(field.name)
@ -846,6 +853,8 @@ async def _extract_form_body(
first_field = body_fields[0] first_field = body_fields[0]
first_field_info = first_field.field_info first_field_info = first_field.field_info
processed_keys = set()
for field in body_fields: for field in body_fields:
value = _get_multidict_value(field, received_body) value = _get_multidict_value(field, received_body)
if ( if (
@ -873,10 +882,13 @@ async def _extract_form_body(
for sub_value in value: for sub_value in value:
tg.start_soon(process_fn, sub_value.read) tg.start_soon(process_fn, sub_value.read)
value = serialize_sequence_value(field=field, value=results) value = serialize_sequence_value(field=field, value=results)
if value is not None: if value != _NOT_SPECIFIED:
values[field.alias] = value values[field.alias] = value
processed_keys.add(field.alias)
processed_keys.add(field.name)
for key, value in received_body.items(): for key, value in received_body.items():
if key not in values: if key not in processed_keys:
values[key] = value values[key] = value
return values return values
@ -904,15 +916,18 @@ async def request_body_to_args(
if single_not_embedded_field: if single_not_embedded_field:
loc: Tuple[str, ...] = ("body",) loc: Tuple[str, ...] = ("body",)
v_, errors_ = _validate_value_with_model_field( v_, errors_ = _validate_value_with_model_field(
field=first_field, value=body_to_process, values=values, loc=loc field=first_field,
value=body_to_process if body_to_process is not None else _NOT_SPECIFIED,
values=values,
loc=loc,
) )
return {first_field.name: v_}, errors_ return {first_field.name: v_}, errors_
for field in body_fields: for field in body_fields:
loc = ("body", field.alias) loc = ("body", field.alias)
value: Optional[Any] = None value: Optional[Any] = _NOT_SPECIFIED
if body_to_process is not None: if body_to_process is not None:
try: try:
value = body_to_process.get(field.alias) value = _get_multidict_value(field, values=body_to_process)
# If the received body is a list, not a dict # If the received body is a list, not a dict
except AttributeError: except AttributeError:
errors.append(get_missing_field_error(loc)) errors.append(get_missing_field_error(loc))

47
tests/test_forms_nullable_param.py

@ -0,0 +1,47 @@
from typing import Optional
from uuid import UUID, uuid4
import pytest
from fastapi import FastAPI, Form
from fastapi.testclient import TestClient
from typing_extensions import Annotated
app = FastAPI()
default_uuid = uuid4()
@app.post("/form-optional/")
def post_form_optional(
test_id: Annotated[Optional[UUID], Form(alias="testId")] = default_uuid,
) -> Optional[UUID]:
return test_id
@app.post("/form-required/")
def post_form_required(
test_id: Annotated[Optional[UUID], Form(alias="testId")],
) -> Optional[UUID]:
return test_id
client = TestClient(app)
def test_unspecified_optional() -> None:
response = client.post("/form-optional/", data={})
assert response.status_code == 200, response.text
assert response.json() == str(default_uuid)
def test_unspecified_required() -> None:
response = client.post("/form-required/", data={})
assert response.status_code == 422, response.text
@pytest.mark.parametrize("url", ["/form-optional/", "/form-required/"])
@pytest.mark.parametrize("test_id", [None, str(uuid4())])
def test_specified(url: str, test_id: Optional[str]) -> None:
response = client.post(url, data={"testId": test_id})
assert response.status_code == 200, response.text
assert response.json() == test_id

4
tests/test_forms_single_model.py

@ -104,13 +104,13 @@ def test_no_data():
"type": "missing", "type": "missing",
"loc": ["body", "username"], "loc": ["body", "username"],
"msg": "Field required", "msg": "Field required",
"input": {"tags": ["foo", "bar"], "with": "nothing"}, "input": {"age": None, "tags": ["foo", "bar"], "with": "nothing"},
}, },
{ {
"type": "missing", "type": "missing",
"loc": ["body", "lastname"], "loc": ["body", "lastname"],
"msg": "Field required", "msg": "Field required",
"input": {"tags": ["foo", "bar"], "with": "nothing"}, "input": {"age": None, "tags": ["foo", "bar"], "with": "nothing"},
}, },
] ]
} }

2
tests/test_tutorial/test_cookie_param_models/test_tutorial001.py

@ -62,7 +62,7 @@ def test_cookie_param_model_invalid(client: TestClient):
"type": "missing", "type": "missing",
"loc": ["cookie", "session_id"], "loc": ["cookie", "session_id"],
"msg": "Field required", "msg": "Field required",
"input": {}, "input": {"fatebook_tracker": None, "googall_tracker": None},
} }
] ]
} }

2
tests/test_tutorial/test_cookie_param_models/test_tutorial002.py

@ -67,7 +67,7 @@ def test_cookie_param_model_invalid(client: TestClient):
"type": "missing", "type": "missing",
"loc": ["cookie", "session_id"], "loc": ["cookie", "session_id"],
"msg": "Field required", "msg": "Field required",
"input": {}, "input": {"fatebook_tracker": None, "googall_tracker": None},
} }
] ]
} }

2
tests/test_tutorial/test_header_param_models/test_tutorial001.py

@ -77,6 +77,8 @@ def test_header_param_model_invalid(client: TestClient):
"accept-encoding": "gzip, deflate", "accept-encoding": "gzip, deflate",
"connection": "keep-alive", "connection": "keep-alive",
"user-agent": "testclient", "user-agent": "testclient",
"if_modified_since": None,
"traceparent": None,
}, },
} }
) )

7
tests/test_tutorial/test_header_param_models/test_tutorial002.py

@ -75,7 +75,12 @@ def test_header_param_model_invalid(client: TestClient):
"type": "missing", "type": "missing",
"loc": ["header", "save_data"], "loc": ["header", "save_data"],
"msg": "Field required", "msg": "Field required",
"input": {"x_tag": [], "host": "testserver"}, "input": {
"x_tag": [],
"host": "testserver",
"if_modified_since": None,
"traceparent": None,
},
} }
) )
| IsDict( | IsDict(

Loading…
Cancel
Save