|
@ -2,6 +2,12 @@ import binascii |
|
|
from base64 import b64decode |
|
|
from base64 import b64decode |
|
|
from typing import Optional |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
|
|
|
from fastapi.security.base import SecurityBase |
|
|
|
|
|
from fastapi import HTTPException, status, Depends |
|
|
|
|
|
from fastapi.security.utils import get_authorization_scheme_param |
|
|
|
|
|
from starlette.requests import Request |
|
|
|
|
|
from starlette.status import HTTP_401_UNAUTHORIZED |
|
|
|
|
|
from typing import Optional, Tuple |
|
|
from fastapi.exceptions import HTTPException |
|
|
from fastapi.exceptions import HTTPException |
|
|
from fastapi.openapi.models import HTTPBase as HTTPBaseModel |
|
|
from fastapi.openapi.models import HTTPBase as HTTPBaseModel |
|
|
from fastapi.openapi.models import HTTPBearer as HTTPBearerModel |
|
|
from fastapi.openapi.models import HTTPBearer as HTTPBearerModel |
|
@ -64,34 +70,49 @@ class HTTPAuthorizationCredentials(BaseModel): |
|
|
""" |
|
|
""" |
|
|
), |
|
|
), |
|
|
] |
|
|
] |
|
|
|
|
|
class HTTPBasicCredentials(BaseModel): |
|
|
|
|
|
username: str |
|
|
|
|
|
password: str |
|
|
|
|
|
|
|
|
|
|
|
class HTTPBasic(SecurityBase): |
|
|
class HTTPBase(SecurityBase): |
|
|
|
|
|
def __init__( |
|
|
def __init__( |
|
|
self, |
|
|
self, |
|
|
*, |
|
|
*, |
|
|
scheme: str, |
|
|
|
|
|
scheme_name: Optional[str] = None, |
|
|
scheme_name: Optional[str] = None, |
|
|
description: Optional[str] = None, |
|
|
realm: str, # REQUIRED parameter (RFC 7617 compliance) |
|
|
auto_error: bool = True, |
|
|
auto_error: bool = True |
|
|
): |
|
|
): |
|
|
self.model = HTTPBaseModel(scheme=scheme, description=description) |
|
|
|
|
|
self.scheme_name = scheme_name or self.__class__.__name__ |
|
|
self.scheme_name = scheme_name or self.__class__.__name__ |
|
|
|
|
|
self.realm = realm |
|
|
self.auto_error = auto_error |
|
|
self.auto_error = auto_error |
|
|
|
|
|
|
|
|
async def __call__( |
|
|
async def __call__(self, request: Request) -> Optional[HTTPBasicCredentials]: |
|
|
self, request: Request |
|
|
authorization_header = request.headers.get("Authorization") |
|
|
) -> Optional[HTTPAuthorizationCredentials]: |
|
|
scheme, credentials = get_authorization_scheme_param(authorization_header) |
|
|
authorization = request.headers.get("Authorization") |
|
|
|
|
|
scheme, credentials = get_authorization_scheme_param(authorization) |
|
|
# Handle missing/invalid scheme |
|
|
if not (authorization and scheme and credentials): |
|
|
if not authorization_header or scheme.lower() != "basic": |
|
|
if self.auto_error: |
|
|
if self.auto_error: |
|
|
raise HTTPException( |
|
|
raise HTTPException( |
|
|
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" |
|
|
status_code=HTTP_401_UNAUTHORIZED, |
|
|
|
|
|
headers={"WWW-Authenticate": f'Basic realm="{self.realm}"'}, |
|
|
|
|
|
detail="Invalid authentication credentials", |
|
|
) |
|
|
) |
|
|
else: |
|
|
return None |
|
|
return None |
|
|
|
|
|
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) |
|
|
# Decode credentials (base64) |
|
|
|
|
|
try: |
|
|
|
|
|
decoded = base64.b64decode(credentials).decode("utf-8") |
|
|
|
|
|
username, _, password = decoded.partition(":") |
|
|
|
|
|
return HTTPBasicCredentials(username=username, password=password) |
|
|
|
|
|
except (ValueError, UnicodeDecodeError): |
|
|
|
|
|
if self.auto_error: |
|
|
|
|
|
raise HTTPException( |
|
|
|
|
|
status_code=HTTP_401_UNAUTHORIZED, |
|
|
|
|
|
headers={"WWW-Authenticate": f'Basic realm="{self.realm}"'}, |
|
|
|
|
|
detail="Invalid authentication credentials", |
|
|
|
|
|
) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HTTPBasic(HTTPBase): |
|
|
class HTTPBasic(HTTPBase): |
|
|