pythonasyncioapiasyncfastapiframeworkjsonjson-schemaopenapiopenapi3pydanticpython-typespython3redocreststarletteswaggerswagger-uiuvicornweb
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
138 lines
4.7 KiB
138 lines
4.7 KiB
from typing import Annotated, Any, Dict, Optional
|
|
|
|
import httpx
|
|
from cachetools import TTLCache
|
|
from fastapi import Depends, FastAPI, HTTPException, Security
|
|
from fastapi.security import (
|
|
HTTPAuthorizationCredentials,
|
|
HTTPBearer,
|
|
OpenIdConnect,
|
|
SecurityScopes,
|
|
)
|
|
from jose import JWTError, jwt
|
|
from pydantic import Field
|
|
from pydantic_settings import BaseSettings
|
|
from starlette.requests import Request
|
|
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_403_FORBIDDEN
|
|
|
|
|
|
class AccessTokenCredentials(HTTPAuthorizationCredentials):
|
|
token: Dict[str, Any]
|
|
|
|
|
|
class AccessTokenValidator(HTTPBearer):
|
|
"""Generic HTTPBearer Validator that validates JWT tokens given the JWKS provided at jwks_url."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
jwks_url: str,
|
|
audience: str,
|
|
issuer: str,
|
|
expire_seconds: int = 3600,
|
|
roles_claim: str = "groups",
|
|
scheme_name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
):
|
|
super().__init__(scheme_name=scheme_name, description=description)
|
|
self.uri = jwks_url
|
|
self.audience = audience
|
|
self.issuer = issuer
|
|
self.roles_claim = roles_claim
|
|
self.keyset_cache: TTLCache[str, str] = TTLCache(16, expire_seconds)
|
|
|
|
async def get_jwt_keyset(self) -> str:
|
|
"""Retrieves keyset when expired/not cached yet."""
|
|
result: Optional[str] = self.keyset_cache.get(self.uri)
|
|
if result is None:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(self.uri)
|
|
result = self.keyset_cache[self.uri] = response.text
|
|
return result
|
|
|
|
async def __call__(
|
|
self, request: Request, security_scopes: SecurityScopes
|
|
) -> AccessTokenCredentials: # type: ignore
|
|
"""Validates the JWT Access Token. If security_scopes are given, they are validated against the roles_claim in the Access Token."""
|
|
# 1. Unpack bearer token
|
|
unverified_token = await super().__call__(request)
|
|
if not unverified_token:
|
|
raise HTTPException(HTTP_400_BAD_REQUEST, "Invalid Access Token")
|
|
access_token = unverified_token.credentials
|
|
try:
|
|
# 2. Get keyset from authorization server so that we can validate the JWT Access Token
|
|
keyset = await self.get_jwt_keyset()
|
|
# 3. Perform validation
|
|
verified_token = jwt.decode(
|
|
token=access_token,
|
|
key=keyset,
|
|
audience=self.audience,
|
|
issuer=self.issuer,
|
|
)
|
|
except JWTError:
|
|
raise HTTPException(
|
|
status_code=HTTP_400_BAD_REQUEST,
|
|
detail="Unsupported authorization code",
|
|
) from None
|
|
|
|
# 4. if security scopes are present, validate them
|
|
if security_scopes and security_scopes.scopes:
|
|
# 4.1 the roles_claim must be present in the access token
|
|
scopes = verified_token.get(self.roles_claim)
|
|
if scopes is None:
|
|
raise HTTPException(
|
|
status_code=HTTP_400_BAD_REQUEST, detail="Unsupported Access Token"
|
|
)
|
|
# 4.2 all required roles in the roles_claim must be present
|
|
if not set(security_scopes.scopes).issubset(set(scopes)):
|
|
raise HTTPException(
|
|
status_code=HTTP_403_FORBIDDEN, detail="Not Authorized"
|
|
)
|
|
|
|
return AccessTokenCredentials(
|
|
scheme=self.scheme_name, credentials=access_token, token=verified_token
|
|
)
|
|
|
|
|
|
class Settings(BaseSettings):
|
|
"""Settings wil be read from an .env file"""
|
|
|
|
issuer: str = Field(default=...)
|
|
audience: str = Field(default=...)
|
|
client_id: str = Field(default=...)
|
|
|
|
class Config:
|
|
env_file = ".env"
|
|
|
|
|
|
settings = Settings()
|
|
|
|
# Standard OIDC URLs
|
|
oidc_url = f"{settings.issuer}/.well-known/openid-configuration"
|
|
jwks_url = f"{settings.issuer}/v1/keys"
|
|
|
|
openid_connect = OpenIdConnect(openIdConnectUrl=oidc_url)
|
|
|
|
swagger_ui_init_oauth = {
|
|
"clientId": settings.client_id,
|
|
"scopes": ["openid"], # fill in additional scopes when necessary
|
|
"appName": "Test Application",
|
|
"usePkceWithAuthorizationCodeGrant": True,
|
|
}
|
|
|
|
# The openid_connect security scheme is given as a dependency so that you can authenticate using the swagger UI
|
|
app = FastAPI(
|
|
swagger_ui_init_oauth=swagger_ui_init_oauth, dependencies=[Depends(openid_connect)]
|
|
)
|
|
|
|
# the tokenvalidator is used for all endpoints that need to be authorized
|
|
oauth2 = AccessTokenValidator(
|
|
jwks_url=jwks_url, audience=settings.audience, issuer=settings.issuer
|
|
)
|
|
|
|
|
|
@app.get("/hello")
|
|
async def hello(
|
|
token: Annotated[AccessTokenCredentials, Security(oauth2, scopes=["Foo"])],
|
|
) -> str:
|
|
return "Hi!"
|
|
|