Browse Source

🐛 Fix Pydantic field clone logic with validators (#899)

pull/948/head
Andy Smith 5 years ago
committed by GitHub
parent
commit
70bdade23b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      fastapi/utils.py
  2. 80
      tests/test_filter_pydantic_sub_model.py

5
fastapi/utils.py

@ -93,12 +93,9 @@ def create_cloned_field(field: ModelField) -> ModelField:
use_type = original_type use_type = original_type
if lenient_issubclass(original_type, BaseModel): if lenient_issubclass(original_type, BaseModel):
original_type = cast(Type[BaseModel], original_type) original_type = cast(Type[BaseModel], original_type)
use_type = create_model( use_type = create_model(original_type.__name__, __base__=original_type)
original_type.__name__, __config__=original_type.__config__
)
for f in original_type.__fields__.values(): for f in original_type.__fields__.values():
use_type.__fields__[f.name] = create_cloned_field(f) use_type.__fields__[f.name] = create_cloned_field(f)
use_type.__validators__ = original_type.__validators__
if PYDANTIC_1: if PYDANTIC_1:
new_field = ModelField( new_field = ModelField(
name=field.name, name=field.name,

80
tests/test_filter_pydantic_sub_model.py

@ -1,5 +1,6 @@
import pytest
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI
from pydantic import BaseModel from pydantic import BaseModel, ValidationError, validator
from starlette.testclient import TestClient from starlette.testclient import TestClient
app = FastAPI() app = FastAPI()
@ -18,14 +19,20 @@ class ModelA(BaseModel):
description: str = None description: str = None
model_b: ModelB model_b: ModelB
@validator("name")
def lower_username(cls, name: str, values):
if not name.endswith("A"):
raise ValueError("name must end in A")
return name
async def get_model_c() -> ModelC: async def get_model_c() -> ModelC:
return ModelC(username="test-user", password="test-password") return ModelC(username="test-user", password="test-password")
@app.get("/model", response_model=ModelA) @app.get("/model/{name}", response_model=ModelA)
async def get_model_a(model_c=Depends(get_model_c)): async def get_model_a(name: str, model_c=Depends(get_model_c)):
return {"name": "model-a-name", "description": "model-a-desc", "model_b": model_c} return {"name": name, "description": "model-a-desc", "model_b": model_c}
client = TestClient(app) client = TestClient(app)
@ -35,10 +42,18 @@ openapi_schema = {
"openapi": "3.0.2", "openapi": "3.0.2",
"info": {"title": "FastAPI", "version": "0.1.0"}, "info": {"title": "FastAPI", "version": "0.1.0"},
"paths": { "paths": {
"/model": { "/model/{name}": {
"get": { "get": {
"summary": "Get Model A", "summary": "Get Model A",
"operationId": "get_model_a_model_get", "operationId": "get_model_a_model__name__get",
"parameters": [
{
"required": True,
"schema": {"title": "Name", "type": "string"},
"name": "name",
"in": "path",
}
],
"responses": { "responses": {
"200": { "200": {
"description": "Successful Response", "description": "Successful Response",
@ -47,13 +62,34 @@ openapi_schema = {
"schema": {"$ref": "#/components/schemas/ModelA"} "schema": {"$ref": "#/components/schemas/ModelA"}
} }
}, },
} },
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
}, },
} }
} }
}, },
"components": { "components": {
"schemas": { "schemas": {
"HTTPValidationError": {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": "#/components/schemas/ValidationError"},
}
},
},
"ModelA": { "ModelA": {
"title": "ModelA", "title": "ModelA",
"required": ["name", "model_b"], "required": ["name", "model_b"],
@ -70,6 +106,20 @@ openapi_schema = {
"type": "object", "type": "object",
"properties": {"username": {"title": "Username", "type": "string"}}, "properties": {"username": {"title": "Username", "type": "string"}},
}, },
"ValidationError": {
"title": "ValidationError",
"required": ["loc", "msg", "type"],
"type": "object",
"properties": {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
},
} }
}, },
} }
@ -82,10 +132,22 @@ def test_openapi_schema():
def test_filter_sub_model(): def test_filter_sub_model():
response = client.get("/model") response = client.get("/model/modelA")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == { assert response.json() == {
"name": "model-a-name", "name": "modelA",
"description": "model-a-desc", "description": "model-a-desc",
"model_b": {"username": "test-user"}, "model_b": {"username": "test-user"},
} }
def test_validator_is_cloned():
with pytest.raises(ValidationError) as err:
client.get("/model/modelX")
assert err.value.errors() == [
{
"loc": ("response", "name"),
"msg": "name must end in A",
"type": "value_error",
}
]

Loading…
Cancel
Save