diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 40dffba64b..a8c53df9a1 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -139,22 +139,23 @@ def get_flat_dependant( dependant: Dependant, *, skip_repeats: bool = False, - visited: list[DependencyCacheKey] | None = None, + visited: set[DependencyCacheKey] | None = None, parent_oauth_scopes: list[str] | None = None, ) -> Dependant: - if visited is None: - visited = [] - visited.append(dependant.cache_key) - use_parent_oauth_scopes = (parent_oauth_scopes or []) + ( - dependant.oauth_scopes or [] - ) + visited_set: set[DependencyCacheKey] | None = None + if skip_repeats: + visited_set = set() if visited is None else visited + visited_set.add(dependant.cache_key) + if ( + parent_oauth_scopes + or dependant.parent_oauth_scopes + or dependant.own_oauth_scopes + ): + use_parent_oauth_scopes = (parent_oauth_scopes or []) + dependant.oauth_scopes + else: + use_parent_oauth_scopes = [] flat_dependant = Dependant( - path_params=dependant.path_params.copy(), - query_params=dependant.query_params.copy(), - header_params=dependant.header_params.copy(), - cookie_params=dependant.cookie_params.copy(), - body_params=dependant.body_params.copy(), name=dependant.name, call=dependant.call, request_param_name=dependant.request_param_name, @@ -169,21 +170,41 @@ def get_flat_dependant( path=dependant.path, scope=dependant.scope, ) + if dependant.path_params: + flat_dependant.path_params = dependant.path_params.copy() + if dependant.query_params: + flat_dependant.query_params = dependant.query_params.copy() + if dependant.header_params: + flat_dependant.header_params = dependant.header_params.copy() + if dependant.cookie_params: + flat_dependant.cookie_params = dependant.cookie_params.copy() + if dependant.body_params: + flat_dependant.body_params = dependant.body_params.copy() + child_parent_oauth_scopes = ( + flat_dependant.oauth_scopes + if flat_dependant.parent_oauth_scopes or flat_dependant.own_oauth_scopes + else [] + ) for sub_dependant in dependant.dependencies: - if skip_repeats and sub_dependant.cache_key in visited: + if visited_set is not None and sub_dependant.cache_key in visited_set: continue flat_sub = get_flat_dependant( sub_dependant, skip_repeats=skip_repeats, - visited=visited, - parent_oauth_scopes=flat_dependant.oauth_scopes, + visited=visited_set, + parent_oauth_scopes=child_parent_oauth_scopes, ) flat_dependant.dependencies.append(flat_sub) - flat_dependant.path_params.extend(flat_sub.path_params) - flat_dependant.query_params.extend(flat_sub.query_params) - flat_dependant.header_params.extend(flat_sub.header_params) - flat_dependant.cookie_params.extend(flat_sub.cookie_params) - flat_dependant.body_params.extend(flat_sub.body_params) + if flat_sub.path_params: + flat_dependant.path_params.extend(flat_sub.path_params) + if flat_sub.query_params: + flat_dependant.query_params.extend(flat_sub.query_params) + if flat_sub.header_params: + flat_dependant.header_params.extend(flat_sub.header_params) + if flat_sub.cookie_params: + flat_dependant.cookie_params.extend(flat_sub.cookie_params) + if flat_sub.body_params: + flat_dependant.body_params.extend(flat_sub.body_params) flat_dependant.dependencies.extend(flat_sub.dependencies) return flat_dependant @@ -201,8 +222,11 @@ def _get_flat_fields_from_params(fields: list[ModelField]) -> list[ModelField]: return fields -def get_flat_params(dependant: Dependant) -> list[ModelField]: - flat_dependant = get_flat_dependant(dependant, skip_repeats=True) +def get_flat_params( + dependant: Dependant, *, flat_dependant: Dependant | None = None +) -> list[ModelField]: + if flat_dependant is None: + flat_dependant = get_flat_dependant(dependant, skip_repeats=True) path_params = _get_flat_fields_from_params(flat_dependant.path_params) query_params = _get_flat_fields_from_params(flat_dependant.query_params) header_params = _get_flat_fields_from_params(flat_dependant.header_params) diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 1c7a17c4ca..095c010ea8 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -82,31 +82,34 @@ def get_openapi_security_definitions( flat_dependant: Dependant, ) -> tuple[dict[str, Any], list[dict[str, Any]]]: security_definitions = {} + security_definition_cache: dict[int, dict[str, Any]] = {} # Use a dict to merge scopes for same security scheme - operation_security_dict: dict[str, list[str]] = {} + operation_security_dict: dict[str, dict[str, None]] = {} for security_dependency in flat_dependant._security_dependencies: - security_definition = jsonable_encoder( - security_dependency._security_scheme.model, - by_alias=True, - exclude_none=True, - ) - security_name = security_dependency._security_scheme.scheme_name + security_scheme = security_dependency._security_scheme + security_name = security_scheme.scheme_name + security_definition = security_definition_cache.get(id(security_scheme)) + if security_definition is None: + security_definition = jsonable_encoder( + security_scheme.model, + by_alias=True, + exclude_none=True, + ) + security_definition_cache[id(security_scheme)] = security_definition security_definitions[security_name] = security_definition - # Merge scopes for the same security scheme - if security_name not in operation_security_dict: - operation_security_dict[security_name] = [] - for scope in security_dependency.oauth_scopes or []: - if scope not in operation_security_dict[security_name]: - operation_security_dict[security_name].append(scope) + # Merge scopes for the same security scheme, preserving insertion order + operation_security_dict.setdefault(security_name, {}).update( + dict.fromkeys(security_dependency.oauth_scopes) + ) operation_security = [ - {name: scopes} for name, scopes in operation_security_dict.items() + {name: list(scopes)} for name, scopes in operation_security_dict.items() ] return security_definitions, operation_security def _get_openapi_operation_parameters( *, - dependant: Dependant, + flat_dependant: Dependant, model_name_map: ModelNameMap, field_mapping: dict[ tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any] @@ -114,7 +117,6 @@ def _get_openapi_operation_parameters( separate_input_output_schemas: bool = True, ) -> list[dict[str, Any]]: parameters = [] - flat_dependant = get_flat_dependant(dependant, skip_repeats=True) path_params = _get_flat_fields_from_params(flat_dependant.path_params) query_params = _get_flat_fields_from_params(flat_dependant.query_params) header_params = _get_flat_fields_from_params(flat_dependant.header_params) @@ -278,12 +280,12 @@ def get_openapi_path( assert current_response_class, "A response class is needed to generate OpenAPI" route_response_media_type: str | None = current_response_class.media_type if route.include_in_schema: + flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True) for method in route.methods: operation = get_openapi_operation_metadata( route=route, method=method, operation_ids=operation_ids ) parameters: list[dict[str, Any]] = [] - flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True) security_definitions, operation_security = get_openapi_security_definitions( flat_dependant=flat_dependant ) @@ -292,7 +294,7 @@ def get_openapi_path( if security_definitions: security_schemes.update(security_definitions) operation_parameters = _get_openapi_operation_parameters( - dependant=route.dependant, + flat_dependant=flat_dependant, model_name_map=model_name_map, field_mapping=field_mapping, separate_input_output_schemas=separate_input_output_schemas, @@ -452,7 +454,9 @@ def get_openapi_path( deep_dict_update(openapi_response, process_response) openapi_response["description"] = description http422 = "422" - all_route_params = get_flat_params(route.dependant) + all_route_params = get_flat_params( + route.dependant, flat_dependant=flat_dependant + ) if (all_route_params or route.body_field) and not any( status in operation["responses"] for status in [http422, "4XX", "default"]