3 changed files with 427 additions and 6 deletions
@ -0,0 +1,393 @@ |
|||||
|
from typing import Any, Dict, List, Type |
||||
|
|
||||
|
import pytest |
||||
|
from dirty_equals import IsDict |
||||
|
from fastapi import Body, FastAPI, Path, Query |
||||
|
from fastapi._compat import PYDANTIC_V2 |
||||
|
from fastapi.testclient import TestClient |
||||
|
from fastapi.utils import match_pydantic_error_url |
||||
|
from pydantic import BaseModel |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if PYDANTIC_V2: |
||||
|
from pydantic import ConfigDict, Field, RootModel, field_validator, model_serializer |
||||
|
|
||||
|
class Basic(RootModel[int]): |
||||
|
pass |
||||
|
|
||||
|
class FieldWrap(RootModel[str]): |
||||
|
model_config = ConfigDict( |
||||
|
json_schema_extra={"description": "parameter starts with bar_"} |
||||
|
) |
||||
|
root: str = Field(pattern=r"^bar_.*$") |
||||
|
|
||||
|
class CustomParsed(RootModel[str]): |
||||
|
model_config = ConfigDict( |
||||
|
json_schema_extra={"description": "parameter starts with foo_"} |
||||
|
) |
||||
|
|
||||
|
@field_validator("root") |
||||
|
@classmethod |
||||
|
def validator(cls, v: str) -> str: |
||||
|
if not v.startswith("foo_"): |
||||
|
raise ValueError("must start with foo_") |
||||
|
return v |
||||
|
|
||||
|
@model_serializer |
||||
|
def serialize(self): |
||||
|
return ( |
||||
|
self.root |
||||
|
if self.root.endswith("_serialized") |
||||
|
else f"{self.root}_serialized" |
||||
|
) |
||||
|
|
||||
|
def parse(self): |
||||
|
return self.root[len("foo_") :] # removeprefix |
||||
|
|
||||
|
class DictWrap(RootModel[Dict[str, int]]): |
||||
|
pass |
||||
|
else: |
||||
|
from pydantic import Field, validator |
||||
|
|
||||
|
class Basic(BaseModel): |
||||
|
__root__: int |
||||
|
|
||||
|
class FieldWrap(BaseModel): |
||||
|
class Config: |
||||
|
schema_extra = {"description": "parameter starts with bar_"} |
||||
|
|
||||
|
__root__: str = Field(regex=r"^bar_.*$") |
||||
|
|
||||
|
class CustomParsed(BaseModel): |
||||
|
class Config: |
||||
|
schema_extra = {"description": "parameter starts with foo_"} |
||||
|
|
||||
|
__root__: str |
||||
|
|
||||
|
@validator("__root__") |
||||
|
@classmethod |
||||
|
def validator(cls, v: str) -> str: |
||||
|
if not v.startswith("foo_"): |
||||
|
raise ValueError("must start with foo_") |
||||
|
return v |
||||
|
|
||||
|
def dict(self, **kwargs: Any) -> Dict[str, Any]: |
||||
|
return { |
||||
|
"__root__": self.__root__ |
||||
|
if self.__root__.endswith("_serialized") |
||||
|
else f"{self.__root__}_serialized" |
||||
|
} |
||||
|
|
||||
|
def parse(self): |
||||
|
return self.__root__[len("foo_") :] # removeprefix |
||||
|
|
||||
|
class DictWrap(BaseModel): |
||||
|
__root__: Dict[str, int] |
||||
|
|
||||
|
|
||||
|
@app.get("/query/basic") |
||||
|
def query_basic(q: Basic = Query()): |
||||
|
return {"q": q} |
||||
|
|
||||
|
|
||||
|
@app.get("/query/fieldwrap") |
||||
|
def query_fieldwrap(q: FieldWrap = Query()): |
||||
|
return {"q": q} |
||||
|
|
||||
|
|
||||
|
@app.get("/query/customparsed") |
||||
|
def query_customparsed(q: CustomParsed = Query()): |
||||
|
return {"q": q.parse()} |
||||
|
|
||||
|
|
||||
|
@app.get("/path/basic/{p}") |
||||
|
def path_basic(p: Basic = Path()): |
||||
|
return {"p": p} |
||||
|
|
||||
|
|
||||
|
@app.get("/path/fieldwrap/{p}") |
||||
|
def path_fieldwrap(p: FieldWrap = Path()): |
||||
|
return {"p": p} |
||||
|
|
||||
|
|
||||
|
@app.get("/path/customparsed/{p}") |
||||
|
def path_customparsed(p: CustomParsed = Path()): |
||||
|
return {"p": p.parse()} |
||||
|
|
||||
|
|
||||
|
@app.post("/body/basic") |
||||
|
def body_basic(b: Basic = Body()): |
||||
|
return {"b": b} |
||||
|
|
||||
|
|
||||
|
@app.post("/body/fieldwrap") |
||||
|
def body_fieldwrap(b: FieldWrap = Body()): |
||||
|
return {"b": b} |
||||
|
|
||||
|
|
||||
|
@app.post("/body/customparsed") |
||||
|
def body_customparsed(b: CustomParsed = Body()): |
||||
|
return {"b": b.parse()} |
||||
|
|
||||
|
|
||||
|
@app.get("/echo/basic") |
||||
|
def echo_basic(q: Basic = Query()) -> Basic: |
||||
|
return q |
||||
|
|
||||
|
|
||||
|
@app.get("/echo/fieldwrap") |
||||
|
def echo_fieldwrap(q: FieldWrap = Query()) -> FieldWrap: |
||||
|
return q |
||||
|
|
||||
|
|
||||
|
@app.get("/echo/customparsed") |
||||
|
def echo_customparsed(q: CustomParsed = Query()) -> CustomParsed: |
||||
|
return q |
||||
|
|
||||
|
|
||||
|
@app.post("/echo/dictwrap") |
||||
|
def echo_dictwrap(b: DictWrap) -> DictWrap: |
||||
|
return b.model_dump() if PYDANTIC_V2 else b.dict()["__root__"] |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize( |
||||
|
"url, response_json, request_body", |
||||
|
[ |
||||
|
("/query/basic?q=42", {"q": 42}, None), |
||||
|
("/query/fieldwrap?q=bar_baz", {"q": "bar_baz"}, None), |
||||
|
("/query/customparsed?q=foo_bar", {"q": "bar"}, None), |
||||
|
("/path/basic/42", {"p": 42}, None), |
||||
|
("/path/fieldwrap/bar_baz", {"p": "bar_baz"}, None), |
||||
|
("/path/customparsed/foo_bar", {"p": "bar"}, None), |
||||
|
("/body/basic", {"b": 42}, "42"), |
||||
|
("/body/fieldwrap", {"b": "bar_baz"}, "bar_baz"), |
||||
|
("/body/customparsed", {"b": "bar"}, "foo_bar"), |
||||
|
("/echo/basic?q=42", 42, None), |
||||
|
("/echo/fieldwrap?q=bar_baz", "bar_baz", None), |
||||
|
("/echo/customparsed?q=foo_bar", "foo_bar_serialized", None), |
||||
|
("/echo/dictwrap", {"test": 1}, {"test": 1}), |
||||
|
], |
||||
|
) |
||||
|
def test_root_model_200(url: str, response_json: Any, request_body: Any): |
||||
|
response = client.post(url, json=request_body) if request_body else client.get(url) |
||||
|
assert response.status_code == 200, response.text |
||||
|
assert response.json() == response_json |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize( |
||||
|
"url, error_path, request_body", |
||||
|
[ |
||||
|
("/query/basic?q=my_bad_not_int", ["query", "q"], None), |
||||
|
("/path/basic/my_bad_not_int", ["path", "p"], None), |
||||
|
("/body/basic", ["body"], "my_bad_not_int"), |
||||
|
], |
||||
|
) |
||||
|
def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any): |
||||
|
response = client.post(url, json=request_body) if request_body else client.get(url) |
||||
|
assert response.status_code == 422, response.text |
||||
|
assert response.json() == IsDict( |
||||
|
{ |
||||
|
"detail": [ |
||||
|
{ |
||||
|
"type": "int_parsing", |
||||
|
"loc": error_path, |
||||
|
"msg": "Input should be a valid integer, unable to parse string as an integer", |
||||
|
"input": "my_bad_not_int", |
||||
|
"url": match_pydantic_error_url("int_parsing"), |
||||
|
} |
||||
|
] |
||||
|
} |
||||
|
) | IsDict( |
||||
|
# TODO: remove when deprecating Pydantic v1 |
||||
|
{ |
||||
|
"detail": [ |
||||
|
{ |
||||
|
"loc": [*error_path, "__root__"], |
||||
|
"msg": "value is not a valid integer", |
||||
|
"type": "type_error.integer", |
||||
|
} |
||||
|
] |
||||
|
} |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize( |
||||
|
"url, error_path, request_body", |
||||
|
[ |
||||
|
("/query/fieldwrap?q=my_bad_prefix_val", ["query", "q"], None), |
||||
|
("/path/fieldwrap/my_bad_prefix_val", ["path", "p"], None), |
||||
|
("/body/fieldwrap", ["body"], "my_bad_prefix_val"), |
||||
|
], |
||||
|
) |
||||
|
def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: Any): |
||||
|
response = client.post(url, json=request_body) if request_body else client.get(url) |
||||
|
assert response.status_code == 422, response.text |
||||
|
assert response.json() == IsDict( |
||||
|
{ |
||||
|
"detail": [ |
||||
|
{ |
||||
|
"type": "string_pattern_mismatch", |
||||
|
"loc": error_path, |
||||
|
"msg": "String should match pattern '^bar_.*$'", |
||||
|
"input": "my_bad_prefix_val", |
||||
|
"url": match_pydantic_error_url("string_pattern_mismatch"), |
||||
|
"ctx": {"pattern": "^bar_.*$"}, |
||||
|
} |
||||
|
] |
||||
|
} |
||||
|
) | IsDict( |
||||
|
# TODO: remove when deprecating Pydantic v1 |
||||
|
{ |
||||
|
"detail": [ |
||||
|
{ |
||||
|
"loc": [*error_path, "__root__"], |
||||
|
"msg": 'string does not match regex "^bar_.*$"', |
||||
|
"type": "value_error.str.regex", |
||||
|
"ctx": {"pattern": "^bar_.*$"}, |
||||
|
} |
||||
|
] |
||||
|
} |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize( |
||||
|
"url, error_path, request_body", |
||||
|
[ |
||||
|
("/query/customparsed?q=my_bad_prefix_val", ["query", "q"], None), |
||||
|
("/path/customparsed/my_bad_prefix_val", ["path", "p"], None), |
||||
|
("/body/customparsed", ["body"], "my_bad_prefix_val"), |
||||
|
], |
||||
|
) |
||||
|
def test_root_model_customparsed_422( |
||||
|
url: str, error_path: List[str], request_body: Any |
||||
|
): |
||||
|
response = client.post(url, json=request_body) if request_body else client.get(url) |
||||
|
assert response.status_code == 422, response.text |
||||
|
assert response.json() == IsDict( |
||||
|
{ |
||||
|
"detail": [ |
||||
|
{ |
||||
|
"type": "value_error", |
||||
|
"loc": error_path, |
||||
|
"msg": "Value error, must start with foo_", |
||||
|
"input": "my_bad_prefix_val", |
||||
|
"ctx": {"error": {}}, |
||||
|
"url": match_pydantic_error_url("value_error"), |
||||
|
} |
||||
|
] |
||||
|
} |
||||
|
) | IsDict( |
||||
|
# TODO: remove when deprecating Pydantic v1 |
||||
|
{ |
||||
|
"detail": [ |
||||
|
{ |
||||
|
"loc": [*error_path, "__root__"], |
||||
|
"msg": "must start with foo_", |
||||
|
"type": "value_error", |
||||
|
} |
||||
|
] |
||||
|
} |
||||
|
) |
||||
|
|
||||
|
|
||||
|
def test_root_model_dictwrap_422(): |
||||
|
response = client.post("/echo/dictwrap", json={"test": "fail_not_int"}) |
||||
|
assert response.status_code == 422, response.text |
||||
|
assert response.json() == IsDict( |
||||
|
{ |
||||
|
"detail": [ |
||||
|
{ |
||||
|
"type": "int_parsing", |
||||
|
"loc": ["body", "test"], |
||||
|
"msg": "Input should be a valid integer, unable to parse string as an integer", |
||||
|
"input": "fail_not_int", |
||||
|
"url": match_pydantic_error_url("int_parsing"), |
||||
|
} |
||||
|
] |
||||
|
} |
||||
|
) | IsDict( |
||||
|
# TODO: remove when deprecating Pydantic v1 |
||||
|
{ |
||||
|
"detail": [ |
||||
|
{ |
||||
|
"loc": ["body", "__root__", "test"], |
||||
|
"msg": "value is not a valid integer", |
||||
|
"type": "type_error.integer", |
||||
|
} |
||||
|
] |
||||
|
} |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize( |
||||
|
"model, model_schema", |
||||
|
[ |
||||
|
(Basic, {"title": "Basic", "type": "integer"}), |
||||
|
( |
||||
|
FieldWrap, |
||||
|
{ |
||||
|
"title": "FieldWrap", |
||||
|
"type": "string", |
||||
|
"pattern": "^bar_.*$", |
||||
|
"description": "parameter starts with bar_", |
||||
|
}, |
||||
|
), |
||||
|
( |
||||
|
CustomParsed, |
||||
|
{ |
||||
|
"title": "CustomParsed", |
||||
|
"type": "string", |
||||
|
"description": "parameter starts with foo_", |
||||
|
}, |
||||
|
), |
||||
|
], |
||||
|
) |
||||
|
def test_openapi_schema(model: Type, model_schema: Dict[str, Any]): |
||||
|
response = client.get("/openapi.json") |
||||
|
assert response.status_code == 200, response.text |
||||
|
|
||||
|
paths = response.json()["paths"] |
||||
|
ref = {"schema": {"$ref": f"#/components/schemas/{model.__name__}"}} |
||||
|
assert paths[f"/query/{model.__name__.lower()}"]["get"]["parameters"] == [ |
||||
|
{"in": "query", "name": "q", "required": True, **ref} |
||||
|
] |
||||
|
assert paths[f"/path/{model.__name__.lower()}/{{p}}"]["get"]["parameters"] == [ |
||||
|
{"in": "path", "name": "p", "required": True, **ref} |
||||
|
] |
||||
|
assert paths[f"/body/{model.__name__.lower()}"]["post"]["requestBody"] == { |
||||
|
"content": {"application/json": ref}, |
||||
|
"required": True, |
||||
|
} |
||||
|
assert paths[f"/echo/{model.__name__.lower()}"]["get"]["responses"]["200"] == { |
||||
|
"content": {"application/json": ref}, |
||||
|
"description": "Successful Response", |
||||
|
} |
||||
|
assert response.json()["components"]["schemas"][model.__name__] == { |
||||
|
"title": model.__name__, |
||||
|
**model_schema, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
def test_openapi_schema_dictwrap(): |
||||
|
response = client.get("/openapi.json") |
||||
|
assert response.status_code == 200, response.text |
||||
|
|
||||
|
operation = response.json()["paths"]["/echo/dictwrap"]["post"] |
||||
|
ref = {"schema": {"$ref": "#/components/schemas/DictWrap"}} |
||||
|
assert operation["requestBody"] == { |
||||
|
"content": {"application/json": ref}, |
||||
|
"required": True, |
||||
|
} |
||||
|
assert operation["responses"]["200"] == { |
||||
|
"content": {"application/json": ref}, |
||||
|
"description": "Successful Response", |
||||
|
} |
||||
|
assert response.json()["components"]["schemas"]["DictWrap"] == { |
||||
|
"title": "DictWrap", |
||||
|
"type": "object", |
||||
|
"additionalProperties": {"type": "integer"}, |
||||
|
} |
Loading…
Reference in new issue