Browse Source

🐛 Admit valid types for Pydantic fields as responses models (#1017)

pull/1060/head
Patrick McKenna 5 years ago
committed by GitHub
parent
commit
afad59dfbb
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 59
      fastapi/dependencies/utils.py
  2. 6
      fastapi/exceptions.py
  3. 51
      fastapi/routing.py
  4. 71
      fastapi/utils.py
  5. 45
      tests/test_response_model_invalid.py
  6. 160
      tests/test_response_model_sub_types.py

59
fastapi/dependencies/utils.py

@ -27,7 +27,12 @@ from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.utils import PYDANTIC_1, get_field_info, get_path_param_names
from fastapi.utils import (
PYDANTIC_1,
create_response_field,
get_field_info,
get_path_param_names,
)
from pydantic import BaseConfig, BaseModel, create_model
from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
@ -362,30 +367,14 @@ def get_param_field(
alias = param.name.replace("_", "-")
else:
alias = field_info.alias or param.name
if PYDANTIC_1:
field = ModelField(
field = create_response_field(
name=param.name,
type_=annotation,
default=None if required else default_value,
alias=alias,
required=required,
model_config=BaseConfig,
class_validators={},
field_info=field_info,
)
# TODO: remove when removing support for Pydantic < 1.2.0
field.required = required
else: # pragma: nocover
field = ModelField( # type: ignore
name=param.name,
type_=annotation,
default=None if required else default_value,
alias=alias,
required=required,
model_config=BaseConfig,
class_validators={},
schema=field_info,
)
field.required = required
if not had_schema and not is_scalar_field(field=field):
if PYDANTIC_1:
@ -694,19 +683,7 @@ def get_schema_compatible_field(*, field: ModelField) -> ModelField:
use_type: type = bytes
if field.shape in sequence_shapes:
use_type = List[bytes]
if PYDANTIC_1:
out_field = ModelField(
name=field.name,
type_=use_type,
class_validators=field.class_validators,
model_config=field.model_config,
default=field.default,
required=field.required,
alias=field.alias,
field_info=field.field_info,
)
else: # pragma: nocover
out_field = ModelField( # type: ignore
out_field = create_response_field(
name=field.name,
type_=use_type,
class_validators=field.class_validators,
@ -714,7 +691,7 @@ def get_schema_compatible_field(*, field: ModelField) -> ModelField:
default=field.default,
required=field.required,
alias=field.alias,
schema=field.schema, # type: ignore
field_info=field.field_info if PYDANTIC_1 else field.schema, # type: ignore
)
return out_field
@ -754,26 +731,10 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
]
if len(set(body_param_media_types)) == 1:
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
if PYDANTIC_1:
field = ModelField(
return create_response_field(
name="body",
type_=BodyModel,
default=None,
required=required,
model_config=BaseConfig,
class_validators={},
alias="body",
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
)
else: # pragma: nocover
field = ModelField( # type: ignore
name="body",
type_=BodyModel,
default=None,
required=required,
model_config=BaseConfig,
class_validators={},
alias="body",
schema=BodyFieldInfo(**BodyFieldInfo_kwargs),
)
return field

6
fastapi/exceptions.py

@ -20,6 +20,12 @@ RequestErrorModel = create_model("Request")
WebSocketErrorModel = create_model("WebSocket")
class FastAPIError(RuntimeError):
"""
A generic, FastAPI-specific error.
"""
class RequestValidationError(ValidationError):
def __init__(self, errors: Sequence[ErrorList], *, body: Any = None) -> None:
self.body = body

51
fastapi/routing.py

@ -17,13 +17,13 @@ from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY
from fastapi.utils import (
PYDANTIC_1,
create_cloned_field,
create_response_field,
generate_operation_id_for_path,
get_field_info,
warning_response_model_skip_defaults_deprecated,
)
from pydantic import BaseConfig, BaseModel
from pydantic import BaseModel
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.utils import lenient_issubclass
from starlette import routing
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
@ -243,25 +243,8 @@ class APIRoute(routing.Route):
status_code not in STATUS_CODES_WITH_NO_BODY
), f"Status code {status_code} must not have a response body"
response_name = "Response_" + self.unique_id
if PYDANTIC_1:
self.response_field: Optional[ModelField] = ModelField(
name=response_name,
type_=self.response_model,
class_validators={},
default=None,
required=False,
model_config=BaseConfig,
field_info=FieldInfo(None),
)
else:
self.response_field: Optional[ModelField] = ModelField( # type: ignore # pragma: nocover
name=response_name,
type_=self.response_model,
class_validators={},
default=None,
required=False,
model_config=BaseConfig,
schema=FieldInfo(None),
self.response_field = create_response_field(
name=response_name, type_=self.response_model
)
# Create a clone of the field, so that a Pydantic submodel is not returned
# as is just because it's an instance of a subclass of a more limited class
@ -274,7 +257,7 @@ class APIRoute(routing.Route):
ModelField
] = create_cloned_field(self.response_field)
else:
self.response_field = None
self.response_field = None # type: ignore
self.secure_cloned_response_field = None
self.status_code = status_code
self.tags = tags or []
@ -297,30 +280,8 @@ class APIRoute(routing.Route):
assert (
additional_status_code not in STATUS_CODES_WITH_NO_BODY
), f"Status code {additional_status_code} must not have a response body"
assert lenient_issubclass(
model, BaseModel
), "A response model must be a Pydantic model"
response_name = f"Response_{additional_status_code}_{self.unique_id}"
if PYDANTIC_1:
response_field = ModelField(
name=response_name,
type_=model,
class_validators=None,
default=None,
required=False,
model_config=BaseConfig,
field_info=FieldInfo(None),
)
else:
response_field = ModelField( # type: ignore # pragma: nocover
name=response_name,
type_=model,
class_validators=None,
default=None,
required=False,
model_config=BaseConfig,
schema=FieldInfo(None),
)
response_field = create_response_field(name=response_name, type_=model)
response_fields[additional_status_code] = response_field
if response_fields:
self.response_fields: Dict[Union[int, str], ModelField] = response_fields

71
fastapi/utils.py

@ -1,17 +1,20 @@
import functools
import re
from dataclasses import is_dataclass
from typing import Any, Dict, List, Sequence, Set, Type, cast
from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union, cast
import fastapi
from fastapi import routing
from fastapi.logger import logger
from fastapi.openapi.constants import REF_PREFIX
from pydantic import BaseConfig, BaseModel, create_model
from pydantic.class_validators import Validator
from pydantic.schema import get_flat_models_from_fields, model_process_schema
from pydantic.utils import lenient_issubclass
from starlette.routing import BaseRoute
try:
from pydantic.fields import FieldInfo, ModelField
from pydantic.fields import FieldInfo, ModelField, UndefinedType
PYDANTIC_1 = True
except ImportError: # pragma: nocover
@ -19,6 +22,10 @@ except ImportError: # pragma: nocover
from pydantic.fields import Field as ModelField # type: ignore
from pydantic import Schema as FieldInfo # type: ignore
class UndefinedType: # type: ignore
def __repr__(self) -> str:
return "PydanticUndefined"
logger.warning(
"Pydantic versions < 1.0.0 are deprecated in FastAPI and support will be "
"removed soon."
@ -86,6 +93,44 @@ def get_path_param_names(path: str) -> Set[str]:
return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
def create_response_field(
name: str,
type_: Type[Any],
class_validators: Optional[Dict[str, Validator]] = None,
default: Optional[Any] = None,
required: Union[bool, UndefinedType] = False,
model_config: Type[BaseConfig] = BaseConfig,
field_info: Optional[FieldInfo] = None,
alias: Optional[str] = None,
) -> ModelField:
"""
Create a new response field. Raises if type_ is invalid.
"""
class_validators = class_validators or {}
field_info = field_info or FieldInfo(None)
response_field = functools.partial(
ModelField,
name=name,
type_=type_,
class_validators=class_validators,
default=default,
required=required,
model_config=model_config,
alias=alias,
)
try:
if PYDANTIC_1:
return response_field(field_info=field_info)
else: # pragma: nocover
return response_field(schema=field_info)
except RuntimeError:
raise fastapi.exceptions.FastAPIError(
f"Invalid args for response field! Hint: check that {type_} is a valid pydantic field type"
)
def create_cloned_field(field: ModelField) -> ModelField:
original_type = field.type_
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
@ -96,26 +141,8 @@ def create_cloned_field(field: ModelField) -> ModelField:
use_type = create_model(original_type.__name__, __base__=original_type)
for f in original_type.__fields__.values():
use_type.__fields__[f.name] = create_cloned_field(f)
if PYDANTIC_1:
new_field = ModelField(
name=field.name,
type_=use_type,
class_validators={},
default=None,
required=False,
model_config=BaseConfig,
field_info=FieldInfo(None),
)
else: # pragma: nocover
new_field = ModelField( # type: ignore
name=field.name,
type_=use_type,
class_validators={},
default=None,
required=False,
model_config=BaseConfig,
schema=FieldInfo(None),
)
new_field = create_response_field(name=field.name, type_=use_type)
new_field.has_alias = field.has_alias
new_field.alias = field.alias
new_field.class_validators = field.class_validators

45
tests/test_response_model_invalid.py

@ -0,0 +1,45 @@
from typing import List
import pytest
from fastapi import FastAPI
from fastapi.exceptions import FastAPIError
class NonPydanticModel:
pass
def test_invalid_response_model_raises():
with pytest.raises(FastAPIError):
app = FastAPI()
@app.get("/", response_model=NonPydanticModel)
def read_root():
pass # pragma: nocover
def test_invalid_response_model_sub_type_raises():
with pytest.raises(FastAPIError):
app = FastAPI()
@app.get("/", response_model=List[NonPydanticModel])
def read_root():
pass # pragma: nocover
def test_invalid_response_model_in_responses_raises():
with pytest.raises(FastAPIError):
app = FastAPI()
@app.get("/", responses={"500": {"model": NonPydanticModel}})
def read_root():
pass # pragma: nocover
def test_invalid_response_model_sub_type_in_responses_raises():
with pytest.raises(FastAPIError):
app = FastAPI()
@app.get("/", responses={"500": {"model": List[NonPydanticModel]}})
def read_root():
pass # pragma: nocover

160
tests/test_response_model_sub_types.py

@ -0,0 +1,160 @@
from typing import List
from fastapi import FastAPI
from pydantic import BaseModel
from starlette.testclient import TestClient
class Model(BaseModel):
name: str
app = FastAPI()
@app.get("/valid1", responses={"500": {"model": int}})
def valid1():
pass
@app.get("/valid2", responses={"500": {"model": List[int]}})
def valid2():
pass
@app.get("/valid3", responses={"500": {"model": Model}})
def valid3():
pass
@app.get("/valid4", responses={"500": {"model": List[Model]}})
def valid4():
pass
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/valid1": {
"get": {
"summary": "Valid1",
"operationId": "valid1_valid1_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {
"title": "Response 500 Valid1 Valid1 Get",
"type": "integer",
}
}
},
},
},
}
},
"/valid2": {
"get": {
"summary": "Valid2",
"operationId": "valid2_valid2_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {
"title": "Response 500 Valid2 Valid2 Get",
"type": "array",
"items": {"type": "integer"},
}
}
},
},
},
}
},
"/valid3": {
"get": {
"summary": "Valid3",
"operationId": "valid3_valid3_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/Model"}
}
},
},
},
}
},
"/valid4": {
"get": {
"summary": "Valid4",
"operationId": "valid4_valid4_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {
"title": "Response 500 Valid4 Valid4 Get",
"type": "array",
"items": {"$ref": "#/components/schemas/Model"},
}
}
},
},
},
}
},
},
"components": {
"schemas": {
"Model": {
"title": "Model",
"required": ["name"],
"type": "object",
"properties": {"name": {"title": "Name", "type": "string"}},
}
}
},
}
client = TestClient(app)
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_path_operations():
response = client.get("/valid1")
assert response.status_code == 200
response = client.get("/valid2")
assert response.status_code == 200
response = client.get("/valid3")
assert response.status_code == 200
response = client.get("/valid4")
assert response.status_code == 200
Loading…
Cancel
Save