@ -1,11 +1,16 @@
import inspect
from collections . abc import Awaitable , Callable
from typing import Any
from starlette . concurrency import run_in_threadpool
from starlette . exceptions import HTTPException
from starlette . requests import Request
from starlette . responses import PlainTextResponse , Response
from starlette . staticfiles import StaticFiles as StaticFiles # noqa
from starlette . types import Receive , Scope , Send
AuthCallable = Callable [ [ Request ] , Awaitable [ Any ] | Any ]
class AuthStaticFiles ( StaticFiles ) :
"""
@ -42,9 +47,10 @@ class AuthStaticFiles(StaticFiles):
## Parameters
* ` auth ` : An async callable that takes a ` Request ` object and performs
authentication . It should raise an ` HTTPException ` if authentication
fails , or return ` None ` if authentication succeeds .
* ` auth ` : A sync or async callable that takes a ` Request ` object and
performs authentication . It should raise an ` HTTPException ` if
authentication fails , or return ` None ` if authentication succeeds .
Sync callables are automatically run in a threadpool .
* ` on_error ` : An optional callable that takes a ` Request ` and an
` HTTPException ` and returns a ` Response ` . Use this to customize
error responses ( e . g . , redirect to login , return HTML instead of
@ -73,8 +79,8 @@ class AuthStaticFiles(StaticFiles):
html : bool = False ,
check_dir : bool = True ,
follow_symlink : bool = False ,
auth : Callable [ [ Request ] , Awaitable [ Any ] ] ,
on_error : Callable [ [ Request , Any ] , Awaitable [ Response ] ] | None = None ,
auth : AuthCallable ,
on_error : Callable [ [ Request , HTTPException ] , Awaitable [ Response ] ] | None = None ,
) - > None :
super ( ) . __init__ (
directory = directory ,
@ -84,26 +90,26 @@ class AuthStaticFiles(StaticFiles):
follow_symlink = follow_symlink ,
)
self . auth = auth
self . _auth_is_async = inspect . iscoroutinefunction ( auth )
self . on_error = on_error
async def __call__ ( self , scope : Scope , receive : Receive , send : Send ) - > None :
if scope [ " type " ] == " http " :
request = Request ( scope , receive )
try :
await self . auth ( request )
except Exception as exc :
from fastapi . exceptions import HTTPException
if isinstance ( exc , HTTPException ) :
if self . on_error is not None :
response = await self . on_error ( request , exc )
else :
response = PlainTextResponse (
str ( exc . detail ) ,
status_code = exc . status_code ,
headers = getattr ( exc , " headers " , None ) ,
)
await response ( scope , receive , send )
return
raise
if self . _auth_is_async :
await self . auth ( request )
else :
await run_in_threadpool ( self . auth , request )
except HTTPException as exc :
if self . on_error is not None :
response = await self . on_error ( request , exc )
else :
response = PlainTextResponse (
str ( exc . detail ) ,
status_code = exc . status_code ,
headers = getattr ( exc , " headers " , None ) ,
)
await response ( scope , receive , send )
return
await super ( ) . __call__ ( scope , receive , send )