Browse Source

Merge a44b4c67f0 into 460f8d2cc8

pull/14561/merge
P-Muench 13 hours ago
committed by GitHub
parent
commit
720a5054ef
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 8
      fastapi/security/api_key.py
  2. 16
      fastapi/security/http.py
  3. 8
      fastapi/security/oauth2.py
  4. 4
      fastapi/security/open_id_connect_url.py
  5. 42
      tests/test_security_api_key_cookie_websocket.py
  6. 45
      tests/test_security_api_key_header_websocket.py
  7. 43
      tests/test_security_api_key_query_websocket.py
  8. 38
      tests/test_security_http_base_websocket.py
  9. 50
      tests/test_security_http_basic_websocket.py
  10. 46
      tests/test_security_http_bearer_websocket.py
  11. 46
      tests/test_security_http_digest_websocket.py
  12. 45
      tests/test_security_oauth2_authorization_code_bearer_websocket.py
  13. 41
      tests/test_security_oauth2_password_bearer_websocket.py
  14. 53
      tests/test_security_openid_connect_websocket.py

8
fastapi/security/api_key.py

@ -4,7 +4,7 @@ from annotated_doc import Doc
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
@ -139,7 +139,7 @@ class APIKeyQuery(APIKeyBase):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, request: HTTPConnection) -> str | None:
api_key = request.query_params.get(self.model.name) api_key = request.query_params.get(self.model.name)
return self.check_api_key(api_key) return self.check_api_key(api_key)
@ -227,7 +227,7 @@ class APIKeyHeader(APIKeyBase):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, request: HTTPConnection) -> str | None:
api_key = request.headers.get(self.model.name) api_key = request.headers.get(self.model.name)
return self.check_api_key(api_key) return self.check_api_key(api_key)
@ -315,6 +315,6 @@ class APIKeyCookie(APIKeyBase):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, request: HTTPConnection) -> str | None:
api_key = request.cookies.get(self.model.name) api_key = request.cookies.get(self.model.name)
return self.check_api_key(api_key) return self.check_api_key(api_key)

16
fastapi/security/http.py

@ -9,7 +9,7 @@ from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel from pydantic import BaseModel
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
@ -91,7 +91,9 @@ class HTTPBase(SecurityBase):
headers=self.make_authenticate_headers(), headers=self.make_authenticate_headers(),
) )
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None: async def __call__(
self, request: HTTPConnection
) -> HTTPAuthorizationCredentials | None:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
@ -200,7 +202,7 @@ class HTTPBasic(HTTPBase):
return {"WWW-Authenticate": "Basic"} return {"WWW-Authenticate": "Basic"}
async def __call__( # type: ignore async def __call__( # type: ignore
self, request: Request self, request: HTTPConnection
) -> HTTPBasicCredentials | None: ) -> HTTPBasicCredentials | None:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
@ -300,7 +302,9 @@ class HTTPBearer(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None: async def __call__(
self, request: HTTPConnection
) -> HTTPAuthorizationCredentials | None:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
@ -401,7 +405,9 @@ class HTTPDigest(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None: async def __call__(
self, request: HTTPConnection
) -> HTTPAuthorizationCredentials | None:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):

8
fastapi/security/oauth2.py

@ -7,7 +7,7 @@ from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
@ -420,7 +420,7 @@ class OAuth2(SecurityBase):
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, request: HTTPConnection) -> str | None:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:
@ -533,7 +533,7 @@ class OAuth2PasswordBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, request: HTTPConnection) -> str | None:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":
@ -639,7 +639,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, request: HTTPConnection) -> str | None:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":

4
fastapi/security/open_id_connect_url.py

@ -4,7 +4,7 @@ from annotated_doc import Doc
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
@ -84,7 +84,7 @@ class OpenIdConnect(SecurityBase):
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, request: HTTPConnection) -> str | None:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:

42
tests/test_security_api_key_cookie_websocket.py

@ -0,0 +1,42 @@
import pytest
from fastapi import Depends, FastAPI, Security
from fastapi.security import APIKeyCookie
from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
api_key = APIKeyCookie(name="key")
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(api_key)):
user = User(username=oauth_header)
return user
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket, current_user: User = Depends(get_current_user)
):
await websocket.accept()
await websocket.send_text(current_user.username)
def test_security_api_key_ws():
client = TestClient(app, cookies={"key": "secret"})
with client.websocket_connect("/ws/users/me") as websocket:
data = websocket.receive_text()
assert data == "secret"
def test_security_api_key_no_key_ws():
client = TestClient(app)
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass

45
tests/test_security_api_key_header_websocket.py

@ -0,0 +1,45 @@
import pytest
from fastapi import Depends, FastAPI, Security
from fastapi.security import APIKeyHeader
from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
api_key = APIKeyHeader(name="key")
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(api_key)):
user = User(username=oauth_header)
return user
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket, current_user: User = Depends(get_current_user)
):
await websocket.accept()
await websocket.send_text(current_user.username)
client = TestClient(app)
def test_security_api_key_ws():
with client.websocket_connect(
"/ws/users/me", headers={"key": "secret"}
) as websocket:
data = websocket.receive_text()
assert data == "secret"
def test_security_api_key_no_key_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass

43
tests/test_security_api_key_query_websocket.py

@ -0,0 +1,43 @@
import pytest
from fastapi import Depends, FastAPI, Security
from fastapi.security import APIKeyQuery
from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
api_key = APIKeyQuery(name="key")
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(api_key)):
user = User(username=oauth_header)
return user
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket, current_user: User = Depends(get_current_user)
):
await websocket.accept()
await websocket.send_text(current_user.username)
client = TestClient(app)
def test_security_api_key_query_ws():
with client.websocket_connect("/ws/users/me?key=secret") as websocket:
data = websocket.receive_text()
assert data == "secret"
def test_security_api_key_query_no_key_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass

38
tests/test_security_http_base_websocket.py

@ -0,0 +1,38 @@
import pytest
from fastapi import FastAPI, Security
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase
from fastapi.testclient import TestClient
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
security = HTTPBase(scheme="Other")
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket,
credentials: HTTPAuthorizationCredentials = Security(security),
):
await websocket.accept()
await websocket.send_json(
{"scheme": credentials.scheme, "credentials": credentials.credentials}
)
client = TestClient(app)
def test_security_http_base_ws():
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Other foobar"}
) as websocket:
data = websocket.receive_json()
assert data == {"scheme": "Other", "credentials": "foobar"}
def test_security_http_base_no_credentials_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass

50
tests/test_security_http_basic_websocket.py

@ -0,0 +1,50 @@
from base64 import b64encode
import pytest
from fastapi import FastAPI, Security
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.testclient import TestClient
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
security = HTTPBasic(realm="simple")
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket, credentials: HTTPBasicCredentials = Security(security)
):
await websocket.accept()
await websocket.send_json(
{"username": credentials.username, "password": credentials.password}
)
client = TestClient(app)
def test_security_http_basic_ws():
# Build Basic header
payload = b64encode(b"john:secret").decode("ascii")
auth_header = f"Basic {payload}"
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": auth_header}
) as websocket:
data = websocket.receive_json()
assert data == {"username": "john", "password": "secret"}
def test_security_http_basic_no_credentials_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass
def test_security_http_basic_invalid_credentials_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Basic notbase64"}
):
pass

46
tests/test_security_http_bearer_websocket.py

@ -0,0 +1,46 @@
import pytest
from fastapi import FastAPI, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.testclient import TestClient
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
security = HTTPBearer()
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket,
credentials: HTTPAuthorizationCredentials = Security(security),
):
await websocket.accept()
await websocket.send_json(
{"scheme": credentials.scheme, "credentials": credentials.credentials}
)
client = TestClient(app)
def test_security_http_bearer_ws():
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Bearer foobar"}
) as websocket:
data = websocket.receive_json()
assert data == {"scheme": "Bearer", "credentials": "foobar"}
def test_security_http_bearer_no_credentials_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass
def test_security_http_bearer_incorrect_scheme_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Basic notreally"}
):
pass

46
tests/test_security_http_digest_websocket.py

@ -0,0 +1,46 @@
import pytest
from fastapi import FastAPI, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPDigest
from fastapi.testclient import TestClient
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
security = HTTPDigest()
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket,
credentials: HTTPAuthorizationCredentials = Security(security),
):
await websocket.accept()
await websocket.send_json(
{"scheme": credentials.scheme, "credentials": credentials.credentials}
)
client = TestClient(app)
def test_security_http_digest_ws():
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Digest foobar"}
) as websocket:
data = websocket.receive_json()
assert data == {"scheme": "Digest", "credentials": "foobar"}
def test_security_http_digest_no_credentials_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass
def test_security_http_digest_incorrect_scheme_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Basic notreally"}
):
pass

45
tests/test_security_oauth2_authorization_code_bearer_websocket.py

@ -0,0 +1,45 @@
import pytest
from fastapi import FastAPI, Security
from fastapi.security import OAuth2AuthorizationCodeBearer
from fastapi.testclient import TestClient
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
oauth2_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl="/api/oauth/authorize",
tokenUrl="/api/oauth/token",
scopes={"read": "Read access", "write": "Write access"},
)
@app.websocket("/ws/admin")
async def read_admin(websocket: WebSocket, token: str = Security(oauth2_scheme)):
await websocket.accept()
await websocket.send_text(token)
client = TestClient(app)
def test_security_oauth2_authorization_code_bearer_ws():
with client.websocket_connect(
"/ws/admin", headers={"Authorization": "Bearer faketoken"}
) as websocket:
data = websocket.receive_text()
assert data == "faketoken"
def test_security_oauth2_authorization_code_bearer_no_header_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/admin"):
pass
def test_security_oauth2_authorization_code_bearer_wrong_scheme_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect(
"/ws/admin", headers={"Authorization": "Basic nope"}
):
pass

41
tests/test_security_oauth2_password_bearer_websocket.py

@ -0,0 +1,41 @@
import pytest
from fastapi import FastAPI, Security
from fastapi.security import OAuth2PasswordBearer
from fastapi.testclient import TestClient
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@app.websocket("/ws/token")
async def read_token(websocket: WebSocket, token: str = Security(oauth2_scheme)):
await websocket.accept()
await websocket.send_text(token)
client = TestClient(app)
def test_security_oauth2_password_bearer_ws():
with client.websocket_connect(
"/ws/token", headers={"Authorization": "Bearer faketoken"}
) as websocket:
data = websocket.receive_text()
assert data == "faketoken"
def test_security_oauth2_password_bearer_no_header_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/token"):
pass
def test_security_oauth2_password_bearer_wrong_scheme_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect(
"/ws/token", headers={"Authorization": "Basic nope"}
):
pass

53
tests/test_security_openid_connect_websocket.py

@ -0,0 +1,53 @@
import pytest
from fastapi import Depends, FastAPI, Security
from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
oid = OpenIdConnect(openIdConnectUrl="/openid")
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(oid)):
user = User(username=oauth_header)
return user
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket, current_user: User = Depends(get_current_user)
):
await websocket.accept()
await websocket.send_json({"username": current_user.username})
client = TestClient(app)
def test_security_openid_connect_ws():
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Bearer footokenbar"}
) as websocket:
data = websocket.receive_json()
assert data == {"username": "Bearer footokenbar"}
def test_security_openid_connect_other_header_ws():
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Other footokenbar"}
) as websocket:
data = websocket.receive_json()
assert data == {"username": "Other footokenbar"}
def test_security_openid_connect_no_header_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass
Loading…
Cancel
Save