|
|
@ -1,5 +1,6 @@ |
|
|
|
import pytest |
|
|
|
from fastapi import Depends, FastAPI |
|
|
|
from pydantic import BaseModel |
|
|
|
from pydantic import BaseModel, ValidationError, validator |
|
|
|
from starlette.testclient import TestClient |
|
|
|
|
|
|
|
app = FastAPI() |
|
|
@ -18,14 +19,20 @@ class ModelA(BaseModel): |
|
|
|
description: str = None |
|
|
|
model_b: ModelB |
|
|
|
|
|
|
|
@validator("name") |
|
|
|
def lower_username(cls, name: str, values): |
|
|
|
if not name.endswith("A"): |
|
|
|
raise ValueError("name must end in A") |
|
|
|
return name |
|
|
|
|
|
|
|
|
|
|
|
async def get_model_c() -> ModelC: |
|
|
|
return ModelC(username="test-user", password="test-password") |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/model", response_model=ModelA) |
|
|
|
async def get_model_a(model_c=Depends(get_model_c)): |
|
|
|
return {"name": "model-a-name", "description": "model-a-desc", "model_b": model_c} |
|
|
|
@app.get("/model/{name}", response_model=ModelA) |
|
|
|
async def get_model_a(name: str, model_c=Depends(get_model_c)): |
|
|
|
return {"name": name, "description": "model-a-desc", "model_b": model_c} |
|
|
|
|
|
|
|
|
|
|
|
client = TestClient(app) |
|
|
@ -35,10 +42,18 @@ openapi_schema = { |
|
|
|
"openapi": "3.0.2", |
|
|
|
"info": {"title": "FastAPI", "version": "0.1.0"}, |
|
|
|
"paths": { |
|
|
|
"/model": { |
|
|
|
"/model/{name}": { |
|
|
|
"get": { |
|
|
|
"summary": "Get Model A", |
|
|
|
"operationId": "get_model_a_model_get", |
|
|
|
"operationId": "get_model_a_model__name__get", |
|
|
|
"parameters": [ |
|
|
|
{ |
|
|
|
"required": True, |
|
|
|
"schema": {"title": "Name", "type": "string"}, |
|
|
|
"name": "name", |
|
|
|
"in": "path", |
|
|
|
} |
|
|
|
], |
|
|
|
"responses": { |
|
|
|
"200": { |
|
|
|
"description": "Successful Response", |
|
|
@ -47,13 +62,34 @@ openapi_schema = { |
|
|
|
"schema": {"$ref": "#/components/schemas/ModelA"} |
|
|
|
} |
|
|
|
}, |
|
|
|
} |
|
|
|
}, |
|
|
|
"422": { |
|
|
|
"description": "Validation Error", |
|
|
|
"content": { |
|
|
|
"application/json": { |
|
|
|
"schema": { |
|
|
|
"$ref": "#/components/schemas/HTTPValidationError" |
|
|
|
} |
|
|
|
} |
|
|
|
}, |
|
|
|
}, |
|
|
|
}, |
|
|
|
} |
|
|
|
} |
|
|
|
}, |
|
|
|
"components": { |
|
|
|
"schemas": { |
|
|
|
"HTTPValidationError": { |
|
|
|
"title": "HTTPValidationError", |
|
|
|
"type": "object", |
|
|
|
"properties": { |
|
|
|
"detail": { |
|
|
|
"title": "Detail", |
|
|
|
"type": "array", |
|
|
|
"items": {"$ref": "#/components/schemas/ValidationError"}, |
|
|
|
} |
|
|
|
}, |
|
|
|
}, |
|
|
|
"ModelA": { |
|
|
|
"title": "ModelA", |
|
|
|
"required": ["name", "model_b"], |
|
|
@ -70,6 +106,20 @@ openapi_schema = { |
|
|
|
"type": "object", |
|
|
|
"properties": {"username": {"title": "Username", "type": "string"}}, |
|
|
|
}, |
|
|
|
"ValidationError": { |
|
|
|
"title": "ValidationError", |
|
|
|
"required": ["loc", "msg", "type"], |
|
|
|
"type": "object", |
|
|
|
"properties": { |
|
|
|
"loc": { |
|
|
|
"title": "Location", |
|
|
|
"type": "array", |
|
|
|
"items": {"type": "string"}, |
|
|
|
}, |
|
|
|
"msg": {"title": "Message", "type": "string"}, |
|
|
|
"type": {"title": "Error Type", "type": "string"}, |
|
|
|
}, |
|
|
|
}, |
|
|
|
} |
|
|
|
}, |
|
|
|
} |
|
|
@ -82,10 +132,22 @@ def test_openapi_schema(): |
|
|
|
|
|
|
|
|
|
|
|
def test_filter_sub_model(): |
|
|
|
response = client.get("/model") |
|
|
|
response = client.get("/model/modelA") |
|
|
|
assert response.status_code == 200 |
|
|
|
assert response.json() == { |
|
|
|
"name": "model-a-name", |
|
|
|
"name": "modelA", |
|
|
|
"description": "model-a-desc", |
|
|
|
"model_b": {"username": "test-user"}, |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def test_validator_is_cloned(): |
|
|
|
with pytest.raises(ValidationError) as err: |
|
|
|
client.get("/model/modelX") |
|
|
|
assert err.value.errors() == [ |
|
|
|
{ |
|
|
|
"loc": ("response", "name"), |
|
|
|
"msg": "name must end in A", |
|
|
|
"type": "value_error", |
|
|
|
} |
|
|
|
] |
|
|
|