Browse Source

feat: add schema generator class param to FastAPI

This feature allows setting a custom JSON schema generator at
app-level, in order to customize JSON schema from pydantic-core
schema generation as described in the following pydantic guide:

https://docs.pydantic.dev/2.1/usage/json_schema/#customizing-the-json-schema-generation-process
pull/12455/head
Thomas Touhey 6 months ago
parent
commit
af47e4e34f
  1. 14
      fastapi/applications.py
  2. 3
      fastapi/openapi/utils.py
  3. 197
      tests/test_openapi_schema_generator.py

14
fastapi/applications.py

@ -14,6 +14,7 @@ from typing import (
) )
from fastapi import routing from fastapi import routing
from fastapi._compat import GenerateJsonSchema
from fastapi.datastructures import Default, DefaultPlaceholder from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.exception_handlers import ( from fastapi.exception_handlers import (
http_exception_handler, http_exception_handler,
@ -752,6 +753,17 @@ class FastAPI(Starlette):
""" """
), ),
] = True, ] = True,
schema_generator_class: Annotated[
Type[GenerateJsonSchema],
Doc(
"""
Schema generator to use for the OpenAPI specification.
You probably don't need it, but it's available.
This affects the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = GenerateJsonSchema,
swagger_ui_parameters: Annotated[ swagger_ui_parameters: Annotated[
Optional[Dict[str, Any]], Optional[Dict[str, Any]],
Doc( Doc(
@ -833,6 +845,7 @@ class FastAPI(Starlette):
self.root_path_in_servers = root_path_in_servers self.root_path_in_servers = root_path_in_servers
self.docs_url = docs_url self.docs_url = docs_url
self.redoc_url = redoc_url self.redoc_url = redoc_url
self.schema_generator_class = schema_generator_class
self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url
self.swagger_ui_init_oauth = swagger_ui_init_oauth self.swagger_ui_init_oauth = swagger_ui_init_oauth
self.swagger_ui_parameters = swagger_ui_parameters self.swagger_ui_parameters = swagger_ui_parameters
@ -992,6 +1005,7 @@ class FastAPI(Starlette):
tags=self.openapi_tags, tags=self.openapi_tags,
servers=self.servers, servers=self.servers,
separate_input_output_schemas=self.separate_input_output_schemas, separate_input_output_schemas=self.separate_input_output_schemas,
schema_generator_class=self.schema_generator_class,
) )
return self.openapi_schema return self.openapi_schema

3
fastapi/openapi/utils.py

@ -468,6 +468,7 @@ def get_openapi(
contact: Optional[Dict[str, Union[str, Any]]] = None, contact: Optional[Dict[str, Union[str, Any]]] = None,
license_info: Optional[Dict[str, Union[str, Any]]] = None, license_info: Optional[Dict[str, Union[str, Any]]] = None,
separate_input_output_schemas: bool = True, separate_input_output_schemas: bool = True,
schema_generator_class: Type[GenerateJsonSchema] = GenerateJsonSchema,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
info: Dict[str, Any] = {"title": title, "version": version} info: Dict[str, Any] = {"title": title, "version": version}
if summary: if summary:
@ -489,7 +490,7 @@ def get_openapi(
operation_ids: Set[str] = set() operation_ids: Set[str] = set()
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or [])) all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
model_name_map = get_compat_model_name_map(all_fields) model_name_map = get_compat_model_name_map(all_fields)
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) schema_generator = schema_generator_class(ref_template=REF_TEMPLATE)
field_mapping, definitions = get_definitions( field_mapping, definitions = get_definitions(
fields=all_fields, fields=all_fields,
schema_generator=schema_generator, schema_generator=schema_generator,

197
tests/test_openapi_schema_generator.py

@ -0,0 +1,197 @@
from typing import Literal
from fastapi import FastAPI, Query
from fastapi._compat import PYDANTIC_V2, GenerateJsonSchema
from fastapi.openapi.utils import JsonSchemaValue
from fastapi.testclient import TestClient
class MyGenerateJsonSchema(GenerateJsonSchema):
"""Custom JSON schema generator."""
def literal_schema(self, schema) -> JsonSchemaValue:
result = super().literal_schema(schema)
if "const" not in result:
return result
# Here, we want to exclude "enum" and "type" from the result.
return {
key: value for key, value in result.items() if key not in ("type", "enum")
}
app = FastAPI(schema_generator_class=MyGenerateJsonSchema)
@app.get("/foo")
def foo(
my_const_param: Literal["cd"] = Query(),
my_enum_param: Literal["h", "ij"] = Query(),
my_const_param_with_default: Literal["ab"] = Query(default="ab"),
my_enum_param_with_default: Literal["ef", "g"] = Query(default="g"),
):
return {"message": "Hello World"}
client = TestClient(app)
def test_app():
response = client.get("/foo?my_const_param=cd&my_enum_param=ij")
assert response.status_code == 200, response.text
def test_openapi_schema():
if PYDANTIC_V2:
parameters = [
{
"name": "my_const_param",
"in": "query",
"required": True,
"schema": {
"title": "My Const Param",
"const": "cd",
},
},
{
"name": "my_enum_param",
"in": "query",
"required": True,
"schema": {
"title": "My Enum Param",
"type": "string",
"enum": ["h", "ij"],
},
},
{
"name": "my_const_param_with_default",
"in": "query",
"required": False,
"schema": {
"title": "My Const Param With Default",
"const": "ab",
"default": "ab",
},
},
{
"name": "my_enum_param_with_default",
"in": "query",
"required": False,
"schema": {
"title": "My Enum Param With Default",
"type": "string",
"enum": ["ef", "g"],
"default": "g",
},
},
]
else:
# pydantic v1 does not use a JSON schema generator, and FastAPI
# only defines it for compatibility.
parameters = [
{
"name": "my_const_param",
"in": "query",
"required": True,
"schema": {
"title": "My Const Param",
"type": "string",
"enum": ["cd"],
},
},
{
"name": "my_enum_param",
"in": "query",
"required": True,
"schema": {
"title": "My Enum Param",
"type": "string",
"enum": ["h", "ij"],
},
},
{
"name": "my_const_param_with_default",
"in": "query",
"required": False,
"schema": {
"title": "My Const Param With Default",
"type": "string",
"enum": ["ab"],
"default": "ab",
},
},
{
"name": "my_enum_param_with_default",
"in": "query",
"required": False,
"schema": {
"title": "My Enum Param With Default",
"type": "string",
"enum": ["ef", "g"],
"default": "g",
},
},
]
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == {
"components": {
"schemas": {
"HTTPValidationError": {
"properties": {
"detail": {
"items": {"$ref": "#/components/schemas/ValidationError"},
"title": "Detail",
"type": "array",
}
},
"title": "HTTPValidationError",
"type": "object",
},
"ValidationError": {
"properties": {
"loc": {
"items": {
"anyOf": [{"type": "string"}, {"type": "integer"}]
},
"title": "Location",
"type": "array",
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
"required": ["loc", "msg", "type"],
"title": "ValidationError",
"type": "object",
},
}
},
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/foo": {
"get": {
"summary": "Foo",
"operationId": "foo_foo_get",
"parameters": parameters,
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
}
}
},
}
Loading…
Cancel
Save