|
|
@ -4,17 +4,28 @@ from fastapi.openapi.models import APIKey, APIKeyIn |
|
|
|
from fastapi.security.base import SecurityBase |
|
|
|
from starlette.exceptions import HTTPException |
|
|
|
from starlette.requests import Request |
|
|
|
from starlette.status import HTTP_403_FORBIDDEN |
|
|
|
from starlette.status import HTTP_401_UNAUTHORIZED |
|
|
|
from typing_extensions import Annotated, Doc |
|
|
|
|
|
|
|
|
|
|
|
class APIKeyBase(SecurityBase): |
|
|
|
@staticmethod |
|
|
|
def check_api_key(api_key: Optional[str], auto_error: bool) -> Optional[str]: |
|
|
|
def check_api_key( |
|
|
|
api_key: Optional[str], auto_error: bool, key_name: str, key_in: APIKeyIn |
|
|
|
) -> Optional[str]: |
|
|
|
if not api_key: |
|
|
|
if auto_error: |
|
|
|
# Customize header based on where the API key should be |
|
|
|
auth_header = { |
|
|
|
APIKeyIn.query: f'ApiKey name="{key_name}", in="query"', |
|
|
|
APIKeyIn.header: f'ApiKey name="{key_name}", in="header"', |
|
|
|
APIKeyIn.cookie: f'ApiKey name="{key_name}", in="cookie"', |
|
|
|
}.get(key_in, "ApiKey") |
|
|
|
|
|
|
|
raise HTTPException( |
|
|
|
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" |
|
|
|
status_code=HTTP_401_UNAUTHORIZED, |
|
|
|
detail="Not authenticated", |
|
|
|
headers={"WWW-Authenticate": auth_header}, |
|
|
|
) |
|
|
|
return None |
|
|
|
return api_key |
|
|
@ -109,7 +120,9 @@ class APIKeyQuery(APIKeyBase): |
|
|
|
|
|
|
|
async def __call__(self, request: Request) -> Optional[str]: |
|
|
|
api_key = request.query_params.get(self.model.name) |
|
|
|
return self.check_api_key(api_key, self.auto_error) |
|
|
|
return self.check_api_key( |
|
|
|
api_key, self.auto_error, self.model.name, APIKeyIn.query |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class APIKeyHeader(APIKeyBase): |
|
|
@ -197,7 +210,9 @@ class APIKeyHeader(APIKeyBase): |
|
|
|
|
|
|
|
async def __call__(self, request: Request) -> Optional[str]: |
|
|
|
api_key = request.headers.get(self.model.name) |
|
|
|
return self.check_api_key(api_key, self.auto_error) |
|
|
|
return self.check_api_key( |
|
|
|
api_key, self.auto_error, self.model.name, APIKeyIn.header |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class APIKeyCookie(APIKeyBase): |
|
|
@ -285,4 +300,6 @@ class APIKeyCookie(APIKeyBase): |
|
|
|
|
|
|
|
async def __call__(self, request: Request) -> Optional[str]: |
|
|
|
api_key = request.cookies.get(self.model.name) |
|
|
|
return self.check_api_key(api_key, self.auto_error) |
|
|
|
return self.check_api_key( |
|
|
|
api_key, self.auto_error, self.model.name, APIKeyIn.cookie |
|
|
|
) |
|
|
|