from datetime import date from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, Union import pytest from dirty_equals import IsDict from fastapi import Body, FastAPI, Header, Path, Query from fastapi._compat import PYDANTIC_V2 from fastapi.testclient import TestClient from pydantic import BaseModel, Field from typing_extensions import Annotated from .utils import needs_pydanticv2 app = FastAPI() if PYDANTIC_V2: from pydantic import RootModel else: from pydantic import BaseModel T = TypeVar("T") class RootModel(Generic[T]): def __class_getitem__(cls, type_): class Inner(BaseModel): __root__: type_ return Inner Basic = RootModel[int] class DictWrap(RootModel[Dict[str, int]]): pass if PYDANTIC_V2: from pydantic import ConfigDict, Field, field_validator, model_serializer class FieldWrap(RootModel[str]): model_config = ConfigDict( json_schema_extra={"description": "parameter starts with bar_"} ) root: str = Field(pattern=r"^bar_.*$") class CustomParsed(RootModel[str]): model_config = ConfigDict( json_schema_extra={"description": "parameter starts with foo_"} ) @field_validator("root") @classmethod def validator(cls, v: str) -> str: if not v.startswith("foo_"): raise ValueError("must start with foo_") return v @model_serializer def serialize(self): return ( self.root if self.root.endswith("_serialized") else f"{self.root}_serialized" ) def parse(self): return self.root[len("foo_") :] # removeprefix else: from pydantic import validator class FieldWrap(BaseModel): class Config: schema_extra = {"description": "parameter starts with bar_"} __root__: str = Field(regex=r"^bar_.*$") class CustomParsed(BaseModel): class Config: schema_extra = {"description": "parameter starts with foo_"} __root__: str @validator("__root__") @classmethod def validator(cls, v: str) -> str: if not v.startswith("foo_"): raise ValueError("must start with foo_") return v def dict(self, **kwargs: Any) -> Dict[str, Any]: return { "__root__": self.__root__ if self.__root__.endswith("_serialized") else f"{self.__root__}_serialized" } def parse(self): return self.__root__[len("foo_") :] # removeprefix @app.get("/query/basic") def query_basic(q: Annotated[Basic, Query()]): return {"q": q} @app.get("/query/fieldwrap") def query_fieldwrap(q: Annotated[FieldWrap, Query()]): return {"q": q} @app.get("/query/customparsed") def query_customparsed(q: Annotated[CustomParsed, Query()]): return {"q": q.parse()} @app.get("/path/basic/{p}") def path_basic(p: Annotated[Basic, Path()]): return {"p": p} @app.get("/path/fieldwrap/{p}") def path_fieldwrap(p: Annotated[FieldWrap, Path()]): return {"p": p} @app.get("/path/customparsed/{p}") def path_customparsed(p: Annotated[CustomParsed, Path()]): return {"p": p.parse()} @app.post("/body/basic") def body_basic(b: Annotated[Basic, Body()]): return {"b": b} @app.post("/body/fieldwrap") def body_fieldwrap(b: Annotated[FieldWrap, Body()]): return {"b": b} @app.post("/body/customparsed") def body_customparsed(b: Annotated[CustomParsed, Body()]): return {"b": b.parse()} @app.post("/body_default/basic") def body_default_basic(b: Basic): return {"b": b} @app.post("/body_default/fieldwrap") def body_default_fieldwrap(b: FieldWrap): return {"b": b} @app.post("/body_default/customparsed") def body_default_customparsed(b: CustomParsed): return {"b": b.parse()} @app.get("/echo/basic") def echo_basic(q: Annotated[Basic, Query()]) -> Basic: return q @app.get("/echo/fieldwrap") def echo_fieldwrap(q: Annotated[FieldWrap, Query()]) -> FieldWrap: return q @app.get("/echo/customparsed") def echo_customparsed(q: Annotated[CustomParsed, Query()]) -> CustomParsed: return q @app.post("/echo/dictwrap") def echo_dictwrap(b: DictWrap) -> DictWrap: return b.model_dump() if PYDANTIC_V2 else b.dict()["__root__"] client = TestClient(app) @pytest.mark.parametrize( "url, response_json, request_body", [ ("/query/basic?q=42", {"q": 42}, None), ("/query/fieldwrap?q=bar_baz", {"q": "bar_baz"}, None), ("/query/customparsed?q=foo_bar", {"q": "bar"}, None), ("/path/basic/42", {"p": 42}, None), ("/path/fieldwrap/bar_baz", {"p": "bar_baz"}, None), ("/path/customparsed/foo_bar", {"p": "bar"}, None), ("/body/basic", {"b": 42}, "42"), ("/body/fieldwrap", {"b": "bar_baz"}, "bar_baz"), ("/body/customparsed", {"b": "bar"}, "foo_bar"), ("/body_default/basic", {"b": 42}, "42"), ("/body_default/fieldwrap", {"b": "bar_baz"}, "bar_baz"), ("/body_default/customparsed", {"b": "bar"}, "foo_bar"), ("/echo/basic?q=42", 42, None), ("/echo/fieldwrap?q=bar_baz", "bar_baz", None), ("/echo/customparsed?q=foo_bar", "foo_bar_serialized", None), ("/echo/dictwrap", {"test": 1}, {"test": 1}), ], ) def test_root_model_200(url: str, response_json: Any, request_body: Any): response = client.post(url, json=request_body) if request_body else client.get(url) assert response.status_code == 200, response.text assert response.json() == response_json @pytest.mark.parametrize("type_", [int, str, bytes, Optional[int]]) @pytest.mark.parametrize("outer", [None, List, Sequence]) def test2_root_model_200__basic(type_: Type, outer: Optional[Type]): inner = outer[type_] if outer else type_ Model = RootModel[inner] app2 = FastAPI() if outer is None: @app2.get("/path/{p}") def path_basic2(p: Annotated[Model, Path()]) -> Model: return p else: with pytest.raises( AssertionError, match="Path params must be of one of the supported types" ): @app2.get("/path/{p}") def path_basic2(p: Annotated[Model, Path()]) -> Model: return p # pragma: nocover @app2.get("/query") def query_basic2(q: Annotated[Model, Query()]) -> Model: return q @app2.get("/header") def path_basic2(h: Annotated[Model, Header()]) -> Model: return h @app2.post("/body") def body_basic2(b: Annotated[Model, Body()]) -> Model: return b client2 = TestClient(app2) if outer is None: expected = "42" if type_ in (str, bytes) else 42 else: expected = ["42", "43"] if type_ in (str, bytes) else [42, 43] if outer is None: assert client2.get("/path/42").json() == expected assert client2.get("/query?q=42").json() == expected assert client2.get("/header", headers={"h": "42"}).json() == expected else: assert client2.get("/path/42").json()["detail"] == "Not Found" assert client2.get("/query?q=42&q=43").json() == expected assert ( client2.get("/header", headers=[("h", "42"), ("h", "43")]).json() == expected ) assert client2.post("/body", json=expected).json() == expected @pytest.mark.parametrize("left", [RootModel[date], date]) @pytest.mark.parametrize("right", [RootModel[int], int]) @needs_pydanticv2 def test_root_model_union(left: Any, right: Any): Model = Union[left, right] app2 = FastAPI() @app2.get("/query") def handler1(q: Annotated[Model, Query()]): return {"q": q} @app2.post("/body") def handler2(b: Annotated[Model, Body()]): return {"b": b} client2 = TestClient(app2) for val in ["2025-01-02", 42]: response = client2.get(f"/query?q={val}") assert response.status_code == 200, response.text assert response.json() == {"q": val} response = client2.post("/body", json=val) assert response.status_code == 200, response.text assert response.json() == {"b": val} response = client2.get("/query?q=bad") assert response.status_code == 422, response.text assert {detail["msg"] for detail in response.json()["detail"]} == { "Input should be a valid date or datetime, input is too short", "Input should be a valid integer, unable to parse string as an integer", } @pytest.mark.parametrize( "url, error_path, request_body", [ ("/query/basic?q=my_bad_not_int", ["query", "q"], None), ("/path/basic/my_bad_not_int", ["path", "p"], None), ("/body/basic", ["body"], "my_bad_not_int"), ("/body_default/basic", ["body"], "my_bad_not_int"), ], ) def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any): response = client.post(url, json=request_body) if request_body else client.get(url) assert response.status_code == 422, response.text assert response.json() == IsDict( { "detail": [ { "type": "int_parsing", "loc": error_path, "msg": "Input should be a valid integer, unable to parse string as an integer", "input": "my_bad_not_int", } ] } ) | IsDict( # TODO: remove when deprecating Pydantic v1 { "detail": [ { "loc": [*error_path, "__root__"], "msg": "value is not a valid integer", "type": "type_error.integer", } ] } ) @pytest.mark.parametrize( "url, error_path, request_body", [ ("/query/fieldwrap?q=my_bad_prefix_val", ["query", "q"], None), ("/path/fieldwrap/my_bad_prefix_val", ["path", "p"], None), ("/body/fieldwrap", ["body"], "my_bad_prefix_val"), ("/body_default/fieldwrap", ["body"], "my_bad_prefix_val"), ], ) def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: Any): response = client.post(url, json=request_body) if request_body else client.get(url) assert response.status_code == 422, response.text assert response.json() == IsDict( { "detail": [ { "type": "string_pattern_mismatch", "loc": error_path, "msg": "String should match pattern '^bar_.*$'", "input": "my_bad_prefix_val", "ctx": {"pattern": "^bar_.*$"}, } ] } ) | IsDict( # TODO: remove when deprecating Pydantic v1 { "detail": [ { "loc": [*error_path, "__root__"], "msg": 'string does not match regex "^bar_.*$"', "type": "value_error.str.regex", "ctx": {"pattern": "^bar_.*$"}, } ] } ) @pytest.mark.parametrize( "url, error_path, request_body", [ ("/query/customparsed?q=my_bad_prefix_val", ["query", "q"], None), ("/path/customparsed/my_bad_prefix_val", ["path", "p"], None), ("/body/customparsed", ["body"], "my_bad_prefix_val"), ("/body_default/customparsed", ["body"], "my_bad_prefix_val"), ], ) def test_root_model_customparsed_422( url: str, error_path: List[str], request_body: Any ): response = client.post(url, json=request_body) if request_body else client.get(url) assert response.status_code == 422, response.text assert response.json() == IsDict( { "detail": [ { "type": "value_error", "loc": error_path, "msg": "Value error, must start with foo_", "input": "my_bad_prefix_val", "ctx": {"error": {}}, } ] } ) | IsDict( # TODO: remove when deprecating Pydantic v1 { "detail": [ { "loc": [*error_path, "__root__"], "msg": "must start with foo_", "type": "value_error", } ] } ) def test_root_model_dictwrap_422(): response = client.post("/echo/dictwrap", json={"test": "fail_not_int"}) assert response.status_code == 422, response.text assert response.json() == IsDict( { "detail": [ { "type": "int_parsing", "loc": ["body", "test"], "msg": "Input should be a valid integer, unable to parse string as an integer", "input": "fail_not_int", } ] } ) | IsDict( # TODO: remove when deprecating Pydantic v1 { "detail": [ { "loc": ["body", "__root__", "test"], "msg": "value is not a valid integer", "type": "type_error.integer", } ] } ) @pytest.mark.parametrize( "model, path_name, expected_model_schema", [ (Basic, "basic", {"type": "integer"}), ( FieldWrap, "fieldwrap", { "type": "string", "pattern": "^bar_.*$", "description": "parameter starts with bar_", }, ), ( CustomParsed, "customparsed", { "type": "string", "description": "parameter starts with foo_", }, ), ], ) def test_openapi_schema( model: Type, path_name: str, expected_model_schema: Dict[str, Any] ): response = client.get("/openapi.json") assert response.status_code == 200, response.text paths = response.json()["paths"] ref_name = model.__name__.replace("[", "_").replace("]", "_") schema_ref = {"schema": {"$ref": f"#/components/schemas/{ref_name}"}} assert paths[f"/query/{path_name}"]["get"]["parameters"] == [ {"in": "query", "name": "q", "required": True, **schema_ref} ] assert paths[f"/path/{path_name}/{{p}}"]["get"]["parameters"] == [ {"in": "path", "name": "p", "required": True, **schema_ref} ] assert paths[f"/body/{path_name}"]["post"]["requestBody"] == { "content": {"application/json": schema_ref}, "required": True, } assert paths[f"/body_default/{path_name}"]["post"]["requestBody"] == { "content": {"application/json": schema_ref}, "required": True, } assert paths[f"/echo/{path_name}"]["get"]["responses"]["200"] == { "content": {"application/json": schema_ref}, "description": "Successful Response", } model_schema = response.json()["components"]["schemas"][ref_name] model_schema.pop("title") assert model_schema == expected_model_schema def test_openapi_schema_dictwrap(): response = client.get("/openapi.json") assert response.status_code == 200, response.text operation = response.json()["paths"]["/echo/dictwrap"]["post"] ref = {"schema": {"$ref": "#/components/schemas/DictWrap"}} assert operation["requestBody"] == { "content": {"application/json": ref}, "required": True, } assert operation["responses"]["200"] == { "content": {"application/json": ref}, "description": "Successful Response", } assert response.json()["components"]["schemas"]["DictWrap"] == { "title": "DictWrap", "type": "object", "additionalProperties": {"type": "integer"}, }