|
|
@ -1,11 +1,13 @@ |
|
|
|
from typing import Any, Callable |
|
|
|
from functools import update_wrapper |
|
|
|
from typing import Any, Callable |
|
|
|
|
|
|
|
from fastapi import Depends, FastAPI |
|
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer |
|
|
|
from fastapi.routing import APIRoute |
|
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer |
|
|
|
from fastapi.testclient import TestClient |
|
|
|
from starlette.exceptions import HTTPException |
|
|
|
|
|
|
|
|
|
|
|
class EndpointWrapper(Callable[..., Any]): |
|
|
|
def __init__(self, endpoint: Callable[..., Any]): |
|
|
|
self.endpoint = endpoint |
|
|
@ -14,39 +16,45 @@ class EndpointWrapper(Callable[..., Any]): |
|
|
|
|
|
|
|
async def __call__(self, *args, **kwargs): |
|
|
|
return await self.endpoint(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dummy_secruity_check(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): |
|
|
|
if token.credentials != "fake-token": |
|
|
|
raise HTTPException(status_code=401, detail="Unauthorized") |
|
|
|
|
|
|
|
|
|
|
|
def protect(endpoint: Callable[..., Any]): |
|
|
|
if not isinstance(endpoint, EndpointWrapper): |
|
|
|
endpoint = EndpointWrapper(endpoint) |
|
|
|
endpoint.protected = True |
|
|
|
return endpoint |
|
|
|
|
|
|
|
|
|
|
|
class CustomAPIRoute(APIRoute): |
|
|
|
def __init__(self, path: str, endpoint: Callable[..., Any], dependencies=None, **kwargs) -> None: |
|
|
|
def __init__( |
|
|
|
self, path: str, endpoint: Callable[..., Any], dependencies=None, **kwargs |
|
|
|
) -> None: |
|
|
|
if dependencies is None: |
|
|
|
dependencies = [] |
|
|
|
if ( |
|
|
|
isinstance(endpoint, EndpointWrapper) |
|
|
|
and endpoint.protected |
|
|
|
): |
|
|
|
if isinstance(endpoint, EndpointWrapper) and endpoint.protected: |
|
|
|
dependencies.append(Depends(dummy_secruity_check)) |
|
|
|
super().__init__(path, endpoint, dependencies=dependencies, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
app.router.route_class = CustomAPIRoute |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/protected") |
|
|
|
@protect |
|
|
|
async def protected_route(): |
|
|
|
return {"message": "This is a protected route"} |
|
|
|
|
|
|
|
|
|
|
|
client = TestClient(app) |
|
|
|
|
|
|
|
|
|
|
|
def test_protected_route(): |
|
|
|
response = client.get("/protected") |
|
|
|
assert response.status_code == 403 |
|
|
|