|
|
@ -9,7 +9,15 @@ from typing_extensions import Annotated, Doc |
|
|
|
|
|
|
|
|
|
|
|
class APIKeyBase(SecurityBase): |
|
|
|
pass |
|
|
|
@staticmethod |
|
|
|
def check_api_key(api_key: Optional[str], auto_error: bool) -> Optional[str]: |
|
|
|
if not api_key: |
|
|
|
if auto_error: |
|
|
|
raise HTTPException( |
|
|
|
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" |
|
|
|
) |
|
|
|
return None |
|
|
|
return api_key |
|
|
|
|
|
|
|
|
|
|
|
class APIKeyQuery(APIKeyBase): |
|
|
@ -101,14 +109,7 @@ class APIKeyQuery(APIKeyBase): |
|
|
|
|
|
|
|
async def __call__(self, request: Request) -> Optional[str]: |
|
|
|
api_key = request.query_params.get(self.model.name) |
|
|
|
if not api_key: |
|
|
|
if self.auto_error: |
|
|
|
raise HTTPException( |
|
|
|
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" |
|
|
|
) |
|
|
|
else: |
|
|
|
return None |
|
|
|
return api_key |
|
|
|
return self.check_api_key(api_key, self.auto_error) |
|
|
|
|
|
|
|
|
|
|
|
class APIKeyHeader(APIKeyBase): |
|
|
@ -196,14 +197,7 @@ class APIKeyHeader(APIKeyBase): |
|
|
|
|
|
|
|
async def __call__(self, request: Request) -> Optional[str]: |
|
|
|
api_key = request.headers.get(self.model.name) |
|
|
|
if not api_key: |
|
|
|
if self.auto_error: |
|
|
|
raise HTTPException( |
|
|
|
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" |
|
|
|
) |
|
|
|
else: |
|
|
|
return None |
|
|
|
return api_key |
|
|
|
return self.check_api_key(api_key, self.auto_error) |
|
|
|
|
|
|
|
|
|
|
|
class APIKeyCookie(APIKeyBase): |
|
|
@ -291,11 +285,4 @@ class APIKeyCookie(APIKeyBase): |
|
|
|
|
|
|
|
async def __call__(self, request: Request) -> Optional[str]: |
|
|
|
api_key = request.cookies.get(self.model.name) |
|
|
|
if not api_key: |
|
|
|
if self.auto_error: |
|
|
|
raise HTTPException( |
|
|
|
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" |
|
|
|
) |
|
|
|
else: |
|
|
|
return None |
|
|
|
return api_key |
|
|
|
return self.check_api_key(api_key, self.auto_error) |
|
|
|