Browse Source

Security dependencies now working with websockets

pull/14561/head
philipmunch 6 months ago
parent
commit
bf7e2b7cab
  1. 8
      fastapi/security/api_key.py
  2. 10
      fastapi/security/http.py
  3. 8
      fastapi/security/oauth2.py
  4. 4
      fastapi/security/open_id_connect_url.py
  5. 42
      tests/test_security_api_key_cookie_websocket.py
  6. 45
      tests/test_security_api_key_header_websocket.py
  7. 43
      tests/test_security_api_key_query_websocket.py
  8. 38
      tests/test_security_http_base_websocket.py
  9. 50
      tests/test_security_http_basic_websocket.py
  10. 46
      tests/test_security_http_bearer_websocket.py
  11. 46
      tests/test_security_http_digest_websocket.py
  12. 45
      tests/test_security_oauth2_authorization_code_bearer_websocket.py
  13. 41
      tests/test_security_oauth2_password_bearer_websocket.py
  14. 53
      tests/test_security_openid_connect_websocket.py

8
fastapi/security/api_key.py

@ -4,7 +4,7 @@ from annotated_doc import Doc
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated from typing_extensions import Annotated
@ -138,7 +138,7 @@ class APIKeyQuery(APIKeyBase):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.query_params.get(self.model.name) api_key = request.query_params.get(self.model.name)
return self.check_api_key(api_key) return self.check_api_key(api_key)
@ -226,7 +226,7 @@ class APIKeyHeader(APIKeyBase):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.headers.get(self.model.name) api_key = request.headers.get(self.model.name)
return self.check_api_key(api_key) return self.check_api_key(api_key)
@ -314,6 +314,6 @@ class APIKeyCookie(APIKeyBase):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.cookies.get(self.model.name) api_key = request.cookies.get(self.model.name)
return self.check_api_key(api_key) return self.check_api_key(api_key)

10
fastapi/security/http.py

@ -9,7 +9,7 @@ from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel from pydantic import BaseModel
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated from typing_extensions import Annotated
@ -93,7 +93,7 @@ class HTTPBase(SecurityBase):
) )
async def __call__( async def __call__(
self, request: Request self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
@ -203,7 +203,7 @@ class HTTPBasic(HTTPBase):
return {"WWW-Authenticate": "Basic"} return {"WWW-Authenticate": "Basic"}
async def __call__( # type: ignore async def __call__( # type: ignore
self, request: Request self, request: HTTPConnection
) -> Optional[HTTPBasicCredentials]: ) -> Optional[HTTPBasicCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
@ -304,7 +304,7 @@ class HTTPBearer(HTTPBase):
self.auto_error = auto_error self.auto_error = auto_error
async def __call__( async def __call__(
self, request: Request self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
@ -407,7 +407,7 @@ class HTTPDigest(HTTPBase):
self.auto_error = auto_error self.auto_error = auto_error
async def __call__( async def __call__(
self, request: Request self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)

8
fastapi/security/oauth2.py

@ -7,7 +7,7 @@ from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
# TODO: import from typing when deprecating Python 3.9 # TODO: import from typing when deprecating Python 3.9
@ -399,7 +399,7 @@ class OAuth2(SecurityBase):
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:
@ -506,7 +506,7 @@ class OAuth2PasswordBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":
@ -612,7 +612,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":

4
fastapi/security/open_id_connect_url.py

@ -4,7 +4,7 @@ from annotated_doc import Doc
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated from typing_extensions import Annotated
@ -85,7 +85,7 @@ class OpenIdConnect(SecurityBase):
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:

42
tests/test_security_api_key_cookie_websocket.py

@ -0,0 +1,42 @@
import pytest
from fastapi import Depends, FastAPI, Security
from fastapi.security import APIKeyCookie
from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
api_key = APIKeyCookie(name="key")
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(api_key)):
user = User(username=oauth_header)
return user
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket, current_user: User = Depends(get_current_user)
):
await websocket.accept()
await websocket.send_text(current_user.username)
def test_security_api_key_ws():
client = TestClient(app, cookies={"key": "secret"})
with client.websocket_connect("/ws/users/me") as websocket:
data = websocket.receive_text()
assert data == "secret"
def test_security_api_key_no_key_ws():
client = TestClient(app)
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass

45
tests/test_security_api_key_header_websocket.py

@ -0,0 +1,45 @@
import pytest
from fastapi import Depends, FastAPI, Security
from fastapi.security import APIKeyHeader
from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
api_key = APIKeyHeader(name="key")
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(api_key)):
user = User(username=oauth_header)
return user
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket, current_user: User = Depends(get_current_user)
):
await websocket.accept()
await websocket.send_text(current_user.username)
client = TestClient(app)
def test_security_api_key_ws():
with client.websocket_connect(
"/ws/users/me", headers={"key": "secret"}
) as websocket:
data = websocket.receive_text()
assert data == "secret"
def test_security_api_key_no_key_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass

43
tests/test_security_api_key_query_websocket.py

@ -0,0 +1,43 @@
import pytest
from fastapi import Depends, FastAPI, Security
from fastapi.security import APIKeyQuery
from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
api_key = APIKeyQuery(name="key")
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(api_key)):
user = User(username=oauth_header)
return user
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket, current_user: User = Depends(get_current_user)
):
await websocket.accept()
await websocket.send_text(current_user.username)
client = TestClient(app)
def test_security_api_key_query_ws():
with client.websocket_connect("/ws/users/me?key=secret") as websocket:
data = websocket.receive_text()
assert data == "secret"
def test_security_api_key_query_no_key_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass

38
tests/test_security_http_base_websocket.py

@ -0,0 +1,38 @@
import pytest
from fastapi import FastAPI, Security
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase
from fastapi.testclient import TestClient
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
security = HTTPBase(scheme="Other")
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket,
credentials: HTTPAuthorizationCredentials = Security(security),
):
await websocket.accept()
await websocket.send_json(
{"scheme": credentials.scheme, "credentials": credentials.credentials}
)
client = TestClient(app)
def test_security_http_base_ws():
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Other foobar"}
) as websocket:
data = websocket.receive_json()
assert data == {"scheme": "Other", "credentials": "foobar"}
def test_security_http_base_no_credentials_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass

50
tests/test_security_http_basic_websocket.py

@ -0,0 +1,50 @@
from base64 import b64encode
import pytest
from fastapi import FastAPI, Security
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.testclient import TestClient
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
security = HTTPBasic(realm="simple")
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket, credentials: HTTPBasicCredentials = Security(security)
):
await websocket.accept()
await websocket.send_json(
{"username": credentials.username, "password": credentials.password}
)
client = TestClient(app)
def test_security_http_basic_ws():
# Build Basic header
payload = b64encode(b"john:secret").decode("ascii")
auth_header = f"Basic {payload}"
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": auth_header}
) as websocket:
data = websocket.receive_json()
assert data == {"username": "john", "password": "secret"}
def test_security_http_basic_no_credentials_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass
def test_security_http_basic_invalid_credentials_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Basic notbase64"}
):
pass

46
tests/test_security_http_bearer_websocket.py

@ -0,0 +1,46 @@
import pytest
from fastapi import FastAPI, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.testclient import TestClient
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
security = HTTPBearer()
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket,
credentials: HTTPAuthorizationCredentials = Security(security),
):
await websocket.accept()
await websocket.send_json(
{"scheme": credentials.scheme, "credentials": credentials.credentials}
)
client = TestClient(app)
def test_security_http_bearer_ws():
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Bearer foobar"}
) as websocket:
data = websocket.receive_json()
assert data == {"scheme": "Bearer", "credentials": "foobar"}
def test_security_http_bearer_no_credentials_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass
def test_security_http_bearer_incorrect_scheme_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Basic notreally"}
):
pass

46
tests/test_security_http_digest_websocket.py

@ -0,0 +1,46 @@
import pytest
from fastapi import FastAPI, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPDigest
from fastapi.testclient import TestClient
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
security = HTTPDigest()
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket,
credentials: HTTPAuthorizationCredentials = Security(security),
):
await websocket.accept()
await websocket.send_json(
{"scheme": credentials.scheme, "credentials": credentials.credentials}
)
client = TestClient(app)
def test_security_http_digest_ws():
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Digest foobar"}
) as websocket:
data = websocket.receive_json()
assert data == {"scheme": "Digest", "credentials": "foobar"}
def test_security_http_digest_no_credentials_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass
def test_security_http_digest_incorrect_scheme_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Basic notreally"}
):
pass

45
tests/test_security_oauth2_authorization_code_bearer_websocket.py

@ -0,0 +1,45 @@
import pytest
from fastapi import FastAPI, Security
from fastapi.security import OAuth2AuthorizationCodeBearer
from fastapi.testclient import TestClient
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
oauth2_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl="/api/oauth/authorize",
tokenUrl="/api/oauth/token",
scopes={"read": "Read access", "write": "Write access"},
)
@app.websocket("/ws/admin")
async def read_admin(websocket: WebSocket, token: str = Security(oauth2_scheme)):
await websocket.accept()
await websocket.send_text(token)
client = TestClient(app)
def test_security_oauth2_authorization_code_bearer_ws():
with client.websocket_connect(
"/ws/admin", headers={"Authorization": "Bearer faketoken"}
) as websocket:
data = websocket.receive_text()
assert data == "faketoken"
def test_security_oauth2_authorization_code_bearer_no_header_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/admin"):
pass
def test_security_oauth2_authorization_code_bearer_wrong_scheme_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect(
"/ws/admin", headers={"Authorization": "Basic nope"}
):
pass

41
tests/test_security_oauth2_password_bearer_websocket.py

@ -0,0 +1,41 @@
import pytest
from fastapi import FastAPI, Security
from fastapi.security import OAuth2PasswordBearer
from fastapi.testclient import TestClient
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@app.websocket("/ws/token")
async def read_token(websocket: WebSocket, token: str = Security(oauth2_scheme)):
await websocket.accept()
await websocket.send_text(token)
client = TestClient(app)
def test_security_oauth2_password_bearer_ws():
with client.websocket_connect(
"/ws/token", headers={"Authorization": "Bearer faketoken"}
) as websocket:
data = websocket.receive_text()
assert data == "faketoken"
def test_security_oauth2_password_bearer_no_header_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/token"):
pass
def test_security_oauth2_password_bearer_wrong_scheme_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect(
"/ws/token", headers={"Authorization": "Basic nope"}
):
pass

53
tests/test_security_openid_connect_websocket.py

@ -0,0 +1,53 @@
import pytest
from fastapi import Depends, FastAPI, Security
from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.testclient import WebSocketDenialResponse
from starlette.websockets import WebSocket
app = FastAPI()
oid = OpenIdConnect(openIdConnectUrl="/openid")
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(oid)):
user = User(username=oauth_header)
return user
@app.websocket("/ws/users/me")
async def read_current_user(
websocket: WebSocket, current_user: User = Depends(get_current_user)
):
await websocket.accept()
await websocket.send_json({"username": current_user.username})
client = TestClient(app)
def test_security_openid_connect_ws():
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Bearer footokenbar"}
) as websocket:
data = websocket.receive_json()
assert data == {"username": "Bearer footokenbar"}
def test_security_openid_connect_other_header_ws():
with client.websocket_connect(
"/ws/users/me", headers={"Authorization": "Other footokenbar"}
) as websocket:
data = websocket.receive_json()
assert data == {"username": "Other footokenbar"}
def test_security_openid_connect_no_header_ws():
with pytest.raises(WebSocketDenialResponse):
with client.websocket_connect("/ws/users/me"):
pass
Loading…
Cancel
Save