Browse Source

🐛 Admit valid types for Pydantic fields as responses models (#1017)

pull/1060/head
Patrick McKenna 5 years ago
committed by GitHub
parent
commit
afad59dfbb
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 103
      fastapi/dependencies/utils.py
  2. 6
      fastapi/exceptions.py
  3. 53
      fastapi/routing.py
  4. 71
      fastapi/utils.py
  5. 45
      tests/test_response_model_invalid.py
  6. 160
      tests/test_response_model_sub_types.py

103
fastapi/dependencies/utils.py

@ -27,7 +27,12 @@ from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes from fastapi.security.oauth2 import OAuth2, SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.utils import PYDANTIC_1, get_field_info, get_path_param_names from fastapi.utils import (
PYDANTIC_1,
create_response_field,
get_field_info,
get_path_param_names,
)
from pydantic import BaseConfig, BaseModel, create_model from pydantic import BaseConfig, BaseModel, create_model
from pydantic.error_wrappers import ErrorWrapper from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError from pydantic.errors import MissingError
@ -362,31 +367,15 @@ def get_param_field(
alias = param.name.replace("_", "-") alias = param.name.replace("_", "-")
else: else:
alias = field_info.alias or param.name alias = field_info.alias or param.name
if PYDANTIC_1: field = create_response_field(
field = ModelField( name=param.name,
name=param.name, type_=annotation,
type_=annotation, default=None if required else default_value,
default=None if required else default_value, alias=alias,
alias=alias, required=required,
required=required, field_info=field_info,
model_config=BaseConfig, )
class_validators={}, field.required = required
field_info=field_info,
)
# TODO: remove when removing support for Pydantic < 1.2.0
field.required = required
else: # pragma: nocover
field = ModelField( # type: ignore
name=param.name,
type_=annotation,
default=None if required else default_value,
alias=alias,
required=required,
model_config=BaseConfig,
class_validators={},
schema=field_info,
)
field.required = required
if not had_schema and not is_scalar_field(field=field): if not had_schema and not is_scalar_field(field=field):
if PYDANTIC_1: if PYDANTIC_1:
field.field_info = params.Body(field_info.default) field.field_info = params.Body(field_info.default)
@ -694,28 +683,16 @@ def get_schema_compatible_field(*, field: ModelField) -> ModelField:
use_type: type = bytes use_type: type = bytes
if field.shape in sequence_shapes: if field.shape in sequence_shapes:
use_type = List[bytes] use_type = List[bytes]
if PYDANTIC_1: out_field = create_response_field(
out_field = ModelField( name=field.name,
name=field.name, type_=use_type,
type_=use_type, class_validators=field.class_validators,
class_validators=field.class_validators, model_config=field.model_config,
model_config=field.model_config, default=field.default,
default=field.default, required=field.required,
required=field.required, alias=field.alias,
alias=field.alias, field_info=field.field_info if PYDANTIC_1 else field.schema, # type: ignore
field_info=field.field_info, )
)
else: # pragma: nocover
out_field = ModelField( # type: ignore
name=field.name,
type_=use_type,
class_validators=field.class_validators,
model_config=field.model_config,
default=field.default,
required=field.required,
alias=field.alias,
schema=field.schema, # type: ignore
)
return out_field return out_field
@ -754,26 +731,10 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
] ]
if len(set(body_param_media_types)) == 1: if len(set(body_param_media_types)) == 1:
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
if PYDANTIC_1: return create_response_field(
field = ModelField( name="body",
name="body", type_=BodyModel,
type_=BodyModel, required=required,
default=None, alias="body",
required=required, field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
model_config=BaseConfig, )
class_validators={},
alias="body",
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
)
else: # pragma: nocover
field = ModelField( # type: ignore
name="body",
type_=BodyModel,
default=None,
required=required,
model_config=BaseConfig,
class_validators={},
alias="body",
schema=BodyFieldInfo(**BodyFieldInfo_kwargs),
)
return field

6
fastapi/exceptions.py

@ -20,6 +20,12 @@ RequestErrorModel = create_model("Request")
WebSocketErrorModel = create_model("WebSocket") WebSocketErrorModel = create_model("WebSocket")
class FastAPIError(RuntimeError):
"""
A generic, FastAPI-specific error.
"""
class RequestValidationError(ValidationError): class RequestValidationError(ValidationError):
def __init__(self, errors: Sequence[ErrorList], *, body: Any = None) -> None: def __init__(self, errors: Sequence[ErrorList], *, body: Any = None) -> None:
self.body = body self.body = body

53
fastapi/routing.py

@ -17,13 +17,13 @@ from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY
from fastapi.utils import ( from fastapi.utils import (
PYDANTIC_1, PYDANTIC_1,
create_cloned_field, create_cloned_field,
create_response_field,
generate_operation_id_for_path, generate_operation_id_for_path,
get_field_info, get_field_info,
warning_response_model_skip_defaults_deprecated, warning_response_model_skip_defaults_deprecated,
) )
from pydantic import BaseConfig, BaseModel from pydantic import BaseModel
from pydantic.error_wrappers import ErrorWrapper, ValidationError from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.utils import lenient_issubclass
from starlette import routing from starlette import routing
from starlette.concurrency import run_in_threadpool from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
@ -243,26 +243,9 @@ class APIRoute(routing.Route):
status_code not in STATUS_CODES_WITH_NO_BODY status_code not in STATUS_CODES_WITH_NO_BODY
), f"Status code {status_code} must not have a response body" ), f"Status code {status_code} must not have a response body"
response_name = "Response_" + self.unique_id response_name = "Response_" + self.unique_id
if PYDANTIC_1: self.response_field = create_response_field(
self.response_field: Optional[ModelField] = ModelField( name=response_name, type_=self.response_model
name=response_name, )
type_=self.response_model,
class_validators={},
default=None,
required=False,
model_config=BaseConfig,
field_info=FieldInfo(None),
)
else:
self.response_field: Optional[ModelField] = ModelField( # type: ignore # pragma: nocover
name=response_name,
type_=self.response_model,
class_validators={},
default=None,
required=False,
model_config=BaseConfig,
schema=FieldInfo(None),
)
# Create a clone of the field, so that a Pydantic submodel is not returned # Create a clone of the field, so that a Pydantic submodel is not returned
# as is just because it's an instance of a subclass of a more limited class # as is just because it's an instance of a subclass of a more limited class
# e.g. UserInDB (containing hashed_password) could be a subclass of User # e.g. UserInDB (containing hashed_password) could be a subclass of User
@ -274,7 +257,7 @@ class APIRoute(routing.Route):
ModelField ModelField
] = create_cloned_field(self.response_field) ] = create_cloned_field(self.response_field)
else: else:
self.response_field = None self.response_field = None # type: ignore
self.secure_cloned_response_field = None self.secure_cloned_response_field = None
self.status_code = status_code self.status_code = status_code
self.tags = tags or [] self.tags = tags or []
@ -297,30 +280,8 @@ class APIRoute(routing.Route):
assert ( assert (
additional_status_code not in STATUS_CODES_WITH_NO_BODY additional_status_code not in STATUS_CODES_WITH_NO_BODY
), f"Status code {additional_status_code} must not have a response body" ), f"Status code {additional_status_code} must not have a response body"
assert lenient_issubclass(
model, BaseModel
), "A response model must be a Pydantic model"
response_name = f"Response_{additional_status_code}_{self.unique_id}" response_name = f"Response_{additional_status_code}_{self.unique_id}"
if PYDANTIC_1: response_field = create_response_field(name=response_name, type_=model)
response_field = ModelField(
name=response_name,
type_=model,
class_validators=None,
default=None,
required=False,
model_config=BaseConfig,
field_info=FieldInfo(None),
)
else:
response_field = ModelField( # type: ignore # pragma: nocover
name=response_name,
type_=model,
class_validators=None,
default=None,
required=False,
model_config=BaseConfig,
schema=FieldInfo(None),
)
response_fields[additional_status_code] = response_field response_fields[additional_status_code] = response_field
if response_fields: if response_fields:
self.response_fields: Dict[Union[int, str], ModelField] = response_fields self.response_fields: Dict[Union[int, str], ModelField] = response_fields

71
fastapi/utils.py

@ -1,17 +1,20 @@
import functools
import re import re
from dataclasses import is_dataclass from dataclasses import is_dataclass
from typing import Any, Dict, List, Sequence, Set, Type, cast from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union, cast
import fastapi
from fastapi import routing from fastapi import routing
from fastapi.logger import logger from fastapi.logger import logger
from fastapi.openapi.constants import REF_PREFIX from fastapi.openapi.constants import REF_PREFIX
from pydantic import BaseConfig, BaseModel, create_model from pydantic import BaseConfig, BaseModel, create_model
from pydantic.class_validators import Validator
from pydantic.schema import get_flat_models_from_fields, model_process_schema from pydantic.schema import get_flat_models_from_fields, model_process_schema
from pydantic.utils import lenient_issubclass from pydantic.utils import lenient_issubclass
from starlette.routing import BaseRoute from starlette.routing import BaseRoute
try: try:
from pydantic.fields import FieldInfo, ModelField from pydantic.fields import FieldInfo, ModelField, UndefinedType
PYDANTIC_1 = True PYDANTIC_1 = True
except ImportError: # pragma: nocover except ImportError: # pragma: nocover
@ -19,6 +22,10 @@ except ImportError: # pragma: nocover
from pydantic.fields import Field as ModelField # type: ignore from pydantic.fields import Field as ModelField # type: ignore
from pydantic import Schema as FieldInfo # type: ignore from pydantic import Schema as FieldInfo # type: ignore
class UndefinedType: # type: ignore
def __repr__(self) -> str:
return "PydanticUndefined"
logger.warning( logger.warning(
"Pydantic versions < 1.0.0 are deprecated in FastAPI and support will be " "Pydantic versions < 1.0.0 are deprecated in FastAPI and support will be "
"removed soon." "removed soon."
@ -86,6 +93,44 @@ def get_path_param_names(path: str) -> Set[str]:
return {item.strip("{}") for item in re.findall("{[^}]*}", path)} return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
def create_response_field(
name: str,
type_: Type[Any],
class_validators: Optional[Dict[str, Validator]] = None,
default: Optional[Any] = None,
required: Union[bool, UndefinedType] = False,
model_config: Type[BaseConfig] = BaseConfig,
field_info: Optional[FieldInfo] = None,
alias: Optional[str] = None,
) -> ModelField:
"""
Create a new response field. Raises if type_ is invalid.
"""
class_validators = class_validators or {}
field_info = field_info or FieldInfo(None)
response_field = functools.partial(
ModelField,
name=name,
type_=type_,
class_validators=class_validators,
default=default,
required=required,
model_config=model_config,
alias=alias,
)
try:
if PYDANTIC_1:
return response_field(field_info=field_info)
else: # pragma: nocover
return response_field(schema=field_info)
except RuntimeError:
raise fastapi.exceptions.FastAPIError(
f"Invalid args for response field! Hint: check that {type_} is a valid pydantic field type"
)
def create_cloned_field(field: ModelField) -> ModelField: def create_cloned_field(field: ModelField) -> ModelField:
original_type = field.type_ original_type = field.type_
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"): if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
@ -96,26 +141,8 @@ def create_cloned_field(field: ModelField) -> ModelField:
use_type = create_model(original_type.__name__, __base__=original_type) use_type = create_model(original_type.__name__, __base__=original_type)
for f in original_type.__fields__.values(): for f in original_type.__fields__.values():
use_type.__fields__[f.name] = create_cloned_field(f) use_type.__fields__[f.name] = create_cloned_field(f)
if PYDANTIC_1:
new_field = ModelField( new_field = create_response_field(name=field.name, type_=use_type)
name=field.name,
type_=use_type,
class_validators={},
default=None,
required=False,
model_config=BaseConfig,
field_info=FieldInfo(None),
)
else: # pragma: nocover
new_field = ModelField( # type: ignore
name=field.name,
type_=use_type,
class_validators={},
default=None,
required=False,
model_config=BaseConfig,
schema=FieldInfo(None),
)
new_field.has_alias = field.has_alias new_field.has_alias = field.has_alias
new_field.alias = field.alias new_field.alias = field.alias
new_field.class_validators = field.class_validators new_field.class_validators = field.class_validators

45
tests/test_response_model_invalid.py

@ -0,0 +1,45 @@
from typing import List
import pytest
from fastapi import FastAPI
from fastapi.exceptions import FastAPIError
class NonPydanticModel:
pass
def test_invalid_response_model_raises():
with pytest.raises(FastAPIError):
app = FastAPI()
@app.get("/", response_model=NonPydanticModel)
def read_root():
pass # pragma: nocover
def test_invalid_response_model_sub_type_raises():
with pytest.raises(FastAPIError):
app = FastAPI()
@app.get("/", response_model=List[NonPydanticModel])
def read_root():
pass # pragma: nocover
def test_invalid_response_model_in_responses_raises():
with pytest.raises(FastAPIError):
app = FastAPI()
@app.get("/", responses={"500": {"model": NonPydanticModel}})
def read_root():
pass # pragma: nocover
def test_invalid_response_model_sub_type_in_responses_raises():
with pytest.raises(FastAPIError):
app = FastAPI()
@app.get("/", responses={"500": {"model": List[NonPydanticModel]}})
def read_root():
pass # pragma: nocover

160
tests/test_response_model_sub_types.py

@ -0,0 +1,160 @@
from typing import List
from fastapi import FastAPI
from pydantic import BaseModel
from starlette.testclient import TestClient
class Model(BaseModel):
name: str
app = FastAPI()
@app.get("/valid1", responses={"500": {"model": int}})
def valid1():
pass
@app.get("/valid2", responses={"500": {"model": List[int]}})
def valid2():
pass
@app.get("/valid3", responses={"500": {"model": Model}})
def valid3():
pass
@app.get("/valid4", responses={"500": {"model": List[Model]}})
def valid4():
pass
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/valid1": {
"get": {
"summary": "Valid1",
"operationId": "valid1_valid1_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {
"title": "Response 500 Valid1 Valid1 Get",
"type": "integer",
}
}
},
},
},
}
},
"/valid2": {
"get": {
"summary": "Valid2",
"operationId": "valid2_valid2_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {
"title": "Response 500 Valid2 Valid2 Get",
"type": "array",
"items": {"type": "integer"},
}
}
},
},
},
}
},
"/valid3": {
"get": {
"summary": "Valid3",
"operationId": "valid3_valid3_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/Model"}
}
},
},
},
}
},
"/valid4": {
"get": {
"summary": "Valid4",
"operationId": "valid4_valid4_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {
"title": "Response 500 Valid4 Valid4 Get",
"type": "array",
"items": {"$ref": "#/components/schemas/Model"},
}
}
},
},
},
}
},
},
"components": {
"schemas": {
"Model": {
"title": "Model",
"required": ["name"],
"type": "object",
"properties": {"name": {"title": "Name", "type": "string"}},
}
}
},
}
client = TestClient(app)
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_path_operations():
response = client.get("/valid1")
assert response.status_code == 200
response = client.get("/valid2")
assert response.status_code == 200
response = client.get("/valid3")
assert response.status_code == 200
response = client.get("/valid4")
assert response.status_code == 200
Loading…
Cancel
Save