Browse Source

Remove handle_exc_for_ws function from security utils

pull/10147/head
Mix 1 year ago
parent
commit
7edb2d9953
  1. 4
      fastapi/security/api_key.py
  2. 6
      fastapi/security/http.py
  3. 5
      fastapi/security/oauth2.py
  4. 2
      fastapi/security/open_id_connect_url.py
  5. 30
      fastapi/security/utils.py

4
fastapi/security/api_key.py

@ -2,7 +2,6 @@ from typing import Optional
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import handle_exc_for_ws
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
@ -100,7 +99,6 @@ class APIKeyQuery(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.query_params.get(self.model.name) api_key = request.query_params.get(self.model.name)
if not api_key: if not api_key:
@ -196,7 +194,6 @@ class APIKeyHeader(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.headers.get(self.model.name) api_key = request.headers.get(self.model.name)
if not api_key: if not api_key:
@ -292,7 +289,6 @@ class APIKeyCookie(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.cookies.get(self.model.name) api_key = request.cookies.get(self.model.name)
if not api_key: if not api_key:

6
fastapi/security/http.py

@ -6,7 +6,7 @@ from fastapi.exceptions import HTTPException
from fastapi.openapi.models import HTTPBase as HTTPBaseModel from fastapi.openapi.models import HTTPBase as HTTPBaseModel
from fastapi.openapi.models import HTTPBearer as HTTPBearerModel from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param, handle_exc_for_ws from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel from pydantic import BaseModel
from starlette.requests import HTTPConnection from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
@ -79,7 +79,6 @@ class HTTPBase(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@handle_exc_for_ws
async def __call__( async def __call__(
self, request: HTTPConnection self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
@ -185,7 +184,6 @@ class HTTPBasic(HTTPBase):
self.realm = realm self.realm = realm
self.auto_error = auto_error self.auto_error = auto_error
@handle_exc_for_ws
async def __call__( # type: ignore async def __call__( # type: ignore
self, request: HTTPConnection self, request: HTTPConnection
) -> Optional[HTTPBasicCredentials]: ) -> Optional[HTTPBasicCredentials]:
@ -300,7 +298,6 @@ class HTTPBearer(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@handle_exc_for_ws
async def __call__( async def __call__(
self, request: HTTPConnection self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
@ -403,7 +400,6 @@ class HTTPDigest(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@handle_exc_for_ws
async def __call__( async def __call__(
self, request: HTTPConnection self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:

5
fastapi/security/oauth2.py

@ -5,7 +5,7 @@ from fastapi.openapi.models import OAuth2 as OAuth2Model
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param, handle_exc_for_ws from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import HTTPConnection from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
@ -376,7 +376,6 @@ class OAuth2(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:
@ -471,7 +470,6 @@ class OAuth2PasswordBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
@ -582,7 +580,6 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)

2
fastapi/security/open_id_connect_url.py

@ -2,7 +2,6 @@ from typing import Optional
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import handle_exc_for_ws
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
@ -73,7 +72,6 @@ class OpenIdConnect(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:

30
fastapi/security/utils.py

@ -1,10 +1,4 @@
from functools import wraps from typing import Optional, Tuple
from typing import Any, Awaitable, Callable, Optional, Tuple, TypeVar
from fastapi.exceptions import HTTPException, WebSocketException
from starlette.requests import HTTPConnection
from starlette.status import WS_1008_POLICY_VIOLATION
from starlette.websockets import WebSocket
def get_authorization_scheme_param( def get_authorization_scheme_param(
@ -14,25 +8,3 @@ def get_authorization_scheme_param(
return "", "" return "", ""
scheme, _, param = authorization_header_value.partition(" ") scheme, _, param = authorization_header_value.partition(" ")
return scheme, param return scheme, param
_SecurityDepFunc = TypeVar(
"_SecurityDepFunc", bound=Callable[[Any, HTTPConnection], Awaitable[Any]]
)
def handle_exc_for_ws(func: _SecurityDepFunc) -> _SecurityDepFunc:
@wraps(func)
async def wrapper(self: Any, request: HTTPConnection) -> Any:
try:
return await func(self, request)
except HTTPException as e:
if not isinstance(request, WebSocket):
raise e
# close before accepted with result a HTTP 403 so the exception argument is ignored
# ref: https://asgi.readthedocs.io/en/latest/specs/www.html#close-send-event
raise WebSocketException(
code=WS_1008_POLICY_VIOLATION, reason=e.detail
) from None
return wrapper # type: ignore

Loading…
Cancel
Save