Browse Source

merge all outer local contexts

pull/10647/head
Nikita Pastukhov 2 years ago
parent
commit
3f5275fbe0
  1. 24
      fastapi/dependencies/utils.py
  2. 14
      tests/test_get_request_body.py

24
fastapi/dependencies/utils.py

@ -1,7 +1,6 @@
import inspect import inspect
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from types import FrameType
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -209,7 +208,6 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = inspect.signature(call) signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {}) globalns = getattr(call, "__globals__", {})
call_frame = get_first_outer_frame()
typed_params = [ typed_params = [
inspect.Parameter( inspect.Parameter(
@ -219,7 +217,7 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
annotation=get_typed_annotation( annotation=get_typed_annotation(
param.annotation, param.annotation,
globalns, globalns,
getattr(call_frame, "f_locals", {}), collect_outer_locals(),
), ),
) )
for param in signature.parameters.values() for param in signature.parameters.values()
@ -228,15 +226,26 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
return typed_signature return typed_signature
def get_first_outer_frame() -> Optional[FrameType]: def collect_outer_locals() -> Dict[str, Any]:
frame = inspect.currentframe() frame = inspect.currentframe()
locals = {}
finded = False
while frame is not None: while frame is not None:
# filter all venv frames
if "site-packages" in frame.f_code.co_filename:
break
if finded:
locals.update(frame.f_locals)
# Find first FastAPI outer frame
if frame.f_code.co_name == "decorator": if frame.f_code.co_name == "decorator":
return frame.f_back finded = True
frame = frame.f_back frame = frame.f_back
return None return locals
def get_typed_annotation( def get_typed_annotation(
@ -251,7 +260,6 @@ def get_typed_annotation(
def get_typed_return_annotation(call: Callable[..., Any]) -> Any: def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
signature = inspect.signature(call) signature = inspect.signature(call)
annotation = signature.return_annotation annotation = signature.return_annotation
call_frame = get_first_outer_frame()
if annotation is inspect.Signature.empty: if annotation is inspect.Signature.empty:
return None return None
@ -260,7 +268,7 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
return get_typed_annotation( return get_typed_annotation(
annotation, annotation,
globalns, globalns,
getattr(call_frame, "f_locals", {}), collect_outer_locals(),
) )

14
tests/test_get_request_body.py

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from pydantic import BaseModel from pydantic import BaseModel
@ -110,19 +112,25 @@ def test_openapi_schema():
def test_get_with_local_declared_body(): def test_get_with_local_declared_body():
def wrap(application: FastAPI, *args: Any):
def wrapper(func):
return application.get(*args)(func)
return wrapper
def init_app() -> FastAPI: def init_app() -> FastAPI:
app = FastAPI() application = FastAPI()
class LocalProduct(BaseModel): class LocalProduct(BaseModel):
name: str name: str
description: str = None # type: ignore description: str = None # type: ignore
price: float price: float
@app.get("/product") @wrap(application, "/product")
async def create_item(product: LocalProduct) -> LocalProduct: async def create_item(product: LocalProduct) -> LocalProduct:
return product return product
return app return application
client = TestClient(init_app()) client = TestClient(init_app())

Loading…
Cancel
Save