|
@ -1,18 +1,42 @@ |
|
|
from typing import Any, Dict, List, Type, Union |
|
|
from datetime import date |
|
|
|
|
|
from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, Union |
|
|
|
|
|
|
|
|
import pytest |
|
|
import pytest |
|
|
from dirty_equals import IsDict |
|
|
from dirty_equals import IsDict |
|
|
from fastapi import Body, FastAPI, Path, Query |
|
|
from fastapi import Body, FastAPI, Header, Path, Query |
|
|
from fastapi._compat import PYDANTIC_V2 |
|
|
from fastapi._compat import PYDANTIC_V2 |
|
|
from fastapi.testclient import TestClient |
|
|
from fastapi.testclient import TestClient |
|
|
from pydantic import BaseModel |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
from typing_extensions import Annotated |
|
|
|
|
|
|
|
|
|
|
|
from .utils import needs_pydanticv2 |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
if PYDANTIC_V2: |
|
|
if PYDANTIC_V2: |
|
|
from pydantic import ConfigDict, Field, RootModel, field_validator, model_serializer |
|
|
from pydantic import RootModel |
|
|
|
|
|
else: |
|
|
|
|
|
from pydantic import BaseModel |
|
|
|
|
|
|
|
|
|
|
|
T = TypeVar("T") |
|
|
|
|
|
|
|
|
|
|
|
class RootModel(Generic[T]): |
|
|
|
|
|
def __class_getitem__(cls, type_): |
|
|
|
|
|
class Inner(BaseModel): |
|
|
|
|
|
__root__: type_ |
|
|
|
|
|
|
|
|
|
|
|
return Inner |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Basic = RootModel[int] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DictWrap(RootModel[Dict[str, int]]): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Basic = RootModel[int] |
|
|
if PYDANTIC_V2: |
|
|
|
|
|
from pydantic import ConfigDict, Field, field_validator, model_serializer |
|
|
|
|
|
|
|
|
class FieldWrap(RootModel[str]): |
|
|
class FieldWrap(RootModel[str]): |
|
|
model_config = ConfigDict( |
|
|
model_config = ConfigDict( |
|
@ -42,14 +66,8 @@ if PYDANTIC_V2: |
|
|
|
|
|
|
|
|
def parse(self): |
|
|
def parse(self): |
|
|
return self.root[len("foo_") :] # removeprefix |
|
|
return self.root[len("foo_") :] # removeprefix |
|
|
|
|
|
|
|
|
class DictWrap(RootModel[Dict[str, int]]): |
|
|
|
|
|
pass |
|
|
|
|
|
else: |
|
|
else: |
|
|
from pydantic import Field, validator |
|
|
from pydantic import validator |
|
|
|
|
|
|
|
|
class Basic(BaseModel): |
|
|
|
|
|
__root__: int |
|
|
|
|
|
|
|
|
|
|
|
class FieldWrap(BaseModel): |
|
|
class FieldWrap(BaseModel): |
|
|
class Config: |
|
|
class Config: |
|
@ -80,52 +98,49 @@ else: |
|
|
def parse(self): |
|
|
def parse(self): |
|
|
return self.__root__[len("foo_") :] # removeprefix |
|
|
return self.__root__[len("foo_") :] # removeprefix |
|
|
|
|
|
|
|
|
class DictWrap(BaseModel): |
|
|
|
|
|
__root__: Dict[str, int] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/query/basic") |
|
|
@app.get("/query/basic") |
|
|
def query_basic(q: Basic = Query()): |
|
|
def query_basic(q: Annotated[Basic, Query()]): |
|
|
return {"q": q} |
|
|
return {"q": q} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/query/fieldwrap") |
|
|
@app.get("/query/fieldwrap") |
|
|
def query_fieldwrap(q: FieldWrap = Query()): |
|
|
def query_fieldwrap(q: Annotated[FieldWrap, Query()]): |
|
|
return {"q": q} |
|
|
return {"q": q} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/query/customparsed") |
|
|
@app.get("/query/customparsed") |
|
|
def query_customparsed(q: CustomParsed = Query()): |
|
|
def query_customparsed(q: Annotated[CustomParsed, Query()]): |
|
|
return {"q": q.parse()} |
|
|
return {"q": q.parse()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/path/basic/{p}") |
|
|
@app.get("/path/basic/{p}") |
|
|
def path_basic(p: Basic = Path()): |
|
|
def path_basic(p: Annotated[Basic, Path()]): |
|
|
return {"p": p} |
|
|
return {"p": p} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/path/fieldwrap/{p}") |
|
|
@app.get("/path/fieldwrap/{p}") |
|
|
def path_fieldwrap(p: FieldWrap = Path()): |
|
|
def path_fieldwrap(p: Annotated[FieldWrap, Path()]): |
|
|
return {"p": p} |
|
|
return {"p": p} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/path/customparsed/{p}") |
|
|
@app.get("/path/customparsed/{p}") |
|
|
def path_customparsed(p: CustomParsed = Path()): |
|
|
def path_customparsed(p: Annotated[CustomParsed, Path()]): |
|
|
return {"p": p.parse()} |
|
|
return {"p": p.parse()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/body/basic") |
|
|
@app.post("/body/basic") |
|
|
def body_basic(b: Basic = Body()): |
|
|
def body_basic(b: Annotated[Basic, Body()]): |
|
|
return {"b": b} |
|
|
return {"b": b} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/body/fieldwrap") |
|
|
@app.post("/body/fieldwrap") |
|
|
def body_fieldwrap(b: FieldWrap = Body()): |
|
|
def body_fieldwrap(b: Annotated[FieldWrap, Body()]): |
|
|
return {"b": b} |
|
|
return {"b": b} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/body/customparsed") |
|
|
@app.post("/body/customparsed") |
|
|
def body_customparsed(b: CustomParsed = Body()): |
|
|
def body_customparsed(b: Annotated[CustomParsed, Body()]): |
|
|
return {"b": b.parse()} |
|
|
return {"b": b.parse()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -145,17 +160,17 @@ def body_default_customparsed(b: CustomParsed): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/echo/basic") |
|
|
@app.get("/echo/basic") |
|
|
def echo_basic(q: Basic = Query()) -> Basic: |
|
|
def echo_basic(q: Annotated[Basic, Query()]) -> Basic: |
|
|
return q |
|
|
return q |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/echo/fieldwrap") |
|
|
@app.get("/echo/fieldwrap") |
|
|
def echo_fieldwrap(q: FieldWrap = Query()) -> FieldWrap: |
|
|
def echo_fieldwrap(q: Annotated[FieldWrap, Query()]) -> FieldWrap: |
|
|
return q |
|
|
return q |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/echo/customparsed") |
|
|
@app.get("/echo/customparsed") |
|
|
def echo_customparsed(q: CustomParsed = Query()) -> CustomParsed: |
|
|
def echo_customparsed(q: Annotated[CustomParsed, Query()]) -> CustomParsed: |
|
|
return q |
|
|
return q |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -194,43 +209,91 @@ def test_root_model_200(url: str, response_json: Any, request_body: Any): |
|
|
assert response.json() == response_json |
|
|
assert response.json() == response_json |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_root_model_union(): |
|
|
@pytest.mark.parametrize("type_", [int, str, bytes, Optional[int]]) |
|
|
if PYDANTIC_V2: |
|
|
@pytest.mark.parametrize("outer", [None, List, Sequence]) |
|
|
from pydantic import RootModel |
|
|
def test2_root_model_200__basic(type_: Type, outer: Optional[Type]): |
|
|
|
|
|
inner = outer[type_] if outer else type_ |
|
|
|
|
|
Model = RootModel[inner] |
|
|
|
|
|
|
|
|
|
|
|
app2 = FastAPI() |
|
|
|
|
|
|
|
|
RootModelInt = RootModel[int] |
|
|
if outer is None: |
|
|
RootModelStr = RootModel[str] |
|
|
|
|
|
|
|
|
@app2.get("/path/{p}") |
|
|
|
|
|
def path_basic2(p: Annotated[Model, Path()]) -> Model: |
|
|
|
|
|
return p |
|
|
else: |
|
|
else: |
|
|
|
|
|
with pytest.raises( |
|
|
|
|
|
AssertionError, match="Path params must be of one of the supported types" |
|
|
|
|
|
): |
|
|
|
|
|
|
|
|
class RootModelInt(BaseModel): |
|
|
@app2.get("/path/{p}") |
|
|
__root__: int |
|
|
def path_basic2(p: Annotated[Model, Path()]) -> Model: |
|
|
|
|
|
return p # pragma: nocover |
|
|
|
|
|
|
|
|
|
|
|
@app2.get("/query") |
|
|
|
|
|
def query_basic2(q: Annotated[Model, Query()]) -> Model: |
|
|
|
|
|
return q |
|
|
|
|
|
|
|
|
|
|
|
@app2.get("/header") |
|
|
|
|
|
def path_basic2(h: Annotated[Model, Header()]) -> Model: |
|
|
|
|
|
return h |
|
|
|
|
|
|
|
|
|
|
|
@app2.post("/body") |
|
|
|
|
|
def body_basic2(b: Annotated[Model, Body()]) -> Model: |
|
|
|
|
|
return b |
|
|
|
|
|
|
|
|
|
|
|
client2 = TestClient(app2) |
|
|
|
|
|
|
|
|
|
|
|
if outer is None: |
|
|
|
|
|
expected = "42" if type_ in (str, bytes) else 42 |
|
|
|
|
|
else: |
|
|
|
|
|
expected = ["42", "43"] if type_ in (str, bytes) else [42, 43] |
|
|
|
|
|
|
|
|
|
|
|
if outer is None: |
|
|
|
|
|
assert client2.get("/path/42").json() == expected |
|
|
|
|
|
assert client2.get("/query?q=42").json() == expected |
|
|
|
|
|
assert client2.get("/header", headers={"h": "42"}).json() == expected |
|
|
|
|
|
else: |
|
|
|
|
|
assert client2.get("/path/42").json()["detail"] == "Not Found" |
|
|
|
|
|
assert client2.get("/query?q=42&q=43").json() == expected |
|
|
|
|
|
assert ( |
|
|
|
|
|
client2.get("/header", headers=[("h", "42"), ("h", "43")]).json() |
|
|
|
|
|
== expected |
|
|
|
|
|
) |
|
|
|
|
|
assert client2.post("/body", json=expected).json() == expected |
|
|
|
|
|
|
|
|
class RootModelStr(BaseModel): |
|
|
|
|
|
__root__: str |
|
|
@pytest.mark.parametrize("left", [RootModel[date], date]) |
|
|
|
|
|
@pytest.mark.parametrize("right", [RootModel[int], int]) |
|
|
|
|
|
@needs_pydanticv2 |
|
|
|
|
|
def test_root_model_union(left: Any, right: Any): |
|
|
|
|
|
Model = Union[left, right] |
|
|
|
|
|
|
|
|
app2 = FastAPI() |
|
|
app2 = FastAPI() |
|
|
|
|
|
|
|
|
@app2.post("/union") |
|
|
@app2.get("/query") |
|
|
def union_handler(b: Union[RootModelInt, RootModelStr]): |
|
|
def handler1(q: Annotated[Model, Query()]): |
|
|
|
|
|
return {"q": q} |
|
|
|
|
|
|
|
|
|
|
|
@app2.post("/body") |
|
|
|
|
|
def handler2(b: Annotated[Model, Body()]): |
|
|
return {"b": b} |
|
|
return {"b": b} |
|
|
|
|
|
|
|
|
client2 = TestClient(app2) |
|
|
client2 = TestClient(app2) |
|
|
for body in [42, "foo"]: |
|
|
for val in ["2025-01-02", 42]: |
|
|
response = client2.post("/union", json=body) |
|
|
response = client2.get(f"/query?q={val}") |
|
|
|
|
|
assert response.status_code == 200, response.text |
|
|
|
|
|
assert response.json() == {"q": val} |
|
|
|
|
|
response = client2.post("/body", json=val) |
|
|
assert response.status_code == 200, response.text |
|
|
assert response.status_code == 200, response.text |
|
|
assert response.json() == {"b": body} |
|
|
assert response.json() == {"b": val} |
|
|
|
|
|
|
|
|
response = client2.post("/union", json=["bad_list"]) |
|
|
response = client2.get("/query?q=bad") |
|
|
assert response.status_code == 422, response.text |
|
|
assert response.status_code == 422, response.text |
|
|
if PYDANTIC_V2: |
|
|
|
|
|
assert {detail["msg"] for detail in response.json()["detail"]} == { |
|
|
|
|
|
"Input should be a valid integer", |
|
|
|
|
|
"Input should be a valid string", |
|
|
|
|
|
} |
|
|
|
|
|
else: |
|
|
|
|
|
assert {detail["msg"] for detail in response.json()["detail"]} == { |
|
|
assert {detail["msg"] for detail in response.json()["detail"]} == { |
|
|
"value is not a valid integer", |
|
|
"Input should be a valid date or datetime, input is too short", |
|
|
"str type expected", |
|
|
"Input should be a valid integer, unable to parse string as an integer", |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|