You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

313 lines
8.8 KiB

import functools
import pytest
from fastapi import (
APIRouter,
Depends,
FastAPI,
Header,
WebSocket,
WebSocketDisconnect,
status,
)
from fastapi.middleware import Middleware
from fastapi.testclient import TestClient
router = APIRouter(tags=["base"])
prefix_router = APIRouter(tags=["prefix"])
native_prefix_router = APIRouter(prefix="/native", tags=["native"])
app = FastAPI()
@app.websocket_route("/")
async def index(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Hello, world!")
await websocket.close()
@router.websocket_route("/router")
async def routerindex(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Hello, router!")
await websocket.close()
@prefix_router.websocket_route("/")
async def routerprefixindex(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Hello, router with prefix!")
await websocket.close()
@router.websocket("/router2")
async def routerindex2(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Hello, router!")
await websocket.close()
@router.websocket("/router/{pathparam:path}")
async def routerindexparams(websocket: WebSocket, pathparam: str, queryparam: str):
await websocket.accept()
await websocket.send_text(pathparam)
await websocket.send_text(queryparam)
await websocket.close()
async def ws_dependency():
return "Socket Dependency"
@router.websocket("/router-ws-depends/")
async def router_ws_decorator_depends(
websocket: WebSocket, data=Depends(ws_dependency)
):
await websocket.accept()
await websocket.send_text(data)
await websocket.close()
@native_prefix_router.websocket("/")
async def router_native_prefix_ws(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Hello, router with native prefix!")
await websocket.close()
async def ws_dependency_err():
raise NotImplementedError()
@router.websocket("/depends-err/")
async def router_ws_depends_err(websocket: WebSocket, data=Depends(ws_dependency_err)):
pass # pragma: no cover
async def ws_dependency_validate(x_missing: str = Header()):
pass # pragma: no cover
@router.websocket("/depends-validate/")
async def router_ws_depends_validate(
websocket: WebSocket, data=Depends(ws_dependency_validate)
):
pass # pragma: no cover
class CustomError(Exception):
pass
@router.websocket("/custom_error/")
async def router_ws_custom_error(websocket: WebSocket):
raise CustomError()
@app.websocket("/test_tags", name="test-app-tags", tags=["test-app-tags"])
@router.websocket("/test_tags/", name="test-router-tags", tags=["test-router-tags"])
async def router_ws_test_tags(websocket: WebSocket):
pass # pragma: no cover
@prefix_router.websocket(
"/test_tags/", name="test-prefix-router-tags", tags=["test-prefix-router-tags"]
)
async def prefix_router_ws_test_tags(websocket: WebSocket):
pass # pragma: no cover
@native_prefix_router.websocket(
"/test_tags/",
name="test-native-prefix-router-tags",
tags=["test-native-prefix-router-tags"],
)
async def native_prefix_router_ws_test_tags(websocket: WebSocket):
pass # pragma: no cover
def make_app(app=None, **kwargs):
app = app or FastAPI(**kwargs)
app.include_router(router)
app.include_router(prefix_router, prefix="/prefix")
app.include_router(native_prefix_router)
return app
app = make_app(app)
def test_app():
client = TestClient(app)
with client.websocket_connect("/") as websocket:
data = websocket.receive_text()
assert data == "Hello, world!"
def test_router():
client = TestClient(app)
with client.websocket_connect("/router") as websocket:
data = websocket.receive_text()
assert data == "Hello, router!"
def test_prefix_router():
client = TestClient(app)
with client.websocket_connect("/prefix/") as websocket:
data = websocket.receive_text()
assert data == "Hello, router with prefix!"
def test_native_prefix_router():
client = TestClient(app)
with client.websocket_connect("/native/") as websocket:
data = websocket.receive_text()
assert data == "Hello, router with native prefix!"
def test_router2():
client = TestClient(app)
with client.websocket_connect("/router2") as websocket:
data = websocket.receive_text()
assert data == "Hello, router!"
def test_router_ws_depends():
client = TestClient(app)
with client.websocket_connect("/router-ws-depends/") as websocket:
assert websocket.receive_text() == "Socket Dependency"
def test_router_ws_depends_with_override():
client = TestClient(app)
app.dependency_overrides[ws_dependency] = lambda: "Override" # noqa: E731
with client.websocket_connect("/router-ws-depends/") as websocket:
assert websocket.receive_text() == "Override"
def test_router_with_params():
client = TestClient(app)
with client.websocket_connect(
"/router/path/to/file?queryparam=a_query_param"
) as websocket:
data = websocket.receive_text()
assert data == "path/to/file"
data = websocket.receive_text()
assert data == "a_query_param"
def test_wrong_uri():
"""
Verify that a websocket connection to a non-existent endpoing returns in a shutdown
"""
client = TestClient(app)
with pytest.raises(WebSocketDisconnect) as e:
with client.websocket_connect("/no-router/"):
pass # pragma: no cover
assert e.value.code == status.WS_1000_NORMAL_CLOSURE
def websocket_middleware(middleware_func):
"""
Helper to create a Starlette pure websocket middleware
"""
def middleware_constructor(app):
@functools.wraps(app)
async def wrapped_app(scope, receive, send):
if scope["type"] != "websocket":
return await app(scope, receive, send) # pragma: no cover
async def call_next():
return await app(scope, receive, send)
websocket = WebSocket(scope, receive=receive, send=send)
return await middleware_func(websocket, call_next)
return wrapped_app
return middleware_constructor
def test_depend_validation():
"""
Verify that a validation in a dependency invokes the correct exception handler
"""
caught = []
@websocket_middleware
async def catcher(websocket, call_next):
try:
return await call_next()
except Exception as e: # pragma: no cover
caught.append(e)
raise
myapp = make_app(middleware=[Middleware(catcher)])
client = TestClient(myapp)
with pytest.raises(WebSocketDisconnect) as e:
with client.websocket_connect("/depends-validate/"):
pass # pragma: no cover
# the validation error does produce a close message
assert e.value.code == status.WS_1008_POLICY_VIOLATION
# and no error is leaked
assert caught == []
def test_depend_err_middleware():
"""
Verify that it is possible to write custom WebSocket middleware to catch errors
"""
@websocket_middleware
async def errorhandler(websocket: WebSocket, call_next):
try:
return await call_next()
except Exception as e:
await websocket.close(code=status.WS_1006_ABNORMAL_CLOSURE, reason=repr(e))
myapp = make_app(middleware=[Middleware(errorhandler)])
client = TestClient(myapp)
with pytest.raises(WebSocketDisconnect) as e:
with client.websocket_connect("/depends-err/"):
pass # pragma: no cover
assert e.value.code == status.WS_1006_ABNORMAL_CLOSURE
assert "NotImplementedError" in e.value.reason
def test_depend_err_handler():
"""
Verify that it is possible to write custom WebSocket middleware to catch errors
"""
async def custom_handler(websocket: WebSocket, exc: CustomError) -> None:
await websocket.close(1002, "foo")
myapp = make_app(exception_handlers={CustomError: custom_handler})
client = TestClient(myapp)
with pytest.raises(WebSocketDisconnect) as e:
with client.websocket_connect("/custom_error/"):
pass # pragma: no cover
assert e.value.code == 1002
assert "foo" in e.value.reason
@pytest.mark.parametrize(
"route_name,route_tags",
[
("test-app-tags", ["test-app-tags"]),
("test-router-tags", ["base", "test-router-tags"]),
("test-prefix-router-tags", ["prefix", "test-prefix-router-tags"]),
(
"test-native-prefix-router-tags",
["native", "test-native-prefix-router-tags"],
),
],
)
def test_websocket_tags(route_name, route_tags):
"""
Verify that it is possible to add tags to websocket routes
"""
route = next(route for route in app.routes if route.name == route_name)
assert sorted(route.tags) == sorted(route_tags)