diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 081b63a8b..247b07bab 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -2,6 +2,7 @@ import inspect from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy from dataclasses import dataclass +from functools import partial from typing import ( Any, Callable, @@ -528,10 +529,10 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: def is_coroutine_callable(call: Callable[..., Any]) -> bool: - if inspect.isroutine(call): - return inspect.iscoroutinefunction(call) - if inspect.isclass(call): - return False + if inspect.iscoroutinefunction(call): + return True + if isinstance(call, partial): + return is_coroutine_callable(call.func) dunder_call = getattr(call, "__call__", None) # noqa: B004 return inspect.iscoroutinefunction(dunder_call) @@ -539,6 +540,8 @@ def is_coroutine_callable(call: Callable[..., Any]) -> bool: def is_async_gen_callable(call: Callable[..., Any]) -> bool: if inspect.isasyncgenfunction(call): return True + if isinstance(call, partial): + return is_async_gen_callable(call.func) dunder_call = getattr(call, "__call__", None) # noqa: B004 return inspect.isasyncgenfunction(dunder_call) @@ -546,6 +549,8 @@ def is_async_gen_callable(call: Callable[..., Any]) -> bool: def is_gen_callable(call: Callable[..., Any]) -> bool: if inspect.isgeneratorfunction(call): return True + if isinstance(call, partial): + return is_gen_callable(call.func) dunder_call = getattr(call, "__call__", None) # noqa: B004 return inspect.isgeneratorfunction(dunder_call) diff --git a/tests/test_dependency_partial.py b/tests/test_dependency_partial.py new file mode 100644 index 000000000..5caffeec1 --- /dev/null +++ b/tests/test_dependency_partial.py @@ -0,0 +1,217 @@ +from functools import partial +from typing import AsyncGenerator, Generator + +import pytest +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient + +app = FastAPI() + + +def function_dependency(value: str) -> str: + return value + + +async def async_function_dependency(value: str) -> str: + return value + + +def gen_dependency(value: str) -> Generator[str, None, None]: + yield value + + +async def async_gen_dependency(value: str) -> AsyncGenerator[str, None]: + yield value + + +class CallableDependency: + def __call__(self, value: str) -> str: + return value + + +class CallableGenDependency: + def __call__(self, value: str) -> Generator[str, None, None]: + yield value + + +class AsyncCallableDependency: + async def __call__(self, value: str) -> str: + return value + + +class AsyncCallableGenDependency: + async def __call__(self, value: str) -> AsyncGenerator[str, None]: + yield value + + +class MethodsDependency: + def synchronous(self, value: str) -> str: + return value + + async def asynchronous(self, value: str) -> str: + return value + + def synchronous_gen(self, value: str) -> Generator[str, None, None]: + yield value + + async def asynchronous_gen(self, value: str) -> AsyncGenerator[str, None]: + yield value + + +callable_dependency = CallableDependency() +callable_gen_dependency = CallableGenDependency() +async_callable_dependency = AsyncCallableDependency() +async_callable_gen_dependency = AsyncCallableGenDependency() +methods_dependency = MethodsDependency() + + +@app.get("/partial-function-dependency") +async def get_partial_function_dependency( + value: str = Depends(partial(function_dependency, "partial-function-dependency")), +) -> str: + return value + + +@app.get("/partial-async-function-dependency") +async def get_partial_async_function_dependency( + value: str = Depends( + partial(async_function_dependency, "partial-async-function-dependency") + ), +) -> str: + return value + + +@app.get("/partial-gen-dependency") +async def get_partial_gen_dependency( + value: str = Depends(partial(gen_dependency, "partial-gen-dependency")), +) -> str: + return value + + +@app.get("/partial-async-gen-dependency") +async def get_partial_async_gen_dependency( + value: str = Depends(partial(async_gen_dependency, "partial-async-gen-dependency")), +) -> str: + return value + + +@app.get("/partial-callable-dependency") +async def get_partial_callable_dependency( + value: str = Depends(partial(callable_dependency, "partial-callable-dependency")), +) -> str: + return value + + +@app.get("/partial-callable-gen-dependency") +async def get_partial_callable_gen_dependency( + value: str = Depends( + partial(callable_gen_dependency, "partial-callable-gen-dependency") + ), +) -> str: + return value + + +@app.get("/partial-async-callable-dependency") +async def get_partial_async_callable_dependency( + value: str = Depends( + partial(async_callable_dependency, "partial-async-callable-dependency") + ), +) -> str: + return value + + +@app.get("/partial-async-callable-gen-dependency") +async def get_partial_async_callable_gen_dependency( + value: str = Depends( + partial(async_callable_gen_dependency, "partial-async-callable-gen-dependency") + ), +) -> str: + return value + + +@app.get("/partial-synchronous-method-dependency") +async def get_partial_synchronous_method_dependency( + value: str = Depends( + partial(methods_dependency.synchronous, "partial-synchronous-method-dependency") + ), +) -> str: + return value + + +@app.get("/partial-synchronous-method-gen-dependency") +async def get_partial_synchronous_method_gen_dependency( + value: str = Depends( + partial( + methods_dependency.synchronous_gen, + "partial-synchronous-method-gen-dependency", + ) + ), +) -> str: + return value + + +@app.get("/partial-asynchronous-method-dependency") +async def get_partial_asynchronous_method_dependency( + value: str = Depends( + partial( + methods_dependency.asynchronous, "partial-asynchronous-method-dependency" + ) + ), +) -> str: + return value + + +@app.get("/partial-asynchronous-method-gen-dependency") +async def get_partial_asynchronous_method_gen_dependency( + value: str = Depends( + partial( + methods_dependency.asynchronous_gen, + "partial-asynchronous-method-gen-dependency", + ) + ), +) -> str: + return value + + +client = TestClient(app) + + +@pytest.mark.parametrize( + "route,value", + [ + ("/partial-function-dependency", "partial-function-dependency"), + ( + "/partial-async-function-dependency", + "partial-async-function-dependency", + ), + ("/partial-gen-dependency", "partial-gen-dependency"), + ("/partial-async-gen-dependency", "partial-async-gen-dependency"), + ("/partial-callable-dependency", "partial-callable-dependency"), + ("/partial-callable-gen-dependency", "partial-callable-gen-dependency"), + ("/partial-async-callable-dependency", "partial-async-callable-dependency"), + ( + "/partial-async-callable-gen-dependency", + "partial-async-callable-gen-dependency", + ), + ( + "/partial-synchronous-method-dependency", + "partial-synchronous-method-dependency", + ), + ( + "/partial-synchronous-method-gen-dependency", + "partial-synchronous-method-gen-dependency", + ), + ( + "/partial-asynchronous-method-dependency", + "partial-asynchronous-method-dependency", + ), + ( + "/partial-asynchronous-method-gen-dependency", + "partial-asynchronous-method-gen-dependency", + ), + ], +) +def test_dependency_types_with_partial(route: str, value: str) -> None: + response = client.get(route) + assert response.status_code == 200, response.text + assert response.json() == value