Browse Source

Add WebSocket handling support for HTTP security dependencies

pull/10147/head
Mix 2 years ago
parent
commit
cad7564c69
  1. 12
      fastapi/security/api_key.py
  2. 18
      fastapi/security/http.py
  3. 13
      fastapi/security/oauth2.py
  4. 6
      fastapi/security/open_id_connect_url.py
  5. 29
      fastapi/security/utils.py

12
fastapi/security/api_key.py

@ -2,8 +2,9 @@ from typing import Optional
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase
from fastapi.security.utils import handle_exc_for_ws
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN
@ -28,7 +29,8 @@ class APIKeyQuery(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.query_params.get(self.model.name)
if not api_key:
if self.auto_error:
@ -57,7 +59,8 @@ class APIKeyHeader(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.headers.get(self.model.name)
if not api_key:
if self.auto_error:
@ -86,7 +89,8 @@ class APIKeyCookie(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.cookies.get(self.model.name)
if not api_key:
if self.auto_error:

18
fastapi/security/http.py

@ -6,9 +6,9 @@ from fastapi.exceptions import HTTPException
from fastapi.openapi.models import HTTPBase as HTTPBaseModel
from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from fastapi.security.utils import get_authorization_scheme_param, handle_exc_for_ws
from pydantic import BaseModel
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
@ -35,8 +35,9 @@ class HTTPBase(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
@handle_exc_for_ws
async def __call__(
self, request: Request
self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
@ -64,9 +65,8 @@ class HTTPBasic(HTTPBase):
self.realm = realm
self.auto_error = auto_error
async def __call__( # type: ignore
self, request: Request
) -> Optional[HTTPBasicCredentials]:
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[HTTPBasicCredentials]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if self.realm:
@ -110,8 +110,9 @@ class HTTPBearer(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
@handle_exc_for_ws
async def __call__(
self, request: Request
self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
@ -145,8 +146,9 @@ class HTTPDigest(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
@handle_exc_for_ws
async def __call__(
self, request: Request
self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)

13
fastapi/security/oauth2.py

@ -5,8 +5,8 @@ from fastapi.openapi.models import OAuth2 as OAuth2Model
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request
from fastapi.security.utils import get_authorization_scheme_param, handle_exc_for_ws
from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
# TODO: import from typing when deprecating Python 3.9
@ -131,7 +131,8 @@ class OAuth2(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
@ -164,7 +165,8 @@ class OAuth2PasswordBearer(OAuth2):
auto_error=auto_error,
)
async def __call__(self, request: Request) -> Optional[str]:
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
@ -210,7 +212,8 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
auto_error=auto_error,
)
async def __call__(self, request: Request) -> Optional[str]:
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":

6
fastapi/security/open_id_connect_url.py

@ -2,8 +2,9 @@ from typing import Optional
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase
from fastapi.security.utils import handle_exc_for_ws
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN
@ -22,7 +23,8 @@ class OpenIdConnect(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:

29
fastapi/security/utils.py

@ -1,4 +1,10 @@
from typing import Optional, Tuple
from functools import wraps
from typing import Any, Awaitable, Callable, Optional, Tuple, TypeVar
from fastapi.exceptions import HTTPException, WebSocketException
from starlette.requests import HTTPConnection
from starlette.status import WS_1008_POLICY_VIOLATION
from starlette.websockets import WebSocket
def get_authorization_scheme_param(
@ -8,3 +14,24 @@ def get_authorization_scheme_param(
return "", ""
scheme, _, param = authorization_header_value.partition(" ")
return scheme, param
_SecurityDepFunc = TypeVar(
"_SecurityDepFunc", bound=Callable[[Any, HTTPConnection], Awaitable]
)
def handle_exc_for_ws(func: _SecurityDepFunc) -> _SecurityDepFunc:
@wraps(func)
async def wrapper(self, request: HTTPConnection, *args, **kwargs):
try:
return await func(self, request, *args, **kwargs)
except HTTPException as e:
if not isinstance(request, WebSocket):
raise e
await request.accept()
raise WebSocketException(
code=WS_1008_POLICY_VIOLATION, reason=e.detail
) from None
return wrapper # type: ignore

Loading…
Cancel
Save