Browse Source

Add tests for websocket with authorization

pull/10147/head
Mix 2 years ago
parent
commit
6a4ab615c8
  1. 3
      fastapi/security/utils.py
  2. 29
      tests/test_security_http_base.py

3
fastapi/security/utils.py

@ -29,7 +29,8 @@ def handle_exc_for_ws(func: _SecurityDepFunc) -> _SecurityDepFunc:
except HTTPException as e: except HTTPException as e:
if not isinstance(request, WebSocket): if not isinstance(request, WebSocket):
raise e raise e
await request.accept() # close before accepted with result a HTTP 403 so the exception argument is ignored
# ref: https://asgi.readthedocs.io/en/latest/specs/www.html#close-send-event
raise WebSocketException( raise WebSocketException(
code=WS_1008_POLICY_VIOLATION, reason=e.detail code=WS_1008_POLICY_VIOLATION, reason=e.detail
) from None ) from None

29
tests/test_security_http_base.py

@ -1,6 +1,8 @@
from fastapi import FastAPI, Security import pytest
from fastapi import FastAPI, Security, WebSocket
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
app = FastAPI() app = FastAPI()
@ -12,6 +14,16 @@ def read_current_user(credentials: HTTPAuthorizationCredentials = Security(secur
return {"scheme": credentials.scheme, "credentials": credentials.credentials} return {"scheme": credentials.scheme, "credentials": credentials.credentials}
@app.websocket("/users/timeline")
async def read_user_timeline(
websocket: WebSocket, credentials: HTTPAuthorizationCredentials = Security(security)
):
await websocket.accept()
await websocket.send_json(
{"scheme": credentials.scheme, "credentials": credentials.credentials}
)
client = TestClient(app) client = TestClient(app)
@ -27,6 +39,21 @@ def test_security_http_base_no_credentials():
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
def test_security_http_base_with_ws():
with client.websocket_connect(
"/users/timeline", headers={"Authorization": "Other foobar"}
) as websocket:
data = websocket.receive_json()
assert data == {"scheme": "Other", "credentials": "foobar"}
def test_security_http_base_with_ws_no_credentials():
with pytest.raises(WebSocketDisconnect) as e:
with client.websocket_connect("/users/timeline"):
pass
assert e.value.reason == "Not authenticated"
def test_openapi_schema(): def test_openapi_schema():
response = client.get("/openapi.json") response = client.get("/openapi.json")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text

Loading…
Cancel
Save