diff --git a/fastapi/applications.py b/fastapi/applications.py index 6d427cdc2..d68123092 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -14,6 +14,7 @@ from typing import ( ) from fastapi import routing +from fastapi._compat import GenerateJsonSchema from fastapi.datastructures import Default, DefaultPlaceholder from fastapi.exception_handlers import ( http_exception_handler, @@ -752,6 +753,17 @@ class FastAPI(Starlette): """ ), ] = True, + schema_generator_class: Annotated[ + Type[GenerateJsonSchema], + Doc( + """ + Schema generator to use for the OpenAPI specification. + You probably don't need it, but it's available. + + This affects the generated OpenAPI (e.g. visible at `/docs`). + """ + ), + ] = GenerateJsonSchema, swagger_ui_parameters: Annotated[ Optional[Dict[str, Any]], Doc( @@ -833,6 +845,7 @@ class FastAPI(Starlette): self.root_path_in_servers = root_path_in_servers self.docs_url = docs_url self.redoc_url = redoc_url + self.schema_generator_class = schema_generator_class self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url self.swagger_ui_init_oauth = swagger_ui_init_oauth self.swagger_ui_parameters = swagger_ui_parameters @@ -992,6 +1005,7 @@ class FastAPI(Starlette): tags=self.openapi_tags, servers=self.servers, separate_input_output_schemas=self.separate_input_output_schemas, + schema_generator_class=self.schema_generator_class, ) return self.openapi_schema diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index bd8f3c106..de6ff15a8 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -468,6 +468,7 @@ def get_openapi( contact: Optional[Dict[str, Union[str, Any]]] = None, license_info: Optional[Dict[str, Union[str, Any]]] = None, separate_input_output_schemas: bool = True, + schema_generator_class: Type[GenerateJsonSchema] = GenerateJsonSchema, ) -> Dict[str, Any]: info: Dict[str, Any] = {"title": title, "version": version} if summary: @@ -489,7 +490,7 @@ def get_openapi( operation_ids: Set[str] = set() all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or [])) model_name_map = get_compat_model_name_map(all_fields) - schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) + schema_generator = schema_generator_class(ref_template=REF_TEMPLATE) field_mapping, definitions = get_definitions( fields=all_fields, schema_generator=schema_generator, diff --git a/tests/test_openapi_schema_generator.py b/tests/test_openapi_schema_generator.py new file mode 100644 index 000000000..669850b2b --- /dev/null +++ b/tests/test_openapi_schema_generator.py @@ -0,0 +1,197 @@ +from typing import Literal + +from fastapi import FastAPI, Query +from fastapi._compat import PYDANTIC_V2, GenerateJsonSchema +from fastapi.openapi.utils import JsonSchemaValue +from fastapi.testclient import TestClient + + +class MyGenerateJsonSchema(GenerateJsonSchema): + """Custom JSON schema generator.""" + + def literal_schema(self, schema) -> JsonSchemaValue: + result = super().literal_schema(schema) + if "const" not in result: + return result + + # Here, we want to exclude "enum" and "type" from the result. + return { + key: value for key, value in result.items() if key not in ("type", "enum") + } + + +app = FastAPI(schema_generator_class=MyGenerateJsonSchema) + + +@app.get("/foo") +def foo( + my_const_param: Literal["cd"] = Query(), + my_enum_param: Literal["h", "ij"] = Query(), + my_const_param_with_default: Literal["ab"] = Query(default="ab"), + my_enum_param_with_default: Literal["ef", "g"] = Query(default="g"), +): + return {"message": "Hello World"} + + +client = TestClient(app) + + +def test_app(): + response = client.get("/foo?my_const_param=cd&my_enum_param=ij") + assert response.status_code == 200, response.text + + +def test_openapi_schema(): + if PYDANTIC_V2: + parameters = [ + { + "name": "my_const_param", + "in": "query", + "required": True, + "schema": { + "title": "My Const Param", + "const": "cd", + }, + }, + { + "name": "my_enum_param", + "in": "query", + "required": True, + "schema": { + "title": "My Enum Param", + "type": "string", + "enum": ["h", "ij"], + }, + }, + { + "name": "my_const_param_with_default", + "in": "query", + "required": False, + "schema": { + "title": "My Const Param With Default", + "const": "ab", + "default": "ab", + }, + }, + { + "name": "my_enum_param_with_default", + "in": "query", + "required": False, + "schema": { + "title": "My Enum Param With Default", + "type": "string", + "enum": ["ef", "g"], + "default": "g", + }, + }, + ] + else: + # pydantic v1 does not use a JSON schema generator, and FastAPI + # only defines it for compatibility. + parameters = [ + { + "name": "my_const_param", + "in": "query", + "required": True, + "schema": { + "title": "My Const Param", + "type": "string", + "enum": ["cd"], + }, + }, + { + "name": "my_enum_param", + "in": "query", + "required": True, + "schema": { + "title": "My Enum Param", + "type": "string", + "enum": ["h", "ij"], + }, + }, + { + "name": "my_const_param_with_default", + "in": "query", + "required": False, + "schema": { + "title": "My Const Param With Default", + "type": "string", + "enum": ["ab"], + "default": "ab", + }, + }, + { + "name": "my_enum_param_with_default", + "in": "query", + "required": False, + "schema": { + "title": "My Enum Param With Default", + "type": "string", + "enum": ["ef", "g"], + "default": "g", + }, + }, + ] + + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == { + "components": { + "schemas": { + "HTTPValidationError": { + "properties": { + "detail": { + "items": {"$ref": "#/components/schemas/ValidationError"}, + "title": "Detail", + "type": "array", + } + }, + "title": "HTTPValidationError", + "type": "object", + }, + "ValidationError": { + "properties": { + "loc": { + "items": { + "anyOf": [{"type": "string"}, {"type": "integer"}] + }, + "title": "Location", + "type": "array", + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + "required": ["loc", "msg", "type"], + "title": "ValidationError", + "type": "object", + }, + } + }, + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/foo": { + "get": { + "summary": "Foo", + "operationId": "foo_foo_get", + "parameters": parameters, + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + } + }, + }