Browse Source

🐛 Fix Enum handling with their own schema definitions (#1463)

* 🐛 Fix extra support for enum with its own schema

*  Fix/update test for enum with its own schema

* 🐛 Fix type declarations

* 🔧 Update format and lint scripts to support locally installed Pydantic and Starlette

* 🐛 Add temporary type ignores while enum schemas are merged
pull/1467/head
Sebastián Ramírez 5 years ago
committed by GitHub
parent
commit
5984233223
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      fastapi/dependencies/utils.py
  2. 80
      fastapi/openapi/utils.py
  3. 39
      fastapi/utils.py
  4. 2
      scripts/format.sh
  5. 2
      scripts/lint.sh
  6. 8
      tests/test_tutorial/test_path_params/test_tutorial005.py

10
fastapi/dependencies/utils.py

@ -188,6 +188,16 @@ def get_flat_dependant(
return flat_dependant
def get_flat_params(dependant: Dependant) -> List[ModelField]:
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
return (
flat_dependant.path_params
+ flat_dependant.query_params
+ flat_dependant.header_params
+ flat_dependant.cookie_params
)
def is_scalar_field(field: ModelField) -> bool:
field_info = get_field_info(field)
if not (

80
fastapi/openapi/utils.py

@ -1,9 +1,10 @@
import http.client
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
from fastapi import routing
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_flat_dependant
from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.constants import (
METHODS_WITH_BODY,
@ -15,11 +16,14 @@ from fastapi.params import Body, Param
from fastapi.utils import (
generate_operation_id_for_path,
get_field_info,
get_flat_models_from_routes,
get_model_definitions,
)
from pydantic import BaseModel
from pydantic.schema import field_schema, get_model_name_map
from pydantic.schema import (
field_schema,
get_flat_models_from_fields,
get_model_name_map,
)
from pydantic.utils import lenient_issubclass
from starlette.responses import JSONResponse
from starlette.routing import BaseRoute
@ -64,16 +68,6 @@ status_code_ranges: Dict[str, str] = {
}
def get_openapi_params(dependant: Dependant) -> List[ModelField]:
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
return (
flat_dependant.path_params
+ flat_dependant.query_params
+ flat_dependant.header_params
+ flat_dependant.cookie_params
)
def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
security_definitions = {}
operation_security = []
@ -90,17 +84,22 @@ def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, L
def get_openapi_operation_parameters(
*,
all_route_params: Sequence[ModelField],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str]
) -> List[Dict[str, Any]]:
parameters = []
for param in all_route_params:
field_info = get_field_info(param)
field_info = cast(Param, field_info)
# ignore mypy error until enum schemas are released
parameter = {
"name": param.alias,
"in": field_info.in_.value,
"required": param.required,
"schema": field_schema(param, model_name_map={})[0],
"schema": field_schema(
param, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
)[0],
}
if field_info.description:
parameter["description"] = field_info.description
@ -111,13 +110,16 @@ def get_openapi_operation_parameters(
def get_openapi_operation_request_body(
*, body_field: Optional[ModelField], model_name_map: Dict[Type[BaseModel], str]
*,
body_field: Optional[ModelField],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str]
) -> Optional[Dict]:
if not body_field:
return None
assert isinstance(body_field, ModelField)
# ignore mypy error until enum schemas are released
body_schema, _, _ = field_schema(
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
)
field_info = cast(Body, get_field_info(body_field))
request_media_type = field_info.media_type
@ -176,8 +178,10 @@ def get_openapi_path(
operation.setdefault("security", []).extend(operation_security)
if security_definitions:
security_schemes.update(security_definitions)
all_route_params = get_openapi_params(route.dependant)
operation_parameters = get_openapi_operation_parameters(all_route_params)
all_route_params = get_flat_params(route.dependant)
operation_parameters = get_openapi_operation_parameters(
all_route_params=all_route_params, model_name_map=model_name_map
)
parameters.extend(operation_parameters)
if parameters:
operation["parameters"] = list(
@ -270,6 +274,38 @@ def get_openapi_path(
return path, security_schemes, definitions
def get_flat_models_from_routes(
routes: Sequence[BaseRoute],
) -> Set[Union[Type[BaseModel], Type[Enum]]]:
body_fields_from_routes: List[ModelField] = []
responses_from_routes: List[ModelField] = []
request_fields_from_routes: List[ModelField] = []
callback_flat_models: Set[Union[Type[BaseModel], Type[Enum]]] = set()
for route in routes:
if getattr(route, "include_in_schema", None) and isinstance(
route, routing.APIRoute
):
if route.body_field:
assert isinstance(
route.body_field, ModelField
), "A request body must be a Pydantic Field"
body_fields_from_routes.append(route.body_field)
if route.response_field:
responses_from_routes.append(route.response_field)
if route.response_fields:
responses_from_routes.extend(route.response_fields.values())
if route.callbacks:
callback_flat_models |= get_flat_models_from_routes(route.callbacks)
params = get_flat_params(route.dependant)
request_fields_from_routes.extend(params)
flat_models = callback_flat_models | get_flat_models_from_fields(
body_fields_from_routes + responses_from_routes + request_fields_from_routes,
known_models=set(),
)
return flat_models
def get_openapi(
*,
title: str,
@ -286,9 +322,11 @@ def get_openapi(
components: Dict[str, Dict] = {}
paths: Dict[str, Dict] = {}
flat_models = get_flat_models_from_routes(routes)
model_name_map = get_model_name_map(flat_models)
# ignore mypy error until enum schemas are released
model_name_map = get_model_name_map(flat_models) # type: ignore
# ignore mypy error until enum schemas are released
definitions = get_model_definitions(
flat_models=flat_models, model_name_map=model_name_map
flat_models=flat_models, model_name_map=model_name_map # type: ignore
)
for route in routes:
if isinstance(route, routing.APIRoute):

39
fastapi/utils.py

@ -1,17 +1,16 @@
import functools
import re
from dataclasses import is_dataclass
from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union, cast
from enum import Enum
from typing import Any, Dict, Optional, Set, Type, Union, cast
import fastapi
from fastapi import routing
from fastapi.logger import logger
from fastapi.openapi.constants import REF_PREFIX
from pydantic import BaseConfig, BaseModel, create_model
from pydantic.class_validators import Validator
from pydantic.schema import get_flat_models_from_fields, model_process_schema
from pydantic.schema import model_process_schema
from pydantic.utils import lenient_issubclass
from starlette.routing import BaseRoute
try:
from pydantic.fields import FieldInfo, ModelField, UndefinedType
@ -50,38 +49,16 @@ def warning_response_model_skip_defaults_deprecated() -> None:
)
def get_flat_models_from_routes(routes: Sequence[BaseRoute]) -> Set[Type[BaseModel]]:
body_fields_from_routes: List[ModelField] = []
responses_from_routes: List[ModelField] = []
callback_flat_models: Set[Type[BaseModel]] = set()
for route in routes:
if getattr(route, "include_in_schema", None) and isinstance(
route, routing.APIRoute
):
if route.body_field:
assert isinstance(
route.body_field, ModelField
), "A request body must be a Pydantic Field"
body_fields_from_routes.append(route.body_field)
if route.response_field:
responses_from_routes.append(route.response_field)
if route.response_fields:
responses_from_routes.extend(route.response_fields.values())
if route.callbacks:
callback_flat_models |= get_flat_models_from_routes(route.callbacks)
flat_models = callback_flat_models | get_flat_models_from_fields(
body_fields_from_routes + responses_from_routes, known_models=set()
)
return flat_models
def get_model_definitions(
*, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
*,
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Dict[str, Any]:
definitions: Dict[str, Dict] = {}
for model in flat_models:
# ignore mypy error until enum schemas are released
m_schema, m_definitions, m_nested_models = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
)
definitions.update(m_definitions)
model_name = model_name_map[model]

2
scripts/format.sh

@ -3,4 +3,4 @@ set -x
autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place docs_src fastapi tests scripts --exclude=__init__.py
black fastapi tests docs_src scripts
isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --thirdparty fastapi --apply fastapi tests docs_src scripts
isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --thirdparty fastapi --thirdparty pydantic --thirdparty starlette --apply fastapi tests docs_src scripts

2
scripts/lint.sh

@ -5,4 +5,4 @@ set -x
mypy fastapi
black fastapi tests --check
isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --check-only --thirdparty fastapi fastapi tests
isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --check-only --thirdparty fastapi --thirdparty fastapi --thirdparty pydantic --thirdparty starlette fastapi tests

8
tests/test_tutorial/test_path_params/test_tutorial005.py

@ -87,7 +87,7 @@ openapi_schema2 = {
"parameters": [
{
"required": True,
"schema": {"$ref": "#/definitions/ModelName"},
"schema": {"$ref": "#/components/schemas/ModelName"},
"name": "model_name",
"in": "path",
}
@ -124,6 +124,12 @@ openapi_schema2 = {
}
},
},
"ModelName": {
"title": "ModelName",
"enum": ["alexnet", "resnet", "lenet"],
"type": "string",
"description": "An enumeration.",
},
"ValidationError": {
"title": "ValidationError",
"required": ["loc", "msg", "type"],

Loading…
Cancel
Save