committed by
GitHub
2 changed files with 70 additions and 1 deletions
@ -0,0 +1,65 @@ |
|||||
|
from functools import update_wrapper |
||||
|
from typing import Any, Callable |
||||
|
|
||||
|
from fastapi import Depends, FastAPI |
||||
|
from fastapi.routing import APIRoute |
||||
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer |
||||
|
from fastapi.testclient import TestClient |
||||
|
from starlette.exceptions import HTTPException |
||||
|
|
||||
|
|
||||
|
class EndpointWrapper(Callable[..., Any]): |
||||
|
def __init__(self, endpoint: Callable[..., Any]): |
||||
|
self.endpoint = endpoint |
||||
|
self.protected = False |
||||
|
update_wrapper(self, endpoint) |
||||
|
|
||||
|
async def __call__(self, *args, **kwargs): |
||||
|
return await self.endpoint(*args, **kwargs) |
||||
|
|
||||
|
|
||||
|
def dummy_secruity_check(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): |
||||
|
if token.credentials != "fake-token": |
||||
|
raise HTTPException(status_code=401, detail="Unauthorized") |
||||
|
|
||||
|
|
||||
|
def protect(endpoint: Callable[..., Any]): |
||||
|
if not isinstance(endpoint, EndpointWrapper): |
||||
|
endpoint = EndpointWrapper(endpoint) |
||||
|
endpoint.protected = True |
||||
|
return endpoint |
||||
|
|
||||
|
|
||||
|
class CustomAPIRoute(APIRoute): |
||||
|
def __init__( |
||||
|
self, path: str, endpoint: Callable[..., Any], dependencies, **kwargs |
||||
|
) -> None: |
||||
|
if isinstance(endpoint, EndpointWrapper) and endpoint.protected: |
||||
|
dependencies.append(Depends(dummy_secruity_check)) |
||||
|
super().__init__(path, endpoint, dependencies=dependencies, **kwargs) |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
app.router.route_class = CustomAPIRoute |
||||
|
|
||||
|
|
||||
|
@app.get("/protected") |
||||
|
@protect |
||||
|
async def protected_route(): |
||||
|
return {"message": "This is a protected route"} |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def test_protected_route(): |
||||
|
response = client.get("/protected") |
||||
|
assert response.status_code == 403 |
||||
|
|
||||
|
response = client.get("/protected", headers={"Authorization": "Bearer some-token"}) |
||||
|
assert response.status_code == 401 |
||||
|
|
||||
|
response = client.get("/protected", headers={"Authorization": "Bearer fake-token"}) |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == {"message": "This is a protected route"} |
Loading…
Reference in new issue