Browse Source

Add WebSocket handling support for HTTP security dependencies

pull/10147/head
Mix 2 years ago
parent
commit
cad7564c69
  1. 12
      fastapi/security/api_key.py
  2. 18
      fastapi/security/http.py
  3. 13
      fastapi/security/oauth2.py
  4. 6
      fastapi/security/open_id_connect_url.py
  5. 29
      fastapi/security/utils.py

12
fastapi/security/api_key.py

@ -2,8 +2,9 @@ from typing import Optional
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import handle_exc_for_ws
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
@ -28,7 +29,8 @@ class APIKeyQuery(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: @handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.query_params.get(self.model.name) api_key = request.query_params.get(self.model.name)
if not api_key: if not api_key:
if self.auto_error: if self.auto_error:
@ -57,7 +59,8 @@ class APIKeyHeader(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: @handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.headers.get(self.model.name) api_key = request.headers.get(self.model.name)
if not api_key: if not api_key:
if self.auto_error: if self.auto_error:
@ -86,7 +89,8 @@ class APIKeyCookie(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: @handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.cookies.get(self.model.name) api_key = request.cookies.get(self.model.name)
if not api_key: if not api_key:
if self.auto_error: if self.auto_error:

18
fastapi/security/http.py

@ -6,9 +6,9 @@ from fastapi.exceptions import HTTPException
from fastapi.openapi.models import HTTPBase as HTTPBaseModel from fastapi.openapi.models import HTTPBase as HTTPBaseModel
from fastapi.openapi.models import HTTPBearer as HTTPBearerModel from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param, handle_exc_for_ws
from pydantic import BaseModel from pydantic import BaseModel
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
@ -35,8 +35,9 @@ class HTTPBase(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@handle_exc_for_ws
async def __call__( async def __call__(
self, request: Request self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
@ -64,9 +65,8 @@ class HTTPBasic(HTTPBase):
self.realm = realm self.realm = realm
self.auto_error = auto_error self.auto_error = auto_error
async def __call__( # type: ignore @handle_exc_for_ws
self, request: Request async def __call__(self, request: HTTPConnection) -> Optional[HTTPBasicCredentials]:
) -> Optional[HTTPBasicCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if self.realm: if self.realm:
@ -110,8 +110,9 @@ class HTTPBearer(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@handle_exc_for_ws
async def __call__( async def __call__(
self, request: Request self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
@ -145,8 +146,9 @@ class HTTPDigest(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@handle_exc_for_ws
async def __call__( async def __call__(
self, request: Request self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)

13
fastapi/security/oauth2.py

@ -5,8 +5,8 @@ from fastapi.openapi.models import OAuth2 as OAuth2Model
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param, handle_exc_for_ws
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
# TODO: import from typing when deprecating Python 3.9 # TODO: import from typing when deprecating Python 3.9
@ -131,7 +131,8 @@ class OAuth2(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: @handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:
@ -164,7 +165,8 @@ class OAuth2PasswordBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: @handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":
@ -210,7 +212,8 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: @handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":

6
fastapi/security/open_id_connect_url.py

@ -2,8 +2,9 @@ from typing import Optional
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import handle_exc_for_ws
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
@ -22,7 +23,8 @@ class OpenIdConnect(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: @handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:

29
fastapi/security/utils.py

@ -1,4 +1,10 @@
from typing import Optional, Tuple from functools import wraps
from typing import Any, Awaitable, Callable, Optional, Tuple, TypeVar
from fastapi.exceptions import HTTPException, WebSocketException
from starlette.requests import HTTPConnection
from starlette.status import WS_1008_POLICY_VIOLATION
from starlette.websockets import WebSocket
def get_authorization_scheme_param( def get_authorization_scheme_param(
@ -8,3 +14,24 @@ def get_authorization_scheme_param(
return "", "" return "", ""
scheme, _, param = authorization_header_value.partition(" ") scheme, _, param = authorization_header_value.partition(" ")
return scheme, param return scheme, param
_SecurityDepFunc = TypeVar(
"_SecurityDepFunc", bound=Callable[[Any, HTTPConnection], Awaitable]
)
def handle_exc_for_ws(func: _SecurityDepFunc) -> _SecurityDepFunc:
@wraps(func)
async def wrapper(self, request: HTTPConnection, *args, **kwargs):
try:
return await func(self, request, *args, **kwargs)
except HTTPException as e:
if not isinstance(request, WebSocket):
raise e
await request.accept()
raise WebSocketException(
code=WS_1008_POLICY_VIOLATION, reason=e.detail
) from None
return wrapper # type: ignore

Loading…
Cancel
Save