diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 808646cc2..4c8530d34 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -489,6 +489,7 @@ def get_openapi( contact: Optional[Dict[str, Union[str, Any]]] = None, license_info: Optional[Dict[str, Union[str, Any]]] = None, separate_input_output_schemas: bool = True, + schema_generator: Optional[GenerateJsonSchema] = None, ) -> Dict[str, Any]: info: Dict[str, Any] = {"title": title, "version": version} if summary: @@ -510,7 +511,7 @@ def get_openapi( operation_ids: Set[str] = set() all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or [])) model_name_map = get_compat_model_name_map(all_fields) - schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) + schema_generator = schema_generator or GenerateJsonSchema(ref_template=REF_TEMPLATE) field_mapping, definitions = get_definitions( fields=all_fields, schema_generator=schema_generator, diff --git a/tests/test_openapi_custom_schema_generator.py b/tests/test_openapi_custom_schema_generator.py new file mode 100644 index 000000000..b547992fc --- /dev/null +++ b/tests/test_openapi_custom_schema_generator.py @@ -0,0 +1,82 @@ +import pytest +from fastapi import FastAPI +from fastapi._compat import PYDANTIC_V2, GenerateJsonSchema +from fastapi.openapi.constants import REF_TEMPLATE +from fastapi.openapi.utils import get_openapi + +app = FastAPI() + + +@app.get("/") +def read_root(): + pass # pragma: no cover + + +# Custom schema generator that does nothing but tracks if it was called +class CustomJsonSchemaGenerator(GenerateJsonSchema): + def __init__(self): + super().__init__(ref_template=REF_TEMPLATE) + self.called = False + + def generate_definitions(self, *args, **kwargs): + self.called = True + return super().generate_definitions(*args, **kwargs) + + +def test_custom_schema_generator_called(): + custom_schema_generator = CustomJsonSchemaGenerator() + get_openapi( + title=app.title, + version=app.version, + routes=app.routes, + schema_generator=custom_schema_generator, + ) + + if PYDANTIC_V2: + assert custom_schema_generator.called is True + else: + assert ( # Pydantic v1 does not use custom schema generators + custom_schema_generator.called is False + ) + + +@pytest.mark.parametrize("use_custom_schema_generator", [True, False]) +def test_custom_schema_generator_openapi(use_custom_schema_generator: bool): + custom_schema_generator = ( + CustomJsonSchemaGenerator() if use_custom_schema_generator else None + ) + openapi = get_openapi( + title=app.title, + version=app.version, + routes=app.routes, + schema_generator=custom_schema_generator, + ) + + assert openapi == OPENAPI_SCHEMA + + +OPENAPI_SCHEMA = { + "info": { + "title": "FastAPI", + "version": "0.1.0", + }, + "openapi": "3.1.0", + "paths": { + "/": { + "get": { + "operationId": "read_root__get", + "responses": { + "200": { + "content": { + "application/json": { + "schema": {}, + }, + }, + "description": "Successful Response", + }, + }, + "summary": "Read Root", + }, + }, + }, +}