From 9c7438ed2d0bca773c56dd8b31e85c63cf1456b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20K=C3=B6tter?= Date: Fri, 30 Jun 2023 16:00:44 +0200 Subject: [PATCH] Params - add style & explode parameter allows deepObject encoded Query parameters --- fastapi/dependencies/utils.py | 79 +++++++++++++++++++++++++++++++-- fastapi/params.py | 26 ++++++++++- pyproject.toml | 1 + tests/test_param_style.py | 83 +++++++++++++++++++++++++++++++++++ 4 files changed, 184 insertions(+), 5 deletions(-) create mode 100644 tests/test_param_style.py diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 84dfa4d03..ee6aed474 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,4 +1,6 @@ +import collections import inspect +import re from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy from dataclasses import dataclass @@ -737,6 +739,62 @@ def _get_multidict_value( return value +class ParameterCodec: + @staticmethod + def _default() -> Dict[str, Any]: + return collections.defaultdict(lambda: ParameterCodec._default()) + + @staticmethod + def decode( + field_info: params.Param, + received_params: Union[Mapping[str, Any], QueryParams, Headers], + field: ModelField, + ) -> Dict[str, Any]: + fn: Callable[ + [params.Param, Union[Mapping[str, Any], QueryParams, Headers], ModelField], + Dict[str, Any], + ] + fn = getattr(ParameterCodec, f"decode_{field_info.style}") + return fn(field_info, received_params, field) + + @staticmethod + def decode_deepObject( + field_info: params.Param, + received_params: Union[Mapping[str, Any], QueryParams, Headers], + field: ModelField, + ) -> Dict[str, Any]: + data: List[Tuple[str, str]] = [] + for k, v in received_params.items(): + if k.startswith(f"{field.alias}["): + data.append((k, v)) + + r = ParameterCodec._default() + + for k, v in data: + """ + k: name[attr0][attr1] + v: "5" + -> {"name":{"attr0":{"attr1":"5"}}} + """ + # p = tuple(map(lambda x: x[:-1] if x[-1] == ']' else x, k.split("["))) + # would do as well, but add basic validation … + p0 = re.split(r"(\[|\]\[|\]$)", k) + s = p0[1::2] + assert ( + p0[-1] == "" + and s[0] == "[" + and s[-1] == "]" + and all(x == "][" for x in s[1:-1]) + ) + p1 = tuple(p0[::2][:-1]) + + o = r + for i in p1[:-1]: + o = o[i] + o[p1[-1]] = v + return r + + def request_params_to_args( fields: Sequence[ModelField], received_params: Union[Mapping[str, Any], QueryParams, Headers], @@ -794,17 +852,30 @@ def request_params_to_args( "Params must be subclasses of Param" ) loc: Tuple[str, ...] = (field_info.in_.value,) - v_, errors_ = _validate_value_with_model_field( - field=first_field, value=params_to_process, values=values, loc=loc - ) + + if field_info.style == "deepObject": + value = ParameterCodec.decode(field_info, received_params, first_field) + value = value[first_field.alias] + v_, errors_ = _validate_value_with_model_field( + field=first_field, value=value, values=value, loc=loc + ) + else: + v_, errors_ = _validate_value_with_model_field( + field=first_field, value=params_to_process, values=values, loc=loc + ) return {first_field.name: v_}, errors_ for field in fields: - value = _get_multidict_value(field, received_params) field_info = field.field_info assert isinstance(field_info, params.Param), ( "Params must be subclasses of Param" ) + + if field_info.style == "deepObject": + value = ParameterCodec.decode(field_info, received_params, field) + value = value[field.alias] + else: + value = _get_multidict_value(field, received_params) loc = (field_info.in_.value, field.alias) v_, errors_ = _validate_value_with_model_field( field=field, value=value, values=values, loc=loc diff --git a/fastapi/params.py b/fastapi/params.py index 8f5601dd3..a332d4557 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union from fastapi.openapi.models import Example from pydantic.fields import FieldInfo -from typing_extensions import Annotated, deprecated +from typing_extensions import Annotated, Literal, deprecated from ._compat import ( PYDANTIC_V2, @@ -70,6 +70,8 @@ class Param(FieldInfo): deprecated: Union[deprecated, str, bool, None] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, + style: str = _Unset, + explode: bool = _Unset, **extra: Any, ): if example is not _Unset: @@ -131,6 +133,8 @@ class Param(FieldInfo): use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset} super().__init__(**use_kwargs) + self.style = style + self.explode = explode def __repr__(self) -> str: return f"{self.__class__.__name__}({self.default})" @@ -184,6 +188,8 @@ class Path(Param): deprecated: Union[deprecated, str, bool, None] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, + style: Literal["matrix", "label", "simple"] = "simple", + explode: bool = False, **extra: Any, ): assert default is ..., "Path parameters cannot have a default value" @@ -218,6 +224,8 @@ class Path(Param): openapi_examples=openapi_examples, include_in_schema=include_in_schema, json_schema_extra=json_schema_extra, + style=style, + explode=explode, **extra, ) @@ -270,8 +278,14 @@ class Query(Param): deprecated: Union[deprecated, str, bool, None] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, + style: Literal[ + "form", "spaceDelimited", "pipeDelimited", "deepObject" + ] = "form", + explode: bool = _Unset, **extra: Any, ): + if explode is _Unset: + explode = False if style != "form" else True super().__init__( default=default, default_factory=default_factory, @@ -302,6 +316,8 @@ class Query(Param): openapi_examples=openapi_examples, include_in_schema=include_in_schema, json_schema_extra=json_schema_extra, + style=style, + explode=explode, **extra, ) @@ -355,6 +371,8 @@ class Header(Param): deprecated: Union[deprecated, str, bool, None] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, + style: Literal["simple"] = "simple", + explode: bool = False, **extra: Any, ): self.convert_underscores = convert_underscores @@ -388,6 +406,8 @@ class Header(Param): openapi_examples=openapi_examples, include_in_schema=include_in_schema, json_schema_extra=json_schema_extra, + style=style, + explode=explode, **extra, ) @@ -440,6 +460,8 @@ class Cookie(Param): deprecated: Union[deprecated, str, bool, None] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, + style: Literal["form"] = "form", + explode: bool = False, **extra: Any, ): super().__init__( @@ -472,6 +494,8 @@ class Cookie(Param): openapi_examples=openapi_examples, include_in_schema=include_in_schema, json_schema_extra=json_schema_extra, + style=style, + explode=explode, **extra, ) diff --git a/pyproject.toml b/pyproject.toml index 1c540e2f6..51cc38ba8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ all = [ "pydantic-settings >=2.0.0", # Extra Pydantic data types "pydantic-extra-types >=2.0.0", + "more-itertools" ] [project.scripts] diff --git a/tests/test_param_style.py b/tests/test_param_style.py new file mode 100644 index 000000000..d4824dfd8 --- /dev/null +++ b/tests/test_param_style.py @@ -0,0 +1,83 @@ +from typing import List, Optional + +import pydantic +import pytest +from fastapi import FastAPI, Query +from fastapi._compat import PYDANTIC_V2 +from fastapi.testclient import TestClient +from pydantic import BaseModel +from typing_extensions import Literal + + +class Dog(BaseModel): + pet_type: Literal["dog"] + name: str + + +class Matrjoschka(BaseModel): + size: str = 0 # without type coecerion Query parameters are limited to str + inner: Optional["Matrjoschka"] = None + + +app = FastAPI() + + +@app.post( + "/pet", + operation_id="createPet", +) +def createPet(pet: Dog = Query(style="deepObject")) -> Dog: + return pet + + +@app.post( + "/toy", + operation_id="createToy", +) +def createToy(toy: Matrjoschka = Query(style="deepObject")) -> Matrjoschka: + return toy + + +@app.post("/multi", operation_id="createMulti") +def createMulti( + a: Matrjoschka = Query(style="deepObject"), + b: Matrjoschka = Query(style="deepObject"), +) -> List[Matrjoschka]: + return [a, b] + + +client = TestClient(app) + + +def test_pet(): + response = client.post("""/pet?pet[pet_type]=dog&pet[name]=doggy""") + if PYDANTIC_V2: + dog = Dog.model_validate(response.json()) + else: + dog = Dog.parse_obj(response.json()) + assert response.status_code == 200 + assert dog.pet_type == "dog" and dog.name == "doggy" + + +def test_matrjoschka(): + response = client.post( + """/toy?toy[size]=3&toy[inner][size]=2&toy[inner][inner][size]=1""" + ) + print(response) + if PYDANTIC_V2: + toy = Matrjoschka.model_validate(response.json()) + else: + toy = Matrjoschka.parse_obj(response.json()) + assert response.status_code == 200 + assert toy + assert toy.inner.size == "2" + + +@pytest.mark.skipif(not PYDANTIC_V2, reason="Only for Pydantic v2") +def test_multi(): + response = client.post("""/multi?a[size]=1&b[size]=1""") + print(response) + + t = pydantic.TypeAdapter(List[Matrjoschka]) + v = t.validate_python(response.json()) + assert all(i.size == "1" for i in v), v