From 743cca58c3816ac3dca5ec3f513f72c25c537826 Mon Sep 17 00:00:00 2001 From: yyklimenko Date: Wed, 22 Feb 2023 21:38:55 +0600 Subject: [PATCH] Added support for injection of custom generic classes via Depends() --- fastapi/dependencies/utils.py | 41 ++++++++++++++-- tests/test_dependency_generic_class.py | 66 ++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 3 deletions(-) create mode 100644 tests/test_dependency_generic_class.py diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 32e171f18..dd10a2a2a 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -245,26 +245,60 @@ def is_scalar_sequence_field(field: ModelField) -> bool: return False +def is_generic_type(obj: Any) -> bool: + return hasattr(obj, '__origin__') and hasattr(obj, '__args__') + + +def substitute_generic_type( + annotation: Any, typevars: Dict[str, Any] +) -> Any: + collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple} + if is_generic_type(annotation): + args = tuple( + substitute_generic_type(arg, typevars) for arg in annotation.__args__ + ) + annotation = collection_shells.get(annotation.__origin__, annotation) + return annotation[args] + return typevars.get(annotation.__name__, annotation) + + def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + typevars = None + if is_generic_type(call): + origin = getattr(call, '__origin__') + typevars = { + typevar.__name__: value for typevar, value in zip( + origin.__parameters__, getattr(call, '__args__') + ) + } + call = origin.__init__ + signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) + typed_params = [ inspect.Parameter( name=param.name, kind=param.kind, default=param.default, - annotation=get_typed_annotation(param.annotation, globalns), + annotation=get_typed_annotation(param.annotation, globalns, typevars), ) - for param in signature.parameters.values() + for param in signature.parameters.values() if param.name != 'self' ] typed_signature = inspect.Signature(typed_params) return typed_signature -def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: +def get_typed_annotation( + annotation: Any, + globalns: Dict[str, Any], + typevars: Optional[Dict[str, type]] = None, +) -> Any: if isinstance(annotation, str): annotation = ForwardRef(annotation) annotation = evaluate_forwardref(annotation, globalns, globalns) + if typevars: + annotation = substitute_generic_type(annotation, typevars) return annotation @@ -765,3 +799,4 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: ) check_file_field(final_field) return final_field + diff --git a/tests/test_dependency_generic_class.py b/tests/test_dependency_generic_class.py new file mode 100644 index 000000000..1cddbf242 --- /dev/null +++ b/tests/test_dependency_generic_class.py @@ -0,0 +1,66 @@ +from typing import Generic, List, TypeVar, Dict + +from starlette.testclient import TestClient + +from fastapi import Depends, FastAPI + +T = TypeVar("T") +C = TypeVar("C") + + +class FirstGenericType(Generic[T]): + + def __init__(self, simple: T, lst: List[T]): + self.simple = simple + self.lst = lst + + +class SecondGenericType(Generic[T, C]): + + def __init__( + self, + simple: T, + lst: List[T], + dct: Dict[T, C], + custom_class: FirstGenericType[T] = Depends() + ): + self.simple = simple + self.lst = lst + self.dct = dct + self.custom_class = custom_class + + +app = FastAPI() + + +@app.post("/test_generic_class") +def depend_generic_type(obj: SecondGenericType[str, int] = Depends()): + return { + "simple": obj.simple, + "lst": obj.lst, + "dct": obj.dct, + "custom_class": { + "simple": obj.custom_class.simple, + "lst": obj.custom_class.lst + } + } + + +client = TestClient(app) + + +def test_generic_class_dependency(): + response = client.post("/test_generic_class?simple=simple", json={ + "lst": ["string_1", "string_2"], + "dct": {"key": 1}, + }) + assert response.status_code == 200, response.json() + assert response.json() == { + "custom_class": { + "lst": ["string_1", "string_2"], + "simple": "simple", + }, + "lst": ["string_1", "string_2"], + "dct": {"key": 1}, + "simple": "simple", + }