import binascii from base64 import b64decode from fastapi.openapi.models import ( HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel, ) from fastapi.security.base import SecurityBase from fastapi.security.utils import get_authorization_scheme_param from pydantic import BaseModel from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.status import HTTP_403_FORBIDDEN class HTTPBasicCredentials(BaseModel): username: str password: str class HTTPAuthorizationCredentials(BaseModel): scheme: str credentials: str class HTTPBase(SecurityBase): def __init__(self, *, scheme: str, scheme_name: str = None): self.model = HTTPBaseModel(scheme=scheme) self.scheme_name = scheme_name or self.__class__.__name__ async def __call__(self, request: Request) -> str: authorization: str = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) if not (authorization and scheme and credentials): raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" ) return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) class HTTPBasic(HTTPBase): def __init__(self, *, scheme_name: str = None, realm: str = None): self.model = HTTPBaseModel(scheme="basic") self.scheme_name = scheme_name or self.__class__.__name__ self.realm = realm async def __call__(self, request: Request) -> str: authorization: str = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) # before implementing headers with 401 errors, wait for: https://github.com/encode/starlette/issues/295 # unauthorized_headers = {"WWW-Authenticate": "Basic"} invalid_user_credentials_exc = HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Invalid authentication credentials" ) if not authorization or scheme.lower() != "basic": raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" ) try: data = b64decode(param).decode("ascii") except (ValueError, UnicodeDecodeError, binascii.Error): raise invalid_user_credentials_exc username, separator, password = data.partition(":") if not (separator): raise invalid_user_credentials_exc return HTTPBasicCredentials(username=username, password=password) class HTTPBearer(HTTPBase): def __init__(self, *, bearerFormat: str = None, scheme_name: str = None): self.model = HTTPBearerModel(bearerFormat=bearerFormat) self.scheme_name = scheme_name or self.__class__.__name__ async def __call__(self, request: Request) -> str: authorization: str = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) if not (authorization and scheme and credentials): raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" ) return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) class HTTPDigest(HTTPBase): def __init__(self, *, scheme_name: str = None): self.model = HTTPBaseModel(scheme="digest") self.scheme_name = scheme_name or self.__class__.__name__ async def __call__(self, request: Request) -> str: authorization: str = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) if not (authorization and scheme and credentials): raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" ) return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)