From f16dd38b478c0a88ae4ae79fd73094fdb3ffa425 Mon Sep 17 00:00:00 2001 From: Lie Ryan Date: Mon, 22 Jan 2024 12:13:05 +1100 Subject: [PATCH] Add support for functools.partial()-wrapped dependables --- fastapi/dependencies/utils.py | 13 ++-- tests/test_dependency_types.py | 116 +++++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 4 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index b73473484..a08bb9742 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,6 +1,7 @@ import inspect from contextlib import AsyncExitStack, contextmanager from copy import deepcopy +from functools import partial from typing import ( Any, Callable, @@ -487,10 +488,10 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: def is_coroutine_callable(call: Callable[..., Any]) -> bool: - if inspect.isroutine(call): - return inspect.iscoroutinefunction(call) - if inspect.isclass(call): - return False + if inspect.iscoroutinefunction(call): + return True + if isinstance(call, partial): + return is_coroutine_callable(call.func) dunder_call = getattr(call, "__call__", None) # noqa: B004 return inspect.iscoroutinefunction(dunder_call) @@ -498,6 +499,8 @@ def is_coroutine_callable(call: Callable[..., Any]) -> bool: def is_async_gen_callable(call: Callable[..., Any]) -> bool: if inspect.isasyncgenfunction(call): return True + if isinstance(call, partial): + return is_async_gen_callable(call.func) dunder_call = getattr(call, "__call__", None) # noqa: B004 return inspect.isasyncgenfunction(dunder_call) @@ -505,6 +508,8 @@ def is_async_gen_callable(call: Callable[..., Any]) -> bool: def is_gen_callable(call: Callable[..., Any]) -> bool: if inspect.isgeneratorfunction(call): return True + if isinstance(call, partial): + return is_gen_callable(call.func) dunder_call = getattr(call, "__call__", None) # noqa: B004 return inspect.isgeneratorfunction(dunder_call) diff --git a/tests/test_dependency_types.py b/tests/test_dependency_types.py index 3e4e5d7d1..4bd37afa5 100644 --- a/tests/test_dependency_types.py +++ b/tests/test_dependency_types.py @@ -1,4 +1,5 @@ from typing import AsyncGenerator, Generator +from functools import partial import pytest from fastapi import Depends, FastAPI @@ -64,6 +65,26 @@ async_callable_gen_dependency = AsyncCallableGenDependency() methods_dependency = MethodsDependency() +@app.get("/function-dependency") +async def get_function_dependency(value: str = Depends(function_dependency)) -> str: + return value + + +@app.get("/async-function-dependency") +async def get_async_function_dependency(value: str = Depends(async_function_dependency)) -> str: + return value + + +@app.get("/gen-dependency") +async def get_gen_dependency(value: str = Depends(gen_dependency)) -> str: + return value + + +@app.get("/async-gen-dependency") +async def get_async_gen_dependency(value: str = Depends(async_gen_dependency)) -> str: + return value + + @app.get("/callable-dependency") async def get_callable_dependency(value: str = Depends(callable_dependency)) -> str: return value @@ -116,6 +137,78 @@ async def get_asynchronous_method_gen_dependency( return value +@app.get("/partial-function-dependency") +async def get_partial_function_dependency(value: str = Depends(partial(function_dependency, "partial-function-dependency"))) -> str: + return value + + +@app.get("/partial-async-function-dependency") +async def get_partial_async_function_dependency(value: str = Depends(partial(async_function_dependency, "partial-async-function-dependency"))) -> str: + return value + + +@app.get("/partial-gen-dependency") +async def get_partial_gen_dependency(value: str = Depends(partial(gen_dependency, "partial-gen-dependency"))) -> str: + return value + + +@app.get("/partial-async-gen-dependency") +async def get_partial_async_gen_dependency(value: str = Depends(partial(async_gen_dependency, "partial-async-gen-dependency"))) -> str: + return value + + +@app.get("/partial-callable-dependency") +async def get_partial_callable_dependency(value: str = Depends(partial(callable_dependency, "partial-callable-dependency"))) -> str: + return value + + +@app.get("/partial-callable-gen-dependency") +async def get_partial_callable_gen_dependency(value: str = Depends(partial(callable_gen_dependency, "partial-callable-gen-dependency"))) -> str: + return value + + +@app.get("/partial-async-callable-dependency") +async def get_partial_async_callable_dependency( + value: str = Depends(partial(async_callable_dependency, "partial-async-callable-dependency")), +) -> str: + return value + + +@app.get("/partial-async-callable-gen-dependency") +async def get_partial_async_callable_gen_dependency( + value: str = Depends(partial(async_callable_gen_dependency, "partial-async-callable-gen-dependency")), +) -> str: + return value + + +@app.get("/partial-synchronous-method-dependency") +async def get_partial_synchronous_method_dependency( + value: str = Depends(partial(methods_dependency.synchronous, "partial-synchronous-method-dependency")), +) -> str: + return value + + +@app.get("/partial-synchronous-method-gen-dependency") +async def get_partial_synchronous_method_gen_dependency( + value: str = Depends(partial(methods_dependency.synchronous_gen, "partial-synchronous-method-gen-dependency")), +) -> str: + return value + + +@app.get("/partial-asynchronous-method-dependency") +async def get_partial_asynchronous_method_dependency( + value: str = Depends(partial(methods_dependency.asynchronous, "partial-asynchronous-method-dependency")), +) -> str: + return value + + +@app.get("/partial-asynchronous-method-gen-dependency") +async def get_partial_asynchronous_method_gen_dependency( + value: str = Depends(partial(methods_dependency.asynchronous_gen, "partial-asynchronous-method-gen-dependency")), +) -> str: + return value + + client = TestClient(app) @@ -140,3 +233,26 @@ def test_dependency_types(route: str, value: str) -> None: response = client.get(route, params={"value": value}) assert response.status_code == 200, response.text assert response.json() == value + + +@pytest.mark.parametrize( + "route,value", + [ + ("/partial-function-dependency", "partial-function-dependency"), + ("/partial-async-function-dependency", "partial-async-function-dependency"), + ("/partial-gen-dependency", "partial-gen-dependency"), + ("/partial-async-gen-dependency", "partial-async-gen-dependency"), + ("/partial-callable-dependency", "partial-callable-dependency"), + ("/partial-callable-gen-dependency", "partial-callable-gen-dependency"), + ("/partial-async-callable-dependency", "partial-async-callable-dependency"), + ("/partial-async-callable-gen-dependency", "partial-async-callable-gen-dependency"), + ("/partial-synchronous-method-dependency", "partial-synchronous-method-dependency"), + ("/partial-synchronous-method-gen-dependency", "partial-synchronous-method-gen-dependency"), + ("/partial-asynchronous-method-dependency", "partial-asynchronous-method-dependency"), + ("/partial-asynchronous-method-gen-dependency", "partial-asynchronous-method-gen-dependency"), + ], +) +def test_dependency_types_with_partial(route: str, value: str) -> None: + response = client.get(route) + assert response.status_code == 200, response.text + assert response.json() == value