diff --git a/fastapi/params.py b/fastapi/params.py index 860146531..cc2a5c13c 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -91,7 +91,7 @@ class Param(FieldInfo): max_length=max_length, discriminator=discriminator, multiple_of=multiple_of, - allow_nan=allow_inf_nan, + allow_inf_nan=allow_inf_nan, max_digits=max_digits, decimal_places=decimal_places, **extra, @@ -547,7 +547,7 @@ class Body(FieldInfo): max_length=max_length, discriminator=discriminator, multiple_of=multiple_of, - allow_nan=allow_inf_nan, + allow_inf_nan=allow_inf_nan, max_digits=max_digits, decimal_places=decimal_places, **extra, diff --git a/tests/test_allow_inf_nan_in_enforcing.py b/tests/test_allow_inf_nan_in_enforcing.py new file mode 100644 index 000000000..9e855fdf8 --- /dev/null +++ b/tests/test_allow_inf_nan_in_enforcing.py @@ -0,0 +1,83 @@ +import pytest +from fastapi import Body, FastAPI, Query +from fastapi.testclient import TestClient +from typing_extensions import Annotated + +app = FastAPI() + + +@app.post("/") +async def get( + x: Annotated[float, Query(allow_inf_nan=True)] = 0, + y: Annotated[float, Query(allow_inf_nan=False)] = 0, + z: Annotated[float, Query()] = 0, + b: Annotated[float, Body(allow_inf_nan=False)] = 0, +) -> str: + return "OK" + + +client = TestClient(app) + + +@pytest.mark.parametrize( + "value,code", + [ + ("-1", 200), + ("inf", 200), + ("-inf", 200), + ("nan", 200), + ("0", 200), + ("342", 200), + ], +) +def test_allow_inf_nan_param_true(value: str, code: int): + response = client.post(f"/?x={value}") + assert response.status_code == code, response.text + + +@pytest.mark.parametrize( + "value,code", + [ + ("-1", 200), + ("inf", 422), + ("-inf", 422), + ("nan", 422), + ("0", 200), + ("342", 200), + ], +) +def test_allow_inf_nan_param_false(value: str, code: int): + response = client.post(f"/?y={value}") + assert response.status_code == code, response.text + + +@pytest.mark.parametrize( + "value,code", + [ + ("-1", 200), + ("inf", 200), + ("-inf", 200), + ("nan", 200), + ("0", 200), + ("342", 200), + ], +) +def test_allow_inf_nan_param_default(value: str, code: int): + response = client.post(f"/?z={value}") + assert response.status_code == code, response.text + + +@pytest.mark.parametrize( + "value,code", + [ + ("-1", 200), + ("inf", 422), + ("-inf", 422), + ("nan", 422), + ("0", 200), + ("342", 200), + ], +) +def test_allow_inf_nan_body(value: str, code: int): + response = client.post("/", json=value) + assert response.status_code == code, response.text