Lie Ryan 2 weeks ago
committed by GitHub
parent
commit
89b5650752
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 13
      fastapi/dependencies/utils.py
  2. 217
      tests/test_dependency_partial.py

13
fastapi/dependencies/utils.py

@ -2,6 +2,7 @@ import inspect
from contextlib import AsyncExitStack, contextmanager from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -528,10 +529,10 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
def is_coroutine_callable(call: Callable[..., Any]) -> bool: def is_coroutine_callable(call: Callable[..., Any]) -> bool:
if inspect.isroutine(call): if inspect.iscoroutinefunction(call):
return inspect.iscoroutinefunction(call) return True
if inspect.isclass(call): if isinstance(call, partial):
return False return is_coroutine_callable(call.func)
dunder_call = getattr(call, "__call__", None) # noqa: B004 dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.iscoroutinefunction(dunder_call) return inspect.iscoroutinefunction(dunder_call)
@ -539,6 +540,8 @@ def is_coroutine_callable(call: Callable[..., Any]) -> bool:
def is_async_gen_callable(call: Callable[..., Any]) -> bool: def is_async_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isasyncgenfunction(call): if inspect.isasyncgenfunction(call):
return True return True
if isinstance(call, partial):
return is_async_gen_callable(call.func)
dunder_call = getattr(call, "__call__", None) # noqa: B004 dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isasyncgenfunction(dunder_call) return inspect.isasyncgenfunction(dunder_call)
@ -546,6 +549,8 @@ def is_async_gen_callable(call: Callable[..., Any]) -> bool:
def is_gen_callable(call: Callable[..., Any]) -> bool: def is_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isgeneratorfunction(call): if inspect.isgeneratorfunction(call):
return True return True
if isinstance(call, partial):
return is_gen_callable(call.func)
dunder_call = getattr(call, "__call__", None) # noqa: B004 dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isgeneratorfunction(dunder_call) return inspect.isgeneratorfunction(dunder_call)

217
tests/test_dependency_partial.py

@ -0,0 +1,217 @@
from functools import partial
from typing import AsyncGenerator, Generator
import pytest
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
app = FastAPI()
def function_dependency(value: str) -> str:
return value
async def async_function_dependency(value: str) -> str:
return value
def gen_dependency(value: str) -> Generator[str, None, None]:
yield value
async def async_gen_dependency(value: str) -> AsyncGenerator[str, None]:
yield value
class CallableDependency:
def __call__(self, value: str) -> str:
return value
class CallableGenDependency:
def __call__(self, value: str) -> Generator[str, None, None]:
yield value
class AsyncCallableDependency:
async def __call__(self, value: str) -> str:
return value
class AsyncCallableGenDependency:
async def __call__(self, value: str) -> AsyncGenerator[str, None]:
yield value
class MethodsDependency:
def synchronous(self, value: str) -> str:
return value
async def asynchronous(self, value: str) -> str:
return value
def synchronous_gen(self, value: str) -> Generator[str, None, None]:
yield value
async def asynchronous_gen(self, value: str) -> AsyncGenerator[str, None]:
yield value
callable_dependency = CallableDependency()
callable_gen_dependency = CallableGenDependency()
async_callable_dependency = AsyncCallableDependency()
async_callable_gen_dependency = AsyncCallableGenDependency()
methods_dependency = MethodsDependency()
@app.get("/partial-function-dependency")
async def get_partial_function_dependency(
value: str = Depends(partial(function_dependency, "partial-function-dependency")),
) -> str:
return value
@app.get("/partial-async-function-dependency")
async def get_partial_async_function_dependency(
value: str = Depends(
partial(async_function_dependency, "partial-async-function-dependency")
),
) -> str:
return value
@app.get("/partial-gen-dependency")
async def get_partial_gen_dependency(
value: str = Depends(partial(gen_dependency, "partial-gen-dependency")),
) -> str:
return value
@app.get("/partial-async-gen-dependency")
async def get_partial_async_gen_dependency(
value: str = Depends(partial(async_gen_dependency, "partial-async-gen-dependency")),
) -> str:
return value
@app.get("/partial-callable-dependency")
async def get_partial_callable_dependency(
value: str = Depends(partial(callable_dependency, "partial-callable-dependency")),
) -> str:
return value
@app.get("/partial-callable-gen-dependency")
async def get_partial_callable_gen_dependency(
value: str = Depends(
partial(callable_gen_dependency, "partial-callable-gen-dependency")
),
) -> str:
return value
@app.get("/partial-async-callable-dependency")
async def get_partial_async_callable_dependency(
value: str = Depends(
partial(async_callable_dependency, "partial-async-callable-dependency")
),
) -> str:
return value
@app.get("/partial-async-callable-gen-dependency")
async def get_partial_async_callable_gen_dependency(
value: str = Depends(
partial(async_callable_gen_dependency, "partial-async-callable-gen-dependency")
),
) -> str:
return value
@app.get("/partial-synchronous-method-dependency")
async def get_partial_synchronous_method_dependency(
value: str = Depends(
partial(methods_dependency.synchronous, "partial-synchronous-method-dependency")
),
) -> str:
return value
@app.get("/partial-synchronous-method-gen-dependency")
async def get_partial_synchronous_method_gen_dependency(
value: str = Depends(
partial(
methods_dependency.synchronous_gen,
"partial-synchronous-method-gen-dependency",
)
),
) -> str:
return value
@app.get("/partial-asynchronous-method-dependency")
async def get_partial_asynchronous_method_dependency(
value: str = Depends(
partial(
methods_dependency.asynchronous, "partial-asynchronous-method-dependency"
)
),
) -> str:
return value
@app.get("/partial-asynchronous-method-gen-dependency")
async def get_partial_asynchronous_method_gen_dependency(
value: str = Depends(
partial(
methods_dependency.asynchronous_gen,
"partial-asynchronous-method-gen-dependency",
)
),
) -> str:
return value
client = TestClient(app)
@pytest.mark.parametrize(
"route,value",
[
("/partial-function-dependency", "partial-function-dependency"),
(
"/partial-async-function-dependency",
"partial-async-function-dependency",
),
("/partial-gen-dependency", "partial-gen-dependency"),
("/partial-async-gen-dependency", "partial-async-gen-dependency"),
("/partial-callable-dependency", "partial-callable-dependency"),
("/partial-callable-gen-dependency", "partial-callable-gen-dependency"),
("/partial-async-callable-dependency", "partial-async-callable-dependency"),
(
"/partial-async-callable-gen-dependency",
"partial-async-callable-gen-dependency",
),
(
"/partial-synchronous-method-dependency",
"partial-synchronous-method-dependency",
),
(
"/partial-synchronous-method-gen-dependency",
"partial-synchronous-method-gen-dependency",
),
(
"/partial-asynchronous-method-dependency",
"partial-asynchronous-method-dependency",
),
(
"/partial-asynchronous-method-gen-dependency",
"partial-asynchronous-method-gen-dependency",
),
],
)
def test_dependency_types_with_partial(route: str, value: str) -> None:
response = client.get(route)
assert response.status_code == 200, response.text
assert response.json() == value
Loading…
Cancel
Save