Browse Source

Add exception handler for `WebSocketRequestValidationError` (which also allows to override it) (#6030)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sebastián Ramírez <[email protected]>
pull/9657/head
Kristján Valur Jónsson 2 years ago
committed by GitHub
parent
commit
ab03f22635
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      fastapi/applications.py
  2. 13
      fastapi/exception_handlers.py
  3. 2
      fastapi/routing.py
  4. 152
      tests/test_ws_router.py

8
fastapi/applications.py

@ -19,8 +19,9 @@ from fastapi.encoders import DictIntStrAny, SetIntStr
from fastapi.exception_handlers import ( from fastapi.exception_handlers import (
http_exception_handler, http_exception_handler,
request_validation_exception_handler, request_validation_exception_handler,
websocket_request_validation_exception_handler,
) )
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.logger import logger from fastapi.logger import logger
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
from fastapi.openapi.docs import ( from fastapi.openapi.docs import (
@ -145,6 +146,11 @@ class FastAPI(Starlette):
self.exception_handlers.setdefault( self.exception_handlers.setdefault(
RequestValidationError, request_validation_exception_handler RequestValidationError, request_validation_exception_handler
) )
self.exception_handlers.setdefault(
WebSocketRequestValidationError,
# Starlette still has incorrect type specification for the handlers
websocket_request_validation_exception_handler, # type: ignore
)
self.user_middleware: List[Middleware] = ( self.user_middleware: List[Middleware] = (
[] if middleware is None else list(middleware) [] if middleware is None else list(middleware)

13
fastapi/exception_handlers.py

@ -1,10 +1,11 @@
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.utils import is_body_allowed_for_status_code from fastapi.utils import is_body_allowed_for_status_code
from fastapi.websockets import WebSocket
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
async def http_exception_handler(request: Request, exc: HTTPException) -> Response: async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
@ -23,3 +24,11 @@ async def request_validation_exception_handler(
status_code=HTTP_422_UNPROCESSABLE_ENTITY, status_code=HTTP_422_UNPROCESSABLE_ENTITY,
content={"detail": jsonable_encoder(exc.errors())}, content={"detail": jsonable_encoder(exc.errors())},
) )
async def websocket_request_validation_exception_handler(
websocket: WebSocket, exc: WebSocketRequestValidationError
) -> None:
await websocket.close(
code=WS_1008_POLICY_VIOLATION, reason=jsonable_encoder(exc.errors())
)

2
fastapi/routing.py

@ -56,7 +56,6 @@ from starlette.routing import (
request_response, request_response,
websocket_session, websocket_session,
) )
from starlette.status import WS_1008_POLICY_VIOLATION
from starlette.types import ASGIApp, Lifespan, Scope from starlette.types import ASGIApp, Lifespan, Scope
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
@ -283,7 +282,6 @@ def get_websocket_app(
) )
values, errors, _, _2, _3 = solved_result values, errors, _, _2, _3 = solved_result
if errors: if errors:
await websocket.close(code=WS_1008_POLICY_VIOLATION)
raise WebSocketRequestValidationError(errors) raise WebSocketRequestValidationError(errors)
assert dependant.call is not None, "dependant.call must be a function" assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**values) await dependant.call(**values)

152
tests/test_ws_router.py

@ -1,4 +1,16 @@
from fastapi import APIRouter, Depends, FastAPI, WebSocket import functools
import pytest
from fastapi import (
APIRouter,
Depends,
FastAPI,
Header,
WebSocket,
WebSocketDisconnect,
status,
)
from fastapi.middleware import Middleware
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
router = APIRouter() router = APIRouter()
@ -63,9 +75,44 @@ async def router_native_prefix_ws(websocket: WebSocket):
await websocket.close() await websocket.close()
app.include_router(router) async def ws_dependency_err():
app.include_router(prefix_router, prefix="/prefix") raise NotImplementedError()
app.include_router(native_prefix_route)
@router.websocket("/depends-err/")
async def router_ws_depends_err(websocket: WebSocket, data=Depends(ws_dependency_err)):
pass # pragma: no cover
async def ws_dependency_validate(x_missing: str = Header()):
pass # pragma: no cover
@router.websocket("/depends-validate/")
async def router_ws_depends_validate(
websocket: WebSocket, data=Depends(ws_dependency_validate)
):
pass # pragma: no cover
class CustomError(Exception):
pass
@router.websocket("/custom_error/")
async def router_ws_custom_error(websocket: WebSocket):
raise CustomError()
def make_app(app=None, **kwargs):
app = app or FastAPI(**kwargs)
app.include_router(router)
app.include_router(prefix_router, prefix="/prefix")
app.include_router(native_prefix_route)
return app
app = make_app(app)
def test_app(): def test_app():
@ -125,3 +172,100 @@ def test_router_with_params():
assert data == "path/to/file" assert data == "path/to/file"
data = websocket.receive_text() data = websocket.receive_text()
assert data == "a_query_param" assert data == "a_query_param"
def test_wrong_uri():
"""
Verify that a websocket connection to a non-existent endpoing returns in a shutdown
"""
client = TestClient(app)
with pytest.raises(WebSocketDisconnect) as e:
with client.websocket_connect("/no-router/"):
pass # pragma: no cover
assert e.value.code == status.WS_1000_NORMAL_CLOSURE
def websocket_middleware(middleware_func):
"""
Helper to create a Starlette pure websocket middleware
"""
def middleware_constructor(app):
@functools.wraps(app)
async def wrapped_app(scope, receive, send):
if scope["type"] != "websocket":
return await app(scope, receive, send) # pragma: no cover
async def call_next():
return await app(scope, receive, send)
websocket = WebSocket(scope, receive=receive, send=send)
return await middleware_func(websocket, call_next)
return wrapped_app
return middleware_constructor
def test_depend_validation():
"""
Verify that a validation in a dependency invokes the correct exception handler
"""
caught = []
@websocket_middleware
async def catcher(websocket, call_next):
try:
return await call_next()
except Exception as e: # pragma: no cover
caught.append(e)
raise
myapp = make_app(middleware=[Middleware(catcher)])
client = TestClient(myapp)
with pytest.raises(WebSocketDisconnect) as e:
with client.websocket_connect("/depends-validate/"):
pass # pragma: no cover
# the validation error does produce a close message
assert e.value.code == status.WS_1008_POLICY_VIOLATION
# and no error is leaked
assert caught == []
def test_depend_err_middleware():
"""
Verify that it is possible to write custom WebSocket middleware to catch errors
"""
@websocket_middleware
async def errorhandler(websocket: WebSocket, call_next):
try:
return await call_next()
except Exception as e:
await websocket.close(code=status.WS_1006_ABNORMAL_CLOSURE, reason=repr(e))
myapp = make_app(middleware=[Middleware(errorhandler)])
client = TestClient(myapp)
with pytest.raises(WebSocketDisconnect) as e:
with client.websocket_connect("/depends-err/"):
pass # pragma: no cover
assert e.value.code == status.WS_1006_ABNORMAL_CLOSURE
assert "NotImplementedError" in e.value.reason
def test_depend_err_handler():
"""
Verify that it is possible to write custom WebSocket middleware to catch errors
"""
async def custom_handler(websocket: WebSocket, exc: CustomError) -> None:
await websocket.close(1002, "foo")
myapp = make_app(exception_handlers={CustomError: custom_handler})
client = TestClient(myapp)
with pytest.raises(WebSocketDisconnect) as e:
with client.websocket_connect("/custom_error/"):
pass # pragma: no cover
assert e.value.code == 1002
assert "foo" in e.value.reason

Loading…
Cancel
Save