From fef3a69e3ecddb39facb4272ab14a7639000cdc9 Mon Sep 17 00:00:00 2001 From: Maxim Claeys Date: Mon, 21 Oct 2024 12:20:52 +0200 Subject: [PATCH 1/2] Fix support for nullable form fields --- fastapi/dependencies/utils.py | 47 ++++++++++++------- tests/test_forms_nullable_param.py | 47 +++++++++++++++++++ tests/test_forms_single_model.py | 4 +- .../test_tutorial001.py | 2 +- .../test_tutorial002.py | 2 +- .../test_tutorial001.py | 2 + .../test_tutorial002.py | 7 ++- 7 files changed, 90 insertions(+), 21 deletions(-) create mode 100644 tests/test_forms_nullable_param.py diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 87653c80d..018f3ded8 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -87,6 +87,9 @@ multipart_incorrect_install_error = ( "pip install python-multipart\n" ) +# Sentinel value for unspecified fields +_NOT_SPECIFIED = object() + def ensure_multipart_is_installed() -> None: try: @@ -690,7 +693,7 @@ async def solve_dependencies( def _validate_value_with_model_field( *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] ) -> Tuple[Any, List[Any]]: - if value is None: + if value is _NOT_SPECIFIED: if field.required: return None, [get_missing_field_error(loc=loc)] else: @@ -712,20 +715,24 @@ def _get_multidict_value( if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)): value = values.getlist(alias) else: - value = values.get(alias, None) + value = values.get(alias, _NOT_SPECIFIED) if ( - value is None - or ( - isinstance(field.field_info, params.Form) - and isinstance(value, str) # For type checks - and value == "" - ) - or (is_sequence_field(field) and len(value) == 0) + isinstance(field.field_info, params.Form) + and isinstance(value, str) # For type checks + and value == "" ): + # Empty strings in a form can be a representation of None values + _, error = field.validate(value=None) + # If None is an accepted value for this field, use that + if error is None: + value = None + + if value == _NOT_SPECIFIED or (is_sequence_field(field) and len(value) == 0): if field.required: - return + return _NOT_SPECIFIED else: return deepcopy(field.default) + return value @@ -763,7 +770,7 @@ def request_params_to_args( else field.name.replace("_", "-") ) value = _get_multidict_value(field, received_params, alias=alias) - if value is not None: + if value != _NOT_SPECIFIED: params_to_process[field.name] = value processed_keys.add(alias or field.alias) processed_keys.add(field.name) @@ -830,6 +837,8 @@ async def _extract_form_body( first_field = body_fields[0] first_field_info = first_field.field_info + processed_keys = set() + for field in body_fields: value = _get_multidict_value(field, received_body) if ( @@ -857,10 +866,13 @@ async def _extract_form_body( for sub_value in value: tg.start_soon(process_fn, sub_value.read) value = serialize_sequence_value(field=field, value=results) - if value is not None: + if value != _NOT_SPECIFIED: values[field.alias] = value + processed_keys.add(field.alias) + processed_keys.add(field.name) + for key, value in received_body.items(): - if key not in values: + if key not in processed_keys: values[key] = value return values @@ -888,15 +900,18 @@ async def request_body_to_args( if single_not_embedded_field: loc: Tuple[str, ...] = ("body",) v_, errors_ = _validate_value_with_model_field( - field=first_field, value=body_to_process, values=values, loc=loc + field=first_field, + value=body_to_process if body_to_process is not None else _NOT_SPECIFIED, + values=values, + loc=loc, ) return {first_field.name: v_}, errors_ for field in body_fields: loc = ("body", field.alias) - value: Optional[Any] = None + value: Optional[Any] = _NOT_SPECIFIED if body_to_process is not None: try: - value = body_to_process.get(field.alias) + value = _get_multidict_value(field, values=body_to_process) # If the received body is a list, not a dict except AttributeError: errors.append(get_missing_field_error(loc)) diff --git a/tests/test_forms_nullable_param.py b/tests/test_forms_nullable_param.py new file mode 100644 index 000000000..7e770bfd5 --- /dev/null +++ b/tests/test_forms_nullable_param.py @@ -0,0 +1,47 @@ +from typing import Optional +from uuid import UUID, uuid4 + +import pytest +from fastapi import FastAPI, Form +from fastapi.testclient import TestClient +from typing_extensions import Annotated + +app = FastAPI() + +default_uuid = uuid4() + + +@app.post("/form-optional/") +def post_form_optional( + test_id: Annotated[Optional[UUID], Form(alias="testId")] = default_uuid, +) -> Optional[UUID]: + return test_id + + +@app.post("/form-required/") +def post_form_required( + test_id: Annotated[Optional[UUID], Form(alias="testId")], +) -> Optional[UUID]: + return test_id + + +client = TestClient(app) + + +def test_unspecified_optional() -> None: + response = client.post("/form-optional/", data={}) + assert response.status_code == 200, response.text + assert response.json() == str(default_uuid) + + +def test_unspecified_required() -> None: + response = client.post("/form-required/", data={}) + assert response.status_code == 422, response.text + + +@pytest.mark.parametrize("url", ["/form-optional/", "/form-required/"]) +@pytest.mark.parametrize("test_id", [None, str(uuid4())]) +def test_specified(url: str, test_id: Optional[str]) -> None: + response = client.post(url, data={"testId": test_id}) + assert response.status_code == 200, response.text + assert response.json() == test_id diff --git a/tests/test_forms_single_model.py b/tests/test_forms_single_model.py index 880ab3820..64d4cd238 100644 --- a/tests/test_forms_single_model.py +++ b/tests/test_forms_single_model.py @@ -104,13 +104,13 @@ def test_no_data(): "type": "missing", "loc": ["body", "username"], "msg": "Field required", - "input": {"tags": ["foo", "bar"], "with": "nothing"}, + "input": {"age": None, "tags": ["foo", "bar"], "with": "nothing"}, }, { "type": "missing", "loc": ["body", "lastname"], "msg": "Field required", - "input": {"tags": ["foo", "bar"], "with": "nothing"}, + "input": {"age": None, "tags": ["foo", "bar"], "with": "nothing"}, }, ] } diff --git a/tests/test_tutorial/test_cookie_param_models/test_tutorial001.py b/tests/test_tutorial/test_cookie_param_models/test_tutorial001.py index 60643185a..e812e0ac6 100644 --- a/tests/test_tutorial/test_cookie_param_models/test_tutorial001.py +++ b/tests/test_tutorial/test_cookie_param_models/test_tutorial001.py @@ -62,7 +62,7 @@ def test_cookie_param_model_invalid(client: TestClient): "type": "missing", "loc": ["cookie", "session_id"], "msg": "Field required", - "input": {}, + "input": {"fatebook_tracker": None, "googall_tracker": None}, } ] } diff --git a/tests/test_tutorial/test_cookie_param_models/test_tutorial002.py b/tests/test_tutorial/test_cookie_param_models/test_tutorial002.py index 30adadc8a..2f1bf4098 100644 --- a/tests/test_tutorial/test_cookie_param_models/test_tutorial002.py +++ b/tests/test_tutorial/test_cookie_param_models/test_tutorial002.py @@ -67,7 +67,7 @@ def test_cookie_param_model_invalid(client: TestClient): "type": "missing", "loc": ["cookie", "session_id"], "msg": "Field required", - "input": {}, + "input": {"fatebook_tracker": None, "googall_tracker": None}, } ] } diff --git a/tests/test_tutorial/test_header_param_models/test_tutorial001.py b/tests/test_tutorial/test_header_param_models/test_tutorial001.py index 06b2404cf..e8147decc 100644 --- a/tests/test_tutorial/test_header_param_models/test_tutorial001.py +++ b/tests/test_tutorial/test_header_param_models/test_tutorial001.py @@ -77,6 +77,8 @@ def test_header_param_model_invalid(client: TestClient): "accept-encoding": "gzip, deflate", "connection": "keep-alive", "user-agent": "testclient", + "if_modified_since": None, + "traceparent": None, }, } ) diff --git a/tests/test_tutorial/test_header_param_models/test_tutorial002.py b/tests/test_tutorial/test_header_param_models/test_tutorial002.py index e07655a0c..e704839ad 100644 --- a/tests/test_tutorial/test_header_param_models/test_tutorial002.py +++ b/tests/test_tutorial/test_header_param_models/test_tutorial002.py @@ -75,7 +75,12 @@ def test_header_param_model_invalid(client: TestClient): "type": "missing", "loc": ["header", "save_data"], "msg": "Field required", - "input": {"x_tag": [], "host": "testserver"}, + "input": { + "x_tag": [], + "host": "testserver", + "if_modified_since": None, + "traceparent": None, + }, } ) | IsDict( From b5032f39a0c49b9811ed0906360cfc947340ab27 Mon Sep 17 00:00:00 2001 From: Maxim Claeys Date: Mon, 21 Oct 2024 12:30:34 +0200 Subject: [PATCH 2/2] Fix support for pydantic v1 --- fastapi/dependencies/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 018f3ded8..c5105efad 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -722,7 +722,7 @@ def _get_multidict_value( and value == "" ): # Empty strings in a form can be a representation of None values - _, error = field.validate(value=None) + _, error = field.validate(None, {}, loc=()) # If None is an accepted value for this field, use that if error is None: value = None