Browse Source

Merge 6846708f00 into 76b324d95b

pull/10147/merge
Mix 2 days ago
committed by GitHub
parent
commit
56238393f5
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 8
      fastapi/security/api_key.py
  2. 10
      fastapi/security/http.py
  3. 8
      fastapi/security/oauth2.py
  4. 4
      fastapi/security/open_id_connect_url.py
  5. 29
      tests/test_security_http_base_optional.py

8
fastapi/security/api_key.py

@ -3,7 +3,7 @@ from typing import Optional
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc from typing_extensions import Annotated, Doc
@ -107,7 +107,7 @@ class APIKeyQuery(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.query_params.get(self.model.name) api_key = request.query_params.get(self.model.name)
return self.check_api_key(api_key, self.auto_error) return self.check_api_key(api_key, self.auto_error)
@ -195,7 +195,7 @@ class APIKeyHeader(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.headers.get(self.model.name) api_key = request.headers.get(self.model.name)
return self.check_api_key(api_key, self.auto_error) return self.check_api_key(api_key, self.auto_error)
@ -283,6 +283,6 @@ class APIKeyCookie(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.cookies.get(self.model.name) api_key = request.cookies.get(self.model.name)
return self.check_api_key(api_key, self.auto_error) return self.check_api_key(api_key, self.auto_error)

10
fastapi/security/http.py

@ -8,7 +8,7 @@ from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel from pydantic import BaseModel
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc from typing_extensions import Annotated, Doc
@ -80,7 +80,7 @@ class HTTPBase(SecurityBase):
self.auto_error = auto_error self.auto_error = auto_error
async def __call__( async def __call__(
self, request: Request self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
@ -185,7 +185,7 @@ class HTTPBasic(HTTPBase):
self.auto_error = auto_error self.auto_error = auto_error
async def __call__( # type: ignore async def __call__( # type: ignore
self, request: Request self, request: HTTPConnection
) -> Optional[HTTPBasicCredentials]: ) -> Optional[HTTPBasicCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
@ -299,7 +299,7 @@ class HTTPBearer(HTTPBase):
self.auto_error = auto_error self.auto_error = auto_error
async def __call__( async def __call__(
self, request: Request self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
@ -401,7 +401,7 @@ class HTTPDigest(HTTPBase):
self.auto_error = auto_error self.auto_error = auto_error
async def __call__( async def __call__(
self, request: Request self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)

8
fastapi/security/oauth2.py

@ -6,7 +6,7 @@ from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
# TODO: import from typing when deprecating Python 3.9 # TODO: import from typing when deprecating Python 3.9
@ -376,7 +376,7 @@ class OAuth2(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:
@ -470,7 +470,7 @@ class OAuth2PasswordBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":
@ -580,7 +580,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":

4
fastapi/security/open_id_connect_url.py

@ -3,7 +3,7 @@ from typing import Optional
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc from typing_extensions import Annotated, Doc
@ -72,7 +72,7 @@ class OpenIdConnect(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:

29
tests/test_security_http_base_optional.py

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from fastapi import FastAPI, Security from fastapi import FastAPI, Security, WebSocket
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -18,6 +18,19 @@ def read_current_user(
return {"scheme": credentials.scheme, "credentials": credentials.credentials} return {"scheme": credentials.scheme, "credentials": credentials.credentials}
@app.websocket("/users/timeline")
async def read_user_timeline(
websocket: WebSocket,
credentials: Optional[HTTPAuthorizationCredentials] = Security(security),
):
await websocket.accept()
await websocket.send_json(
{"scheme": credentials.scheme, "credentials": credentials.credentials}
if credentials
else {"msg": "Create an account first"}
)
client = TestClient(app) client = TestClient(app)
@ -33,6 +46,20 @@ def test_security_http_base_no_credentials():
assert response.json() == {"msg": "Create an account first"} assert response.json() == {"msg": "Create an account first"}
def test_security_http_base_with_ws():
with client.websocket_connect(
"/users/timeline", headers={"Authorization": "Other foobar"}
) as websocket:
data = websocket.receive_json()
assert data == {"scheme": "Other", "credentials": "foobar"}
def test_security_http_base_with_ws_no_credentials():
with client.websocket_connect("/users/timeline") as websocket:
data = websocket.receive_json()
assert data == {"msg": "Create an account first"}
def test_openapi_schema(): def test_openapi_schema():
response = client.get("/openapi.json") response = client.get("/openapi.json")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text

Loading…
Cancel
Save