Browse Source

🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

pull/11508/head
pre-commit-ci[bot] 12 months ago
parent
commit
f644f72306
  1. 24
      tests/test_endpoint_decorator.py

24
tests/test_endpoint_decorator.py

@ -1,11 +1,13 @@
from typing import Any, Callable
from functools import update_wrapper from functools import update_wrapper
from typing import Any, Callable
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
class EndpointWrapper(Callable[..., Any]): class EndpointWrapper(Callable[..., Any]):
def __init__(self, endpoint: Callable[..., Any]): def __init__(self, endpoint: Callable[..., Any]):
self.endpoint = endpoint self.endpoint = endpoint
@ -14,39 +16,45 @@ class EndpointWrapper(Callable[..., Any]):
async def __call__(self, *args, **kwargs): async def __call__(self, *args, **kwargs):
return await self.endpoint(*args, **kwargs) return await self.endpoint(*args, **kwargs)
def dummy_secruity_check(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): def dummy_secruity_check(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
if token.credentials != "fake-token": if token.credentials != "fake-token":
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
def protect(endpoint: Callable[..., Any]): def protect(endpoint: Callable[..., Any]):
if not isinstance(endpoint, EndpointWrapper): if not isinstance(endpoint, EndpointWrapper):
endpoint = EndpointWrapper(endpoint) endpoint = EndpointWrapper(endpoint)
endpoint.protected = True endpoint.protected = True
return endpoint return endpoint
class CustomAPIRoute(APIRoute): class CustomAPIRoute(APIRoute):
def __init__(self, path: str, endpoint: Callable[..., Any], dependencies=None, **kwargs) -> None: def __init__(
self, path: str, endpoint: Callable[..., Any], dependencies=None, **kwargs
) -> None:
if dependencies is None: if dependencies is None:
dependencies = [] dependencies = []
if ( if isinstance(endpoint, EndpointWrapper) and endpoint.protected:
isinstance(endpoint, EndpointWrapper)
and endpoint.protected
):
dependencies.append(Depends(dummy_secruity_check)) dependencies.append(Depends(dummy_secruity_check))
super().__init__(path, endpoint, dependencies=dependencies, **kwargs) super().__init__(path, endpoint, dependencies=dependencies, **kwargs)
app = FastAPI() app = FastAPI()
app.router.route_class = CustomAPIRoute app.router.route_class = CustomAPIRoute
@app.get("/protected") @app.get("/protected")
@protect @protect
async def protected_route(): async def protected_route():
return {"message": "This is a protected route"} return {"message": "This is a protected route"}
client = TestClient(app) client = TestClient(app)
def test_protected_route(): def test_protected_route():
response = client.get("/protected") response = client.get("/protected")
assert response.status_code == 403 assert response.status_code == 403

Loading…
Cancel
Save