Browse Source

🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

pull/6031/head
pre-commit-ci[bot] 2 years ago
parent
commit
410424b4d1
  1. 23
      fastapi/dependencies/utils.py
  2. 32
      tests/test_dependency_generic_class.py

23
fastapi/dependencies/utils.py

@ -246,12 +246,10 @@ def is_scalar_sequence_field(field: ModelField) -> bool:
def is_generic_type(obj: Any) -> bool: def is_generic_type(obj: Any) -> bool:
return hasattr(obj, '__origin__') and hasattr(obj, '__args__') return hasattr(obj, "__origin__") and hasattr(obj, "__args__")
def substitute_generic_type( def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any:
annotation: Any, typevars: Dict[str, Any]
) -> Any:
collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple} collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple}
if is_generic_type(annotation): if is_generic_type(annotation):
args = tuple( args = tuple(
@ -265,11 +263,10 @@ def substitute_generic_type(
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
typevars = None typevars = None
if is_generic_type(call): if is_generic_type(call):
origin = getattr(call, '__origin__') origin = call.__origin__
typevars = { typevars = {
typevar.__name__: value for typevar, value in zip( typevar.__name__: value
origin.__parameters__, getattr(call, '__args__') for typevar, value in zip(origin.__parameters__, call.__args__)
)
} }
call = origin.__init__ call = origin.__init__
@ -283,16 +280,17 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
default=param.default, default=param.default,
annotation=get_typed_annotation(param.annotation, globalns, typevars), annotation=get_typed_annotation(param.annotation, globalns, typevars),
) )
for param in signature.parameters.values() if param.name != 'self' for param in signature.parameters.values()
if param.name != "self"
] ]
typed_signature = inspect.Signature(typed_params) typed_signature = inspect.Signature(typed_params)
return typed_signature return typed_signature
def get_typed_annotation( def get_typed_annotation(
annotation: Any, annotation: Any,
globalns: Dict[str, Any], globalns: Dict[str, Any],
typevars: Optional[Dict[str, type]] = None, typevars: Optional[Dict[str, type]] = None,
) -> Any: ) -> Any:
if isinstance(annotation, str): if isinstance(annotation, str):
annotation = ForwardRef(annotation) annotation = ForwardRef(annotation)
@ -799,4 +797,3 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
) )
check_file_field(final_field) check_file_field(final_field)
return final_field return final_field

32
tests/test_dependency_generic_class.py

@ -1,28 +1,25 @@
from typing import Generic, List, TypeVar, Dict from typing import Dict, Generic, List, TypeVar
from starlette.testclient import TestClient
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI
from starlette.testclient import TestClient
T = TypeVar("T") T = TypeVar("T")
C = TypeVar("C") C = TypeVar("C")
class FirstGenericType(Generic[T]): class FirstGenericType(Generic[T]):
def __init__(self, simple: T, lst: List[T]): def __init__(self, simple: T, lst: List[T]):
self.simple = simple self.simple = simple
self.lst = lst self.lst = lst
class SecondGenericType(Generic[T, C]): class SecondGenericType(Generic[T, C]):
def __init__( def __init__(
self, self,
simple: T, simple: T,
lst: List[T], lst: List[T],
dct: Dict[T, C], dct: Dict[T, C],
custom_class: FirstGenericType[T] = Depends() custom_class: FirstGenericType[T] = Depends(),
): ):
self.simple = simple self.simple = simple
self.lst = lst self.lst = lst
@ -41,8 +38,8 @@ def depend_generic_type(obj: SecondGenericType[str, int] = Depends()):
"dct": obj.dct, "dct": obj.dct,
"custom_class": { "custom_class": {
"simple": obj.custom_class.simple, "simple": obj.custom_class.simple,
"lst": obj.custom_class.lst "lst": obj.custom_class.lst,
} },
} }
@ -50,10 +47,13 @@ client = TestClient(app)
def test_generic_class_dependency(): def test_generic_class_dependency():
response = client.post("/test_generic_class?simple=simple", json={ response = client.post(
"lst": ["string_1", "string_2"], "/test_generic_class?simple=simple",
"dct": {"key": 1}, json={
}) "lst": ["string_1", "string_2"],
"dct": {"key": 1},
},
)
assert response.status_code == 200, response.json() assert response.status_code == 200, response.json()
assert response.json() == { assert response.json() == {
"custom_class": { "custom_class": {

Loading…
Cancel
Save