From 693962173071ed99cdcdc1fca9b982144a6669bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 13 Jun 2019 18:37:48 +0200 Subject: [PATCH] bug: Fix handling an empty-body request with a required body param (#311) * :bug: Fix solving a required body param from an empty body request * :white_check_mark: Add tests for receiving required body parameters with body not provided --- fastapi/dependencies/utils.py | 19 +- .../test_tutorial003.py | 172 ++++++++++++++++++ 2 files changed, 184 insertions(+), 7 deletions(-) create mode 100644 tests/test_tutorial/test_body_multiple_params/test_tutorial003.py diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index a16c64904..61dd01142 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -297,7 +297,7 @@ async def solve_dependencies( *, request: Union[Request, WebSocket], dependant: Dependant, - body: Dict[str, Any] = None, + body: Optional[Union[Dict[str, Any], FormData]] = None, background_tasks: BackgroundTasks = None, response: Response = None, dependency_overrides_provider: Any = None, @@ -388,7 +388,7 @@ async def solve_dependencies( errors += path_errors + query_errors + header_errors + cookie_errors if dependant.body_params: body_values, body_errors = await request_body_to_args( # type: ignore # body_params checked above - dependant.body_params, body + required_params=dependant.body_params, received_body=body ) values.update(body_values) errors.extend(body_errors) @@ -447,7 +447,8 @@ def request_params_to_args( async def request_body_to_args( - required_params: List[Field], received_body: Dict[str, Any] + required_params: List[Field], + received_body: Optional[Union[Dict[str, Any], FormData]], ) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: values = {} errors = [] @@ -457,10 +458,14 @@ async def request_body_to_args( if len(required_params) == 1 and not embed: received_body = {field.alias: received_body} for field in required_params: - if field.shape in sequence_shapes and isinstance(received_body, FormData): - value = received_body.getlist(field.alias) - else: - value = received_body.get(field.alias) + value = None + if received_body is not None: + if field.shape in sequence_shapes and isinstance( + received_body, FormData + ): + value = received_body.getlist(field.alias) + else: + value = received_body.get(field.alias) if ( value is None or (isinstance(field.schema, params.Form) and value == "") diff --git a/tests/test_tutorial/test_body_multiple_params/test_tutorial003.py b/tests/test_tutorial/test_body_multiple_params/test_tutorial003.py new file mode 100644 index 000000000..9a1c56bc2 --- /dev/null +++ b/tests/test_tutorial/test_body_multiple_params/test_tutorial003.py @@ -0,0 +1,172 @@ +import pytest +from starlette.testclient import TestClient + +from body_multiple_params.tutorial003 import app + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/items/{item_id}": { + "put": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Update Item", + "operationId": "update_item_items__item_id__put", + "parameters": [ + { + "required": True, + "schema": {"title": "Item_Id", "type": "integer"}, + "name": "item_id", + "in": "path", + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Body_update_item"} + } + }, + "required": True, + }, + } + } + }, + "components": { + "schemas": { + "Item": { + "title": "Item", + "required": ["name", "price"], + "type": "object", + "properties": { + "name": {"title": "Name", "type": "string"}, + "price": {"title": "Price", "type": "number"}, + "description": {"title": "Description", "type": "string"}, + "tax": {"title": "Tax", "type": "number"}, + }, + }, + "User": { + "title": "User", + "required": ["username"], + "type": "object", + "properties": { + "username": {"title": "Username", "type": "string"}, + "full_name": {"title": "Full_Name", "type": "string"}, + }, + }, + "Body_update_item": { + "title": "Body_update_item", + "required": ["item", "user", "importance"], + "type": "object", + "properties": { + "item": {"$ref": "#/components/schemas/Item"}, + "user": {"$ref": "#/components/schemas/User"}, + "importance": {"title": "Importance", "type": "integer"}, + }, + }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + } + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +# Test required and embedded body parameters with no bodies sent +@pytest.mark.parametrize( + "path,body,expected_status,expected_response", + [ + ( + "/items/5", + { + "importance": 2, + "item": {"name": "Foo", "price": 50.5}, + "user": {"username": "Dave"}, + }, + 200, + { + "item_id": 5, + "importance": 2, + "item": { + "name": "Foo", + "price": 50.5, + "description": None, + "tax": None, + }, + "user": {"username": "Dave", "full_name": None}, + }, + ), + ( + "/items/5", + None, + 422, + { + "detail": [ + { + "loc": ["body", "item"], + "msg": "field required", + "type": "value_error.missing", + }, + { + "loc": ["body", "user"], + "msg": "field required", + "type": "value_error.missing", + }, + { + "loc": ["body", "importance"], + "msg": "field required", + "type": "value_error.missing", + }, + ] + }, + ), + ], +) +def test_post_body(path, body, expected_status, expected_response): + response = client.put(path, json=body) + assert response.status_code == expected_status + assert response.json() == expected_response