Browse Source

test: add test for fix/allow-callable_get_request_handler

pull/11508/head
reton2 11 months ago
parent
commit
4c5c0f60d2
  1. 59
      tests/test_endpoint_decorator.py

59
tests/test_endpoint_decorator.py

@ -0,0 +1,59 @@
from typing import Any, Callable
from functools import update_wrapper
from fastapi import Depends, FastAPI
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.routing import APIRoute
from fastapi.testclient import TestClient
from starlette.exceptions import HTTPException
class EndpointWrapper(Callable[..., Any]):
def __init__(self, endpoint: Callable[..., Any]):
self.endpoint = endpoint
self.protected = False
update_wrapper(self, endpoint)
async def __call__(self, *args, **kwargs):
return await self.endpoint(*args, **kwargs)
def dummy_secruity_check(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
if token.credentials != "fake-token":
raise HTTPException(status_code=401, detail="Unauthorized")
def protect(endpoint: Callable[..., Any]):
if not isinstance(endpoint, EndpointWrapper):
endpoint = EndpointWrapper(endpoint)
endpoint.protected = True
return endpoint
class CustomAPIRoute(APIRoute):
def __init__(self, path: str, endpoint: Callable[..., Any], dependencies=None, **kwargs) -> None:
if dependencies is None:
dependencies = []
if (
isinstance(endpoint, EndpointWrapper)
and endpoint.protected
):
dependencies.append(Depends(dummy_secruity_check))
super().__init__(path, endpoint, dependencies=dependencies, **kwargs)
app = FastAPI()
app.router.route_class = CustomAPIRoute
@app.get("/protected")
@protect
async def protected_route():
return {"message": "This is a protected route"}
client = TestClient(app)
def test_protected_route():
response = client.get("/protected")
assert response.status_code == 403
response = client.get("/protected", headers={"Authorization": "Bearer some-token"})
assert response.status_code == 401
response = client.get("/protected", headers={"Authorization": "Bearer fake-token"})
assert response.status_code == 200
assert response.json() == {"message": "This is a protected route"}
Loading…
Cancel
Save