Browse Source

Params - add style & explode parameter

allows deepObject encoded Query parameters
pull/9867/head
Markus Kötter 2 years ago
committed by commonism
parent
commit
9c7438ed2d
  1. 79
      fastapi/dependencies/utils.py
  2. 26
      fastapi/params.py
  3. 1
      pyproject.toml
  4. 83
      tests/test_param_style.py

79
fastapi/dependencies/utils.py

@ -1,4 +1,6 @@
import collections
import inspect import inspect
import re
from contextlib import AsyncExitStack, contextmanager from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass from dataclasses import dataclass
@ -737,6 +739,62 @@ def _get_multidict_value(
return value return value
class ParameterCodec:
@staticmethod
def _default() -> Dict[str, Any]:
return collections.defaultdict(lambda: ParameterCodec._default())
@staticmethod
def decode(
field_info: params.Param,
received_params: Union[Mapping[str, Any], QueryParams, Headers],
field: ModelField,
) -> Dict[str, Any]:
fn: Callable[
[params.Param, Union[Mapping[str, Any], QueryParams, Headers], ModelField],
Dict[str, Any],
]
fn = getattr(ParameterCodec, f"decode_{field_info.style}")
return fn(field_info, received_params, field)
@staticmethod
def decode_deepObject(
field_info: params.Param,
received_params: Union[Mapping[str, Any], QueryParams, Headers],
field: ModelField,
) -> Dict[str, Any]:
data: List[Tuple[str, str]] = []
for k, v in received_params.items():
if k.startswith(f"{field.alias}["):
data.append((k, v))
r = ParameterCodec._default()
for k, v in data:
"""
k: name[attr0][attr1]
v: "5"
-> {"name":{"attr0":{"attr1":"5"}}}
"""
# p = tuple(map(lambda x: x[:-1] if x[-1] == ']' else x, k.split("[")))
# would do as well, but add basic validation …
p0 = re.split(r"(\[|\]\[|\]$)", k)
s = p0[1::2]
assert (
p0[-1] == ""
and s[0] == "["
and s[-1] == "]"
and all(x == "][" for x in s[1:-1])
)
p1 = tuple(p0[::2][:-1])
o = r
for i in p1[:-1]:
o = o[i]
o[p1[-1]] = v
return r
def request_params_to_args( def request_params_to_args(
fields: Sequence[ModelField], fields: Sequence[ModelField],
received_params: Union[Mapping[str, Any], QueryParams, Headers], received_params: Union[Mapping[str, Any], QueryParams, Headers],
@ -794,17 +852,30 @@ def request_params_to_args(
"Params must be subclasses of Param" "Params must be subclasses of Param"
) )
loc: Tuple[str, ...] = (field_info.in_.value,) loc: Tuple[str, ...] = (field_info.in_.value,)
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=params_to_process, values=values, loc=loc if field_info.style == "deepObject":
) value = ParameterCodec.decode(field_info, received_params, first_field)
value = value[first_field.alias]
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=value, values=value, loc=loc
)
else:
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=params_to_process, values=values, loc=loc
)
return {first_field.name: v_}, errors_ return {first_field.name: v_}, errors_
for field in fields: for field in fields:
value = _get_multidict_value(field, received_params)
field_info = field.field_info field_info = field.field_info
assert isinstance(field_info, params.Param), ( assert isinstance(field_info, params.Param), (
"Params must be subclasses of Param" "Params must be subclasses of Param"
) )
if field_info.style == "deepObject":
value = ParameterCodec.decode(field_info, received_params, field)
value = value[field.alias]
else:
value = _get_multidict_value(field, received_params)
loc = (field_info.in_.value, field.alias) loc = (field_info.in_.value, field.alias)
v_, errors_ = _validate_value_with_model_field( v_, errors_ = _validate_value_with_model_field(
field=field, value=value, values=values, loc=loc field=field, value=value, values=values, loc=loc

26
fastapi/params.py

@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from fastapi.openapi.models import Example from fastapi.openapi.models import Example
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from typing_extensions import Annotated, deprecated from typing_extensions import Annotated, Literal, deprecated
from ._compat import ( from ._compat import (
PYDANTIC_V2, PYDANTIC_V2,
@ -70,6 +70,8 @@ class Param(FieldInfo):
deprecated: Union[deprecated, str, bool, None] = None, deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True, include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None, json_schema_extra: Union[Dict[str, Any], None] = None,
style: str = _Unset,
explode: bool = _Unset,
**extra: Any, **extra: Any,
): ):
if example is not _Unset: if example is not _Unset:
@ -131,6 +133,8 @@ class Param(FieldInfo):
use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset} use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset}
super().__init__(**use_kwargs) super().__init__(**use_kwargs)
self.style = style
self.explode = explode
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.default})" return f"{self.__class__.__name__}({self.default})"
@ -184,6 +188,8 @@ class Path(Param):
deprecated: Union[deprecated, str, bool, None] = None, deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True, include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None, json_schema_extra: Union[Dict[str, Any], None] = None,
style: Literal["matrix", "label", "simple"] = "simple",
explode: bool = False,
**extra: Any, **extra: Any,
): ):
assert default is ..., "Path parameters cannot have a default value" assert default is ..., "Path parameters cannot have a default value"
@ -218,6 +224,8 @@ class Path(Param):
openapi_examples=openapi_examples, openapi_examples=openapi_examples,
include_in_schema=include_in_schema, include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra, json_schema_extra=json_schema_extra,
style=style,
explode=explode,
**extra, **extra,
) )
@ -270,8 +278,14 @@ class Query(Param):
deprecated: Union[deprecated, str, bool, None] = None, deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True, include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None, json_schema_extra: Union[Dict[str, Any], None] = None,
style: Literal[
"form", "spaceDelimited", "pipeDelimited", "deepObject"
] = "form",
explode: bool = _Unset,
**extra: Any, **extra: Any,
): ):
if explode is _Unset:
explode = False if style != "form" else True
super().__init__( super().__init__(
default=default, default=default,
default_factory=default_factory, default_factory=default_factory,
@ -302,6 +316,8 @@ class Query(Param):
openapi_examples=openapi_examples, openapi_examples=openapi_examples,
include_in_schema=include_in_schema, include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra, json_schema_extra=json_schema_extra,
style=style,
explode=explode,
**extra, **extra,
) )
@ -355,6 +371,8 @@ class Header(Param):
deprecated: Union[deprecated, str, bool, None] = None, deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True, include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None, json_schema_extra: Union[Dict[str, Any], None] = None,
style: Literal["simple"] = "simple",
explode: bool = False,
**extra: Any, **extra: Any,
): ):
self.convert_underscores = convert_underscores self.convert_underscores = convert_underscores
@ -388,6 +406,8 @@ class Header(Param):
openapi_examples=openapi_examples, openapi_examples=openapi_examples,
include_in_schema=include_in_schema, include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra, json_schema_extra=json_schema_extra,
style=style,
explode=explode,
**extra, **extra,
) )
@ -440,6 +460,8 @@ class Cookie(Param):
deprecated: Union[deprecated, str, bool, None] = None, deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True, include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None, json_schema_extra: Union[Dict[str, Any], None] = None,
style: Literal["form"] = "form",
explode: bool = False,
**extra: Any, **extra: Any,
): ):
super().__init__( super().__init__(
@ -472,6 +494,8 @@ class Cookie(Param):
openapi_examples=openapi_examples, openapi_examples=openapi_examples,
include_in_schema=include_in_schema, include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra, json_schema_extra=json_schema_extra,
style=style,
explode=explode,
**extra, **extra,
) )

1
pyproject.toml

@ -100,6 +100,7 @@ all = [
"pydantic-settings >=2.0.0", "pydantic-settings >=2.0.0",
# Extra Pydantic data types # Extra Pydantic data types
"pydantic-extra-types >=2.0.0", "pydantic-extra-types >=2.0.0",
"more-itertools"
] ]
[project.scripts] [project.scripts]

83
tests/test_param_style.py

@ -0,0 +1,83 @@
from typing import List, Optional
import pydantic
import pytest
from fastapi import FastAPI, Query
from fastapi._compat import PYDANTIC_V2
from fastapi.testclient import TestClient
from pydantic import BaseModel
from typing_extensions import Literal
class Dog(BaseModel):
pet_type: Literal["dog"]
name: str
class Matrjoschka(BaseModel):
size: str = 0 # without type coecerion Query parameters are limited to str
inner: Optional["Matrjoschka"] = None
app = FastAPI()
@app.post(
"/pet",
operation_id="createPet",
)
def createPet(pet: Dog = Query(style="deepObject")) -> Dog:
return pet
@app.post(
"/toy",
operation_id="createToy",
)
def createToy(toy: Matrjoschka = Query(style="deepObject")) -> Matrjoschka:
return toy
@app.post("/multi", operation_id="createMulti")
def createMulti(
a: Matrjoschka = Query(style="deepObject"),
b: Matrjoschka = Query(style="deepObject"),
) -> List[Matrjoschka]:
return [a, b]
client = TestClient(app)
def test_pet():
response = client.post("""/pet?pet[pet_type]=dog&pet[name]=doggy""")
if PYDANTIC_V2:
dog = Dog.model_validate(response.json())
else:
dog = Dog.parse_obj(response.json())
assert response.status_code == 200
assert dog.pet_type == "dog" and dog.name == "doggy"
def test_matrjoschka():
response = client.post(
"""/toy?toy[size]=3&toy[inner][size]=2&toy[inner][inner][size]=1"""
)
print(response)
if PYDANTIC_V2:
toy = Matrjoschka.model_validate(response.json())
else:
toy = Matrjoschka.parse_obj(response.json())
assert response.status_code == 200
assert toy
assert toy.inner.size == "2"
@pytest.mark.skipif(not PYDANTIC_V2, reason="Only for Pydantic v2")
def test_multi():
response = client.post("""/multi?a[size]=1&b[size]=1""")
print(response)
t = pydantic.TypeAdapter(List[Matrjoschka])
v = t.validate_python(response.json())
assert all(i.size == "1" for i in v), v
Loading…
Cancel
Save