Browse Source

Add support for multiple Annotated annotations, e.g. `Annotated[str, Field(), Query()]` (#10773)

pull/10774/head
Sebastián Ramírez 1 year ago
committed by GitHub
parent
commit
6f5aa81c07
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      .github/workflows/test.yml
  2. 7
      fastapi/_compat.py
  3. 53
      fastapi/dependencies/utils.py
  4. 27
      tests/test_ambiguous_params.py
  5. 2
      tests/test_annotated.py

4
.github/workflows/test.yml

@ -29,7 +29,7 @@ jobs:
id: cache id: cache
with: with:
path: ${{ env.pythonLocation }} path: ${{ env.pythonLocation }}
key: ${{ runner.os }}-python-${{ env.pythonLocation }}-pydantic-v2-${{ hashFiles('pyproject.toml', 'requirements-tests.txt', 'requirements-docs-tests.txt') }}-test-v06 key: ${{ runner.os }}-python-${{ env.pythonLocation }}-pydantic-v2-${{ hashFiles('pyproject.toml', 'requirements-tests.txt', 'requirements-docs-tests.txt') }}-test-v07
- name: Install Dependencies - name: Install Dependencies
if: steps.cache.outputs.cache-hit != 'true' if: steps.cache.outputs.cache-hit != 'true'
run: pip install -r requirements-tests.txt run: pip install -r requirements-tests.txt
@ -62,7 +62,7 @@ jobs:
id: cache id: cache
with: with:
path: ${{ env.pythonLocation }} path: ${{ env.pythonLocation }}
key: ${{ runner.os }}-python-${{ env.pythonLocation }}-${{ matrix.pydantic-version }}-${{ hashFiles('pyproject.toml', 'requirements-tests.txt', 'requirements-docs-tests.txt') }}-test-v06 key: ${{ runner.os }}-python-${{ env.pythonLocation }}-${{ matrix.pydantic-version }}-${{ hashFiles('pyproject.toml', 'requirements-tests.txt', 'requirements-docs-tests.txt') }}-test-v07
- name: Install Dependencies - name: Install Dependencies
if: steps.cache.outputs.cache-hit != 'true' if: steps.cache.outputs.cache-hit != 'true'
run: pip install -r requirements-tests.txt run: pip install -r requirements-tests.txt

7
fastapi/_compat.py

@ -249,7 +249,12 @@ if PYDANTIC_V2:
return is_bytes_sequence_annotation(field.type_) return is_bytes_sequence_annotation(field.type_)
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return type(field_info).from_annotation(annotation) cls = type(field_info)
merged_field_info = cls.from_annotation(annotation)
new_field_info = copy(field_info)
new_field_info.metadata = merged_field_info.metadata
new_field_info.annotation = merged_field_info.annotation
return new_field_info
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
origin_type = ( origin_type = (

53
fastapi/dependencies/utils.py

@ -325,10 +325,11 @@ def analyze_param(
field_info = None field_info = None
depends = None depends = None
type_annotation: Any = Any type_annotation: Any = Any
if ( use_annotation: Any = Any
annotation is not inspect.Signature.empty if annotation is not inspect.Signature.empty:
and get_origin(annotation) is Annotated use_annotation = annotation
): type_annotation = annotation
if get_origin(use_annotation) is Annotated:
annotated_args = get_args(annotation) annotated_args = get_args(annotation)
type_annotation = annotated_args[0] type_annotation = annotated_args[0]
fastapi_annotations = [ fastapi_annotations = [
@ -336,14 +337,21 @@ def analyze_param(
for arg in annotated_args[1:] for arg in annotated_args[1:]
if isinstance(arg, (FieldInfo, params.Depends)) if isinstance(arg, (FieldInfo, params.Depends))
] ]
assert ( fastapi_specific_annotations = [
len(fastapi_annotations) <= 1 arg
), f"Cannot specify multiple `Annotated` FastAPI arguments for {param_name!r}" for arg in fastapi_annotations
fastapi_annotation = next(iter(fastapi_annotations), None) if isinstance(arg, (params.Param, params.Body, params.Depends))
]
if fastapi_specific_annotations:
fastapi_annotation: Union[
FieldInfo, params.Depends, None
] = fastapi_specific_annotations[-1]
else:
fastapi_annotation = None
if isinstance(fastapi_annotation, FieldInfo): if isinstance(fastapi_annotation, FieldInfo):
# Copy `field_info` because we mutate `field_info.default` below. # Copy `field_info` because we mutate `field_info.default` below.
field_info = copy_field_info( field_info = copy_field_info(
field_info=fastapi_annotation, annotation=annotation field_info=fastapi_annotation, annotation=use_annotation
) )
assert field_info.default is Undefined or field_info.default is Required, ( assert field_info.default is Undefined or field_info.default is Required, (
f"`{field_info.__class__.__name__}` default value cannot be set in" f"`{field_info.__class__.__name__}` default value cannot be set in"
@ -356,8 +364,6 @@ def analyze_param(
field_info.default = Required field_info.default = Required
elif isinstance(fastapi_annotation, params.Depends): elif isinstance(fastapi_annotation, params.Depends):
depends = fastapi_annotation depends = fastapi_annotation
elif annotation is not inspect.Signature.empty:
type_annotation = annotation
if isinstance(value, params.Depends): if isinstance(value, params.Depends):
assert depends is None, ( assert depends is None, (
@ -402,15 +408,15 @@ def analyze_param(
# We might check here that `default_value is Required`, but the fact is that the same # We might check here that `default_value is Required`, but the fact is that the same
# parameter might sometimes be a path parameter and sometimes not. See # parameter might sometimes be a path parameter and sometimes not. See
# `tests/test_infer_param_optionality.py` for an example. # `tests/test_infer_param_optionality.py` for an example.
field_info = params.Path(annotation=type_annotation) field_info = params.Path(annotation=use_annotation)
elif is_uploadfile_or_nonable_uploadfile_annotation( elif is_uploadfile_or_nonable_uploadfile_annotation(
type_annotation type_annotation
) or is_uploadfile_sequence_annotation(type_annotation): ) or is_uploadfile_sequence_annotation(type_annotation):
field_info = params.File(annotation=type_annotation, default=default_value) field_info = params.File(annotation=use_annotation, default=default_value)
elif not field_annotation_is_scalar(annotation=type_annotation): elif not field_annotation_is_scalar(annotation=type_annotation):
field_info = params.Body(annotation=type_annotation, default=default_value) field_info = params.Body(annotation=use_annotation, default=default_value)
else: else:
field_info = params.Query(annotation=type_annotation, default=default_value) field_info = params.Query(annotation=use_annotation, default=default_value)
field = None field = None
if field_info is not None: if field_info is not None:
@ -424,8 +430,8 @@ def analyze_param(
and getattr(field_info, "in_", None) is None and getattr(field_info, "in_", None) is None
): ):
field_info.in_ = params.ParamTypes.query field_info.in_ = params.ParamTypes.query
use_annotation = get_annotation_from_field_info( use_annotation_from_field_info = get_annotation_from_field_info(
type_annotation, use_annotation,
field_info, field_info,
param_name, param_name,
) )
@ -436,7 +442,7 @@ def analyze_param(
field_info.alias = alias field_info.alias = alias
field = create_response_field( field = create_response_field(
name=param_name, name=param_name,
type_=use_annotation, type_=use_annotation_from_field_info,
default=field_info.default, default=field_info.default,
alias=alias, alias=alias,
required=field_info.default in (Required, Undefined), required=field_info.default in (Required, Undefined),
@ -466,16 +472,17 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
field_info = cast(params.Param, field.field_info) field_info = field.field_info
if field_info.in_ == params.ParamTypes.path: field_info_in = getattr(field_info, "in_", None)
if field_info_in == params.ParamTypes.path:
dependant.path_params.append(field) dependant.path_params.append(field)
elif field_info.in_ == params.ParamTypes.query: elif field_info_in == params.ParamTypes.query:
dependant.query_params.append(field) dependant.query_params.append(field)
elif field_info.in_ == params.ParamTypes.header: elif field_info_in == params.ParamTypes.header:
dependant.header_params.append(field) dependant.header_params.append(field)
else: else:
assert ( assert (
field_info.in_ == params.ParamTypes.cookie field_info_in == params.ParamTypes.cookie
), f"non-body parameters must be in path, query, header or cookie: {field.name}" ), f"non-body parameters must be in path, query, header or cookie: {field.name}"
dependant.cookie_params.append(field) dependant.cookie_params.append(field)

27
tests/test_ambiguous_params.py

@ -1,6 +1,8 @@
import pytest import pytest
from fastapi import Depends, FastAPI, Path from fastapi import Depends, FastAPI, Path
from fastapi.param_functions import Query from fastapi.param_functions import Query
from fastapi.testclient import TestClient
from fastapi.utils import PYDANTIC_V2
from typing_extensions import Annotated from typing_extensions import Annotated
app = FastAPI() app = FastAPI()
@ -28,18 +30,13 @@ def test_no_annotated_defaults():
pass # pragma: nocover pass # pragma: nocover
def test_no_multiple_annotations(): def test_multiple_annotations():
async def dep(): async def dep():
pass # pragma: nocover pass # pragma: nocover
with pytest.raises( @app.get("/multi-query")
AssertionError, async def get(foo: Annotated[int, Query(gt=2), Query(lt=10)]):
match="Cannot specify multiple `Annotated` FastAPI arguments for 'foo'", return foo
):
@app.get("/")
async def get(foo: Annotated[int, Query(min_length=1), Query()]):
pass # pragma: nocover
with pytest.raises( with pytest.raises(
AssertionError, AssertionError,
@ -64,3 +61,15 @@ def test_no_multiple_annotations():
@app.get("/") @app.get("/")
async def get3(foo: Annotated[int, Query(min_length=1)] = Depends(dep)): async def get3(foo: Annotated[int, Query(min_length=1)] = Depends(dep)):
pass # pragma: nocover pass # pragma: nocover
client = TestClient(app)
response = client.get("/multi-query", params={"foo": "5"})
assert response.status_code == 200
assert response.json() == 5
response = client.get("/multi-query", params={"foo": "123"})
assert response.status_code == 422
if PYDANTIC_V2:
response = client.get("/multi-query", params={"foo": "1"})
assert response.status_code == 422

2
tests/test_annotated.py

@ -57,7 +57,7 @@ foo_is_short = {
{ {
"ctx": {"min_length": 1}, "ctx": {"min_length": 1},
"loc": ["query", "foo"], "loc": ["query", "foo"],
"msg": "String should have at least 1 characters", "msg": "String should have at least 1 character",
"type": "string_too_short", "type": "string_too_short",
"input": "", "input": "",
"url": match_pydantic_error_url("string_too_short"), "url": match_pydantic_error_url("string_too_short"),

Loading…
Cancel
Save