pythonasyncioapiasyncfastapiframeworkjsonjson-schemaopenapiopenapi3pydanticpython-typespython3redocreststarletteswaggerswagger-uiuvicornweb
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
519 lines
16 KiB
519 lines
16 KiB
from datetime import date
|
|
from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, Union
|
|
|
|
import pytest
|
|
from dirty_equals import IsDict
|
|
from fastapi import Body, FastAPI, Header, Path, Query
|
|
from fastapi._compat import PYDANTIC_V2
|
|
from fastapi.testclient import TestClient
|
|
from pydantic import BaseModel, Field
|
|
from typing_extensions import Annotated
|
|
|
|
from .utils import needs_pydanticv2
|
|
|
|
app = FastAPI()
|
|
|
|
if PYDANTIC_V2:
|
|
from pydantic import RootModel
|
|
else:
|
|
from pydantic import BaseModel
|
|
|
|
T = TypeVar("T")
|
|
|
|
class RootModel(Generic[T]):
|
|
def __class_getitem__(cls, type_):
|
|
class Inner(BaseModel):
|
|
__root__: type_
|
|
|
|
return Inner
|
|
|
|
|
|
Basic = RootModel[int]
|
|
|
|
|
|
class DictWrap(RootModel[Dict[str, int]]):
|
|
pass
|
|
|
|
|
|
if PYDANTIC_V2:
|
|
from pydantic import ConfigDict, Field, field_validator, model_serializer
|
|
|
|
class FieldWrap(RootModel[str]):
|
|
model_config = ConfigDict(
|
|
json_schema_extra={"description": "parameter starts with bar_"}
|
|
)
|
|
root: str = Field(pattern=r"^bar_.*$")
|
|
|
|
class CustomParsed(RootModel[str]):
|
|
model_config = ConfigDict(
|
|
json_schema_extra={"description": "parameter starts with foo_"}
|
|
)
|
|
|
|
@field_validator("root")
|
|
@classmethod
|
|
def validator(cls, v: str) -> str:
|
|
if not v.startswith("foo_"):
|
|
raise ValueError("must start with foo_")
|
|
return v
|
|
|
|
@model_serializer
|
|
def serialize(self):
|
|
return (
|
|
self.root
|
|
if self.root.endswith("_serialized")
|
|
else f"{self.root}_serialized"
|
|
)
|
|
|
|
def parse(self):
|
|
return self.root[len("foo_") :] # removeprefix
|
|
else:
|
|
from pydantic import validator
|
|
|
|
class FieldWrap(BaseModel):
|
|
class Config:
|
|
schema_extra = {"description": "parameter starts with bar_"}
|
|
|
|
__root__: str = Field(regex=r"^bar_.*$")
|
|
|
|
class CustomParsed(BaseModel):
|
|
class Config:
|
|
schema_extra = {"description": "parameter starts with foo_"}
|
|
|
|
__root__: str
|
|
|
|
@validator("__root__")
|
|
@classmethod
|
|
def validator(cls, v: str) -> str:
|
|
if not v.startswith("foo_"):
|
|
raise ValueError("must start with foo_")
|
|
return v
|
|
|
|
def dict(self, **kwargs: Any) -> Dict[str, Any]:
|
|
return {
|
|
"__root__": self.__root__
|
|
if self.__root__.endswith("_serialized")
|
|
else f"{self.__root__}_serialized"
|
|
}
|
|
|
|
def parse(self):
|
|
return self.__root__[len("foo_") :] # removeprefix
|
|
|
|
|
|
@app.get("/query/basic")
|
|
def query_basic(q: Annotated[Basic, Query()]):
|
|
return {"q": q}
|
|
|
|
|
|
@app.get("/query/fieldwrap")
|
|
def query_fieldwrap(q: Annotated[FieldWrap, Query()]):
|
|
return {"q": q}
|
|
|
|
|
|
@app.get("/query/customparsed")
|
|
def query_customparsed(q: Annotated[CustomParsed, Query()]):
|
|
return {"q": q.parse()}
|
|
|
|
|
|
@app.get("/path/basic/{p}")
|
|
def path_basic(p: Annotated[Basic, Path()]):
|
|
return {"p": p}
|
|
|
|
|
|
@app.get("/path/fieldwrap/{p}")
|
|
def path_fieldwrap(p: Annotated[FieldWrap, Path()]):
|
|
return {"p": p}
|
|
|
|
|
|
@app.get("/path/customparsed/{p}")
|
|
def path_customparsed(p: Annotated[CustomParsed, Path()]):
|
|
return {"p": p.parse()}
|
|
|
|
|
|
@app.post("/body/basic")
|
|
def body_basic(b: Annotated[Basic, Body()]):
|
|
return {"b": b}
|
|
|
|
|
|
@app.post("/body/fieldwrap")
|
|
def body_fieldwrap(b: Annotated[FieldWrap, Body()]):
|
|
return {"b": b}
|
|
|
|
|
|
@app.post("/body/customparsed")
|
|
def body_customparsed(b: Annotated[CustomParsed, Body()]):
|
|
return {"b": b.parse()}
|
|
|
|
|
|
@app.post("/body_default/basic")
|
|
def body_default_basic(b: Basic):
|
|
return {"b": b}
|
|
|
|
|
|
@app.post("/body_default/fieldwrap")
|
|
def body_default_fieldwrap(b: FieldWrap):
|
|
return {"b": b}
|
|
|
|
|
|
@app.post("/body_default/customparsed")
|
|
def body_default_customparsed(b: CustomParsed):
|
|
return {"b": b.parse()}
|
|
|
|
|
|
@app.get("/echo/basic")
|
|
def echo_basic(q: Annotated[Basic, Query()]) -> Basic:
|
|
return q
|
|
|
|
|
|
@app.get("/echo/fieldwrap")
|
|
def echo_fieldwrap(q: Annotated[FieldWrap, Query()]) -> FieldWrap:
|
|
return q
|
|
|
|
|
|
@app.get("/echo/customparsed")
|
|
def echo_customparsed(q: Annotated[CustomParsed, Query()]) -> CustomParsed:
|
|
return q
|
|
|
|
|
|
@app.post("/echo/dictwrap")
|
|
def echo_dictwrap(b: DictWrap) -> DictWrap:
|
|
return b.model_dump() if PYDANTIC_V2 else b.dict()["__root__"]
|
|
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"url, response_json, request_body",
|
|
[
|
|
("/query/basic?q=42", {"q": 42}, None),
|
|
("/query/fieldwrap?q=bar_baz", {"q": "bar_baz"}, None),
|
|
("/query/customparsed?q=foo_bar", {"q": "bar"}, None),
|
|
("/path/basic/42", {"p": 42}, None),
|
|
("/path/fieldwrap/bar_baz", {"p": "bar_baz"}, None),
|
|
("/path/customparsed/foo_bar", {"p": "bar"}, None),
|
|
("/body/basic", {"b": 42}, "42"),
|
|
("/body/fieldwrap", {"b": "bar_baz"}, "bar_baz"),
|
|
("/body/customparsed", {"b": "bar"}, "foo_bar"),
|
|
("/body_default/basic", {"b": 42}, "42"),
|
|
("/body_default/fieldwrap", {"b": "bar_baz"}, "bar_baz"),
|
|
("/body_default/customparsed", {"b": "bar"}, "foo_bar"),
|
|
("/echo/basic?q=42", 42, None),
|
|
("/echo/fieldwrap?q=bar_baz", "bar_baz", None),
|
|
("/echo/customparsed?q=foo_bar", "foo_bar_serialized", None),
|
|
("/echo/dictwrap", {"test": 1}, {"test": 1}),
|
|
],
|
|
)
|
|
def test_root_model_200(url: str, response_json: Any, request_body: Any):
|
|
response = client.post(url, json=request_body) if request_body else client.get(url)
|
|
assert response.status_code == 200, response.text
|
|
assert response.json() == response_json
|
|
|
|
|
|
@pytest.mark.parametrize("type_", [int, str, bytes, Optional[int]])
|
|
@pytest.mark.parametrize("outer", [None, List, Sequence])
|
|
def test2_root_model_200__basic(type_: Type, outer: Optional[Type]):
|
|
inner = outer[type_] if outer else type_
|
|
Model = RootModel[inner]
|
|
|
|
app2 = FastAPI()
|
|
|
|
if outer is None:
|
|
|
|
@app2.get("/path/{p}")
|
|
def path_basic2(p: Annotated[Model, Path()]) -> Model:
|
|
return p
|
|
else:
|
|
with pytest.raises(
|
|
AssertionError, match="Path params must be of one of the supported types"
|
|
):
|
|
|
|
@app2.get("/path/{p}")
|
|
def path_basic2(p: Annotated[Model, Path()]) -> Model:
|
|
return p # pragma: nocover
|
|
|
|
@app2.get("/query")
|
|
def query_basic2(q: Annotated[Model, Query()]) -> Model:
|
|
return q
|
|
|
|
@app2.get("/header")
|
|
def path_basic2(h: Annotated[Model, Header()]) -> Model:
|
|
return h
|
|
|
|
@app2.post("/body")
|
|
def body_basic2(b: Annotated[Model, Body()]) -> Model:
|
|
return b
|
|
|
|
client2 = TestClient(app2)
|
|
|
|
if outer is None:
|
|
expected = "42" if type_ in (str, bytes) else 42
|
|
else:
|
|
expected = ["42", "43"] if type_ in (str, bytes) else [42, 43]
|
|
|
|
if outer is None:
|
|
assert client2.get("/path/42").json() == expected
|
|
assert client2.get("/query?q=42").json() == expected
|
|
assert client2.get("/header", headers={"h": "42"}).json() == expected
|
|
else:
|
|
assert client2.get("/path/42").json()["detail"] == "Not Found"
|
|
assert client2.get("/query?q=42&q=43").json() == expected
|
|
assert (
|
|
client2.get("/header", headers=[("h", "42"), ("h", "43")]).json()
|
|
== expected
|
|
)
|
|
assert client2.post("/body", json=expected).json() == expected
|
|
|
|
|
|
@pytest.mark.parametrize("left", [RootModel[date], date])
|
|
@pytest.mark.parametrize("right", [RootModel[int], int])
|
|
@needs_pydanticv2
|
|
def test_root_model_union(left: Any, right: Any):
|
|
Model = Union[left, right]
|
|
|
|
app2 = FastAPI()
|
|
|
|
@app2.get("/query")
|
|
def handler1(q: Annotated[Model, Query()]):
|
|
return {"q": q}
|
|
|
|
@app2.post("/body")
|
|
def handler2(b: Annotated[Model, Body()]):
|
|
return {"b": b}
|
|
|
|
client2 = TestClient(app2)
|
|
for val in ["2025-01-02", 42]:
|
|
response = client2.get(f"/query?q={val}")
|
|
assert response.status_code == 200, response.text
|
|
assert response.json() == {"q": val}
|
|
response = client2.post("/body", json=val)
|
|
assert response.status_code == 200, response.text
|
|
assert response.json() == {"b": val}
|
|
|
|
response = client2.get("/query?q=bad")
|
|
assert response.status_code == 422, response.text
|
|
assert {detail["msg"] for detail in response.json()["detail"]} == {
|
|
"Input should be a valid date or datetime, input is too short",
|
|
"Input should be a valid integer, unable to parse string as an integer",
|
|
}
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"url, error_path, request_body",
|
|
[
|
|
("/query/basic?q=my_bad_not_int", ["query", "q"], None),
|
|
("/path/basic/my_bad_not_int", ["path", "p"], None),
|
|
("/body/basic", ["body"], "my_bad_not_int"),
|
|
("/body_default/basic", ["body"], "my_bad_not_int"),
|
|
],
|
|
)
|
|
def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any):
|
|
response = client.post(url, json=request_body) if request_body else client.get(url)
|
|
assert response.status_code == 422, response.text
|
|
assert response.json() == IsDict(
|
|
{
|
|
"detail": [
|
|
{
|
|
"type": "int_parsing",
|
|
"loc": error_path,
|
|
"msg": "Input should be a valid integer, unable to parse string as an integer",
|
|
"input": "my_bad_not_int",
|
|
}
|
|
]
|
|
}
|
|
) | IsDict(
|
|
# TODO: remove when deprecating Pydantic v1
|
|
{
|
|
"detail": [
|
|
{
|
|
"loc": [*error_path, "__root__"],
|
|
"msg": "value is not a valid integer",
|
|
"type": "type_error.integer",
|
|
}
|
|
]
|
|
}
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"url, error_path, request_body",
|
|
[
|
|
("/query/fieldwrap?q=my_bad_prefix_val", ["query", "q"], None),
|
|
("/path/fieldwrap/my_bad_prefix_val", ["path", "p"], None),
|
|
("/body/fieldwrap", ["body"], "my_bad_prefix_val"),
|
|
("/body_default/fieldwrap", ["body"], "my_bad_prefix_val"),
|
|
],
|
|
)
|
|
def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: Any):
|
|
response = client.post(url, json=request_body) if request_body else client.get(url)
|
|
assert response.status_code == 422, response.text
|
|
assert response.json() == IsDict(
|
|
{
|
|
"detail": [
|
|
{
|
|
"type": "string_pattern_mismatch",
|
|
"loc": error_path,
|
|
"msg": "String should match pattern '^bar_.*$'",
|
|
"input": "my_bad_prefix_val",
|
|
"ctx": {"pattern": "^bar_.*$"},
|
|
}
|
|
]
|
|
}
|
|
) | IsDict(
|
|
# TODO: remove when deprecating Pydantic v1
|
|
{
|
|
"detail": [
|
|
{
|
|
"loc": [*error_path, "__root__"],
|
|
"msg": 'string does not match regex "^bar_.*$"',
|
|
"type": "value_error.str.regex",
|
|
"ctx": {"pattern": "^bar_.*$"},
|
|
}
|
|
]
|
|
}
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"url, error_path, request_body",
|
|
[
|
|
("/query/customparsed?q=my_bad_prefix_val", ["query", "q"], None),
|
|
("/path/customparsed/my_bad_prefix_val", ["path", "p"], None),
|
|
("/body/customparsed", ["body"], "my_bad_prefix_val"),
|
|
("/body_default/customparsed", ["body"], "my_bad_prefix_val"),
|
|
],
|
|
)
|
|
def test_root_model_customparsed_422(
|
|
url: str, error_path: List[str], request_body: Any
|
|
):
|
|
response = client.post(url, json=request_body) if request_body else client.get(url)
|
|
assert response.status_code == 422, response.text
|
|
assert response.json() == IsDict(
|
|
{
|
|
"detail": [
|
|
{
|
|
"type": "value_error",
|
|
"loc": error_path,
|
|
"msg": "Value error, must start with foo_",
|
|
"input": "my_bad_prefix_val",
|
|
"ctx": {"error": {}},
|
|
}
|
|
]
|
|
}
|
|
) | IsDict(
|
|
# TODO: remove when deprecating Pydantic v1
|
|
{
|
|
"detail": [
|
|
{
|
|
"loc": [*error_path, "__root__"],
|
|
"msg": "must start with foo_",
|
|
"type": "value_error",
|
|
}
|
|
]
|
|
}
|
|
)
|
|
|
|
|
|
def test_root_model_dictwrap_422():
|
|
response = client.post("/echo/dictwrap", json={"test": "fail_not_int"})
|
|
assert response.status_code == 422, response.text
|
|
assert response.json() == IsDict(
|
|
{
|
|
"detail": [
|
|
{
|
|
"type": "int_parsing",
|
|
"loc": ["body", "test"],
|
|
"msg": "Input should be a valid integer, unable to parse string as an integer",
|
|
"input": "fail_not_int",
|
|
}
|
|
]
|
|
}
|
|
) | IsDict(
|
|
# TODO: remove when deprecating Pydantic v1
|
|
{
|
|
"detail": [
|
|
{
|
|
"loc": ["body", "__root__", "test"],
|
|
"msg": "value is not a valid integer",
|
|
"type": "type_error.integer",
|
|
}
|
|
]
|
|
}
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model, path_name, expected_model_schema",
|
|
[
|
|
(Basic, "basic", {"type": "integer"}),
|
|
(
|
|
FieldWrap,
|
|
"fieldwrap",
|
|
{
|
|
"type": "string",
|
|
"pattern": "^bar_.*$",
|
|
"description": "parameter starts with bar_",
|
|
},
|
|
),
|
|
(
|
|
CustomParsed,
|
|
"customparsed",
|
|
{
|
|
"type": "string",
|
|
"description": "parameter starts with foo_",
|
|
},
|
|
),
|
|
],
|
|
)
|
|
def test_openapi_schema(
|
|
model: Type, path_name: str, expected_model_schema: Dict[str, Any]
|
|
):
|
|
response = client.get("/openapi.json")
|
|
assert response.status_code == 200, response.text
|
|
|
|
paths = response.json()["paths"]
|
|
ref_name = model.__name__.replace("[", "_").replace("]", "_")
|
|
schema_ref = {"schema": {"$ref": f"#/components/schemas/{ref_name}"}}
|
|
|
|
assert paths[f"/query/{path_name}"]["get"]["parameters"] == [
|
|
{"in": "query", "name": "q", "required": True, **schema_ref}
|
|
]
|
|
assert paths[f"/path/{path_name}/{{p}}"]["get"]["parameters"] == [
|
|
{"in": "path", "name": "p", "required": True, **schema_ref}
|
|
]
|
|
assert paths[f"/body/{path_name}"]["post"]["requestBody"] == {
|
|
"content": {"application/json": schema_ref},
|
|
"required": True,
|
|
}
|
|
assert paths[f"/body_default/{path_name}"]["post"]["requestBody"] == {
|
|
"content": {"application/json": schema_ref},
|
|
"required": True,
|
|
}
|
|
assert paths[f"/echo/{path_name}"]["get"]["responses"]["200"] == {
|
|
"content": {"application/json": schema_ref},
|
|
"description": "Successful Response",
|
|
}
|
|
|
|
model_schema = response.json()["components"]["schemas"][ref_name]
|
|
model_schema.pop("title")
|
|
assert model_schema == expected_model_schema
|
|
|
|
|
|
def test_openapi_schema_dictwrap():
|
|
response = client.get("/openapi.json")
|
|
assert response.status_code == 200, response.text
|
|
|
|
operation = response.json()["paths"]["/echo/dictwrap"]["post"]
|
|
ref = {"schema": {"$ref": "#/components/schemas/DictWrap"}}
|
|
assert operation["requestBody"] == {
|
|
"content": {"application/json": ref},
|
|
"required": True,
|
|
}
|
|
assert operation["responses"]["200"] == {
|
|
"content": {"application/json": ref},
|
|
"description": "Successful Response",
|
|
}
|
|
assert response.json()["components"]["schemas"]["DictWrap"] == {
|
|
"title": "DictWrap",
|
|
"type": "object",
|
|
"additionalProperties": {"type": "integer"},
|
|
}
|
|
|