Browse Source

perf: optimize OpenAPI schema generation (get_flat_dependant)

Reduce redundant work in the OpenAPI schema-generation path with no
change to generated output.

- get_openapi_path: compute get_flat_dependant once per route instead of
  three times, threading the result into _get_openapi_operation_parameters
  and get_flat_params.
- get_flat_dependant: use a set for visited-tracking (O(1) membership),
  skip tracking entirely when skip_repeats is False, copy param lists
  copy-on-write, and short-circuit oauth-scope computation when there are
  no scopes.
- get_openapi_security_definitions: O(1) scope dedup via dict.fromkeys
  (insertion order preserved) and cache jsonable_encoder per security
  scheme instead of per referencing dependency.
pull/15638/head
sebastianbreguel 7 days ago
parent
commit
8426695d29
  1. 68
      fastapi/dependencies/utils.py
  2. 42
      fastapi/openapi/utils.py

68
fastapi/dependencies/utils.py

@ -139,22 +139,23 @@ def get_flat_dependant(
dependant: Dependant,
*,
skip_repeats: bool = False,
visited: list[DependencyCacheKey] | None = None,
visited: set[DependencyCacheKey] | None = None,
parent_oauth_scopes: list[str] | None = None,
) -> Dependant:
if visited is None:
visited = []
visited.append(dependant.cache_key)
use_parent_oauth_scopes = (parent_oauth_scopes or []) + (
dependant.oauth_scopes or []
)
visited_set: set[DependencyCacheKey] | None = None
if skip_repeats:
visited_set = set() if visited is None else visited
visited_set.add(dependant.cache_key)
if (
parent_oauth_scopes
or dependant.parent_oauth_scopes
or dependant.own_oauth_scopes
):
use_parent_oauth_scopes = (parent_oauth_scopes or []) + dependant.oauth_scopes
else:
use_parent_oauth_scopes = []
flat_dependant = Dependant(
path_params=dependant.path_params.copy(),
query_params=dependant.query_params.copy(),
header_params=dependant.header_params.copy(),
cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(),
name=dependant.name,
call=dependant.call,
request_param_name=dependant.request_param_name,
@ -169,21 +170,41 @@ def get_flat_dependant(
path=dependant.path,
scope=dependant.scope,
)
if dependant.path_params:
flat_dependant.path_params = dependant.path_params.copy()
if dependant.query_params:
flat_dependant.query_params = dependant.query_params.copy()
if dependant.header_params:
flat_dependant.header_params = dependant.header_params.copy()
if dependant.cookie_params:
flat_dependant.cookie_params = dependant.cookie_params.copy()
if dependant.body_params:
flat_dependant.body_params = dependant.body_params.copy()
child_parent_oauth_scopes = (
flat_dependant.oauth_scopes
if flat_dependant.parent_oauth_scopes or flat_dependant.own_oauth_scopes
else []
)
for sub_dependant in dependant.dependencies:
if skip_repeats and sub_dependant.cache_key in visited:
if visited_set is not None and sub_dependant.cache_key in visited_set:
continue
flat_sub = get_flat_dependant(
sub_dependant,
skip_repeats=skip_repeats,
visited=visited,
parent_oauth_scopes=flat_dependant.oauth_scopes,
visited=visited_set,
parent_oauth_scopes=child_parent_oauth_scopes,
)
flat_dependant.dependencies.append(flat_sub)
flat_dependant.path_params.extend(flat_sub.path_params)
flat_dependant.query_params.extend(flat_sub.query_params)
flat_dependant.header_params.extend(flat_sub.header_params)
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
flat_dependant.body_params.extend(flat_sub.body_params)
if flat_sub.path_params:
flat_dependant.path_params.extend(flat_sub.path_params)
if flat_sub.query_params:
flat_dependant.query_params.extend(flat_sub.query_params)
if flat_sub.header_params:
flat_dependant.header_params.extend(flat_sub.header_params)
if flat_sub.cookie_params:
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
if flat_sub.body_params:
flat_dependant.body_params.extend(flat_sub.body_params)
flat_dependant.dependencies.extend(flat_sub.dependencies)
return flat_dependant
@ -201,8 +222,11 @@ def _get_flat_fields_from_params(fields: list[ModelField]) -> list[ModelField]:
return fields
def get_flat_params(dependant: Dependant) -> list[ModelField]:
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
def get_flat_params(
dependant: Dependant, *, flat_dependant: Dependant | None = None
) -> list[ModelField]:
if flat_dependant is None:
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
header_params = _get_flat_fields_from_params(flat_dependant.header_params)

42
fastapi/openapi/utils.py

@ -82,31 +82,34 @@ def get_openapi_security_definitions(
flat_dependant: Dependant,
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
security_definitions = {}
security_definition_cache: dict[int, dict[str, Any]] = {}
# Use a dict to merge scopes for same security scheme
operation_security_dict: dict[str, list[str]] = {}
operation_security_dict: dict[str, dict[str, None]] = {}
for security_dependency in flat_dependant._security_dependencies:
security_definition = jsonable_encoder(
security_dependency._security_scheme.model,
by_alias=True,
exclude_none=True,
)
security_name = security_dependency._security_scheme.scheme_name
security_scheme = security_dependency._security_scheme
security_name = security_scheme.scheme_name
security_definition = security_definition_cache.get(id(security_scheme))
if security_definition is None:
security_definition = jsonable_encoder(
security_scheme.model,
by_alias=True,
exclude_none=True,
)
security_definition_cache[id(security_scheme)] = security_definition
security_definitions[security_name] = security_definition
# Merge scopes for the same security scheme
if security_name not in operation_security_dict:
operation_security_dict[security_name] = []
for scope in security_dependency.oauth_scopes or []:
if scope not in operation_security_dict[security_name]:
operation_security_dict[security_name].append(scope)
# Merge scopes for the same security scheme, preserving insertion order
operation_security_dict.setdefault(security_name, {}).update(
dict.fromkeys(security_dependency.oauth_scopes)
)
operation_security = [
{name: scopes} for name, scopes in operation_security_dict.items()
{name: list(scopes)} for name, scopes in operation_security_dict.items()
]
return security_definitions, operation_security
def _get_openapi_operation_parameters(
*,
dependant: Dependant,
flat_dependant: Dependant,
model_name_map: ModelNameMap,
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
@ -114,7 +117,6 @@ def _get_openapi_operation_parameters(
separate_input_output_schemas: bool = True,
) -> list[dict[str, Any]]:
parameters = []
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
header_params = _get_flat_fields_from_params(flat_dependant.header_params)
@ -278,12 +280,12 @@ def get_openapi_path(
assert current_response_class, "A response class is needed to generate OpenAPI"
route_response_media_type: str | None = current_response_class.media_type
if route.include_in_schema:
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
for method in route.methods:
operation = get_openapi_operation_metadata(
route=route, method=method, operation_ids=operation_ids
)
parameters: list[dict[str, Any]] = []
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
security_definitions, operation_security = get_openapi_security_definitions(
flat_dependant=flat_dependant
)
@ -292,7 +294,7 @@ def get_openapi_path(
if security_definitions:
security_schemes.update(security_definitions)
operation_parameters = _get_openapi_operation_parameters(
dependant=route.dependant,
flat_dependant=flat_dependant,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
@ -452,7 +454,9 @@ def get_openapi_path(
deep_dict_update(openapi_response, process_response)
openapi_response["description"] = description
http422 = "422"
all_route_params = get_flat_params(route.dependant)
all_route_params = get_flat_params(
route.dependant, flat_dependant=flat_dependant
)
if (all_route_params or route.body_field) and not any(
status in operation["responses"]
for status in [http422, "4XX", "default"]

Loading…
Cancel
Save