pythonasyncioapiasyncfastapiframeworkjsonjson-schemaopenapiopenapi3pydanticpython-typespython3redocreststarletteswaggerswagger-uiuvicornweb
271 lines
7.5 KiB
271 lines
7.5 KiB
import functools
|
|
|
|
import pytest
|
|
from fastapi import (
|
|
APIRouter,
|
|
Depends,
|
|
FastAPI,
|
|
Header,
|
|
WebSocket,
|
|
WebSocketDisconnect,
|
|
status,
|
|
)
|
|
from fastapi.middleware import Middleware
|
|
from fastapi.testclient import TestClient
|
|
|
|
router = APIRouter()
|
|
prefix_router = APIRouter()
|
|
native_prefix_route = APIRouter(prefix="/native")
|
|
app = FastAPI()
|
|
|
|
|
|
@app.websocket_route("/")
|
|
async def index(websocket: WebSocket):
|
|
await websocket.accept()
|
|
await websocket.send_text("Hello, world!")
|
|
await websocket.close()
|
|
|
|
|
|
@router.websocket_route("/router")
|
|
async def routerindex(websocket: WebSocket):
|
|
await websocket.accept()
|
|
await websocket.send_text("Hello, router!")
|
|
await websocket.close()
|
|
|
|
|
|
@prefix_router.websocket_route("/")
|
|
async def routerprefixindex(websocket: WebSocket):
|
|
await websocket.accept()
|
|
await websocket.send_text("Hello, router with prefix!")
|
|
await websocket.close()
|
|
|
|
|
|
@router.websocket("/router2")
|
|
async def routerindex2(websocket: WebSocket):
|
|
await websocket.accept()
|
|
await websocket.send_text("Hello, router!")
|
|
await websocket.close()
|
|
|
|
|
|
@router.websocket("/router/{pathparam:path}")
|
|
async def routerindexparams(websocket: WebSocket, pathparam: str, queryparam: str):
|
|
await websocket.accept()
|
|
await websocket.send_text(pathparam)
|
|
await websocket.send_text(queryparam)
|
|
await websocket.close()
|
|
|
|
|
|
async def ws_dependency():
|
|
return "Socket Dependency"
|
|
|
|
|
|
@router.websocket("/router-ws-depends/")
|
|
async def router_ws_decorator_depends(
|
|
websocket: WebSocket, data=Depends(ws_dependency)
|
|
):
|
|
await websocket.accept()
|
|
await websocket.send_text(data)
|
|
await websocket.close()
|
|
|
|
|
|
@native_prefix_route.websocket("/")
|
|
async def router_native_prefix_ws(websocket: WebSocket):
|
|
await websocket.accept()
|
|
await websocket.send_text("Hello, router with native prefix!")
|
|
await websocket.close()
|
|
|
|
|
|
async def ws_dependency_err():
|
|
raise NotImplementedError()
|
|
|
|
|
|
@router.websocket("/depends-err/")
|
|
async def router_ws_depends_err(websocket: WebSocket, data=Depends(ws_dependency_err)):
|
|
pass # pragma: no cover
|
|
|
|
|
|
async def ws_dependency_validate(x_missing: str = Header()):
|
|
pass # pragma: no cover
|
|
|
|
|
|
@router.websocket("/depends-validate/")
|
|
async def router_ws_depends_validate(
|
|
websocket: WebSocket, data=Depends(ws_dependency_validate)
|
|
):
|
|
pass # pragma: no cover
|
|
|
|
|
|
class CustomError(Exception):
|
|
pass
|
|
|
|
|
|
@router.websocket("/custom_error/")
|
|
async def router_ws_custom_error(websocket: WebSocket):
|
|
raise CustomError()
|
|
|
|
|
|
def make_app(app=None, **kwargs):
|
|
app = app or FastAPI(**kwargs)
|
|
app.include_router(router)
|
|
app.include_router(prefix_router, prefix="/prefix")
|
|
app.include_router(native_prefix_route)
|
|
return app
|
|
|
|
|
|
app = make_app(app)
|
|
|
|
|
|
def test_app():
|
|
client = TestClient(app)
|
|
with client.websocket_connect("/") as websocket:
|
|
data = websocket.receive_text()
|
|
assert data == "Hello, world!"
|
|
|
|
|
|
def test_router():
|
|
client = TestClient(app)
|
|
with client.websocket_connect("/router") as websocket:
|
|
data = websocket.receive_text()
|
|
assert data == "Hello, router!"
|
|
|
|
|
|
def test_prefix_router():
|
|
client = TestClient(app)
|
|
with client.websocket_connect("/prefix/") as websocket:
|
|
data = websocket.receive_text()
|
|
assert data == "Hello, router with prefix!"
|
|
|
|
|
|
def test_native_prefix_router():
|
|
client = TestClient(app)
|
|
with client.websocket_connect("/native/") as websocket:
|
|
data = websocket.receive_text()
|
|
assert data == "Hello, router with native prefix!"
|
|
|
|
|
|
def test_router2():
|
|
client = TestClient(app)
|
|
with client.websocket_connect("/router2") as websocket:
|
|
data = websocket.receive_text()
|
|
assert data == "Hello, router!"
|
|
|
|
|
|
def test_router_ws_depends():
|
|
client = TestClient(app)
|
|
with client.websocket_connect("/router-ws-depends/") as websocket:
|
|
assert websocket.receive_text() == "Socket Dependency"
|
|
|
|
|
|
def test_router_ws_depends_with_override():
|
|
client = TestClient(app)
|
|
app.dependency_overrides[ws_dependency] = lambda: "Override" # noqa: E731
|
|
with client.websocket_connect("/router-ws-depends/") as websocket:
|
|
assert websocket.receive_text() == "Override"
|
|
|
|
|
|
def test_router_with_params():
|
|
client = TestClient(app)
|
|
with client.websocket_connect(
|
|
"/router/path/to/file?queryparam=a_query_param"
|
|
) as websocket:
|
|
data = websocket.receive_text()
|
|
assert data == "path/to/file"
|
|
data = websocket.receive_text()
|
|
assert data == "a_query_param"
|
|
|
|
|
|
def test_wrong_uri():
|
|
"""
|
|
Verify that a websocket connection to a non-existent endpoing returns in a shutdown
|
|
"""
|
|
client = TestClient(app)
|
|
with pytest.raises(WebSocketDisconnect) as e:
|
|
with client.websocket_connect("/no-router/"):
|
|
pass # pragma: no cover
|
|
assert e.value.code == status.WS_1000_NORMAL_CLOSURE
|
|
|
|
|
|
def websocket_middleware(middleware_func):
|
|
"""
|
|
Helper to create a Starlette pure websocket middleware
|
|
"""
|
|
|
|
def middleware_constructor(app):
|
|
@functools.wraps(app)
|
|
async def wrapped_app(scope, receive, send):
|
|
if scope["type"] != "websocket":
|
|
return await app(scope, receive, send) # pragma: no cover
|
|
|
|
async def call_next():
|
|
return await app(scope, receive, send)
|
|
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
return await middleware_func(websocket, call_next)
|
|
|
|
return wrapped_app
|
|
|
|
return middleware_constructor
|
|
|
|
|
|
def test_depend_validation():
|
|
"""
|
|
Verify that a validation in a dependency invokes the correct exception handler
|
|
"""
|
|
caught = []
|
|
|
|
@websocket_middleware
|
|
async def catcher(websocket, call_next):
|
|
try:
|
|
return await call_next()
|
|
except Exception as e: # pragma: no cover
|
|
caught.append(e)
|
|
raise
|
|
|
|
myapp = make_app(middleware=[Middleware(catcher)])
|
|
|
|
client = TestClient(myapp)
|
|
with pytest.raises(WebSocketDisconnect) as e:
|
|
with client.websocket_connect("/depends-validate/"):
|
|
pass # pragma: no cover
|
|
# the validation error does produce a close message
|
|
assert e.value.code == status.WS_1008_POLICY_VIOLATION
|
|
# and no error is leaked
|
|
assert caught == []
|
|
|
|
|
|
def test_depend_err_middleware():
|
|
"""
|
|
Verify that it is possible to write custom WebSocket middleware to catch errors
|
|
"""
|
|
|
|
@websocket_middleware
|
|
async def errorhandler(websocket: WebSocket, call_next):
|
|
try:
|
|
return await call_next()
|
|
except Exception as e:
|
|
await websocket.close(code=status.WS_1006_ABNORMAL_CLOSURE, reason=repr(e))
|
|
|
|
myapp = make_app(middleware=[Middleware(errorhandler)])
|
|
client = TestClient(myapp)
|
|
with pytest.raises(WebSocketDisconnect) as e:
|
|
with client.websocket_connect("/depends-err/"):
|
|
pass # pragma: no cover
|
|
assert e.value.code == status.WS_1006_ABNORMAL_CLOSURE
|
|
assert "NotImplementedError" in e.value.reason
|
|
|
|
|
|
def test_depend_err_handler():
|
|
"""
|
|
Verify that it is possible to write custom WebSocket middleware to catch errors
|
|
"""
|
|
|
|
async def custom_handler(websocket: WebSocket, exc: CustomError) -> None:
|
|
await websocket.close(1002, "foo")
|
|
|
|
myapp = make_app(exception_handlers={CustomError: custom_handler})
|
|
client = TestClient(myapp)
|
|
with pytest.raises(WebSocketDisconnect) as e:
|
|
with client.websocket_connect("/custom_error/"):
|
|
pass # pragma: no cover
|
|
assert e.value.code == 1002
|
|
assert "foo" in e.value.reason
|
|
|