Browse Source

respect localns

pull/10647/head
Nikita Pastukhov 2 years ago
parent
commit
fdc3a25ac9
  1. 23
      fastapi/dependencies/utils.py
  2. 3
      tests/test_get_request_body.py

23
fastapi/dependencies/utils.py

@ -1,6 +1,7 @@
import inspect
from contextlib import contextmanager
from copy import deepcopy
from types import FrameType
from typing import (
Any,
Callable,
@ -208,12 +209,14 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
call_frame = get_first_outer_frame()
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(param.annotation, globalns),
annotation=get_typed_annotation(param.annotation, globalns, getattr(call_frame, "f_locals", {}),),
)
for param in signature.parameters.values()
]
@ -221,22 +224,34 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
return typed_signature
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
def get_first_outer_frame() -> Optional[FrameType]:
frame = inspect.currentframe()
while frame is not None:
if frame.f_code.co_name == "decorator":
return frame.f_back
frame = frame.f_back
return None
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any], localns: Dict[str, Any]) -> Any:
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
annotation = evaluate_forwardref(annotation, globalns, localns)
return annotation
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
signature = inspect.signature(call)
annotation = signature.return_annotation
call_frame = get_first_outer_frame()
if annotation is inspect.Signature.empty:
return None
globalns = getattr(call, "__globals__", {})
return get_typed_annotation(annotation, globalns)
return get_typed_annotation(annotation, globalns, getattr(call_frame, "f_locals", {}),)
def get_dependant(

3
tests/test_get_request_body.py

@ -118,9 +118,8 @@ def test_get_with_local_declared_body():
description: str = None # type: ignore
price: float
@app.get("/product")
async def create_item(product: LocalProduct):
async def create_item(product: LocalProduct) -> LocalProduct:
return product
return app

Loading…
Cancel
Save