From 410424b4d1ab6d36a76900b4212e9b48d554a60e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Feb 2023 16:19:05 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20format?= =?UTF-8?q?=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/dependencies/utils.py | 23 ++++++++---------- tests/test_dependency_generic_class.py | 32 +++++++++++++------------- 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index dd10a2a2a..1bf780cd6 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -246,12 +246,10 @@ def is_scalar_sequence_field(field: ModelField) -> bool: def is_generic_type(obj: Any) -> bool: - return hasattr(obj, '__origin__') and hasattr(obj, '__args__') + return hasattr(obj, "__origin__") and hasattr(obj, "__args__") -def substitute_generic_type( - annotation: Any, typevars: Dict[str, Any] -) -> Any: +def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any: collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple} if is_generic_type(annotation): args = tuple( @@ -265,11 +263,10 @@ def substitute_generic_type( def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: typevars = None if is_generic_type(call): - origin = getattr(call, '__origin__') + origin = call.__origin__ typevars = { - typevar.__name__: value for typevar, value in zip( - origin.__parameters__, getattr(call, '__args__') - ) + typevar.__name__: value + for typevar, value in zip(origin.__parameters__, call.__args__) } call = origin.__init__ @@ -283,16 +280,17 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: default=param.default, annotation=get_typed_annotation(param.annotation, globalns, typevars), ) - for param in signature.parameters.values() if param.name != 'self' + for param in signature.parameters.values() + if param.name != "self" ] typed_signature = inspect.Signature(typed_params) return typed_signature def get_typed_annotation( - annotation: Any, - globalns: Dict[str, Any], - typevars: Optional[Dict[str, type]] = None, + annotation: Any, + globalns: Dict[str, Any], + typevars: Optional[Dict[str, type]] = None, ) -> Any: if isinstance(annotation, str): annotation = ForwardRef(annotation) @@ -799,4 +797,3 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: ) check_file_field(final_field) return final_field - diff --git a/tests/test_dependency_generic_class.py b/tests/test_dependency_generic_class.py index 1cddbf242..3eeb8291e 100644 --- a/tests/test_dependency_generic_class.py +++ b/tests/test_dependency_generic_class.py @@ -1,28 +1,25 @@ -from typing import Generic, List, TypeVar, Dict - -from starlette.testclient import TestClient +from typing import Dict, Generic, List, TypeVar from fastapi import Depends, FastAPI +from starlette.testclient import TestClient T = TypeVar("T") C = TypeVar("C") class FirstGenericType(Generic[T]): - def __init__(self, simple: T, lst: List[T]): self.simple = simple self.lst = lst class SecondGenericType(Generic[T, C]): - def __init__( - self, - simple: T, - lst: List[T], - dct: Dict[T, C], - custom_class: FirstGenericType[T] = Depends() + self, + simple: T, + lst: List[T], + dct: Dict[T, C], + custom_class: FirstGenericType[T] = Depends(), ): self.simple = simple self.lst = lst @@ -41,8 +38,8 @@ def depend_generic_type(obj: SecondGenericType[str, int] = Depends()): "dct": obj.dct, "custom_class": { "simple": obj.custom_class.simple, - "lst": obj.custom_class.lst - } + "lst": obj.custom_class.lst, + }, } @@ -50,10 +47,13 @@ client = TestClient(app) def test_generic_class_dependency(): - response = client.post("/test_generic_class?simple=simple", json={ - "lst": ["string_1", "string_2"], - "dct": {"key": 1}, - }) + response = client.post( + "/test_generic_class?simple=simple", + json={ + "lst": ["string_1", "string_2"], + "dct": {"key": 1}, + }, + ) assert response.status_code == 200, response.json() assert response.json() == { "custom_class": {