Yuri Klimenko 2 days ago
committed by GitHub
parent
commit
63a21a3e7b
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 11
      fastapi/dependencies/protocols.py
  2. 37
      fastapi/dependencies/utils.py
  3. 66
      tests/test_dependency_generic_class.py

11
fastapi/dependencies/protocols.py

@ -0,0 +1,11 @@
from typing import Any, Tuple
from typing_extensions import Protocol
class GenericTypeProtocol(Protocol):
class OriginTypeProtocol(Protocol):
__parameters__: Tuple[Any]
__origin__: OriginTypeProtocol
__args__: Tuple[Any]

37
fastapi/dependencies/utils.py

@ -52,6 +52,7 @@ from fastapi.concurrency import (
contextmanager_in_threadpool,
)
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.dependencies.protocols import GenericTypeProtocol
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
@ -228,7 +229,32 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:
return path_params + query_params + header_params + cookie_params
def is_generic_type(obj: Any) -> bool:
return hasattr(obj, "__origin__") and hasattr(obj, "__args__")
def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any:
collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple}
if is_generic_type(annotation):
args = tuple(
substitute_generic_type(arg, typevars) for arg in annotation.__args__
)
annotation = collection_shells.get(annotation.__origin__, annotation)
return annotation[args]
return typevars.get(annotation.__name__, annotation)
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
typevars = None
if is_generic_type(call):
generic: GenericTypeProtocol = cast(GenericTypeProtocol, call)
origin: Any = generic.__origin__
typevars = {
typevar.__name__: value
for typevar, value in zip(origin.__parameters__, generic.__args__)
}
call = origin.__init__
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
@ -236,18 +262,25 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(param.annotation, globalns),
annotation=get_typed_annotation(param.annotation, globalns, typevars),
)
for param in signature.parameters.values()
if param.name != "self"
]
typed_signature = inspect.Signature(typed_params)
return typed_signature
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
def get_typed_annotation(
annotation: Any,
globalns: Dict[str, Any],
typevars: Optional[Dict[str, type]] = None,
) -> Any:
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
if typevars:
annotation = substitute_generic_type(annotation, typevars)
return annotation

66
tests/test_dependency_generic_class.py

@ -0,0 +1,66 @@
from typing import Dict, Generic, List, TypeVar
from fastapi import Depends, FastAPI
from starlette.testclient import TestClient
T = TypeVar("T")
C = TypeVar("C")
class FirstGenericType(Generic[T]):
def __init__(self, simple: T, lst: List[T]):
self.simple = simple
self.lst = lst
class SecondGenericType(Generic[T, C]):
def __init__(
self,
simple: T,
lst: List[T],
dct: Dict[T, C],
custom_class: FirstGenericType[T] = Depends(),
):
self.simple = simple
self.lst = lst
self.dct = dct
self.custom_class = custom_class
app = FastAPI()
@app.post("/test_generic_class")
def depend_generic_type(obj: SecondGenericType[str, int] = Depends()):
return {
"simple": obj.simple,
"lst": obj.lst,
"dct": obj.dct,
"custom_class": {
"simple": obj.custom_class.simple,
"lst": obj.custom_class.lst,
},
}
client = TestClient(app)
def test_generic_class_dependency():
response = client.post(
"/test_generic_class?simple=simple",
json={
"lst": ["string_1", "string_2"],
"dct": {"key": 1},
},
)
assert response.status_code == 200, response.json()
assert response.json() == {
"custom_class": {
"lst": ["string_1", "string_2"],
"simple": "simple",
},
"lst": ["string_1", "string_2"],
"dct": {"key": 1},
"simple": "simple",
}
Loading…
Cancel
Save