From ab03f226353394da467b77ff08cad4cbf94463e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 11 Jun 2023 19:08:14 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20exception=20handler=20for=20`?= =?UTF-8?q?WebSocketRequestValidationError`=20(which=20also=20allows=20to?= =?UTF-8?q?=20override=20it)=20(#6030)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sebastián Ramírez --- fastapi/applications.py | 8 +- fastapi/exception_handlers.py | 13 ++- fastapi/routing.py | 2 - tests/test_ws_router.py | 152 +++++++++++++++++++++++++++++++++- 4 files changed, 166 insertions(+), 9 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index 8b3a74d3c..d5ea1d72a 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -19,8 +19,9 @@ from fastapi.encoders import DictIntStrAny, SetIntStr from fastapi.exception_handlers import ( http_exception_handler, request_validation_exception_handler, + websocket_request_validation_exception_handler, ) -from fastapi.exceptions import RequestValidationError +from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError from fastapi.logger import logger from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware from fastapi.openapi.docs import ( @@ -145,6 +146,11 @@ class FastAPI(Starlette): self.exception_handlers.setdefault( RequestValidationError, request_validation_exception_handler ) + self.exception_handlers.setdefault( + WebSocketRequestValidationError, + # Starlette still has incorrect type specification for the handlers + websocket_request_validation_exception_handler, # type: ignore + ) self.user_middleware: List[Middleware] = ( [] if middleware is None else list(middleware) diff --git a/fastapi/exception_handlers.py b/fastapi/exception_handlers.py index 4d7ea5ec2..6c2ba7fed 100644 --- a/fastapi/exception_handlers.py +++ b/fastapi/exception_handlers.py @@ -1,10 +1,11 @@ from fastapi.encoders import jsonable_encoder -from fastapi.exceptions import RequestValidationError +from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError from fastapi.utils import is_body_allowed_for_status_code +from fastapi.websockets import WebSocket from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import JSONResponse, Response -from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY +from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION async def http_exception_handler(request: Request, exc: HTTPException) -> Response: @@ -23,3 +24,11 @@ async def request_validation_exception_handler( status_code=HTTP_422_UNPROCESSABLE_ENTITY, content={"detail": jsonable_encoder(exc.errors())}, ) + + +async def websocket_request_validation_exception_handler( + websocket: WebSocket, exc: WebSocketRequestValidationError +) -> None: + await websocket.close( + code=WS_1008_POLICY_VIOLATION, reason=jsonable_encoder(exc.errors()) + ) diff --git a/fastapi/routing.py b/fastapi/routing.py index 06c71bffa..7f1936f7f 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -56,7 +56,6 @@ from starlette.routing import ( request_response, websocket_session, ) -from starlette.status import WS_1008_POLICY_VIOLATION from starlette.types import ASGIApp, Lifespan, Scope from starlette.websockets import WebSocket @@ -283,7 +282,6 @@ def get_websocket_app( ) values, errors, _, _2, _3 = solved_result if errors: - await websocket.close(code=WS_1008_POLICY_VIOLATION) raise WebSocketRequestValidationError(errors) assert dependant.call is not None, "dependant.call must be a function" await dependant.call(**values) diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py index c312821e9..240a42bb0 100644 --- a/tests/test_ws_router.py +++ b/tests/test_ws_router.py @@ -1,4 +1,16 @@ -from fastapi import APIRouter, Depends, FastAPI, WebSocket +import functools + +import pytest +from fastapi import ( + APIRouter, + Depends, + FastAPI, + Header, + WebSocket, + WebSocketDisconnect, + status, +) +from fastapi.middleware import Middleware from fastapi.testclient import TestClient router = APIRouter() @@ -63,9 +75,44 @@ async def router_native_prefix_ws(websocket: WebSocket): await websocket.close() -app.include_router(router) -app.include_router(prefix_router, prefix="/prefix") -app.include_router(native_prefix_route) +async def ws_dependency_err(): + raise NotImplementedError() + + +@router.websocket("/depends-err/") +async def router_ws_depends_err(websocket: WebSocket, data=Depends(ws_dependency_err)): + pass # pragma: no cover + + +async def ws_dependency_validate(x_missing: str = Header()): + pass # pragma: no cover + + +@router.websocket("/depends-validate/") +async def router_ws_depends_validate( + websocket: WebSocket, data=Depends(ws_dependency_validate) +): + pass # pragma: no cover + + +class CustomError(Exception): + pass + + +@router.websocket("/custom_error/") +async def router_ws_custom_error(websocket: WebSocket): + raise CustomError() + + +def make_app(app=None, **kwargs): + app = app or FastAPI(**kwargs) + app.include_router(router) + app.include_router(prefix_router, prefix="/prefix") + app.include_router(native_prefix_route) + return app + + +app = make_app(app) def test_app(): @@ -125,3 +172,100 @@ def test_router_with_params(): assert data == "path/to/file" data = websocket.receive_text() assert data == "a_query_param" + + +def test_wrong_uri(): + """ + Verify that a websocket connection to a non-existent endpoing returns in a shutdown + """ + client = TestClient(app) + with pytest.raises(WebSocketDisconnect) as e: + with client.websocket_connect("/no-router/"): + pass # pragma: no cover + assert e.value.code == status.WS_1000_NORMAL_CLOSURE + + +def websocket_middleware(middleware_func): + """ + Helper to create a Starlette pure websocket middleware + """ + + def middleware_constructor(app): + @functools.wraps(app) + async def wrapped_app(scope, receive, send): + if scope["type"] != "websocket": + return await app(scope, receive, send) # pragma: no cover + + async def call_next(): + return await app(scope, receive, send) + + websocket = WebSocket(scope, receive=receive, send=send) + return await middleware_func(websocket, call_next) + + return wrapped_app + + return middleware_constructor + + +def test_depend_validation(): + """ + Verify that a validation in a dependency invokes the correct exception handler + """ + caught = [] + + @websocket_middleware + async def catcher(websocket, call_next): + try: + return await call_next() + except Exception as e: # pragma: no cover + caught.append(e) + raise + + myapp = make_app(middleware=[Middleware(catcher)]) + + client = TestClient(myapp) + with pytest.raises(WebSocketDisconnect) as e: + with client.websocket_connect("/depends-validate/"): + pass # pragma: no cover + # the validation error does produce a close message + assert e.value.code == status.WS_1008_POLICY_VIOLATION + # and no error is leaked + assert caught == [] + + +def test_depend_err_middleware(): + """ + Verify that it is possible to write custom WebSocket middleware to catch errors + """ + + @websocket_middleware + async def errorhandler(websocket: WebSocket, call_next): + try: + return await call_next() + except Exception as e: + await websocket.close(code=status.WS_1006_ABNORMAL_CLOSURE, reason=repr(e)) + + myapp = make_app(middleware=[Middleware(errorhandler)]) + client = TestClient(myapp) + with pytest.raises(WebSocketDisconnect) as e: + with client.websocket_connect("/depends-err/"): + pass # pragma: no cover + assert e.value.code == status.WS_1006_ABNORMAL_CLOSURE + assert "NotImplementedError" in e.value.reason + + +def test_depend_err_handler(): + """ + Verify that it is possible to write custom WebSocket middleware to catch errors + """ + + async def custom_handler(websocket: WebSocket, exc: CustomError) -> None: + await websocket.close(1002, "foo") + + myapp = make_app(exception_handlers={CustomError: custom_handler}) + client = TestClient(myapp) + with pytest.raises(WebSocketDisconnect) as e: + with client.websocket_connect("/custom_error/"): + pass # pragma: no cover + assert e.value.code == 1002 + assert "foo" in e.value.reason