diff --git a/fastapi/dependencies/protocols.py b/fastapi/dependencies/protocols.py new file mode 100644 index 000000000..857e2169e --- /dev/null +++ b/fastapi/dependencies/protocols.py @@ -0,0 +1,11 @@ +from typing import Any, Tuple + +from typing_extensions import Protocol + + +class GenericTypeProtocol(Protocol): + class OriginTypeProtocol(Protocol): + __parameters__: Tuple[Any] + + __origin__: OriginTypeProtocol + __args__: Tuple[Any] diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 84dfa4d03..19a06b8bc 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -52,6 +52,7 @@ from fastapi.concurrency import ( contextmanager_in_threadpool, ) from fastapi.dependencies.models import Dependant, SecurityRequirement +from fastapi.dependencies.protocols import GenericTypeProtocol from fastapi.logger import logger from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes @@ -228,7 +229,32 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]: return path_params + query_params + header_params + cookie_params +def is_generic_type(obj: Any) -> bool: + return hasattr(obj, "__origin__") and hasattr(obj, "__args__") + + +def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any: + collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple} + if is_generic_type(annotation): + args = tuple( + substitute_generic_type(arg, typevars) for arg in annotation.__args__ + ) + annotation = collection_shells.get(annotation.__origin__, annotation) + return annotation[args] + return typevars.get(annotation.__name__, annotation) + + def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + typevars = None + if is_generic_type(call): + generic: GenericTypeProtocol = cast(GenericTypeProtocol, call) + origin: Any = generic.__origin__ + typevars = { + typevar.__name__: value + for typevar, value in zip(origin.__parameters__, generic.__args__) + } + call = origin.__init__ + signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) typed_params = [ @@ -236,18 +262,25 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: name=param.name, kind=param.kind, default=param.default, - annotation=get_typed_annotation(param.annotation, globalns), + annotation=get_typed_annotation(param.annotation, globalns, typevars), ) for param in signature.parameters.values() + if param.name != "self" ] typed_signature = inspect.Signature(typed_params) return typed_signature -def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: +def get_typed_annotation( + annotation: Any, + globalns: Dict[str, Any], + typevars: Optional[Dict[str, type]] = None, +) -> Any: if isinstance(annotation, str): annotation = ForwardRef(annotation) annotation = evaluate_forwardref(annotation, globalns, globalns) + if typevars: + annotation = substitute_generic_type(annotation, typevars) return annotation diff --git a/tests/test_dependency_generic_class.py b/tests/test_dependency_generic_class.py new file mode 100644 index 000000000..3eeb8291e --- /dev/null +++ b/tests/test_dependency_generic_class.py @@ -0,0 +1,66 @@ +from typing import Dict, Generic, List, TypeVar + +from fastapi import Depends, FastAPI +from starlette.testclient import TestClient + +T = TypeVar("T") +C = TypeVar("C") + + +class FirstGenericType(Generic[T]): + def __init__(self, simple: T, lst: List[T]): + self.simple = simple + self.lst = lst + + +class SecondGenericType(Generic[T, C]): + def __init__( + self, + simple: T, + lst: List[T], + dct: Dict[T, C], + custom_class: FirstGenericType[T] = Depends(), + ): + self.simple = simple + self.lst = lst + self.dct = dct + self.custom_class = custom_class + + +app = FastAPI() + + +@app.post("/test_generic_class") +def depend_generic_type(obj: SecondGenericType[str, int] = Depends()): + return { + "simple": obj.simple, + "lst": obj.lst, + "dct": obj.dct, + "custom_class": { + "simple": obj.custom_class.simple, + "lst": obj.custom_class.lst, + }, + } + + +client = TestClient(app) + + +def test_generic_class_dependency(): + response = client.post( + "/test_generic_class?simple=simple", + json={ + "lst": ["string_1", "string_2"], + "dct": {"key": 1}, + }, + ) + assert response.status_code == 200, response.json() + assert response.json() == { + "custom_class": { + "lst": ["string_1", "string_2"], + "simple": "simple", + }, + "lst": ["string_1", "string_2"], + "dct": {"key": 1}, + "simple": "simple", + }