From 76c4d317fd4119282141db787f7cd82a417de4b5 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Thu, 6 Mar 2025 19:06:44 -0800 Subject: [PATCH] don't prefill defaults in form input --- fastapi/dependencies/utils.py | 13 ++- tests/test_forms_defaults.py | 192 +++++++++++++++++++++++++++++++ tests/test_forms_single_model.py | 4 +- 3 files changed, 202 insertions(+), 7 deletions(-) create mode 100644 tests/test_forms_defaults.py diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index d205d17fa..dbab0b523 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -714,7 +714,10 @@ def _validate_value_with_model_field( def _get_multidict_value( - field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None + field: ModelField, + values: Mapping[str, Any], + alias: Union[str, None] = None, + return_default=True, ) -> Any: alias = alias or field.alias if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)): @@ -730,10 +733,10 @@ def _get_multidict_value( ) or (is_sequence_field(field) and len(value) == 0) ): - if field.required: - return - else: + if return_default and not field.required: return deepcopy(field.default) + else: + return None return value @@ -839,7 +842,7 @@ async def _extract_form_body( first_field_info = first_field.field_info for field in body_fields: - value = _get_multidict_value(field, received_body) + value = _get_multidict_value(field, received_body, return_default=False) if ( isinstance(first_field_info, params.File) and is_bytes_field(field) diff --git a/tests/test_forms_defaults.py b/tests/test_forms_defaults.py new file mode 100644 index 000000000..2fc0976e6 --- /dev/null +++ b/tests/test_forms_defaults.py @@ -0,0 +1,192 @@ +from typing import Annotated, Optional + +import pytest +from fastapi import FastAPI, Form +from pydantic import BaseModel, Field, model_validator +from starlette.testclient import TestClient + + +class Parent(BaseModel): + init_input: dict + # importantly, no default here + + @model_validator(mode="before") + def validate_inputs(cls, value: dict) -> dict: + """ + model validators in before mode should receive values passed + to model instantiation before any further validation + """ + # we should not be double-instantiating the models + assert isinstance(value, dict) + value["init_input"] = value.copy() + + # differentiate between explicit Nones and unpassed values + if "true_if_unset" not in value: + value["true_if_unset"] = True + return value + + +class StandardModel(Parent): + default_true: bool = True + default_false: bool = False + default_none: Optional[bool] = None + default_zero: int = 0 + true_if_unset: Optional[bool] = None + + +class FieldModel(Parent): + default_true: bool = Field(default=True) + default_false: bool = Field(default=False) + default_none: Optional[bool] = Field(default=None) + default_zero: int = Field(default=0) + true_if_unset: Optional[bool] = Field(default=None) + + +class AnnotatedFieldModel(Parent): + default_true: Annotated[bool, Field(default=True)] + default_false: Annotated[bool, Field(default=False)] + default_none: Annotated[Optional[bool], Field(default=None)] + default_zero: Annotated[int, Field(default=0)] + true_if_unset: Annotated[Optional[bool], Field(default=None)] + + +class AnnotatedFormModel(Parent): + default_true: Annotated[bool, Form(default=True)] + default_false: Annotated[bool, Form(default=False)] + default_none: Annotated[Optional[bool], Form(default=None)] + default_zero: Annotated[int, Form(default=0)] + true_if_unset: Annotated[Optional[bool], Form(default=None)] + + +class ResponseModel(BaseModel): + fields_set: list = Field(default_factory=list) + dumped_fields_no_exclude: dict = Field(default_factory=dict) + dumped_fields_exclude_default: dict = Field(default_factory=dict) + dumped_fields_exclude_unset: dict = Field(default_factory=dict) + init_input: dict + + @classmethod + def from_value(cls, value: Parent) -> "ResponseModel": + return ResponseModel( + init_input=value.init_input, + fields_set=list(value.model_fields_set), + dumped_fields_no_exclude=value.model_dump(), + dumped_fields_exclude_default=value.model_dump(exclude_defaults=True), + dumped_fields_exclude_unset=value.model_dump(exclude_unset=True), + ) + + +app = FastAPI() + + +@app.post("/form/standard") +async def form_standard(value: Annotated[StandardModel, Form()]) -> ResponseModel: + return ResponseModel.from_value(value) + + +@app.post("/form/field") +async def form_field(value: Annotated[FieldModel, Form()]) -> ResponseModel: + return ResponseModel.from_value(value) + + +@app.post("/form/annotated-field") +async def form_annotated_field( + value: Annotated[AnnotatedFieldModel, Form()], +) -> ResponseModel: + return ResponseModel.from_value(value) + + +@app.post("/form/annotated-form") +async def form_annotated_form( + value: Annotated[AnnotatedFormModel, Form()], +) -> ResponseModel: + return ResponseModel.from_value(value) + + +@app.post("/json/standard") +async def json_standard(value: StandardModel) -> ResponseModel: + return ResponseModel.from_value(value) + + +@app.post("/json/field") +async def json_field(value: FieldModel) -> ResponseModel: + return ResponseModel.from_value(value) + + +@app.post("/json/annotated-field") +async def json_annotated_field(value: AnnotatedFieldModel) -> ResponseModel: + return ResponseModel.from_value(value) + + +@app.post("/json/annotated-form") +async def json_annotated_form(value: AnnotatedFormModel) -> ResponseModel: + return ResponseModel.from_value(value) + + +MODEL_TYPES = { + "standard": StandardModel, + "field": FieldModel, + "annotated-field": AnnotatedFieldModel, + "annotated-form": AnnotatedFormModel, +} +ENCODINGS = ("form", "json") + + +@pytest.fixture(scope="module") +def client() -> TestClient: + with TestClient(app) as test_client: + yield test_client + + +@pytest.mark.parametrize("encoding", ENCODINGS) +@pytest.mark.parametrize("model_type", MODEL_TYPES.keys()) +def test_no_prefill_defaults_all_unset(encoding, model_type, client, monkeypatch): + """ + When the model is instantiated by the server, it should not have its defaults prefilled + """ + + endpoint = f"/{encoding}/{model_type}" + if encoding == "form": + res = client.post(endpoint, data={}) + else: + res = client.post(endpoint, json={}) + + assert res.status_code == 200 + response_model = ResponseModel(**res.json()) + assert response_model.init_input == {} + assert len(response_model.fields_set) == 2 + assert response_model.dumped_fields_no_exclude["true_if_unset"] is True + + +@pytest.mark.parametrize("encoding", ENCODINGS) +@pytest.mark.parametrize("model_type", MODEL_TYPES.keys()) +def test_no_prefill_defaults_partially_set(encoding, model_type, client, monkeypatch): + """ + When the model is instantiated by the server, it should not have its defaults prefilled, + and pydantic should be able to differentiate between unset and default values when some are passed + """ + endpoint = f"/{encoding}/{model_type}" + if encoding == "form": + data = {"true_if_unset": "False", "default_false": "True", "default_zero": "0"} + res = client.post(endpoint, data=data) + else: + data = {"true_if_unset": False, "default_false": True, "default_zero": 0} + res = client.post(endpoint, json=data) + + data_with_init_input = data.copy() + data_with_init_input["init_input"] = data.copy() + + assert res.status_code == 200 + response_model = ResponseModel(**res.json()) + assert response_model.init_input == data + assert len(response_model.fields_set) == 4 + dumped_exclude_unset = MODEL_TYPES[model_type](**data).model_dump( + exclude_unset=True + ) + assert response_model.dumped_fields_exclude_unset == dumped_exclude_unset + assert response_model.dumped_fields_no_exclude["true_if_unset"] is False + dumped_exclude_default = MODEL_TYPES[model_type](**data).model_dump( + exclude_defaults=True + ) + assert "default_zero" not in dumped_exclude_default + assert "default_zero" not in response_model.dumped_fields_exclude_default diff --git a/tests/test_forms_single_model.py b/tests/test_forms_single_model.py index 880ab3820..c57ee973a 100644 --- a/tests/test_forms_single_model.py +++ b/tests/test_forms_single_model.py @@ -104,13 +104,13 @@ def test_no_data(): "type": "missing", "loc": ["body", "username"], "msg": "Field required", - "input": {"tags": ["foo", "bar"], "with": "nothing"}, + "input": {}, }, { "type": "missing", "loc": ["body", "lastname"], "msg": "Field required", - "input": {"tags": ["foo", "bar"], "with": "nothing"}, + "input": {}, }, ] }