|
|
@ -3,7 +3,9 @@ from typing import List, Optional |
|
|
|
from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel |
|
|
|
from fastapi.security.base import SecurityBase |
|
|
|
from pydantic import BaseModel, Schema |
|
|
|
from starlette.exceptions import HTTPException |
|
|
|
from starlette.requests import Request |
|
|
|
from starlette.status import HTTP_403_FORBIDDEN |
|
|
|
|
|
|
|
|
|
|
|
class OAuth2PasswordRequestData(BaseModel): |
|
|
@ -45,3 +47,20 @@ class OAuth2(SecurityBase): |
|
|
|
|
|
|
|
async def __call__(self, request: Request) -> str: |
|
|
|
return request.headers.get("Authorization") |
|
|
|
|
|
|
|
|
|
|
|
class OAuth2PasswordBearer(OAuth2): |
|
|
|
def __init__(self, tokenUrl: str, scheme_name: str = None, scopes: dict = None): |
|
|
|
if not scopes: |
|
|
|
scopes = {} |
|
|
|
flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes}) |
|
|
|
super().__init__(flows=flows, scheme_name=scheme_name) |
|
|
|
|
|
|
|
async def __call__(self, request: Request) -> str: |
|
|
|
authorization: str = request.headers.get("Authorization") |
|
|
|
if not authorization or "Bearer " not in authorization: |
|
|
|
raise HTTPException( |
|
|
|
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" |
|
|
|
) |
|
|
|
token = authorization.replace("Bearer ", "") |
|
|
|
return token |
|
|
|