diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 28c57c296..c898ab7db 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -131,12 +131,17 @@ def get_flat_dependant(dependant: Dependant) -> Dependant: def is_scalar_field(field: Field) -> bool: - return ( + if not ( field.shape == Shape.SINGLETON and not lenient_issubclass(field.type_, BaseModel) and not lenient_issubclass(field.type_, sequence_types + (dict,)) and not isinstance(field.schema, params.Body) - ) + ): + return False + if field.sub_fields: + if not all(is_scalar_field(f) for f in field.sub_fields): + return False + return True def is_scalar_sequence_field(field: Field) -> bool: diff --git a/tests/test_union_body.py b/tests/test_union_body.py new file mode 100644 index 000000000..b7b8b38b2 --- /dev/null +++ b/tests/test_union_body.py @@ -0,0 +1,124 @@ +from typing import Optional, Union + +from fastapi import FastAPI +from pydantic import BaseModel +from starlette.testclient import TestClient + +app = FastAPI() + + +class Item(BaseModel): + name: Optional[str] = None + + +class OtherItem(BaseModel): + price: int + + +@app.post("/items/") +def save_union_body(item: Union[OtherItem, Item]): + return {"item": item} + + +client = TestClient(app) + +item_openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/items/": { + "post": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Save Union Body", + "operationId": "save_union_body_items__post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "title": "Item", + "anyOf": [ + {"$ref": "#/components/schemas/OtherItem"}, + {"$ref": "#/components/schemas/Item"}, + ], + } + } + }, + "required": True, + }, + } + } + }, + "components": { + "schemas": { + "OtherItem": { + "title": "OtherItem", + "required": ["price"], + "type": "object", + "properties": {"price": {"title": "Price", "type": "integer"}}, + }, + "Item": { + "title": "Item", + "type": "object", + "properties": {"name": {"title": "Name", "type": "string"}}, + }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + } + }, +} + + +def test_item_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == item_openapi_schema + + +def test_post_other_item(): + response = client.post("/items/", json={"price": 100}) + assert response.status_code == 200 + assert response.json() == {"item": {"price": 100}} + + +def test_post_item(): + response = client.post("/items/", json={"name": "Foo"}) + assert response.status_code == 200 + assert response.json() == {"item": {"name": "Foo"}} diff --git a/tests/test_union_inherited_body.py b/tests/test_union_inherited_body.py new file mode 100644 index 000000000..022326e51 --- /dev/null +++ b/tests/test_union_inherited_body.py @@ -0,0 +1,140 @@ +import sys +from typing import Optional, Union + +import pytest +from fastapi import FastAPI +from pydantic import BaseModel +from starlette.testclient import TestClient + +# In Python 3.6: +# u = Union[ExtendedItem, Item] == __main__.Item + +# But in Python 3.7: +# u = Union[ExtendedItem, Item] == typing.Union[__main__.ExtendedItem, __main__.Item] +skip_py36 = pytest.mark.skipif(sys.version_info < (3, 7), reason="skip python3.6") + +app = FastAPI() + + +class Item(BaseModel): + name: Optional[str] = None + + +class ExtendedItem(Item): + age: int + + +@app.post("/items/") +def save_union_different_body(item: Union[ExtendedItem, Item]): + return {"item": item} + + +client = TestClient(app) + + +inherited_item_openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/items/": { + "post": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Save Union Different Body", + "operationId": "save_union_different_body_items__post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "title": "Item", + "anyOf": [ + {"$ref": "#/components/schemas/ExtendedItem"}, + {"$ref": "#/components/schemas/Item"}, + ], + } + } + }, + "required": True, + }, + } + } + }, + "components": { + "schemas": { + "Item": { + "title": "Item", + "type": "object", + "properties": {"name": {"title": "Name", "type": "string"}}, + }, + "ExtendedItem": { + "title": "ExtendedItem", + "required": ["age"], + "type": "object", + "properties": { + "name": {"title": "Name", "type": "string"}, + "age": {"title": "Age", "type": "integer"}, + }, + }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + } + }, +} + + +@skip_py36 +def test_inherited_item_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == inherited_item_openapi_schema + + +@skip_py36 +def test_post_extended_item(): + response = client.post("/items/", json={"name": "Foo", "age": 5}) + assert response.status_code == 200 + assert response.json() == {"item": {"name": "Foo", "age": 5}} + + +@skip_py36 +def test_post_item(): + response = client.post("/items/", json={"name": "Foo"}) + assert response.status_code == 200 + assert response.json() == {"item": {"name": "Foo"}}