Browse Source

Merge d2ec5d6620 into 8032e21418

pull/6038/merge
Yuri Klimenko 3 days ago
committed by GitHub
parent
commit
41cce00a1d
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 11
      fastapi/dependencies/protocols.py
  2. 37
      fastapi/dependencies/utils.py
  3. 66
      tests/test_dependency_generic_class.py

11
fastapi/dependencies/protocols.py

@ -0,0 +1,11 @@
from typing import Any, Tuple
from typing_extensions import Protocol
class GenericTypeProtocol(Protocol):
class OriginTypeProtocol(Protocol):
__parameters__: Tuple[Any]
__origin__: OriginTypeProtocol
__args__: Tuple[Any]

37
fastapi/dependencies/utils.py

@ -52,6 +52,7 @@ from fastapi.concurrency import (
contextmanager_in_threadpool, contextmanager_in_threadpool,
) )
from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.dependencies.protocols import GenericTypeProtocol
from fastapi.logger import logger from fastapi.logger import logger
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes from fastapi.security.oauth2 import OAuth2, SecurityScopes
@ -228,7 +229,32 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:
return path_params + query_params + header_params + cookie_params return path_params + query_params + header_params + cookie_params
def is_generic_type(obj: Any) -> bool:
return hasattr(obj, "__origin__") and hasattr(obj, "__args__")
def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any:
collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple}
if is_generic_type(annotation):
args = tuple(
substitute_generic_type(arg, typevars) for arg in annotation.__args__
)
annotation = collection_shells.get(annotation.__origin__, annotation)
return annotation[args]
return typevars.get(annotation.__name__, annotation)
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
typevars = None
if is_generic_type(call):
generic: GenericTypeProtocol = cast(GenericTypeProtocol, call)
origin: Any = generic.__origin__
typevars = {
typevar.__name__: value
for typevar, value in zip(origin.__parameters__, generic.__args__)
}
call = origin.__init__
signature = inspect.signature(call) signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {}) globalns = getattr(call, "__globals__", {})
typed_params = [ typed_params = [
@ -236,18 +262,25 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
name=param.name, name=param.name,
kind=param.kind, kind=param.kind,
default=param.default, default=param.default,
annotation=get_typed_annotation(param.annotation, globalns), annotation=get_typed_annotation(param.annotation, globalns, typevars),
) )
for param in signature.parameters.values() for param in signature.parameters.values()
if param.name != "self"
] ]
typed_signature = inspect.Signature(typed_params) typed_signature = inspect.Signature(typed_params)
return typed_signature return typed_signature
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: def get_typed_annotation(
annotation: Any,
globalns: Dict[str, Any],
typevars: Optional[Dict[str, type]] = None,
) -> Any:
if isinstance(annotation, str): if isinstance(annotation, str):
annotation = ForwardRef(annotation) annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns) annotation = evaluate_forwardref(annotation, globalns, globalns)
if typevars:
annotation = substitute_generic_type(annotation, typevars)
return annotation return annotation

66
tests/test_dependency_generic_class.py

@ -0,0 +1,66 @@
from typing import Dict, Generic, List, TypeVar
from fastapi import Depends, FastAPI
from starlette.testclient import TestClient
T = TypeVar("T")
C = TypeVar("C")
class FirstGenericType(Generic[T]):
def __init__(self, simple: T, lst: List[T]):
self.simple = simple
self.lst = lst
class SecondGenericType(Generic[T, C]):
def __init__(
self,
simple: T,
lst: List[T],
dct: Dict[T, C],
custom_class: FirstGenericType[T] = Depends(),
):
self.simple = simple
self.lst = lst
self.dct = dct
self.custom_class = custom_class
app = FastAPI()
@app.post("/test_generic_class")
def depend_generic_type(obj: SecondGenericType[str, int] = Depends()):
return {
"simple": obj.simple,
"lst": obj.lst,
"dct": obj.dct,
"custom_class": {
"simple": obj.custom_class.simple,
"lst": obj.custom_class.lst,
},
}
client = TestClient(app)
def test_generic_class_dependency():
response = client.post(
"/test_generic_class?simple=simple",
json={
"lst": ["string_1", "string_2"],
"dct": {"key": 1},
},
)
assert response.status_code == 200, response.json()
assert response.json() == {
"custom_class": {
"lst": ["string_1", "string_2"],
"simple": "simple",
},
"lst": ["string_1", "string_2"],
"dct": {"key": 1},
"simple": "simple",
}
Loading…
Cancel
Save