diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 43ab4a098..1a660f5d3 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -188,6 +188,16 @@ def get_flat_dependant( return flat_dependant +def get_flat_params(dependant: Dependant) -> List[ModelField]: + flat_dependant = get_flat_dependant(dependant, skip_repeats=True) + return ( + flat_dependant.path_params + + flat_dependant.query_params + + flat_dependant.header_params + + flat_dependant.cookie_params + ) + + def is_scalar_field(field: ModelField) -> bool: field_info = get_field_info(field) if not ( diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index c1e66fc8d..b5778327b 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -1,9 +1,10 @@ import http.client -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast +from enum import Enum +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast from fastapi import routing from fastapi.dependencies.models import Dependant -from fastapi.dependencies.utils import get_flat_dependant +from fastapi.dependencies.utils import get_flat_dependant, get_flat_params from fastapi.encoders import jsonable_encoder from fastapi.openapi.constants import ( METHODS_WITH_BODY, @@ -15,11 +16,14 @@ from fastapi.params import Body, Param from fastapi.utils import ( generate_operation_id_for_path, get_field_info, - get_flat_models_from_routes, get_model_definitions, ) from pydantic import BaseModel -from pydantic.schema import field_schema, get_model_name_map +from pydantic.schema import ( + field_schema, + get_flat_models_from_fields, + get_model_name_map, +) from pydantic.utils import lenient_issubclass from starlette.responses import JSONResponse from starlette.routing import BaseRoute @@ -64,16 +68,6 @@ status_code_ranges: Dict[str, str] = { } -def get_openapi_params(dependant: Dependant) -> List[ModelField]: - flat_dependant = get_flat_dependant(dependant, skip_repeats=True) - return ( - flat_dependant.path_params - + flat_dependant.query_params - + flat_dependant.header_params - + flat_dependant.cookie_params - ) - - def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]: security_definitions = {} operation_security = [] @@ -90,17 +84,22 @@ def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, L def get_openapi_operation_parameters( + *, all_route_params: Sequence[ModelField], + model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str] ) -> List[Dict[str, Any]]: parameters = [] for param in all_route_params: field_info = get_field_info(param) field_info = cast(Param, field_info) + # ignore mypy error until enum schemas are released parameter = { "name": param.alias, "in": field_info.in_.value, "required": param.required, - "schema": field_schema(param, model_name_map={})[0], + "schema": field_schema( + param, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore + )[0], } if field_info.description: parameter["description"] = field_info.description @@ -111,13 +110,16 @@ def get_openapi_operation_parameters( def get_openapi_operation_request_body( - *, body_field: Optional[ModelField], model_name_map: Dict[Type[BaseModel], str] + *, + body_field: Optional[ModelField], + model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str] ) -> Optional[Dict]: if not body_field: return None assert isinstance(body_field, ModelField) + # ignore mypy error until enum schemas are released body_schema, _, _ = field_schema( - body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX + body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore ) field_info = cast(Body, get_field_info(body_field)) request_media_type = field_info.media_type @@ -176,8 +178,10 @@ def get_openapi_path( operation.setdefault("security", []).extend(operation_security) if security_definitions: security_schemes.update(security_definitions) - all_route_params = get_openapi_params(route.dependant) - operation_parameters = get_openapi_operation_parameters(all_route_params) + all_route_params = get_flat_params(route.dependant) + operation_parameters = get_openapi_operation_parameters( + all_route_params=all_route_params, model_name_map=model_name_map + ) parameters.extend(operation_parameters) if parameters: operation["parameters"] = list( @@ -270,6 +274,38 @@ def get_openapi_path( return path, security_schemes, definitions +def get_flat_models_from_routes( + routes: Sequence[BaseRoute], +) -> Set[Union[Type[BaseModel], Type[Enum]]]: + body_fields_from_routes: List[ModelField] = [] + responses_from_routes: List[ModelField] = [] + request_fields_from_routes: List[ModelField] = [] + callback_flat_models: Set[Union[Type[BaseModel], Type[Enum]]] = set() + for route in routes: + if getattr(route, "include_in_schema", None) and isinstance( + route, routing.APIRoute + ): + if route.body_field: + assert isinstance( + route.body_field, ModelField + ), "A request body must be a Pydantic Field" + body_fields_from_routes.append(route.body_field) + if route.response_field: + responses_from_routes.append(route.response_field) + if route.response_fields: + responses_from_routes.extend(route.response_fields.values()) + if route.callbacks: + callback_flat_models |= get_flat_models_from_routes(route.callbacks) + params = get_flat_params(route.dependant) + request_fields_from_routes.extend(params) + + flat_models = callback_flat_models | get_flat_models_from_fields( + body_fields_from_routes + responses_from_routes + request_fields_from_routes, + known_models=set(), + ) + return flat_models + + def get_openapi( *, title: str, @@ -286,9 +322,11 @@ def get_openapi( components: Dict[str, Dict] = {} paths: Dict[str, Dict] = {} flat_models = get_flat_models_from_routes(routes) - model_name_map = get_model_name_map(flat_models) + # ignore mypy error until enum schemas are released + model_name_map = get_model_name_map(flat_models) # type: ignore + # ignore mypy error until enum schemas are released definitions = get_model_definitions( - flat_models=flat_models, model_name_map=model_name_map + flat_models=flat_models, model_name_map=model_name_map # type: ignore ) for route in routes: if isinstance(route, routing.APIRoute): diff --git a/fastapi/utils.py b/fastapi/utils.py index 154dd9aa1..c9022fbc3 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -1,17 +1,16 @@ import functools import re from dataclasses import is_dataclass -from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union, cast +from enum import Enum +from typing import Any, Dict, Optional, Set, Type, Union, cast import fastapi -from fastapi import routing from fastapi.logger import logger from fastapi.openapi.constants import REF_PREFIX from pydantic import BaseConfig, BaseModel, create_model from pydantic.class_validators import Validator -from pydantic.schema import get_flat_models_from_fields, model_process_schema +from pydantic.schema import model_process_schema from pydantic.utils import lenient_issubclass -from starlette.routing import BaseRoute try: from pydantic.fields import FieldInfo, ModelField, UndefinedType @@ -50,38 +49,16 @@ def warning_response_model_skip_defaults_deprecated() -> None: ) -def get_flat_models_from_routes(routes: Sequence[BaseRoute]) -> Set[Type[BaseModel]]: - body_fields_from_routes: List[ModelField] = [] - responses_from_routes: List[ModelField] = [] - callback_flat_models: Set[Type[BaseModel]] = set() - for route in routes: - if getattr(route, "include_in_schema", None) and isinstance( - route, routing.APIRoute - ): - if route.body_field: - assert isinstance( - route.body_field, ModelField - ), "A request body must be a Pydantic Field" - body_fields_from_routes.append(route.body_field) - if route.response_field: - responses_from_routes.append(route.response_field) - if route.response_fields: - responses_from_routes.extend(route.response_fields.values()) - if route.callbacks: - callback_flat_models |= get_flat_models_from_routes(route.callbacks) - flat_models = callback_flat_models | get_flat_models_from_fields( - body_fields_from_routes + responses_from_routes, known_models=set() - ) - return flat_models - - def get_model_definitions( - *, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str] + *, + flat_models: Set[Union[Type[BaseModel], Type[Enum]]], + model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], ) -> Dict[str, Any]: definitions: Dict[str, Dict] = {} for model in flat_models: + # ignore mypy error until enum schemas are released m_schema, m_definitions, m_nested_models = model_process_schema( - model, model_name_map=model_name_map, ref_prefix=REF_PREFIX + model, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore ) definitions.update(m_definitions) model_name = model_name_map[model] diff --git a/scripts/format.sh b/scripts/format.sh index bbcb04354..07ce78f69 100755 --- a/scripts/format.sh +++ b/scripts/format.sh @@ -3,4 +3,4 @@ set -x autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place docs_src fastapi tests scripts --exclude=__init__.py black fastapi tests docs_src scripts -isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --thirdparty fastapi --apply fastapi tests docs_src scripts +isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --thirdparty fastapi --thirdparty pydantic --thirdparty starlette --apply fastapi tests docs_src scripts diff --git a/scripts/lint.sh b/scripts/lint.sh index 6472f1845..ec0f7f41b 100755 --- a/scripts/lint.sh +++ b/scripts/lint.sh @@ -5,4 +5,4 @@ set -x mypy fastapi black fastapi tests --check -isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --check-only --thirdparty fastapi fastapi tests +isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --check-only --thirdparty fastapi --thirdparty fastapi --thirdparty pydantic --thirdparty starlette fastapi tests diff --git a/tests/test_tutorial/test_path_params/test_tutorial005.py b/tests/test_tutorial/test_path_params/test_tutorial005.py index b0e0535e8..836a6264b 100644 --- a/tests/test_tutorial/test_path_params/test_tutorial005.py +++ b/tests/test_tutorial/test_path_params/test_tutorial005.py @@ -87,7 +87,7 @@ openapi_schema2 = { "parameters": [ { "required": True, - "schema": {"$ref": "#/definitions/ModelName"}, + "schema": {"$ref": "#/components/schemas/ModelName"}, "name": "model_name", "in": "path", } @@ -124,6 +124,12 @@ openapi_schema2 = { } }, }, + "ModelName": { + "title": "ModelName", + "enum": ["alexnet", "resnet", "lenet"], + "type": "string", + "description": "An enumeration.", + }, "ValidationError": { "title": "ValidationError", "required": ["loc", "msg", "type"],