Rafay Siddiqui 5 days ago
committed by GitHub
parent
commit
4c8debceea
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 6
      fastapi/routing.py
  2. 65
      tests/test_endpoint_decorator.py

6
fastapi/routing.py

@ -230,7 +230,11 @@ def get_request_handler(
embed_body_fields: bool = False,
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
is_coroutine = (
asyncio.iscoroutinefunction(dependant.call)
or callable(dependant.call)
and inspect.iscoroutinefunction(dependant.call.__call__) # type: ignore[operator]
)
is_body_form = body_field and isinstance(body_field.field_info, params.Form)
if isinstance(response_class, DefaultPlaceholder):
actual_response_class: Type[Response] = response_class.value

65
tests/test_endpoint_decorator.py

@ -0,0 +1,65 @@
from functools import update_wrapper
from typing import Any, Callable
from fastapi import Depends, FastAPI
from fastapi.routing import APIRoute
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.testclient import TestClient
from starlette.exceptions import HTTPException
class EndpointWrapper(Callable[..., Any]):
def __init__(self, endpoint: Callable[..., Any]):
self.endpoint = endpoint
self.protected = False
update_wrapper(self, endpoint)
async def __call__(self, *args, **kwargs):
return await self.endpoint(*args, **kwargs)
def dummy_secruity_check(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
if token.credentials != "fake-token":
raise HTTPException(status_code=401, detail="Unauthorized")
def protect(endpoint: Callable[..., Any]):
if not isinstance(endpoint, EndpointWrapper):
endpoint = EndpointWrapper(endpoint)
endpoint.protected = True
return endpoint
class CustomAPIRoute(APIRoute):
def __init__(
self, path: str, endpoint: Callable[..., Any], dependencies, **kwargs
) -> None:
if isinstance(endpoint, EndpointWrapper) and endpoint.protected:
dependencies.append(Depends(dummy_secruity_check))
super().__init__(path, endpoint, dependencies=dependencies, **kwargs)
app = FastAPI()
app.router.route_class = CustomAPIRoute
@app.get("/protected")
@protect
async def protected_route():
return {"message": "This is a protected route"}
client = TestClient(app)
def test_protected_route():
response = client.get("/protected")
assert response.status_code == 403
response = client.get("/protected", headers={"Authorization": "Bearer some-token"})
assert response.status_code == 401
response = client.get("/protected", headers={"Authorization": "Bearer fake-token"})
assert response.status_code == 200
assert response.json() == {"message": "This is a protected route"}
Loading…
Cancel
Save