Browse Source

Add support for functools.partial()-wrapped dependables

pull/9753/head
Lie Ryan 2 years ago
parent
commit
f16dd38b47
  1. 13
      fastapi/dependencies/utils.py
  2. 116
      tests/test_dependency_types.py

13
fastapi/dependencies/utils.py

@ -1,6 +1,7 @@
import inspect import inspect
from contextlib import AsyncExitStack, contextmanager from contextlib import AsyncExitStack, contextmanager
from copy import deepcopy from copy import deepcopy
from functools import partial
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -487,10 +488,10 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
def is_coroutine_callable(call: Callable[..., Any]) -> bool: def is_coroutine_callable(call: Callable[..., Any]) -> bool:
if inspect.isroutine(call): if inspect.iscoroutinefunction(call):
return inspect.iscoroutinefunction(call) return True
if inspect.isclass(call): if isinstance(call, partial):
return False return is_coroutine_callable(call.func)
dunder_call = getattr(call, "__call__", None) # noqa: B004 dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.iscoroutinefunction(dunder_call) return inspect.iscoroutinefunction(dunder_call)
@ -498,6 +499,8 @@ def is_coroutine_callable(call: Callable[..., Any]) -> bool:
def is_async_gen_callable(call: Callable[..., Any]) -> bool: def is_async_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isasyncgenfunction(call): if inspect.isasyncgenfunction(call):
return True return True
if isinstance(call, partial):
return is_async_gen_callable(call.func)
dunder_call = getattr(call, "__call__", None) # noqa: B004 dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isasyncgenfunction(dunder_call) return inspect.isasyncgenfunction(dunder_call)
@ -505,6 +508,8 @@ def is_async_gen_callable(call: Callable[..., Any]) -> bool:
def is_gen_callable(call: Callable[..., Any]) -> bool: def is_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isgeneratorfunction(call): if inspect.isgeneratorfunction(call):
return True return True
if isinstance(call, partial):
return is_gen_callable(call.func)
dunder_call = getattr(call, "__call__", None) # noqa: B004 dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isgeneratorfunction(dunder_call) return inspect.isgeneratorfunction(dunder_call)

116
tests/test_dependency_types.py

@ -1,4 +1,5 @@
from typing import AsyncGenerator, Generator from typing import AsyncGenerator, Generator
from functools import partial
import pytest import pytest
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI
@ -64,6 +65,26 @@ async_callable_gen_dependency = AsyncCallableGenDependency()
methods_dependency = MethodsDependency() methods_dependency = MethodsDependency()
@app.get("/function-dependency")
async def get_function_dependency(value: str = Depends(function_dependency)) -> str:
return value
@app.get("/async-function-dependency")
async def get_async_function_dependency(value: str = Depends(async_function_dependency)) -> str:
return value
@app.get("/gen-dependency")
async def get_gen_dependency(value: str = Depends(gen_dependency)) -> str:
return value
@app.get("/async-gen-dependency")
async def get_async_gen_dependency(value: str = Depends(async_gen_dependency)) -> str:
return value
@app.get("/callable-dependency") @app.get("/callable-dependency")
async def get_callable_dependency(value: str = Depends(callable_dependency)) -> str: async def get_callable_dependency(value: str = Depends(callable_dependency)) -> str:
return value return value
@ -116,6 +137,78 @@ async def get_asynchronous_method_gen_dependency(
return value return value
@app.get("/partial-function-dependency")
async def get_partial_function_dependency(value: str = Depends(partial(function_dependency, "partial-function-dependency"))) -> str:
return value
@app.get("/partial-async-function-dependency")
async def get_partial_async_function_dependency(value: str = Depends(partial(async_function_dependency, "partial-async-function-dependency"))) -> str:
return value
@app.get("/partial-gen-dependency")
async def get_partial_gen_dependency(value: str = Depends(partial(gen_dependency, "partial-gen-dependency"))) -> str:
return value
@app.get("/partial-async-gen-dependency")
async def get_partial_async_gen_dependency(value: str = Depends(partial(async_gen_dependency, "partial-async-gen-dependency"))) -> str:
return value
@app.get("/partial-callable-dependency")
async def get_partial_callable_dependency(value: str = Depends(partial(callable_dependency, "partial-callable-dependency"))) -> str:
return value
@app.get("/partial-callable-gen-dependency")
async def get_partial_callable_gen_dependency(value: str = Depends(partial(callable_gen_dependency, "partial-callable-gen-dependency"))) -> str:
return value
@app.get("/partial-async-callable-dependency")
async def get_partial_async_callable_dependency(
value: str = Depends(partial(async_callable_dependency, "partial-async-callable-dependency")),
) -> str:
return value
@app.get("/partial-async-callable-gen-dependency")
async def get_partial_async_callable_gen_dependency(
value: str = Depends(partial(async_callable_gen_dependency, "partial-async-callable-gen-dependency")),
) -> str:
return value
@app.get("/partial-synchronous-method-dependency")
async def get_partial_synchronous_method_dependency(
value: str = Depends(partial(methods_dependency.synchronous, "partial-synchronous-method-dependency")),
) -> str:
return value
@app.get("/partial-synchronous-method-gen-dependency")
async def get_partial_synchronous_method_gen_dependency(
value: str = Depends(partial(methods_dependency.synchronous_gen, "partial-synchronous-method-gen-dependency")),
) -> str:
return value
@app.get("/partial-asynchronous-method-dependency")
async def get_partial_asynchronous_method_dependency(
value: str = Depends(partial(methods_dependency.asynchronous, "partial-asynchronous-method-dependency")),
) -> str:
return value
@app.get("/partial-asynchronous-method-gen-dependency")
async def get_partial_asynchronous_method_gen_dependency(
value: str = Depends(partial(methods_dependency.asynchronous_gen, "partial-asynchronous-method-gen-dependency")),
) -> str:
return value
client = TestClient(app) client = TestClient(app)
@ -140,3 +233,26 @@ def test_dependency_types(route: str, value: str) -> None:
response = client.get(route, params={"value": value}) response = client.get(route, params={"value": value})
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == value assert response.json() == value
@pytest.mark.parametrize(
"route,value",
[
("/partial-function-dependency", "partial-function-dependency"),
("/partial-async-function-dependency", "partial-async-function-dependency"),
("/partial-gen-dependency", "partial-gen-dependency"),
("/partial-async-gen-dependency", "partial-async-gen-dependency"),
("/partial-callable-dependency", "partial-callable-dependency"),
("/partial-callable-gen-dependency", "partial-callable-gen-dependency"),
("/partial-async-callable-dependency", "partial-async-callable-dependency"),
("/partial-async-callable-gen-dependency", "partial-async-callable-gen-dependency"),
("/partial-synchronous-method-dependency", "partial-synchronous-method-dependency"),
("/partial-synchronous-method-gen-dependency", "partial-synchronous-method-gen-dependency"),
("/partial-asynchronous-method-dependency", "partial-asynchronous-method-dependency"),
("/partial-asynchronous-method-gen-dependency", "partial-asynchronous-method-gen-dependency"),
],
)
def test_dependency_types_with_partial(route: str, value: str) -> None:
response = client.get(route)
assert response.status_code == 200, response.text
assert response.json() == value

Loading…
Cancel
Save