From 70bdade23b25baa17eb6ebdf3b0fdcfcbf7984ad Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Mon, 3 Feb 2020 22:03:51 -0500 Subject: [PATCH] :bug: Fix Pydantic field clone logic with validators (#899) --- fastapi/utils.py | 5 +- tests/test_filter_pydantic_sub_model.py | 80 ++++++++++++++++++++++--- 2 files changed, 72 insertions(+), 13 deletions(-) diff --git a/fastapi/utils.py b/fastapi/utils.py index 6a0c1bfd7..e7d3891f4 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -93,12 +93,9 @@ def create_cloned_field(field: ModelField) -> ModelField: use_type = original_type if lenient_issubclass(original_type, BaseModel): original_type = cast(Type[BaseModel], original_type) - use_type = create_model( - original_type.__name__, __config__=original_type.__config__ - ) + use_type = create_model(original_type.__name__, __base__=original_type) for f in original_type.__fields__.values(): use_type.__fields__[f.name] = create_cloned_field(f) - use_type.__validators__ = original_type.__validators__ if PYDANTIC_1: new_field = ModelField( name=field.name, diff --git a/tests/test_filter_pydantic_sub_model.py b/tests/test_filter_pydantic_sub_model.py index aef635040..1f7d1deed 100644 --- a/tests/test_filter_pydantic_sub_model.py +++ b/tests/test_filter_pydantic_sub_model.py @@ -1,5 +1,6 @@ +import pytest from fastapi import Depends, FastAPI -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError, validator from starlette.testclient import TestClient app = FastAPI() @@ -18,14 +19,20 @@ class ModelA(BaseModel): description: str = None model_b: ModelB + @validator("name") + def lower_username(cls, name: str, values): + if not name.endswith("A"): + raise ValueError("name must end in A") + return name + async def get_model_c() -> ModelC: return ModelC(username="test-user", password="test-password") -@app.get("/model", response_model=ModelA) -async def get_model_a(model_c=Depends(get_model_c)): - return {"name": "model-a-name", "description": "model-a-desc", "model_b": model_c} +@app.get("/model/{name}", response_model=ModelA) +async def get_model_a(name: str, model_c=Depends(get_model_c)): + return {"name": name, "description": "model-a-desc", "model_b": model_c} client = TestClient(app) @@ -35,10 +42,18 @@ openapi_schema = { "openapi": "3.0.2", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { - "/model": { + "/model/{name}": { "get": { "summary": "Get Model A", - "operationId": "get_model_a_model_get", + "operationId": "get_model_a_model__name__get", + "parameters": [ + { + "required": True, + "schema": {"title": "Name", "type": "string"}, + "name": "name", + "in": "path", + } + ], "responses": { "200": { "description": "Successful Response", @@ -47,13 +62,34 @@ openapi_schema = { "schema": {"$ref": "#/components/schemas/ModelA"} } }, - } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, }, } } }, "components": { "schemas": { + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, "ModelA": { "title": "ModelA", "required": ["name", "model_b"], @@ -70,6 +106,20 @@ openapi_schema = { "type": "object", "properties": {"username": {"title": "Username", "type": "string"}}, }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, } }, } @@ -82,10 +132,22 @@ def test_openapi_schema(): def test_filter_sub_model(): - response = client.get("/model") + response = client.get("/model/modelA") assert response.status_code == 200 assert response.json() == { - "name": "model-a-name", + "name": "modelA", "description": "model-a-desc", "model_b": {"username": "test-user"}, } + + +def test_validator_is_cloned(): + with pytest.raises(ValidationError) as err: + client.get("/model/modelX") + assert err.value.errors() == [ + { + "loc": ("response", "name"), + "msg": "name must end in A", + "type": "value_error", + } + ]