From f644f72306efb0cf18369fd065811a59d3437f3a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 May 2024 10:22:21 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20format?= =?UTF-8?q?=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_endpoint_decorator.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/test_endpoint_decorator.py b/tests/test_endpoint_decorator.py index c420bb486..33f03ec54 100644 --- a/tests/test_endpoint_decorator.py +++ b/tests/test_endpoint_decorator.py @@ -1,11 +1,13 @@ -from typing import Any, Callable from functools import update_wrapper +from typing import Any, Callable + from fastapi import Depends, FastAPI -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.routing import APIRoute +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.testclient import TestClient from starlette.exceptions import HTTPException + class EndpointWrapper(Callable[..., Any]): def __init__(self, endpoint: Callable[..., Any]): self.endpoint = endpoint @@ -14,39 +16,45 @@ class EndpointWrapper(Callable[..., Any]): async def __call__(self, *args, **kwargs): return await self.endpoint(*args, **kwargs) - + + def dummy_secruity_check(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): if token.credentials != "fake-token": raise HTTPException(status_code=401, detail="Unauthorized") + def protect(endpoint: Callable[..., Any]): if not isinstance(endpoint, EndpointWrapper): endpoint = EndpointWrapper(endpoint) endpoint.protected = True return endpoint + class CustomAPIRoute(APIRoute): - def __init__(self, path: str, endpoint: Callable[..., Any], dependencies=None, **kwargs) -> None: + def __init__( + self, path: str, endpoint: Callable[..., Any], dependencies=None, **kwargs + ) -> None: if dependencies is None: dependencies = [] - if ( - isinstance(endpoint, EndpointWrapper) - and endpoint.protected - ): + if isinstance(endpoint, EndpointWrapper) and endpoint.protected: dependencies.append(Depends(dummy_secruity_check)) super().__init__(path, endpoint, dependencies=dependencies, **kwargs) + app = FastAPI() app.router.route_class = CustomAPIRoute + @app.get("/protected") @protect async def protected_route(): return {"message": "This is a protected route"} + client = TestClient(app) + def test_protected_route(): response = client.get("/protected") assert response.status_code == 403