Lie Ryan 2 weeks ago
committed by GitHub
parent
commit
89b5650752
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 13
      fastapi/dependencies/utils.py
  2. 217
      tests/test_dependency_partial.py

13
fastapi/dependencies/utils.py

@ -2,6 +2,7 @@ import inspect
from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy
from dataclasses import dataclass
from functools import partial
from typing import (
Any,
Callable,
@ -528,10 +529,10 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
return False
if inspect.iscoroutinefunction(call):
return True
if isinstance(call, partial):
return is_coroutine_callable(call.func)
dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.iscoroutinefunction(dunder_call)
@ -539,6 +540,8 @@ def is_coroutine_callable(call: Callable[..., Any]) -> bool:
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isasyncgenfunction(call):
return True
if isinstance(call, partial):
return is_async_gen_callable(call.func)
dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isasyncgenfunction(dunder_call)
@ -546,6 +549,8 @@ def is_async_gen_callable(call: Callable[..., Any]) -> bool:
def is_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isgeneratorfunction(call):
return True
if isinstance(call, partial):
return is_gen_callable(call.func)
dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isgeneratorfunction(dunder_call)

217
tests/test_dependency_partial.py

@ -0,0 +1,217 @@
from functools import partial
from typing import AsyncGenerator, Generator
import pytest
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
app = FastAPI()
def function_dependency(value: str) -> str:
return value
async def async_function_dependency(value: str) -> str:
return value
def gen_dependency(value: str) -> Generator[str, None, None]:
yield value
async def async_gen_dependency(value: str) -> AsyncGenerator[str, None]:
yield value
class CallableDependency:
def __call__(self, value: str) -> str:
return value
class CallableGenDependency:
def __call__(self, value: str) -> Generator[str, None, None]:
yield value
class AsyncCallableDependency:
async def __call__(self, value: str) -> str:
return value
class AsyncCallableGenDependency:
async def __call__(self, value: str) -> AsyncGenerator[str, None]:
yield value
class MethodsDependency:
def synchronous(self, value: str) -> str:
return value
async def asynchronous(self, value: str) -> str:
return value
def synchronous_gen(self, value: str) -> Generator[str, None, None]:
yield value
async def asynchronous_gen(self, value: str) -> AsyncGenerator[str, None]:
yield value
callable_dependency = CallableDependency()
callable_gen_dependency = CallableGenDependency()
async_callable_dependency = AsyncCallableDependency()
async_callable_gen_dependency = AsyncCallableGenDependency()
methods_dependency = MethodsDependency()
@app.get("/partial-function-dependency")
async def get_partial_function_dependency(
value: str = Depends(partial(function_dependency, "partial-function-dependency")),
) -> str:
return value
@app.get("/partial-async-function-dependency")
async def get_partial_async_function_dependency(
value: str = Depends(
partial(async_function_dependency, "partial-async-function-dependency")
),
) -> str:
return value
@app.get("/partial-gen-dependency")
async def get_partial_gen_dependency(
value: str = Depends(partial(gen_dependency, "partial-gen-dependency")),
) -> str:
return value
@app.get("/partial-async-gen-dependency")
async def get_partial_async_gen_dependency(
value: str = Depends(partial(async_gen_dependency, "partial-async-gen-dependency")),
) -> str:
return value
@app.get("/partial-callable-dependency")
async def get_partial_callable_dependency(
value: str = Depends(partial(callable_dependency, "partial-callable-dependency")),
) -> str:
return value
@app.get("/partial-callable-gen-dependency")
async def get_partial_callable_gen_dependency(
value: str = Depends(
partial(callable_gen_dependency, "partial-callable-gen-dependency")
),
) -> str:
return value
@app.get("/partial-async-callable-dependency")
async def get_partial_async_callable_dependency(
value: str = Depends(
partial(async_callable_dependency, "partial-async-callable-dependency")
),
) -> str:
return value
@app.get("/partial-async-callable-gen-dependency")
async def get_partial_async_callable_gen_dependency(
value: str = Depends(
partial(async_callable_gen_dependency, "partial-async-callable-gen-dependency")
),
) -> str:
return value
@app.get("/partial-synchronous-method-dependency")
async def get_partial_synchronous_method_dependency(
value: str = Depends(
partial(methods_dependency.synchronous, "partial-synchronous-method-dependency")
),
) -> str:
return value
@app.get("/partial-synchronous-method-gen-dependency")
async def get_partial_synchronous_method_gen_dependency(
value: str = Depends(
partial(
methods_dependency.synchronous_gen,
"partial-synchronous-method-gen-dependency",
)
),
) -> str:
return value
@app.get("/partial-asynchronous-method-dependency")
async def get_partial_asynchronous_method_dependency(
value: str = Depends(
partial(
methods_dependency.asynchronous, "partial-asynchronous-method-dependency"
)
),
) -> str:
return value
@app.get("/partial-asynchronous-method-gen-dependency")
async def get_partial_asynchronous_method_gen_dependency(
value: str = Depends(
partial(
methods_dependency.asynchronous_gen,
"partial-asynchronous-method-gen-dependency",
)
),
) -> str:
return value
client = TestClient(app)
@pytest.mark.parametrize(
"route,value",
[
("/partial-function-dependency", "partial-function-dependency"),
(
"/partial-async-function-dependency",
"partial-async-function-dependency",
),
("/partial-gen-dependency", "partial-gen-dependency"),
("/partial-async-gen-dependency", "partial-async-gen-dependency"),
("/partial-callable-dependency", "partial-callable-dependency"),
("/partial-callable-gen-dependency", "partial-callable-gen-dependency"),
("/partial-async-callable-dependency", "partial-async-callable-dependency"),
(
"/partial-async-callable-gen-dependency",
"partial-async-callable-gen-dependency",
),
(
"/partial-synchronous-method-dependency",
"partial-synchronous-method-dependency",
),
(
"/partial-synchronous-method-gen-dependency",
"partial-synchronous-method-gen-dependency",
),
(
"/partial-asynchronous-method-dependency",
"partial-asynchronous-method-dependency",
),
(
"/partial-asynchronous-method-gen-dependency",
"partial-asynchronous-method-gen-dependency",
),
],
)
def test_dependency_types_with_partial(route: str, value: str) -> None:
response = client.get(route)
assert response.status_code == 200, response.text
assert response.json() == value
Loading…
Cancel
Save