Browse Source

Added support for injection of custom generic classes via Depends()

pull/6031/head
yyklimenko 2 years ago
parent
commit
743cca58c3
  1. 41
      fastapi/dependencies/utils.py
  2. 66
      tests/test_dependency_generic_class.py

41
fastapi/dependencies/utils.py

@ -245,26 +245,60 @@ def is_scalar_sequence_field(field: ModelField) -> bool:
return False
def is_generic_type(obj: Any) -> bool:
return hasattr(obj, '__origin__') and hasattr(obj, '__args__')
def substitute_generic_type(
annotation: Any, typevars: Dict[str, Any]
) -> Any:
collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple}
if is_generic_type(annotation):
args = tuple(
substitute_generic_type(arg, typevars) for arg in annotation.__args__
)
annotation = collection_shells.get(annotation.__origin__, annotation)
return annotation[args]
return typevars.get(annotation.__name__, annotation)
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
typevars = None
if is_generic_type(call):
origin = getattr(call, '__origin__')
typevars = {
typevar.__name__: value for typevar, value in zip(
origin.__parameters__, getattr(call, '__args__')
)
}
call = origin.__init__
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(param.annotation, globalns),
annotation=get_typed_annotation(param.annotation, globalns, typevars),
)
for param in signature.parameters.values()
for param in signature.parameters.values() if param.name != 'self'
]
typed_signature = inspect.Signature(typed_params)
return typed_signature
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
def get_typed_annotation(
annotation: Any,
globalns: Dict[str, Any],
typevars: Optional[Dict[str, type]] = None,
) -> Any:
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
if typevars:
annotation = substitute_generic_type(annotation, typevars)
return annotation
@ -765,3 +799,4 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
)
check_file_field(final_field)
return final_field

66
tests/test_dependency_generic_class.py

@ -0,0 +1,66 @@
from typing import Generic, List, TypeVar, Dict
from starlette.testclient import TestClient
from fastapi import Depends, FastAPI
T = TypeVar("T")
C = TypeVar("C")
class FirstGenericType(Generic[T]):
def __init__(self, simple: T, lst: List[T]):
self.simple = simple
self.lst = lst
class SecondGenericType(Generic[T, C]):
def __init__(
self,
simple: T,
lst: List[T],
dct: Dict[T, C],
custom_class: FirstGenericType[T] = Depends()
):
self.simple = simple
self.lst = lst
self.dct = dct
self.custom_class = custom_class
app = FastAPI()
@app.post("/test_generic_class")
def depend_generic_type(obj: SecondGenericType[str, int] = Depends()):
return {
"simple": obj.simple,
"lst": obj.lst,
"dct": obj.dct,
"custom_class": {
"simple": obj.custom_class.simple,
"lst": obj.custom_class.lst
}
}
client = TestClient(app)
def test_generic_class_dependency():
response = client.post("/test_generic_class?simple=simple", json={
"lst": ["string_1", "string_2"],
"dct": {"key": 1},
})
assert response.status_code == 200, response.json()
assert response.json() == {
"custom_class": {
"lst": ["string_1", "string_2"],
"simple": "simple",
},
"lst": ["string_1", "string_2"],
"dct": {"key": 1},
"simple": "simple",
}
Loading…
Cancel
Save