diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 64a6c1276..cc3a10002 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,4 +1,6 @@ +import asyncio import dataclasses +import functools import inspect from contextlib import contextmanager from copy import deepcopy @@ -261,6 +263,15 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: return typed_signature +def is_coroutine_function(obj: Any) -> bool: + while isinstance(obj, functools.partial): + obj = obj.func + + return asyncio.iscoroutinefunction(obj) or ( + callable(obj) and asyncio.iscoroutinefunction(obj.__call__) + ) + + def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any: annotation = param.annotation if isinstance(annotation, str): diff --git a/fastapi/routing.py b/fastapi/routing.py index 7caf018b5..219d3eec2 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1,4 +1,3 @@ -import asyncio import dataclasses import email.message import inspect @@ -26,6 +25,7 @@ from fastapi.dependencies.utils import ( get_dependant, get_parameterless_sub_dependant, solve_dependencies, + is_coroutine_function, ) from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError @@ -177,7 +177,7 @@ def get_request_handler( dependency_overrides_provider: Optional[Any] = None, ) -> Callable[[Request], Coroutine[Any, Any, Response]]: assert dependant.call is not None, "dependant.call must be a function" - is_coroutine = asyncio.iscoroutinefunction(dependant.call) + is_coroutine = is_coroutine_function(dependant.call) is_body_form = body_field and isinstance(body_field.field_info, params.Form) if isinstance(response_class, DefaultPlaceholder): actual_response_class: Type[Response] = response_class.value diff --git a/tests/test_callable_endpoint.py b/tests/test_callable_endpoint.py index 1882e9053..615f5a616 100644 --- a/tests/test_callable_endpoint.py +++ b/tests/test_callable_endpoint.py @@ -5,15 +5,29 @@ from fastapi import FastAPI from fastapi.testclient import TestClient +class CallableObjectEndpoint: + def __call__(self, some_arg, q: Optional[str] = None): + return {"some_arg": some_arg, "q": q} + + +class AsyncCallableObjectEndpoint: + async def __call__(self, some_arg, q: Optional[str] = None): + return {"some_arg": some_arg, "q": q} + + def main(some_arg, q: Optional[str] = None): return {"some_arg": some_arg, "q": q} endpoint = partial(main, "foo") +obj_endpoint = partial(CallableObjectEndpoint(), "foo") +async_obj_endpoint = partial(AsyncCallableObjectEndpoint(), "foo") app = FastAPI() app.get("/")(endpoint) +app.get("/obj")(obj_endpoint) +app.get("/async_obj")(async_obj_endpoint) client = TestClient(app) @@ -23,3 +37,11 @@ def test_partial(): response = client.get("/?q=bar") data = response.json() assert data == {"some_arg": "foo", "q": "bar"} + + response = client.get("/obj?q=bar") + data = response.json() + assert data == {"some_arg": "foo", "q": "bar"} + + response = client.get("/async_obj?q=bar") + data = response.json() + assert data == {"some_arg": "foo", "q": "bar"}