diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 96e07a45c..c3d892726 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,6 +1,7 @@ import inspect from contextlib import contextmanager from copy import deepcopy +from types import FrameType from typing import ( Any, Callable, @@ -208,12 +209,14 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) + call_frame = get_first_outer_frame() + typed_params = [ inspect.Parameter( name=param.name, kind=param.kind, default=param.default, - annotation=get_typed_annotation(param.annotation, globalns), + annotation=get_typed_annotation(param.annotation, globalns, getattr(call_frame, "f_locals", {}),), ) for param in signature.parameters.values() ] @@ -221,22 +224,34 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: return typed_signature -def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: +def get_first_outer_frame() -> Optional[FrameType]: + frame = inspect.currentframe() + while frame is not None: + if frame.f_code.co_name == "decorator": + return frame.f_back + + frame = frame.f_back + + return None + + +def get_typed_annotation(annotation: Any, globalns: Dict[str, Any], localns: Dict[str, Any]) -> Any: if isinstance(annotation, str): annotation = ForwardRef(annotation) - annotation = evaluate_forwardref(annotation, globalns, globalns) + annotation = evaluate_forwardref(annotation, globalns, localns) return annotation def get_typed_return_annotation(call: Callable[..., Any]) -> Any: signature = inspect.signature(call) annotation = signature.return_annotation + call_frame = get_first_outer_frame() if annotation is inspect.Signature.empty: return None globalns = getattr(call, "__globals__", {}) - return get_typed_annotation(annotation, globalns) + return get_typed_annotation(annotation, globalns, getattr(call_frame, "f_locals", {}),) def get_dependant( diff --git a/tests/test_get_request_body.py b/tests/test_get_request_body.py index a4df82cb2..e334e1867 100644 --- a/tests/test_get_request_body.py +++ b/tests/test_get_request_body.py @@ -118,9 +118,8 @@ def test_get_with_local_declared_body(): description: str = None # type: ignore price: float - @app.get("/product") - async def create_item(product: LocalProduct): + async def create_item(product: LocalProduct) -> LocalProduct: return product return app