Browse Source

Add supporting for callable object

pull/5506/head
Святослав Зайцев 3 years ago
parent
commit
2c46e890b5
  1. 11
      fastapi/dependencies/utils.py
  2. 4
      fastapi/routing.py
  3. 22
      tests/test_callable_endpoint.py

11
fastapi/dependencies/utils.py

@ -1,4 +1,6 @@
import asyncio
import dataclasses
import functools
import inspect
from contextlib import contextmanager
from copy import deepcopy
@ -261,6 +263,15 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
return typed_signature
def is_coroutine_function(obj: Any) -> bool:
while isinstance(obj, functools.partial):
obj = obj.func
return asyncio.iscoroutinefunction(obj) or (
callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
)
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
annotation = param.annotation
if isinstance(annotation, str):

4
fastapi/routing.py

@ -1,4 +1,3 @@
import asyncio
import dataclasses
import email.message
import inspect
@ -26,6 +25,7 @@ from fastapi.dependencies.utils import (
get_dependant,
get_parameterless_sub_dependant,
solve_dependencies,
is_coroutine_function,
)
from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
@ -177,7 +177,7 @@ def get_request_handler(
dependency_overrides_provider: Optional[Any] = None,
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
is_coroutine = is_coroutine_function(dependant.call)
is_body_form = body_field and isinstance(body_field.field_info, params.Form)
if isinstance(response_class, DefaultPlaceholder):
actual_response_class: Type[Response] = response_class.value

22
tests/test_callable_endpoint.py

@ -5,15 +5,29 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
class CallableObjectEndpoint:
def __call__(self, some_arg, q: Optional[str] = None):
return {"some_arg": some_arg, "q": q}
class AsyncCallableObjectEndpoint:
async def __call__(self, some_arg, q: Optional[str] = None):
return {"some_arg": some_arg, "q": q}
def main(some_arg, q: Optional[str] = None):
return {"some_arg": some_arg, "q": q}
endpoint = partial(main, "foo")
obj_endpoint = partial(CallableObjectEndpoint(), "foo")
async_obj_endpoint = partial(AsyncCallableObjectEndpoint(), "foo")
app = FastAPI()
app.get("/")(endpoint)
app.get("/obj")(obj_endpoint)
app.get("/async_obj")(async_obj_endpoint)
client = TestClient(app)
@ -23,3 +37,11 @@ def test_partial():
response = client.get("/?q=bar")
data = response.json()
assert data == {"some_arg": "foo", "q": "bar"}
response = client.get("/obj?q=bar")
data = response.json()
assert data == {"some_arg": "foo", "q": "bar"}
response = client.get("/async_obj?q=bar")
data = response.json()
assert data == {"some_arg": "foo", "q": "bar"}

Loading…
Cancel
Save