diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index b5778327b..bb2e7dff7 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -203,27 +203,31 @@ def get_openapi_path( operation["callbacks"] = callbacks if route.responses: for (additional_status_code, response) in route.responses.items(): + process_response = response.copy() assert isinstance( - response, dict + process_response, dict ), "An additional response must be a dict" field = route.response_fields.get(additional_status_code) if field: response_schema, _, _ = field_schema( field, model_name_map=model_name_map, ref_prefix=REF_PREFIX ) - response.setdefault("content", {}).setdefault( + process_response.setdefault("content", {}).setdefault( route_response_media_type or "application/json", {} )["schema"] = response_schema status_text: Optional[str] = status_code_ranges.get( str(additional_status_code).upper() ) or http.client.responses.get(int(additional_status_code)) - response.setdefault( + process_response.setdefault( "description", status_text or "Additional Response" ) status_code_key = str(additional_status_code).upper() if status_code_key == "DEFAULT": status_code_key = "default" - operation.setdefault("responses", {})[status_code_key] = response + process_response.pop("model", None) + operation.setdefault("responses", {})[ + status_code_key + ] = process_response status_code = str(route.status_code) operation.setdefault("responses", {}).setdefault(status_code, {})[ "description" diff --git a/tests/test_additional_responses_custom_model_in_callback.py b/tests/test_additional_responses_custom_model_in_callback.py new file mode 100644 index 000000000..36dd0d6db --- /dev/null +++ b/tests/test_additional_responses_custom_model_in_callback.py @@ -0,0 +1,138 @@ +from fastapi import APIRouter, FastAPI +from fastapi.testclient import TestClient +from pydantic import BaseModel, HttpUrl +from starlette.responses import JSONResponse + + +class CustomModel(BaseModel): + a: int + + +app = FastAPI() + +callback_router = APIRouter(default_response_class=JSONResponse) + + +@callback_router.get( + "{$callback_url}/callback/", responses={400: {"model": CustomModel}} +) +def callback_route(): + pass # pragma: no cover + + +@app.post("/", callbacks=callback_router.routes) +def main_route(callback_url: HttpUrl): + pass # pragma: no cover + + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/": { + "post": { + "summary": "Main Route", + "operationId": "main_route__post", + "parameters": [ + { + "required": True, + "schema": { + "title": "Callback Url", + "maxLength": 2083, + "minLength": 1, + "type": "string", + "format": "uri", + }, + "name": "callback_url", + "in": "query", + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "callbacks": { + "callback_route": { + "{$callback_url}/callback/": { + "get": { + "summary": "Callback Route", + "operationId": "callback_route__callback_url__callback__get", + "responses": { + "400": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CustomModel" + } + } + }, + "description": "Bad Request", + }, + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + }, + } + } + } + }, + } + } + }, + "components": { + "schemas": { + "CustomModel": { + "title": "CustomModel", + "required": ["a"], + "type": "object", + "properties": {"a": {"title": "A", "type": "integer"}}, + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + } + }, +} + +client = TestClient(app) + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == openapi_schema