From 19c53b21c19fc193b070281baa9e585d26cec91c Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sat, 31 Aug 2019 00:46:05 +0300 Subject: [PATCH] :sparkles: Allow using custom 422 validation error and use media type from response class in schema (#437) * media_type of additional responses from the response_class * Use HTTPValidationError only if a custom one is not defined (Fixes: #429) --- fastapi/openapi/utils.py | 38 +++--- ...tional_responses_custom_validationerror.py | 100 +++++++++++++++ ...ional_responses_default_validationerror.py | 85 +++++++++++++ ...est_additional_responses_response_class.py | 117 ++++++++++++++++++ 4 files changed, 322 insertions(+), 18 deletions(-) create mode 100644 tests/test_additional_responses_custom_validationerror.py create mode 100644 tests/test_additional_responses_default_validationerror.py create mode 100644 tests/test_additional_responses_response_class.py diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index c3cc120fd..6c987a29f 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -80,15 +80,11 @@ def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, L def get_openapi_operation_parameters( all_route_params: Sequence[Field] -) -> Tuple[Dict[str, Dict], List[Dict[str, Any]]]: - definitions: Dict[str, Dict] = {} +) -> List[Dict[str, Any]]: parameters = [] for param in all_route_params: schema = param.schema schema = cast(Param, schema) - if "ValidationError" not in definitions: - definitions["ValidationError"] = validation_error_definition - definitions["HTTPValidationError"] = validation_error_response_definition parameter = { "name": param.alias, "in": schema.in_.value, @@ -100,7 +96,7 @@ def get_openapi_operation_parameters( if schema.deprecated: parameter["deprecated"] = schema.deprecated parameters.append(parameter) - return definitions, parameters + return parameters def get_openapi_operation_request_body( @@ -168,10 +164,7 @@ def get_openapi_path( if security_definitions: security_schemes.update(security_definitions) all_route_params = get_openapi_params(route.dependant) - validation_definitions, operation_parameters = get_openapi_operation_parameters( - all_route_params=all_route_params - ) - definitions.update(validation_definitions) + operation_parameters = get_openapi_operation_parameters(all_route_params) parameters.extend(operation_parameters) if parameters: operation["parameters"] = parameters @@ -181,11 +174,6 @@ def get_openapi_path( ) if request_body_oai: operation["requestBody"] = request_body_oai - if "ValidationError" not in definitions: - definitions["ValidationError"] = validation_error_definition - definitions[ - "HTTPValidationError" - ] = validation_error_response_definition if route.responses: for (additional_status_code, response) in route.responses.items(): assert isinstance( @@ -197,7 +185,7 @@ def get_openapi_path( field, model_name_map=model_name_map, ref_prefix=REF_PREFIX ) response.setdefault("content", {}).setdefault( - "application/json", {} + route.response_class.media_type, {} )["schema"] = response_schema status_text: Optional[str] = status_code_ranges.get( str(additional_status_code).upper() @@ -228,8 +216,15 @@ def get_openapi_path( ).setdefault("content", {}).setdefault(route.response_class.media_type, {})[ "schema" ] = response_schema - if all_route_params or route.body_field: - operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = { + + http422 = str(HTTP_422_UNPROCESSABLE_ENTITY) + if (all_route_params or route.body_field) and not any( + [ + status in operation["responses"] + for status in [http422, "4xx", "default"] + ] + ): + operation["responses"][http422] = { "description": "Validation Error", "content": { "application/json": { @@ -237,6 +232,13 @@ def get_openapi_path( } }, } + if "ValidationError" not in definitions: + definitions.update( + { + "ValidationError": validation_error_definition, + "HTTPValidationError": validation_error_response_definition, + } + ) path[method.lower()] = operation return path, security_schemes, definitions diff --git a/tests/test_additional_responses_custom_validationerror.py b/tests/test_additional_responses_custom_validationerror.py new file mode 100644 index 000000000..37982eef4 --- /dev/null +++ b/tests/test_additional_responses_custom_validationerror.py @@ -0,0 +1,100 @@ +import typing + +from fastapi import FastAPI +from pydantic import BaseModel +from starlette.responses import JSONResponse +from starlette.testclient import TestClient + +app = FastAPI() + + +class JsonApiResponse(JSONResponse): + media_type = "application/vnd.api+json" + + +class Error(BaseModel): + status: str + title: str + + +class JsonApiError(BaseModel): + errors: typing.List[Error] + + +@app.get( + "/a/{id}", + response_class=JsonApiResponse, + responses={422: {"description": "Error", "model": JsonApiError}}, +) +async def a(id): + pass # pragma: no cover + + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/a/{id}": { + "get": { + "responses": { + "422": { + "description": "Error", + "content": { + "application/vnd.api+json": { + "schema": {"$ref": "#/components/schemas/JsonApiError"} + } + }, + }, + "200": { + "description": "Successful Response", + "content": {"application/vnd.api+json": {"schema": {}}}, + }, + }, + "summary": "A", + "operationId": "a_a__id__get", + "parameters": [ + { + "required": True, + "schema": {"title": "Id"}, + "name": "id", + "in": "path", + } + ], + } + } + }, + "components": { + "schemas": { + "Error": { + "title": "Error", + "required": ["status", "title"], + "type": "object", + "properties": { + "status": {"title": "Status", "type": "string"}, + "title": {"title": "Title", "type": "string"}, + }, + }, + "JsonApiError": { + "title": "JsonApiError", + "required": ["errors"], + "type": "object", + "properties": { + "errors": { + "title": "Errors", + "type": "array", + "items": {"$ref": "#/components/schemas/Error"}, + } + }, + }, + } + }, +} + + +client = TestClient(app) + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema diff --git a/tests/test_additional_responses_default_validationerror.py b/tests/test_additional_responses_default_validationerror.py new file mode 100644 index 000000000..ac22bf573 --- /dev/null +++ b/tests/test_additional_responses_default_validationerror.py @@ -0,0 +1,85 @@ +from fastapi import FastAPI +from starlette.testclient import TestClient + +app = FastAPI() + + +@app.get("/a/{id}") +async def a(id): + pass # pragma: no cover + + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/a/{id}": { + "get": { + "responses": { + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + }, + "summary": "A", + "operationId": "a_a__id__get", + "parameters": [ + { + "required": True, + "schema": {"title": "Id"}, + "name": "id", + "in": "path", + } + ], + } + } + }, + "components": { + "schemas": { + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + } + }, +} + + +client = TestClient(app) + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema diff --git a/tests/test_additional_responses_response_class.py b/tests/test_additional_responses_response_class.py new file mode 100644 index 000000000..81c28e348 --- /dev/null +++ b/tests/test_additional_responses_response_class.py @@ -0,0 +1,117 @@ +import typing + +from fastapi import FastAPI +from pydantic import BaseModel +from starlette.responses import JSONResponse +from starlette.testclient import TestClient + +app = FastAPI() + + +class JsonApiResponse(JSONResponse): + media_type = "application/vnd.api+json" + + +class Error(BaseModel): + status: str + title: str + + +class JsonApiError(BaseModel): + errors: typing.List[Error] + + +@app.get( + "/a", + response_class=JsonApiResponse, + responses={500: {"description": "Error", "model": JsonApiError}}, +) +async def a(): + pass # pragma: no cover + + +@app.get("/b", responses={500: {"description": "Error", "model": Error}}) +async def b(): + pass # pragma: no cover + + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/a": { + "get": { + "responses": { + "500": { + "description": "Error", + "content": { + "application/vnd.api+json": { + "schema": {"$ref": "#/components/schemas/JsonApiError"} + } + }, + }, + "200": { + "description": "Successful Response", + "content": {"application/vnd.api+json": {"schema": {}}}, + }, + }, + "summary": "A", + "operationId": "a_a_get", + } + }, + "/b": { + "get": { + "responses": { + "500": { + "description": "Error", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Error"} + } + }, + }, + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + }, + "summary": "B", + "operationId": "b_b_get", + } + }, + }, + "components": { + "schemas": { + "Error": { + "title": "Error", + "required": ["status", "title"], + "type": "object", + "properties": { + "status": {"title": "Status", "type": "string"}, + "title": {"title": "Title", "type": "string"}, + }, + }, + "JsonApiError": { + "title": "JsonApiError", + "required": ["errors"], + "type": "object", + "properties": { + "errors": { + "title": "Errors", + "type": "array", + "items": {"$ref": "#/components/schemas/Error"}, + } + }, + }, + } + }, +} + + +client = TestClient(app) + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema