Browse Source

Merge b5032f39a0 into 76b324d95b

pull/12502/merge
maxclaey 2 days ago
committed by GitHub
parent
commit
24ccec04f3
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 47
      fastapi/dependencies/utils.py
  2. 47
      tests/test_forms_nullable_param.py
  3. 4
      tests/test_forms_single_model.py
  4. 2
      tests/test_tutorial/test_cookie_param_models/test_tutorial001.py
  5. 2
      tests/test_tutorial/test_cookie_param_models/test_tutorial002.py
  6. 2
      tests/test_tutorial/test_header_param_models/test_tutorial001.py
  7. 7
      tests/test_tutorial/test_header_param_models/test_tutorial002.py

47
fastapi/dependencies/utils.py

@ -87,6 +87,9 @@ multipart_incorrect_install_error = (
"pip install python-multipart\n"
)
# Sentinel value for unspecified fields
_NOT_SPECIFIED = object()
def ensure_multipart_is_installed() -> None:
try:
@ -698,7 +701,7 @@ async def solve_dependencies(
def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]:
if value is None:
if value is _NOT_SPECIFIED:
if field.required:
return None, [get_missing_field_error(loc=loc)]
else:
@ -720,20 +723,24 @@ def _get_multidict_value(
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
value = values.getlist(alias)
else:
value = values.get(alias, None)
value = values.get(alias, _NOT_SPECIFIED)
if (
value is None
or (
isinstance(field.field_info, params.Form)
and isinstance(value, str) # For type checks
and value == ""
)
or (is_sequence_field(field) and len(value) == 0)
isinstance(field.field_info, params.Form)
and isinstance(value, str) # For type checks
and value == ""
):
# Empty strings in a form can be a representation of None values
_, error = field.validate(None, {}, loc=())
# If None is an accepted value for this field, use that
if error is None:
value = None
if value == _NOT_SPECIFIED or (is_sequence_field(field) and len(value) == 0):
if field.required:
return
return _NOT_SPECIFIED
else:
return deepcopy(field.default)
return value
@ -779,7 +786,7 @@ def request_params_to_args(
else field.name.replace("_", "-")
)
value = _get_multidict_value(field, received_params, alias=alias)
if value is not None:
if value != _NOT_SPECIFIED:
params_to_process[field.name] = value
processed_keys.add(alias or field.alias)
processed_keys.add(field.name)
@ -846,6 +853,8 @@ async def _extract_form_body(
first_field = body_fields[0]
first_field_info = first_field.field_info
processed_keys = set()
for field in body_fields:
value = _get_multidict_value(field, received_body)
if (
@ -873,10 +882,13 @@ async def _extract_form_body(
for sub_value in value:
tg.start_soon(process_fn, sub_value.read)
value = serialize_sequence_value(field=field, value=results)
if value is not None:
if value != _NOT_SPECIFIED:
values[field.alias] = value
processed_keys.add(field.alias)
processed_keys.add(field.name)
for key, value in received_body.items():
if key not in values:
if key not in processed_keys:
values[key] = value
return values
@ -904,15 +916,18 @@ async def request_body_to_args(
if single_not_embedded_field:
loc: Tuple[str, ...] = ("body",)
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=body_to_process, values=values, loc=loc
field=first_field,
value=body_to_process if body_to_process is not None else _NOT_SPECIFIED,
values=values,
loc=loc,
)
return {first_field.name: v_}, errors_
for field in body_fields:
loc = ("body", field.alias)
value: Optional[Any] = None
value: Optional[Any] = _NOT_SPECIFIED
if body_to_process is not None:
try:
value = body_to_process.get(field.alias)
value = _get_multidict_value(field, values=body_to_process)
# If the received body is a list, not a dict
except AttributeError:
errors.append(get_missing_field_error(loc))

47
tests/test_forms_nullable_param.py

@ -0,0 +1,47 @@
from typing import Optional
from uuid import UUID, uuid4
import pytest
from fastapi import FastAPI, Form
from fastapi.testclient import TestClient
from typing_extensions import Annotated
app = FastAPI()
default_uuid = uuid4()
@app.post("/form-optional/")
def post_form_optional(
test_id: Annotated[Optional[UUID], Form(alias="testId")] = default_uuid,
) -> Optional[UUID]:
return test_id
@app.post("/form-required/")
def post_form_required(
test_id: Annotated[Optional[UUID], Form(alias="testId")],
) -> Optional[UUID]:
return test_id
client = TestClient(app)
def test_unspecified_optional() -> None:
response = client.post("/form-optional/", data={})
assert response.status_code == 200, response.text
assert response.json() == str(default_uuid)
def test_unspecified_required() -> None:
response = client.post("/form-required/", data={})
assert response.status_code == 422, response.text
@pytest.mark.parametrize("url", ["/form-optional/", "/form-required/"])
@pytest.mark.parametrize("test_id", [None, str(uuid4())])
def test_specified(url: str, test_id: Optional[str]) -> None:
response = client.post(url, data={"testId": test_id})
assert response.status_code == 200, response.text
assert response.json() == test_id

4
tests/test_forms_single_model.py

@ -104,13 +104,13 @@ def test_no_data():
"type": "missing",
"loc": ["body", "username"],
"msg": "Field required",
"input": {"tags": ["foo", "bar"], "with": "nothing"},
"input": {"age": None, "tags": ["foo", "bar"], "with": "nothing"},
},
{
"type": "missing",
"loc": ["body", "lastname"],
"msg": "Field required",
"input": {"tags": ["foo", "bar"], "with": "nothing"},
"input": {"age": None, "tags": ["foo", "bar"], "with": "nothing"},
},
]
}

2
tests/test_tutorial/test_cookie_param_models/test_tutorial001.py

@ -62,7 +62,7 @@ def test_cookie_param_model_invalid(client: TestClient):
"type": "missing",
"loc": ["cookie", "session_id"],
"msg": "Field required",
"input": {},
"input": {"fatebook_tracker": None, "googall_tracker": None},
}
]
}

2
tests/test_tutorial/test_cookie_param_models/test_tutorial002.py

@ -67,7 +67,7 @@ def test_cookie_param_model_invalid(client: TestClient):
"type": "missing",
"loc": ["cookie", "session_id"],
"msg": "Field required",
"input": {},
"input": {"fatebook_tracker": None, "googall_tracker": None},
}
]
}

2
tests/test_tutorial/test_header_param_models/test_tutorial001.py

@ -77,6 +77,8 @@ def test_header_param_model_invalid(client: TestClient):
"accept-encoding": "gzip, deflate",
"connection": "keep-alive",
"user-agent": "testclient",
"if_modified_since": None,
"traceparent": None,
},
}
)

7
tests/test_tutorial/test_header_param_models/test_tutorial002.py

@ -75,7 +75,12 @@ def test_header_param_model_invalid(client: TestClient):
"type": "missing",
"loc": ["header", "save_data"],
"msg": "Field required",
"input": {"x_tag": [], "host": "testserver"},
"input": {
"x_tag": [],
"host": "testserver",
"if_modified_since": None,
"traceparent": None,
},
}
)
| IsDict(

Loading…
Cancel
Save