Mix 1 month ago
committed by GitHub
parent
commit
9a1613ddd8
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 14
      fastapi/security/api_key.py
  2. 24
      fastapi/security/http.py
  3. 14
      fastapi/security/oauth2.py
  4. 6
      fastapi/security/open_id_connect_url.py
  5. 29
      tests/test_security_http_base_optional.py

14
fastapi/security/api_key.py

@ -4,7 +4,7 @@ from annotated_doc import Doc
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
@ -139,8 +139,8 @@ class APIKeyQuery(APIKeyBase):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, conn: HTTPConnection) -> str | None:
api_key = request.query_params.get(self.model.name) api_key = conn.query_params.get(self.model.name)
return self.check_api_key(api_key) return self.check_api_key(api_key)
@ -227,8 +227,8 @@ class APIKeyHeader(APIKeyBase):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, conn: HTTPConnection) -> str | None:
api_key = request.headers.get(self.model.name) api_key = conn.headers.get(self.model.name)
return self.check_api_key(api_key) return self.check_api_key(api_key)
@ -315,6 +315,6 @@ class APIKeyCookie(APIKeyBase):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, conn: HTTPConnection) -> str | None:
api_key = request.cookies.get(self.model.name) api_key = conn.cookies.get(self.model.name)
return self.check_api_key(api_key) return self.check_api_key(api_key)

24
fastapi/security/http.py

@ -9,7 +9,7 @@ from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel from pydantic import BaseModel
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
@ -91,8 +91,10 @@ class HTTPBase(SecurityBase):
headers=self.make_authenticate_headers(), headers=self.make_authenticate_headers(),
) )
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None: async def __call__(
authorization = request.headers.get("Authorization") self, conn: HTTPConnection
) -> HTTPAuthorizationCredentials | None:
authorization = conn.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
if self.auto_error: if self.auto_error:
@ -200,9 +202,9 @@ class HTTPBasic(HTTPBase):
return {"WWW-Authenticate": "Basic"} return {"WWW-Authenticate": "Basic"}
async def __call__( # type: ignore async def __call__( # type: ignore
self, request: Request self, conn: HTTPConnection
) -> HTTPBasicCredentials | None: ) -> HTTPBasicCredentials | None:
authorization = request.headers.get("Authorization") authorization = conn.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "basic": if not authorization or scheme.lower() != "basic":
if self.auto_error: if self.auto_error:
@ -300,8 +302,10 @@ class HTTPBearer(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None: async def __call__(
authorization = request.headers.get("Authorization") self, conn: HTTPConnection
) -> HTTPAuthorizationCredentials | None:
authorization = conn.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
if self.auto_error: if self.auto_error:
@ -401,8 +405,10 @@ class HTTPDigest(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None: async def __call__(
authorization = request.headers.get("Authorization") self, conn: HTTPConnection
) -> HTTPAuthorizationCredentials | None:
authorization = conn.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
if self.auto_error: if self.auto_error:

14
fastapi/security/oauth2.py

@ -7,7 +7,7 @@ from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
@ -420,8 +420,8 @@ class OAuth2(SecurityBase):
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, conn: HTTPConnection) -> str | None:
authorization = request.headers.get("Authorization") authorization = conn.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:
raise self.make_not_authenticated_error() raise self.make_not_authenticated_error()
@ -533,8 +533,8 @@ class OAuth2PasswordBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, conn: HTTPConnection) -> str | None:
authorization = request.headers.get("Authorization") authorization = conn.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":
if self.auto_error: if self.auto_error:
@ -639,8 +639,8 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, conn: HTTPConnection) -> str | None:
authorization = request.headers.get("Authorization") authorization = conn.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":
if self.auto_error: if self.auto_error:

6
fastapi/security/open_id_connect_url.py

@ -4,7 +4,7 @@ from annotated_doc import Doc
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
@ -84,8 +84,8 @@ class OpenIdConnect(SecurityBase):
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, conn: HTTPConnection) -> str | None:
authorization = request.headers.get("Authorization") authorization = conn.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:
raise self.make_not_authenticated_error() raise self.make_not_authenticated_error()

29
tests/test_security_http_base_optional.py

@ -1,4 +1,4 @@
from fastapi import FastAPI, Security from fastapi import FastAPI, Security, WebSocket
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from inline_snapshot import snapshot from inline_snapshot import snapshot
@ -17,6 +17,19 @@ def read_current_user(
return {"scheme": credentials.scheme, "credentials": credentials.credentials} return {"scheme": credentials.scheme, "credentials": credentials.credentials}
@app.websocket("/users/timeline")
async def read_user_timeline(
websocket: WebSocket,
credentials: HTTPAuthorizationCredentials | None = Security(security),
):
await websocket.accept()
await websocket.send_json(
{"scheme": credentials.scheme, "credentials": credentials.credentials}
if credentials
else {"msg": "Create an account first"}
)
client = TestClient(app) client = TestClient(app)
@ -32,6 +45,20 @@ def test_security_http_base_no_credentials():
assert response.json() == {"msg": "Create an account first"} assert response.json() == {"msg": "Create an account first"}
def test_security_http_base_with_ws():
with client.websocket_connect(
"/users/timeline", headers={"Authorization": "Other foobar"}
) as websocket:
data = websocket.receive_json()
assert data == {"scheme": "Other", "credentials": "foobar"}
def test_security_http_base_with_ws_no_credentials():
with client.websocket_connect("/users/timeline") as websocket:
data = websocket.receive_json()
assert data == {"msg": "Create an account first"}
def test_openapi_schema(): def test_openapi_schema():
response = client.get("/openapi.json") response = client.get("/openapi.json")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text

Loading…
Cancel
Save