From 2c46e890b5f7b17eeb6d863675cc311582f4b416 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=A1=D0=B2=D1=8F=D1=82=D0=BE=D1=81=D0=BB=D0=B0=D0=B2=20?= =?UTF-8?q?=D0=97=D0=B0=D0=B9=D1=86=D0=B5=D0=B2?= Date: Mon, 17 Oct 2022 14:58:05 +0300 Subject: [PATCH] Add supporting for callable object --- fastapi/dependencies/utils.py | 11 +++++++++++ fastapi/routing.py | 4 ++-- tests/test_callable_endpoint.py | 22 ++++++++++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 64a6c1276..cc3a10002 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,4 +1,6 @@ +import asyncio import dataclasses +import functools import inspect from contextlib import contextmanager from copy import deepcopy @@ -261,6 +263,15 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: return typed_signature +def is_coroutine_function(obj: Any) -> bool: + while isinstance(obj, functools.partial): + obj = obj.func + + return asyncio.iscoroutinefunction(obj) or ( + callable(obj) and asyncio.iscoroutinefunction(obj.__call__) + ) + + def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any: annotation = param.annotation if isinstance(annotation, str): diff --git a/fastapi/routing.py b/fastapi/routing.py index 7caf018b5..219d3eec2 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1,4 +1,3 @@ -import asyncio import dataclasses import email.message import inspect @@ -26,6 +25,7 @@ from fastapi.dependencies.utils import ( get_dependant, get_parameterless_sub_dependant, solve_dependencies, + is_coroutine_function, ) from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError @@ -177,7 +177,7 @@ def get_request_handler( dependency_overrides_provider: Optional[Any] = None, ) -> Callable[[Request], Coroutine[Any, Any, Response]]: assert dependant.call is not None, "dependant.call must be a function" - is_coroutine = asyncio.iscoroutinefunction(dependant.call) + is_coroutine = is_coroutine_function(dependant.call) is_body_form = body_field and isinstance(body_field.field_info, params.Form) if isinstance(response_class, DefaultPlaceholder): actual_response_class: Type[Response] = response_class.value diff --git a/tests/test_callable_endpoint.py b/tests/test_callable_endpoint.py index 1882e9053..615f5a616 100644 --- a/tests/test_callable_endpoint.py +++ b/tests/test_callable_endpoint.py @@ -5,15 +5,29 @@ from fastapi import FastAPI from fastapi.testclient import TestClient +class CallableObjectEndpoint: + def __call__(self, some_arg, q: Optional[str] = None): + return {"some_arg": some_arg, "q": q} + + +class AsyncCallableObjectEndpoint: + async def __call__(self, some_arg, q: Optional[str] = None): + return {"some_arg": some_arg, "q": q} + + def main(some_arg, q: Optional[str] = None): return {"some_arg": some_arg, "q": q} endpoint = partial(main, "foo") +obj_endpoint = partial(CallableObjectEndpoint(), "foo") +async_obj_endpoint = partial(AsyncCallableObjectEndpoint(), "foo") app = FastAPI() app.get("/")(endpoint) +app.get("/obj")(obj_endpoint) +app.get("/async_obj")(async_obj_endpoint) client = TestClient(app) @@ -23,3 +37,11 @@ def test_partial(): response = client.get("/?q=bar") data = response.json() assert data == {"some_arg": "foo", "q": "bar"} + + response = client.get("/obj?q=bar") + data = response.json() + assert data == {"some_arg": "foo", "q": "bar"} + + response = client.get("/async_obj?q=bar") + data = response.json() + assert data == {"some_arg": "foo", "q": "bar"}