diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py index 8b2c5c080..83b404b80 100644 --- a/fastapi/security/api_key.py +++ b/fastapi/security/api_key.py @@ -2,8 +2,9 @@ from typing import Optional from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.security.base import SecurityBase +from fastapi.security.utils import handle_exc_for_ws from starlette.exceptions import HTTPException -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.status import HTTP_403_FORBIDDEN @@ -28,7 +29,8 @@ class APIKeyQuery(APIKeyBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error - async def __call__(self, request: Request) -> Optional[str]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[str]: api_key = request.query_params.get(self.model.name) if not api_key: if self.auto_error: @@ -57,7 +59,8 @@ class APIKeyHeader(APIKeyBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error - async def __call__(self, request: Request) -> Optional[str]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[str]: api_key = request.headers.get(self.model.name) if not api_key: if self.auto_error: @@ -86,7 +89,8 @@ class APIKeyCookie(APIKeyBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error - async def __call__(self, request: Request) -> Optional[str]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[str]: api_key = request.cookies.get(self.model.name) if not api_key: if self.auto_error: diff --git a/fastapi/security/http.py b/fastapi/security/http.py index 8fc0aafd9..764dfac07 100644 --- a/fastapi/security/http.py +++ b/fastapi/security/http.py @@ -6,9 +6,9 @@ from fastapi.exceptions import HTTPException from fastapi.openapi.models import HTTPBase as HTTPBaseModel from fastapi.openapi.models import HTTPBearer as HTTPBearerModel from fastapi.security.base import SecurityBase -from fastapi.security.utils import get_authorization_scheme_param +from fastapi.security.utils import get_authorization_scheme_param, handle_exc_for_ws from pydantic import BaseModel -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN @@ -35,8 +35,9 @@ class HTTPBase(SecurityBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error + @handle_exc_for_ws async def __call__( - self, request: Request + self, request: HTTPConnection ) -> Optional[HTTPAuthorizationCredentials]: authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) @@ -64,9 +65,8 @@ class HTTPBasic(HTTPBase): self.realm = realm self.auto_error = auto_error - async def __call__( # type: ignore - self, request: Request - ) -> Optional[HTTPBasicCredentials]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[HTTPBasicCredentials]: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if self.realm: @@ -110,8 +110,9 @@ class HTTPBearer(HTTPBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error + @handle_exc_for_ws async def __call__( - self, request: Request + self, request: HTTPConnection ) -> Optional[HTTPAuthorizationCredentials]: authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) @@ -145,8 +146,9 @@ class HTTPDigest(HTTPBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error + @handle_exc_for_ws async def __call__( - self, request: Request + self, request: HTTPConnection ) -> Optional[HTTPAuthorizationCredentials]: authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py index e4c4357e7..6b4d20dcc 100644 --- a/fastapi/security/oauth2.py +++ b/fastapi/security/oauth2.py @@ -5,8 +5,8 @@ from fastapi.openapi.models import OAuth2 as OAuth2Model from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel from fastapi.param_functions import Form from fastapi.security.base import SecurityBase -from fastapi.security.utils import get_authorization_scheme_param -from starlette.requests import Request +from fastapi.security.utils import get_authorization_scheme_param, handle_exc_for_ws +from starlette.requests import HTTPConnection from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN # TODO: import from typing when deprecating Python 3.9 @@ -131,7 +131,8 @@ class OAuth2(SecurityBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error - async def __call__(self, request: Request) -> Optional[str]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") if not authorization: if self.auto_error: @@ -164,7 +165,8 @@ class OAuth2PasswordBearer(OAuth2): auto_error=auto_error, ) - async def __call__(self, request: Request) -> Optional[str]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "bearer": @@ -210,7 +212,8 @@ class OAuth2AuthorizationCodeBearer(OAuth2): auto_error=auto_error, ) - async def __call__(self, request: Request) -> Optional[str]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "bearer": diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py index 4e65f1f6c..390e99749 100644 --- a/fastapi/security/open_id_connect_url.py +++ b/fastapi/security/open_id_connect_url.py @@ -2,8 +2,9 @@ from typing import Optional from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.security.base import SecurityBase +from fastapi.security.utils import handle_exc_for_ws from starlette.exceptions import HTTPException -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.status import HTTP_403_FORBIDDEN @@ -22,7 +23,8 @@ class OpenIdConnect(SecurityBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error - async def __call__(self, request: Request) -> Optional[str]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") if not authorization: if self.auto_error: diff --git a/fastapi/security/utils.py b/fastapi/security/utils.py index fa7a450b7..2a0849303 100644 --- a/fastapi/security/utils.py +++ b/fastapi/security/utils.py @@ -1,4 +1,10 @@ -from typing import Optional, Tuple +from functools import wraps +from typing import Any, Awaitable, Callable, Optional, Tuple, TypeVar + +from fastapi.exceptions import HTTPException, WebSocketException +from starlette.requests import HTTPConnection +from starlette.status import WS_1008_POLICY_VIOLATION +from starlette.websockets import WebSocket def get_authorization_scheme_param( @@ -8,3 +14,24 @@ def get_authorization_scheme_param( return "", "" scheme, _, param = authorization_header_value.partition(" ") return scheme, param + + +_SecurityDepFunc = TypeVar( + "_SecurityDepFunc", bound=Callable[[Any, HTTPConnection], Awaitable] +) + + +def handle_exc_for_ws(func: _SecurityDepFunc) -> _SecurityDepFunc: + @wraps(func) + async def wrapper(self, request: HTTPConnection, *args, **kwargs): + try: + return await func(self, request, *args, **kwargs) + except HTTPException as e: + if not isinstance(request, WebSocket): + raise e + await request.accept() + raise WebSocketException( + code=WS_1008_POLICY_VIOLATION, reason=e.detail + ) from None + + return wrapper # type: ignore