Browse Source

🐛 Fix Enum handling with their own schema definitions (#1463)

* 🐛 Fix extra support for enum with its own schema

*  Fix/update test for enum with its own schema

* 🐛 Fix type declarations

* 🔧 Update format and lint scripts to support locally installed Pydantic and Starlette

* 🐛 Add temporary type ignores while enum schemas are merged
pull/1467/head
Sebastián Ramírez 5 years ago
committed by GitHub
parent
commit
5984233223
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      fastapi/dependencies/utils.py
  2. 80
      fastapi/openapi/utils.py
  3. 39
      fastapi/utils.py
  4. 2
      scripts/format.sh
  5. 2
      scripts/lint.sh
  6. 8
      tests/test_tutorial/test_path_params/test_tutorial005.py

10
fastapi/dependencies/utils.py

@ -188,6 +188,16 @@ def get_flat_dependant(
return flat_dependant return flat_dependant
def get_flat_params(dependant: Dependant) -> List[ModelField]:
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
return (
flat_dependant.path_params
+ flat_dependant.query_params
+ flat_dependant.header_params
+ flat_dependant.cookie_params
)
def is_scalar_field(field: ModelField) -> bool: def is_scalar_field(field: ModelField) -> bool:
field_info = get_field_info(field) field_info = get_field_info(field)
if not ( if not (

80
fastapi/openapi/utils.py

@ -1,9 +1,10 @@
import http.client import http.client
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
from fastapi import routing from fastapi import routing
from fastapi.dependencies.models import Dependant from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_flat_dependant from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.openapi.constants import ( from fastapi.openapi.constants import (
METHODS_WITH_BODY, METHODS_WITH_BODY,
@ -15,11 +16,14 @@ from fastapi.params import Body, Param
from fastapi.utils import ( from fastapi.utils import (
generate_operation_id_for_path, generate_operation_id_for_path,
get_field_info, get_field_info,
get_flat_models_from_routes,
get_model_definitions, get_model_definitions,
) )
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.schema import field_schema, get_model_name_map from pydantic.schema import (
field_schema,
get_flat_models_from_fields,
get_model_name_map,
)
from pydantic.utils import lenient_issubclass from pydantic.utils import lenient_issubclass
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from starlette.routing import BaseRoute from starlette.routing import BaseRoute
@ -64,16 +68,6 @@ status_code_ranges: Dict[str, str] = {
} }
def get_openapi_params(dependant: Dependant) -> List[ModelField]:
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
return (
flat_dependant.path_params
+ flat_dependant.query_params
+ flat_dependant.header_params
+ flat_dependant.cookie_params
)
def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]: def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
security_definitions = {} security_definitions = {}
operation_security = [] operation_security = []
@ -90,17 +84,22 @@ def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, L
def get_openapi_operation_parameters( def get_openapi_operation_parameters(
*,
all_route_params: Sequence[ModelField], all_route_params: Sequence[ModelField],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
parameters = [] parameters = []
for param in all_route_params: for param in all_route_params:
field_info = get_field_info(param) field_info = get_field_info(param)
field_info = cast(Param, field_info) field_info = cast(Param, field_info)
# ignore mypy error until enum schemas are released
parameter = { parameter = {
"name": param.alias, "name": param.alias,
"in": field_info.in_.value, "in": field_info.in_.value,
"required": param.required, "required": param.required,
"schema": field_schema(param, model_name_map={})[0], "schema": field_schema(
param, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
)[0],
} }
if field_info.description: if field_info.description:
parameter["description"] = field_info.description parameter["description"] = field_info.description
@ -111,13 +110,16 @@ def get_openapi_operation_parameters(
def get_openapi_operation_request_body( def get_openapi_operation_request_body(
*, body_field: Optional[ModelField], model_name_map: Dict[Type[BaseModel], str] *,
body_field: Optional[ModelField],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str]
) -> Optional[Dict]: ) -> Optional[Dict]:
if not body_field: if not body_field:
return None return None
assert isinstance(body_field, ModelField) assert isinstance(body_field, ModelField)
# ignore mypy error until enum schemas are released
body_schema, _, _ = field_schema( body_schema, _, _ = field_schema(
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
) )
field_info = cast(Body, get_field_info(body_field)) field_info = cast(Body, get_field_info(body_field))
request_media_type = field_info.media_type request_media_type = field_info.media_type
@ -176,8 +178,10 @@ def get_openapi_path(
operation.setdefault("security", []).extend(operation_security) operation.setdefault("security", []).extend(operation_security)
if security_definitions: if security_definitions:
security_schemes.update(security_definitions) security_schemes.update(security_definitions)
all_route_params = get_openapi_params(route.dependant) all_route_params = get_flat_params(route.dependant)
operation_parameters = get_openapi_operation_parameters(all_route_params) operation_parameters = get_openapi_operation_parameters(
all_route_params=all_route_params, model_name_map=model_name_map
)
parameters.extend(operation_parameters) parameters.extend(operation_parameters)
if parameters: if parameters:
operation["parameters"] = list( operation["parameters"] = list(
@ -270,6 +274,38 @@ def get_openapi_path(
return path, security_schemes, definitions return path, security_schemes, definitions
def get_flat_models_from_routes(
routes: Sequence[BaseRoute],
) -> Set[Union[Type[BaseModel], Type[Enum]]]:
body_fields_from_routes: List[ModelField] = []
responses_from_routes: List[ModelField] = []
request_fields_from_routes: List[ModelField] = []
callback_flat_models: Set[Union[Type[BaseModel], Type[Enum]]] = set()
for route in routes:
if getattr(route, "include_in_schema", None) and isinstance(
route, routing.APIRoute
):
if route.body_field:
assert isinstance(
route.body_field, ModelField
), "A request body must be a Pydantic Field"
body_fields_from_routes.append(route.body_field)
if route.response_field:
responses_from_routes.append(route.response_field)
if route.response_fields:
responses_from_routes.extend(route.response_fields.values())
if route.callbacks:
callback_flat_models |= get_flat_models_from_routes(route.callbacks)
params = get_flat_params(route.dependant)
request_fields_from_routes.extend(params)
flat_models = callback_flat_models | get_flat_models_from_fields(
body_fields_from_routes + responses_from_routes + request_fields_from_routes,
known_models=set(),
)
return flat_models
def get_openapi( def get_openapi(
*, *,
title: str, title: str,
@ -286,9 +322,11 @@ def get_openapi(
components: Dict[str, Dict] = {} components: Dict[str, Dict] = {}
paths: Dict[str, Dict] = {} paths: Dict[str, Dict] = {}
flat_models = get_flat_models_from_routes(routes) flat_models = get_flat_models_from_routes(routes)
model_name_map = get_model_name_map(flat_models) # ignore mypy error until enum schemas are released
model_name_map = get_model_name_map(flat_models) # type: ignore
# ignore mypy error until enum schemas are released
definitions = get_model_definitions( definitions = get_model_definitions(
flat_models=flat_models, model_name_map=model_name_map flat_models=flat_models, model_name_map=model_name_map # type: ignore
) )
for route in routes: for route in routes:
if isinstance(route, routing.APIRoute): if isinstance(route, routing.APIRoute):

39
fastapi/utils.py

@ -1,17 +1,16 @@
import functools import functools
import re import re
from dataclasses import is_dataclass from dataclasses import is_dataclass
from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union, cast from enum import Enum
from typing import Any, Dict, Optional, Set, Type, Union, cast
import fastapi import fastapi
from fastapi import routing
from fastapi.logger import logger from fastapi.logger import logger
from fastapi.openapi.constants import REF_PREFIX from fastapi.openapi.constants import REF_PREFIX
from pydantic import BaseConfig, BaseModel, create_model from pydantic import BaseConfig, BaseModel, create_model
from pydantic.class_validators import Validator from pydantic.class_validators import Validator
from pydantic.schema import get_flat_models_from_fields, model_process_schema from pydantic.schema import model_process_schema
from pydantic.utils import lenient_issubclass from pydantic.utils import lenient_issubclass
from starlette.routing import BaseRoute
try: try:
from pydantic.fields import FieldInfo, ModelField, UndefinedType from pydantic.fields import FieldInfo, ModelField, UndefinedType
@ -50,38 +49,16 @@ def warning_response_model_skip_defaults_deprecated() -> None:
) )
def get_flat_models_from_routes(routes: Sequence[BaseRoute]) -> Set[Type[BaseModel]]:
body_fields_from_routes: List[ModelField] = []
responses_from_routes: List[ModelField] = []
callback_flat_models: Set[Type[BaseModel]] = set()
for route in routes:
if getattr(route, "include_in_schema", None) and isinstance(
route, routing.APIRoute
):
if route.body_field:
assert isinstance(
route.body_field, ModelField
), "A request body must be a Pydantic Field"
body_fields_from_routes.append(route.body_field)
if route.response_field:
responses_from_routes.append(route.response_field)
if route.response_fields:
responses_from_routes.extend(route.response_fields.values())
if route.callbacks:
callback_flat_models |= get_flat_models_from_routes(route.callbacks)
flat_models = callback_flat_models | get_flat_models_from_fields(
body_fields_from_routes + responses_from_routes, known_models=set()
)
return flat_models
def get_model_definitions( def get_model_definitions(
*, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str] *,
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
definitions: Dict[str, Dict] = {} definitions: Dict[str, Dict] = {}
for model in flat_models: for model in flat_models:
# ignore mypy error until enum schemas are released
m_schema, m_definitions, m_nested_models = model_process_schema( m_schema, m_definitions, m_nested_models = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX model, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
) )
definitions.update(m_definitions) definitions.update(m_definitions)
model_name = model_name_map[model] model_name = model_name_map[model]

2
scripts/format.sh

@ -3,4 +3,4 @@ set -x
autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place docs_src fastapi tests scripts --exclude=__init__.py autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place docs_src fastapi tests scripts --exclude=__init__.py
black fastapi tests docs_src scripts black fastapi tests docs_src scripts
isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --thirdparty fastapi --apply fastapi tests docs_src scripts isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --thirdparty fastapi --thirdparty pydantic --thirdparty starlette --apply fastapi tests docs_src scripts

2
scripts/lint.sh

@ -5,4 +5,4 @@ set -x
mypy fastapi mypy fastapi
black fastapi tests --check black fastapi tests --check
isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --check-only --thirdparty fastapi fastapi tests isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --check-only --thirdparty fastapi --thirdparty fastapi --thirdparty pydantic --thirdparty starlette fastapi tests

8
tests/test_tutorial/test_path_params/test_tutorial005.py

@ -87,7 +87,7 @@ openapi_schema2 = {
"parameters": [ "parameters": [
{ {
"required": True, "required": True,
"schema": {"$ref": "#/definitions/ModelName"}, "schema": {"$ref": "#/components/schemas/ModelName"},
"name": "model_name", "name": "model_name",
"in": "path", "in": "path",
} }
@ -124,6 +124,12 @@ openapi_schema2 = {
} }
}, },
}, },
"ModelName": {
"title": "ModelName",
"enum": ["alexnet", "resnet", "lenet"],
"type": "string",
"description": "An enumeration.",
},
"ValidationError": { "ValidationError": {
"title": "ValidationError", "title": "ValidationError",
"required": ["loc", "msg", "type"], "required": ["loc", "msg", "type"],

Loading…
Cancel
Save