Browse Source

🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

pull/6031/head
pre-commit-ci[bot] 2 years ago
parent
commit
410424b4d1
  1. 23
      fastapi/dependencies/utils.py
  2. 32
      tests/test_dependency_generic_class.py

23
fastapi/dependencies/utils.py

@ -246,12 +246,10 @@ def is_scalar_sequence_field(field: ModelField) -> bool:
def is_generic_type(obj: Any) -> bool:
return hasattr(obj, '__origin__') and hasattr(obj, '__args__')
return hasattr(obj, "__origin__") and hasattr(obj, "__args__")
def substitute_generic_type(
annotation: Any, typevars: Dict[str, Any]
) -> Any:
def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any:
collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple}
if is_generic_type(annotation):
args = tuple(
@ -265,11 +263,10 @@ def substitute_generic_type(
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
typevars = None
if is_generic_type(call):
origin = getattr(call, '__origin__')
origin = call.__origin__
typevars = {
typevar.__name__: value for typevar, value in zip(
origin.__parameters__, getattr(call, '__args__')
)
typevar.__name__: value
for typevar, value in zip(origin.__parameters__, call.__args__)
}
call = origin.__init__
@ -283,16 +280,17 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
default=param.default,
annotation=get_typed_annotation(param.annotation, globalns, typevars),
)
for param in signature.parameters.values() if param.name != 'self'
for param in signature.parameters.values()
if param.name != "self"
]
typed_signature = inspect.Signature(typed_params)
return typed_signature
def get_typed_annotation(
annotation: Any,
globalns: Dict[str, Any],
typevars: Optional[Dict[str, type]] = None,
annotation: Any,
globalns: Dict[str, Any],
typevars: Optional[Dict[str, type]] = None,
) -> Any:
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
@ -799,4 +797,3 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
)
check_file_field(final_field)
return final_field

32
tests/test_dependency_generic_class.py

@ -1,28 +1,25 @@
from typing import Generic, List, TypeVar, Dict
from starlette.testclient import TestClient
from typing import Dict, Generic, List, TypeVar
from fastapi import Depends, FastAPI
from starlette.testclient import TestClient
T = TypeVar("T")
C = TypeVar("C")
class FirstGenericType(Generic[T]):
def __init__(self, simple: T, lst: List[T]):
self.simple = simple
self.lst = lst
class SecondGenericType(Generic[T, C]):
def __init__(
self,
simple: T,
lst: List[T],
dct: Dict[T, C],
custom_class: FirstGenericType[T] = Depends()
self,
simple: T,
lst: List[T],
dct: Dict[T, C],
custom_class: FirstGenericType[T] = Depends(),
):
self.simple = simple
self.lst = lst
@ -41,8 +38,8 @@ def depend_generic_type(obj: SecondGenericType[str, int] = Depends()):
"dct": obj.dct,
"custom_class": {
"simple": obj.custom_class.simple,
"lst": obj.custom_class.lst
}
"lst": obj.custom_class.lst,
},
}
@ -50,10 +47,13 @@ client = TestClient(app)
def test_generic_class_dependency():
response = client.post("/test_generic_class?simple=simple", json={
"lst": ["string_1", "string_2"],
"dct": {"key": 1},
})
response = client.post(
"/test_generic_class?simple=simple",
json={
"lst": ["string_1", "string_2"],
"dct": {"key": 1},
},
)
assert response.status_code == 200, response.json()
assert response.json() == {
"custom_class": {

Loading…
Cancel
Save