pythonasyncioapiasyncfastapiframeworkjsonjson-schemaopenapiopenapi3pydanticpython-typespython3redocreststarletteswaggerswagger-uiuvicornweb
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
67 lines
2.0 KiB
67 lines
2.0 KiB
from functools import update_wrapper
|
|
from typing import Any, Callable
|
|
|
|
from fastapi import Depends, FastAPI
|
|
from fastapi.routing import APIRoute
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
from fastapi.testclient import TestClient
|
|
from starlette.exceptions import HTTPException
|
|
|
|
|
|
class EndpointWrapper(Callable[..., Any]):
|
|
def __init__(self, endpoint: Callable[..., Any]):
|
|
self.endpoint = endpoint
|
|
self.protected = False
|
|
update_wrapper(self, endpoint)
|
|
|
|
async def __call__(self, *args, **kwargs):
|
|
return await self.endpoint(*args, **kwargs)
|
|
|
|
|
|
def dummy_secruity_check(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
|
|
if token.credentials != "fake-token":
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
|
|
def protect(endpoint: Callable[..., Any]):
|
|
if not isinstance(endpoint, EndpointWrapper):
|
|
endpoint = EndpointWrapper(endpoint)
|
|
endpoint.protected = True
|
|
return endpoint
|
|
|
|
|
|
class CustomAPIRoute(APIRoute):
|
|
def __init__(
|
|
self, path: str, endpoint: Callable[..., Any], dependencies=None, **kwargs
|
|
) -> None:
|
|
if dependencies is None:
|
|
dependencies = []
|
|
if isinstance(endpoint, EndpointWrapper) and endpoint.protected:
|
|
dependencies.append(Depends(dummy_secruity_check))
|
|
super().__init__(path, endpoint, dependencies=dependencies, **kwargs)
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
app.router.route_class = CustomAPIRoute
|
|
|
|
|
|
@app.get("/protected")
|
|
@protect
|
|
async def protected_route():
|
|
return {"message": "This is a protected route"}
|
|
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
def test_protected_route():
|
|
response = client.get("/protected")
|
|
assert response.status_code == 403
|
|
|
|
response = client.get("/protected", headers={"Authorization": "Bearer some-token"})
|
|
assert response.status_code == 401
|
|
|
|
response = client.get("/protected", headers={"Authorization": "Bearer fake-token"})
|
|
assert response.status_code == 200
|
|
assert response.json() == {"message": "This is a protected route"}
|
|
|