Browse Source

Add tests for websocket with authorization

pull/10147/head
Mix 2 years ago
parent
commit
6a4ab615c8
  1. 3
      fastapi/security/utils.py
  2. 29
      tests/test_security_http_base.py

3
fastapi/security/utils.py

@ -29,7 +29,8 @@ def handle_exc_for_ws(func: _SecurityDepFunc) -> _SecurityDepFunc:
except HTTPException as e:
if not isinstance(request, WebSocket):
raise e
await request.accept()
# close before accepted with result a HTTP 403 so the exception argument is ignored
# ref: https://asgi.readthedocs.io/en/latest/specs/www.html#close-send-event
raise WebSocketException(
code=WS_1008_POLICY_VIOLATION, reason=e.detail
) from None

29
tests/test_security_http_base.py

@ -1,6 +1,8 @@
from fastapi import FastAPI, Security
import pytest
from fastapi import FastAPI, Security, WebSocket
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase
from fastapi.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
app = FastAPI()
@ -12,6 +14,16 @@ def read_current_user(credentials: HTTPAuthorizationCredentials = Security(secur
return {"scheme": credentials.scheme, "credentials": credentials.credentials}
@app.websocket("/users/timeline")
async def read_user_timeline(
websocket: WebSocket, credentials: HTTPAuthorizationCredentials = Security(security)
):
await websocket.accept()
await websocket.send_json(
{"scheme": credentials.scheme, "credentials": credentials.credentials}
)
client = TestClient(app)
@ -27,6 +39,21 @@ def test_security_http_base_no_credentials():
assert response.json() == {"detail": "Not authenticated"}
def test_security_http_base_with_ws():
with client.websocket_connect(
"/users/timeline", headers={"Authorization": "Other foobar"}
) as websocket:
data = websocket.receive_json()
assert data == {"scheme": "Other", "credentials": "foobar"}
def test_security_http_base_with_ws_no_credentials():
with pytest.raises(WebSocketDisconnect) as e:
with client.websocket_connect("/users/timeline"):
pass
assert e.value.reason == "Not authenticated"
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200, response.text

Loading…
Cancel
Save