You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

96 lines
3.9 KiB

import binascii
from base64 import b64decode
from fastapi.openapi.models import (
HTTPBase as HTTPBaseModel,
HTTPBearer as HTTPBearerModel,
)
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN
class HTTPBasicCredentials(BaseModel):
username: str
password: str
class HTTPAuthorizationCredentials(BaseModel):
scheme: str
credentials: str
class HTTPBase(SecurityBase):
def __init__(self, *, scheme: str, scheme_name: str = None):
self.model = HTTPBaseModel(scheme=scheme)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request) -> str:
authorization: str = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
class HTTPBasic(HTTPBase):
def __init__(self, *, scheme_name: str = None, realm: str = None):
self.model = HTTPBaseModel(scheme="basic")
self.scheme_name = scheme_name or self.__class__.__name__
self.realm = realm
async def __call__(self, request: Request) -> str:
authorization: str = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
# before implementing headers with 401 errors, wait for: https://github.com/encode/starlette/issues/295
# unauthorized_headers = {"WWW-Authenticate": "Basic"}
invalid_user_credentials_exc = HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid authentication credentials"
)
if not authorization or scheme.lower() != "basic":
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
try:
data = b64decode(param).decode("ascii")
except (ValueError, UnicodeDecodeError, binascii.Error):
raise invalid_user_credentials_exc
username, separator, password = data.partition(":")
if not (separator):
raise invalid_user_credentials_exc
return HTTPBasicCredentials(username=username, password=password)
class HTTPBearer(HTTPBase):
def __init__(self, *, bearerFormat: str = None, scheme_name: str = None):
self.model = HTTPBearerModel(bearerFormat=bearerFormat)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request) -> str:
authorization: str = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
class HTTPDigest(HTTPBase):
def __init__(self, *, scheme_name: str = None):
self.model = HTTPBaseModel(scheme="digest")
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request) -> str:
authorization: str = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)