|
|
@ -1,4 +1,16 @@ |
|
|
|
from fastapi import APIRouter, Depends, FastAPI, WebSocket |
|
|
|
import functools |
|
|
|
|
|
|
|
import pytest |
|
|
|
from fastapi import ( |
|
|
|
APIRouter, |
|
|
|
Depends, |
|
|
|
FastAPI, |
|
|
|
Header, |
|
|
|
WebSocket, |
|
|
|
WebSocketDisconnect, |
|
|
|
status, |
|
|
|
) |
|
|
|
from fastapi.middleware import Middleware |
|
|
|
from fastapi.testclient import TestClient |
|
|
|
|
|
|
|
router = APIRouter() |
|
|
@ -63,9 +75,44 @@ async def router_native_prefix_ws(websocket: WebSocket): |
|
|
|
await websocket.close() |
|
|
|
|
|
|
|
|
|
|
|
app.include_router(router) |
|
|
|
app.include_router(prefix_router, prefix="/prefix") |
|
|
|
app.include_router(native_prefix_route) |
|
|
|
async def ws_dependency_err(): |
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
|
|
@router.websocket("/depends-err/") |
|
|
|
async def router_ws_depends_err(websocket: WebSocket, data=Depends(ws_dependency_err)): |
|
|
|
pass # pragma: no cover |
|
|
|
|
|
|
|
|
|
|
|
async def ws_dependency_validate(x_missing: str = Header()): |
|
|
|
pass # pragma: no cover |
|
|
|
|
|
|
|
|
|
|
|
@router.websocket("/depends-validate/") |
|
|
|
async def router_ws_depends_validate( |
|
|
|
websocket: WebSocket, data=Depends(ws_dependency_validate) |
|
|
|
): |
|
|
|
pass # pragma: no cover |
|
|
|
|
|
|
|
|
|
|
|
class CustomError(Exception): |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
@router.websocket("/custom_error/") |
|
|
|
async def router_ws_custom_error(websocket: WebSocket): |
|
|
|
raise CustomError() |
|
|
|
|
|
|
|
|
|
|
|
def make_app(app=None, **kwargs): |
|
|
|
app = app or FastAPI(**kwargs) |
|
|
|
app.include_router(router) |
|
|
|
app.include_router(prefix_router, prefix="/prefix") |
|
|
|
app.include_router(native_prefix_route) |
|
|
|
return app |
|
|
|
|
|
|
|
|
|
|
|
app = make_app(app) |
|
|
|
|
|
|
|
|
|
|
|
def test_app(): |
|
|
@ -125,3 +172,100 @@ def test_router_with_params(): |
|
|
|
assert data == "path/to/file" |
|
|
|
data = websocket.receive_text() |
|
|
|
assert data == "a_query_param" |
|
|
|
|
|
|
|
|
|
|
|
def test_wrong_uri(): |
|
|
|
""" |
|
|
|
Verify that a websocket connection to a non-existent endpoing returns in a shutdown |
|
|
|
""" |
|
|
|
client = TestClient(app) |
|
|
|
with pytest.raises(WebSocketDisconnect) as e: |
|
|
|
with client.websocket_connect("/no-router/"): |
|
|
|
pass # pragma: no cover |
|
|
|
assert e.value.code == status.WS_1000_NORMAL_CLOSURE |
|
|
|
|
|
|
|
|
|
|
|
def websocket_middleware(middleware_func): |
|
|
|
""" |
|
|
|
Helper to create a Starlette pure websocket middleware |
|
|
|
""" |
|
|
|
|
|
|
|
def middleware_constructor(app): |
|
|
|
@functools.wraps(app) |
|
|
|
async def wrapped_app(scope, receive, send): |
|
|
|
if scope["type"] != "websocket": |
|
|
|
return await app(scope, receive, send) # pragma: no cover |
|
|
|
|
|
|
|
async def call_next(): |
|
|
|
return await app(scope, receive, send) |
|
|
|
|
|
|
|
websocket = WebSocket(scope, receive=receive, send=send) |
|
|
|
return await middleware_func(websocket, call_next) |
|
|
|
|
|
|
|
return wrapped_app |
|
|
|
|
|
|
|
return middleware_constructor |
|
|
|
|
|
|
|
|
|
|
|
def test_depend_validation(): |
|
|
|
""" |
|
|
|
Verify that a validation in a dependency invokes the correct exception handler |
|
|
|
""" |
|
|
|
caught = [] |
|
|
|
|
|
|
|
@websocket_middleware |
|
|
|
async def catcher(websocket, call_next): |
|
|
|
try: |
|
|
|
return await call_next() |
|
|
|
except Exception as e: # pragma: no cover |
|
|
|
caught.append(e) |
|
|
|
raise |
|
|
|
|
|
|
|
myapp = make_app(middleware=[Middleware(catcher)]) |
|
|
|
|
|
|
|
client = TestClient(myapp) |
|
|
|
with pytest.raises(WebSocketDisconnect) as e: |
|
|
|
with client.websocket_connect("/depends-validate/"): |
|
|
|
pass # pragma: no cover |
|
|
|
# the validation error does produce a close message |
|
|
|
assert e.value.code == status.WS_1008_POLICY_VIOLATION |
|
|
|
# and no error is leaked |
|
|
|
assert caught == [] |
|
|
|
|
|
|
|
|
|
|
|
def test_depend_err_middleware(): |
|
|
|
""" |
|
|
|
Verify that it is possible to write custom WebSocket middleware to catch errors |
|
|
|
""" |
|
|
|
|
|
|
|
@websocket_middleware |
|
|
|
async def errorhandler(websocket: WebSocket, call_next): |
|
|
|
try: |
|
|
|
return await call_next() |
|
|
|
except Exception as e: |
|
|
|
await websocket.close(code=status.WS_1006_ABNORMAL_CLOSURE, reason=repr(e)) |
|
|
|
|
|
|
|
myapp = make_app(middleware=[Middleware(errorhandler)]) |
|
|
|
client = TestClient(myapp) |
|
|
|
with pytest.raises(WebSocketDisconnect) as e: |
|
|
|
with client.websocket_connect("/depends-err/"): |
|
|
|
pass # pragma: no cover |
|
|
|
assert e.value.code == status.WS_1006_ABNORMAL_CLOSURE |
|
|
|
assert "NotImplementedError" in e.value.reason |
|
|
|
|
|
|
|
|
|
|
|
def test_depend_err_handler(): |
|
|
|
""" |
|
|
|
Verify that it is possible to write custom WebSocket middleware to catch errors |
|
|
|
""" |
|
|
|
|
|
|
|
async def custom_handler(websocket: WebSocket, exc: CustomError) -> None: |
|
|
|
await websocket.close(1002, "foo") |
|
|
|
|
|
|
|
myapp = make_app(exception_handlers={CustomError: custom_handler}) |
|
|
|
client = TestClient(myapp) |
|
|
|
with pytest.raises(WebSocketDisconnect) as e: |
|
|
|
with client.websocket_connect("/custom_error/"): |
|
|
|
pass # pragma: no cover |
|
|
|
assert e.value.code == 1002 |
|
|
|
assert "foo" in e.value.reason |
|
|
|