diff --git a/tests/test_forms_defaults.py b/tests/test_forms_defaults.py index 2fc0976e6..2beda62da 100644 --- a/tests/test_forms_defaults.py +++ b/tests/test_forms_defaults.py @@ -1,29 +1,49 @@ -from typing import Annotated, Optional +from importlib.metadata import version +from typing import Optional import pytest from fastapi import FastAPI, Form -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field from starlette.testclient import TestClient +from typing_extensions import Annotated + +PYDANTIC_V2 = int(version("pydantic")[0]) >= 2 + +if PYDANTIC_V2: + from pydantic import model_validator +else: + from pydantic import root_validator + + +def _validate_input(value: dict) -> dict: + """ + model validators in before mode should receive values passed + to model instantiation before any further validation + """ + # we should not be double-instantiating the models + assert isinstance(value, dict) + value["init_input"] = value.copy() + + # differentiate between explicit Nones and unpassed values + if "true_if_unset" not in value: + value["true_if_unset"] = True + return value class Parent(BaseModel): init_input: dict # importantly, no default here - @model_validator(mode="before") - def validate_inputs(cls, value: dict) -> dict: - """ - model validators in before mode should receive values passed - to model instantiation before any further validation - """ - # we should not be double-instantiating the models - assert isinstance(value, dict) - value["init_input"] = value.copy() - - # differentiate between explicit Nones and unpassed values - if "true_if_unset" not in value: - value["true_if_unset"] = True - return value + if PYDANTIC_V2: + + @model_validator(mode="before") + def validate_inputs(cls, value: dict) -> dict: + return _validate_input(value) + else: + + @root_validator(pre=True) + def validate_inputs(cls, value: dict) -> dict: + return _validate_input(value) class StandardModel(Parent):