From c5817912d2be25bb310bf9da517882f57bbe7bb5 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sat, 31 Aug 2019 01:32:39 +0300 Subject: [PATCH] :bug: use media_type from Body params for OpenAPI requestBody (Fixes: #431) (#439) --- fastapi/dependencies/utils.py | 12 +++- ...test_request_body_parameters_media_type.py | 67 +++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 tests/test_request_body_parameters_media_type.py diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index f9e42d0a8..7f0f59092 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -559,6 +559,8 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[Field]: for f in flat_dependant.body_params: BodyModel.__fields__[f.name] = get_schema_compatible_field(field=f) required = any(True for f in flat_dependant.body_params if f.required) + + BodySchema_kwargs: Dict[str, Any] = dict(default=None) if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params): BodySchema: Type[params.Body] = params.File elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params): @@ -566,6 +568,14 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[Field]: else: BodySchema = params.Body + body_param_media_types = [ + getattr(f.schema, "media_type") + for f in flat_dependant.body_params + if isinstance(f.schema, params.Body) + ] + if len(set(body_param_media_types)) == 1: + BodySchema_kwargs["media_type"] = body_param_media_types[0] + field = Field( name="body", type_=BodyModel, @@ -574,6 +584,6 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[Field]: model_config=BaseConfig, class_validators={}, alias="body", - schema=BodySchema(None), + schema=BodySchema(**BodySchema_kwargs), ) return field diff --git a/tests/test_request_body_parameters_media_type.py b/tests/test_request_body_parameters_media_type.py new file mode 100644 index 000000000..89b98b220 --- /dev/null +++ b/tests/test_request_body_parameters_media_type.py @@ -0,0 +1,67 @@ +import typing + +from fastapi import Body, FastAPI +from pydantic import BaseModel +from starlette.testclient import TestClient + +app = FastAPI() + +media_type = "application/vnd.api+json" + +# NOTE: These are not valid JSON:API resources +# but they are fine for testing requestBody with custom media_type +class Product(BaseModel): + name: str + price: float + + +class Shop(BaseModel): + name: str + + +@app.post("/products") +async def create_product(data: Product = Body(..., media_type=media_type, embed=True)): + pass # pragma: no cover + + +@app.post("/shops") +async def create_shop( + data: Shop = Body(..., media_type=media_type), + included: typing.List[Product] = Body([], media_type=media_type), +): + pass # pragma: no cover + + +create_product_request_body = { + "content": { + "application/vnd.api+json": { + "schema": {"$ref": "#/components/schemas/Body_create_product_products_post"} + } + }, + "required": True, +} + +create_shop_request_body = { + "content": { + "application/vnd.api+json": { + "schema": {"$ref": "#/components/schemas/Body_create_shop_shops_post"} + } + }, + "required": True, +} + +client = TestClient(app) + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + openapi_schema = response.json() + assert ( + openapi_schema["paths"]["/products"]["post"]["requestBody"] + == create_product_request_body + ) + assert ( + openapi_schema["paths"]["/shops"]["post"]["requestBody"] + == create_shop_request_body + )