|
|
@ -1,4 +1,4 @@ |
|
|
|
from typing import Optional |
|
|
|
from typing import Optional, Union |
|
|
|
|
|
|
|
from fastapi.openapi.models import APIKey, APIKeyIn |
|
|
|
from fastapi.security.base import SecurityBase |
|
|
@ -9,10 +9,27 @@ from typing_extensions import Annotated, Doc |
|
|
|
|
|
|
|
|
|
|
|
class APIKeyBase(SecurityBase): |
|
|
|
@staticmethod |
|
|
|
def check_api_key(api_key: Optional[str], auto_error: bool) -> Optional[str]: |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
location: APIKeyIn, |
|
|
|
name: str, |
|
|
|
description: Union[str, None], |
|
|
|
scheme_name: Union[str, None], |
|
|
|
auto_error: bool, |
|
|
|
): |
|
|
|
self.parameter_location = location.value |
|
|
|
self.parameter_name = name |
|
|
|
self.auto_error = auto_error |
|
|
|
self.model: APIKey = APIKey( |
|
|
|
**{"in": location}, # type: ignore[arg-type] |
|
|
|
name=name, |
|
|
|
description=description, |
|
|
|
) |
|
|
|
self.scheme_name = scheme_name or self.__class__.__name__ |
|
|
|
|
|
|
|
def check_api_key(self, api_key: Optional[str]) -> Optional[str]: |
|
|
|
if not api_key: |
|
|
|
if auto_error: |
|
|
|
if self.auto_error: |
|
|
|
raise HTTPException( |
|
|
|
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" |
|
|
|
) |
|
|
@ -99,17 +116,17 @@ class APIKeyQuery(APIKeyBase): |
|
|
|
), |
|
|
|
] = True, |
|
|
|
): |
|
|
|
self.model: APIKey = APIKey( |
|
|
|
**{"in": APIKeyIn.query}, # type: ignore[arg-type] |
|
|
|
super().__init__( |
|
|
|
location=APIKeyIn.query, |
|
|
|
name=name, |
|
|
|
scheme_name=scheme_name, |
|
|
|
description=description, |
|
|
|
auto_error=auto_error, |
|
|
|
) |
|
|
|
self.scheme_name = scheme_name or self.__class__.__name__ |
|
|
|
self.auto_error = auto_error |
|
|
|
|
|
|
|
async def __call__(self, request: Request) -> Optional[str]: |
|
|
|
api_key = request.query_params.get(self.model.name) |
|
|
|
return self.check_api_key(api_key, self.auto_error) |
|
|
|
return self.check_api_key(api_key) |
|
|
|
|
|
|
|
|
|
|
|
class APIKeyHeader(APIKeyBase): |
|
|
@ -187,17 +204,17 @@ class APIKeyHeader(APIKeyBase): |
|
|
|
), |
|
|
|
] = True, |
|
|
|
): |
|
|
|
self.model: APIKey = APIKey( |
|
|
|
**{"in": APIKeyIn.header}, # type: ignore[arg-type] |
|
|
|
super().__init__( |
|
|
|
location=APIKeyIn.header, |
|
|
|
name=name, |
|
|
|
scheme_name=scheme_name, |
|
|
|
description=description, |
|
|
|
auto_error=auto_error, |
|
|
|
) |
|
|
|
self.scheme_name = scheme_name or self.__class__.__name__ |
|
|
|
self.auto_error = auto_error |
|
|
|
|
|
|
|
async def __call__(self, request: Request) -> Optional[str]: |
|
|
|
api_key = request.headers.get(self.model.name) |
|
|
|
return self.check_api_key(api_key, self.auto_error) |
|
|
|
return self.check_api_key(api_key) |
|
|
|
|
|
|
|
|
|
|
|
class APIKeyCookie(APIKeyBase): |
|
|
@ -275,14 +292,14 @@ class APIKeyCookie(APIKeyBase): |
|
|
|
), |
|
|
|
] = True, |
|
|
|
): |
|
|
|
self.model: APIKey = APIKey( |
|
|
|
**{"in": APIKeyIn.cookie}, # type: ignore[arg-type] |
|
|
|
super().__init__( |
|
|
|
location=APIKeyIn.cookie, |
|
|
|
name=name, |
|
|
|
scheme_name=scheme_name, |
|
|
|
description=description, |
|
|
|
auto_error=auto_error, |
|
|
|
) |
|
|
|
self.scheme_name = scheme_name or self.__class__.__name__ |
|
|
|
self.auto_error = auto_error |
|
|
|
|
|
|
|
async def __call__(self, request: Request) -> Optional[str]: |
|
|
|
api_key = request.cookies.get(self.model.name) |
|
|
|
return self.check_api_key(api_key, self.auto_error) |
|
|
|
return self.check_api_key(api_key) |
|
|
|