From afad59dfbb0a74a02a6aeefebf0eefbfff228f7b Mon Sep 17 00:00:00 2001 From: Patrick McKenna Date: Sat, 29 Feb 2020 05:04:35 -0800 Subject: [PATCH] :bug: Admit valid types for Pydantic fields as responses models (#1017) --- fastapi/dependencies/utils.py | 103 +++++----------- fastapi/exceptions.py | 6 + fastapi/routing.py | 53 ++------ fastapi/utils.py | 71 +++++++---- tests/test_response_model_invalid.py | 45 +++++++ tests/test_response_model_sub_types.py | 160 +++++++++++++++++++++++++ 6 files changed, 299 insertions(+), 139 deletions(-) create mode 100644 tests/test_response_model_invalid.py create mode 100644 tests/test_response_model_sub_types.py diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 33130a90e..543479be8 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -27,7 +27,12 @@ from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes from fastapi.security.open_id_connect_url import OpenIdConnect -from fastapi.utils import PYDANTIC_1, get_field_info, get_path_param_names +from fastapi.utils import ( + PYDANTIC_1, + create_response_field, + get_field_info, + get_path_param_names, +) from pydantic import BaseConfig, BaseModel, create_model from pydantic.error_wrappers import ErrorWrapper from pydantic.errors import MissingError @@ -362,31 +367,15 @@ def get_param_field( alias = param.name.replace("_", "-") else: alias = field_info.alias or param.name - if PYDANTIC_1: - field = ModelField( - name=param.name, - type_=annotation, - default=None if required else default_value, - alias=alias, - required=required, - model_config=BaseConfig, - class_validators={}, - field_info=field_info, - ) - # TODO: remove when removing support for Pydantic < 1.2.0 - field.required = required - else: # pragma: nocover - field = ModelField( # type: ignore - name=param.name, - type_=annotation, - default=None if required else default_value, - alias=alias, - required=required, - model_config=BaseConfig, - class_validators={}, - schema=field_info, - ) - field.required = required + field = create_response_field( + name=param.name, + type_=annotation, + default=None if required else default_value, + alias=alias, + required=required, + field_info=field_info, + ) + field.required = required if not had_schema and not is_scalar_field(field=field): if PYDANTIC_1: field.field_info = params.Body(field_info.default) @@ -694,28 +683,16 @@ def get_schema_compatible_field(*, field: ModelField) -> ModelField: use_type: type = bytes if field.shape in sequence_shapes: use_type = List[bytes] - if PYDANTIC_1: - out_field = ModelField( - name=field.name, - type_=use_type, - class_validators=field.class_validators, - model_config=field.model_config, - default=field.default, - required=field.required, - alias=field.alias, - field_info=field.field_info, - ) - else: # pragma: nocover - out_field = ModelField( # type: ignore - name=field.name, - type_=use_type, - class_validators=field.class_validators, - model_config=field.model_config, - default=field.default, - required=field.required, - alias=field.alias, - schema=field.schema, # type: ignore - ) + out_field = create_response_field( + name=field.name, + type_=use_type, + class_validators=field.class_validators, + model_config=field.model_config, + default=field.default, + required=field.required, + alias=field.alias, + field_info=field.field_info if PYDANTIC_1 else field.schema, # type: ignore + ) return out_field @@ -754,26 +731,10 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: ] if len(set(body_param_media_types)) == 1: BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] - if PYDANTIC_1: - field = ModelField( - name="body", - type_=BodyModel, - default=None, - required=required, - model_config=BaseConfig, - class_validators={}, - alias="body", - field_info=BodyFieldInfo(**BodyFieldInfo_kwargs), - ) - else: # pragma: nocover - field = ModelField( # type: ignore - name="body", - type_=BodyModel, - default=None, - required=required, - model_config=BaseConfig, - class_validators={}, - alias="body", - schema=BodyFieldInfo(**BodyFieldInfo_kwargs), - ) - return field + return create_response_field( + name="body", + type_=BodyModel, + required=required, + alias="body", + field_info=BodyFieldInfo(**BodyFieldInfo_kwargs), + ) diff --git a/fastapi/exceptions.py b/fastapi/exceptions.py index ac002205a..be196d0cb 100644 --- a/fastapi/exceptions.py +++ b/fastapi/exceptions.py @@ -20,6 +20,12 @@ RequestErrorModel = create_model("Request") WebSocketErrorModel = create_model("WebSocket") +class FastAPIError(RuntimeError): + """ + A generic, FastAPI-specific error. + """ + + class RequestValidationError(ValidationError): def __init__(self, errors: Sequence[ErrorList], *, body: Any = None) -> None: self.body = body diff --git a/fastapi/routing.py b/fastapi/routing.py index bbc2b4133..7c626af41 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -17,13 +17,13 @@ from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY from fastapi.utils import ( PYDANTIC_1, create_cloned_field, + create_response_field, generate_operation_id_for_path, get_field_info, warning_response_model_skip_defaults_deprecated, ) -from pydantic import BaseConfig, BaseModel +from pydantic import BaseModel from pydantic.error_wrappers import ErrorWrapper, ValidationError -from pydantic.utils import lenient_issubclass from starlette import routing from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException @@ -243,26 +243,9 @@ class APIRoute(routing.Route): status_code not in STATUS_CODES_WITH_NO_BODY ), f"Status code {status_code} must not have a response body" response_name = "Response_" + self.unique_id - if PYDANTIC_1: - self.response_field: Optional[ModelField] = ModelField( - name=response_name, - type_=self.response_model, - class_validators={}, - default=None, - required=False, - model_config=BaseConfig, - field_info=FieldInfo(None), - ) - else: - self.response_field: Optional[ModelField] = ModelField( # type: ignore # pragma: nocover - name=response_name, - type_=self.response_model, - class_validators={}, - default=None, - required=False, - model_config=BaseConfig, - schema=FieldInfo(None), - ) + self.response_field = create_response_field( + name=response_name, type_=self.response_model + ) # Create a clone of the field, so that a Pydantic submodel is not returned # as is just because it's an instance of a subclass of a more limited class # e.g. UserInDB (containing hashed_password) could be a subclass of User @@ -274,7 +257,7 @@ class APIRoute(routing.Route): ModelField ] = create_cloned_field(self.response_field) else: - self.response_field = None + self.response_field = None # type: ignore self.secure_cloned_response_field = None self.status_code = status_code self.tags = tags or [] @@ -297,30 +280,8 @@ class APIRoute(routing.Route): assert ( additional_status_code not in STATUS_CODES_WITH_NO_BODY ), f"Status code {additional_status_code} must not have a response body" - assert lenient_issubclass( - model, BaseModel - ), "A response model must be a Pydantic model" response_name = f"Response_{additional_status_code}_{self.unique_id}" - if PYDANTIC_1: - response_field = ModelField( - name=response_name, - type_=model, - class_validators=None, - default=None, - required=False, - model_config=BaseConfig, - field_info=FieldInfo(None), - ) - else: - response_field = ModelField( # type: ignore # pragma: nocover - name=response_name, - type_=model, - class_validators=None, - default=None, - required=False, - model_config=BaseConfig, - schema=FieldInfo(None), - ) + response_field = create_response_field(name=response_name, type_=model) response_fields[additional_status_code] = response_field if response_fields: self.response_fields: Dict[Union[int, str], ModelField] = response_fields diff --git a/fastapi/utils.py b/fastapi/utils.py index e7d3891f4..f24f28073 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -1,17 +1,20 @@ +import functools import re from dataclasses import is_dataclass -from typing import Any, Dict, List, Sequence, Set, Type, cast +from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union, cast +import fastapi from fastapi import routing from fastapi.logger import logger from fastapi.openapi.constants import REF_PREFIX from pydantic import BaseConfig, BaseModel, create_model +from pydantic.class_validators import Validator from pydantic.schema import get_flat_models_from_fields, model_process_schema from pydantic.utils import lenient_issubclass from starlette.routing import BaseRoute try: - from pydantic.fields import FieldInfo, ModelField + from pydantic.fields import FieldInfo, ModelField, UndefinedType PYDANTIC_1 = True except ImportError: # pragma: nocover @@ -19,6 +22,10 @@ except ImportError: # pragma: nocover from pydantic.fields import Field as ModelField # type: ignore from pydantic import Schema as FieldInfo # type: ignore + class UndefinedType: # type: ignore + def __repr__(self) -> str: + return "PydanticUndefined" + logger.warning( "Pydantic versions < 1.0.0 are deprecated in FastAPI and support will be " "removed soon." @@ -86,6 +93,44 @@ def get_path_param_names(path: str) -> Set[str]: return {item.strip("{}") for item in re.findall("{[^}]*}", path)} +def create_response_field( + name: str, + type_: Type[Any], + class_validators: Optional[Dict[str, Validator]] = None, + default: Optional[Any] = None, + required: Union[bool, UndefinedType] = False, + model_config: Type[BaseConfig] = BaseConfig, + field_info: Optional[FieldInfo] = None, + alias: Optional[str] = None, +) -> ModelField: + """ + Create a new response field. Raises if type_ is invalid. + """ + class_validators = class_validators or {} + field_info = field_info or FieldInfo(None) + + response_field = functools.partial( + ModelField, + name=name, + type_=type_, + class_validators=class_validators, + default=default, + required=required, + model_config=model_config, + alias=alias, + ) + + try: + if PYDANTIC_1: + return response_field(field_info=field_info) + else: # pragma: nocover + return response_field(schema=field_info) + except RuntimeError: + raise fastapi.exceptions.FastAPIError( + f"Invalid args for response field! Hint: check that {type_} is a valid pydantic field type" + ) + + def create_cloned_field(field: ModelField) -> ModelField: original_type = field.type_ if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"): @@ -96,26 +141,8 @@ def create_cloned_field(field: ModelField) -> ModelField: use_type = create_model(original_type.__name__, __base__=original_type) for f in original_type.__fields__.values(): use_type.__fields__[f.name] = create_cloned_field(f) - if PYDANTIC_1: - new_field = ModelField( - name=field.name, - type_=use_type, - class_validators={}, - default=None, - required=False, - model_config=BaseConfig, - field_info=FieldInfo(None), - ) - else: # pragma: nocover - new_field = ModelField( # type: ignore - name=field.name, - type_=use_type, - class_validators={}, - default=None, - required=False, - model_config=BaseConfig, - schema=FieldInfo(None), - ) + + new_field = create_response_field(name=field.name, type_=use_type) new_field.has_alias = field.has_alias new_field.alias = field.alias new_field.class_validators = field.class_validators diff --git a/tests/test_response_model_invalid.py b/tests/test_response_model_invalid.py new file mode 100644 index 000000000..88b55a436 --- /dev/null +++ b/tests/test_response_model_invalid.py @@ -0,0 +1,45 @@ +from typing import List + +import pytest +from fastapi import FastAPI +from fastapi.exceptions import FastAPIError + + +class NonPydanticModel: + pass + + +def test_invalid_response_model_raises(): + with pytest.raises(FastAPIError): + app = FastAPI() + + @app.get("/", response_model=NonPydanticModel) + def read_root(): + pass # pragma: nocover + + +def test_invalid_response_model_sub_type_raises(): + with pytest.raises(FastAPIError): + app = FastAPI() + + @app.get("/", response_model=List[NonPydanticModel]) + def read_root(): + pass # pragma: nocover + + +def test_invalid_response_model_in_responses_raises(): + with pytest.raises(FastAPIError): + app = FastAPI() + + @app.get("/", responses={"500": {"model": NonPydanticModel}}) + def read_root(): + pass # pragma: nocover + + +def test_invalid_response_model_sub_type_in_responses_raises(): + with pytest.raises(FastAPIError): + app = FastAPI() + + @app.get("/", responses={"500": {"model": List[NonPydanticModel]}}) + def read_root(): + pass # pragma: nocover diff --git a/tests/test_response_model_sub_types.py b/tests/test_response_model_sub_types.py new file mode 100644 index 000000000..ac1209837 --- /dev/null +++ b/tests/test_response_model_sub_types.py @@ -0,0 +1,160 @@ +from typing import List + +from fastapi import FastAPI +from pydantic import BaseModel +from starlette.testclient import TestClient + + +class Model(BaseModel): + name: str + + +app = FastAPI() + + +@app.get("/valid1", responses={"500": {"model": int}}) +def valid1(): + pass + + +@app.get("/valid2", responses={"500": {"model": List[int]}}) +def valid2(): + pass + + +@app.get("/valid3", responses={"500": {"model": Model}}) +def valid3(): + pass + + +@app.get("/valid4", responses={"500": {"model": List[Model]}}) +def valid4(): + pass + + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/valid1": { + "get": { + "summary": "Valid1", + "operationId": "valid1_valid1_get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "title": "Response 500 Valid1 Valid1 Get", + "type": "integer", + } + } + }, + }, + }, + } + }, + "/valid2": { + "get": { + "summary": "Valid2", + "operationId": "valid2_valid2_get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "title": "Response 500 Valid2 Valid2 Get", + "type": "array", + "items": {"type": "integer"}, + } + } + }, + }, + }, + } + }, + "/valid3": { + "get": { + "summary": "Valid3", + "operationId": "valid3_valid3_get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Model"} + } + }, + }, + }, + } + }, + "/valid4": { + "get": { + "summary": "Valid4", + "operationId": "valid4_valid4_get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "title": "Response 500 Valid4 Valid4 Get", + "type": "array", + "items": {"$ref": "#/components/schemas/Model"}, + } + } + }, + }, + }, + } + }, + }, + "components": { + "schemas": { + "Model": { + "title": "Model", + "required": ["name"], + "type": "object", + "properties": {"name": {"title": "Name", "type": "string"}}, + } + } + }, +} + +client = TestClient(app) + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_path_operations(): + response = client.get("/valid1") + assert response.status_code == 200 + response = client.get("/valid2") + assert response.status_code == 200 + response = client.get("/valid3") + assert response.status_code == 200 + response = client.get("/valid4") + assert response.status_code == 200