From f8b86bf71483d3db441b6ae6db8939eb52532f6b Mon Sep 17 00:00:00 2001 From: Yurii Motov Date: Mon, 14 Apr 2025 09:07:10 +0200 Subject: [PATCH] Added tests for passing multiple codes or exceptons to `exception_handler` --- tests/test_exception_handlers.py | 67 ++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/test_exception_handlers.py b/tests/test_exception_handlers.py index 6a3cbd830..9f3055d00 100644 --- a/tests/test_exception_handlers.py +++ b/tests/test_exception_handlers.py @@ -5,6 +5,14 @@ from fastapi.testclient import TestClient from starlette.responses import JSONResponse +class CustomException1(HTTPException): + pass + + +class CustomException2(HTTPException): + pass + + def http_exception_handler(request, exception): return JSONResponse({"exception": "http-exception"}) @@ -86,3 +94,62 @@ def test_traceback_for_dependency_with_yield(): last_frame = exc_info.traceback[-1] assert str(last_frame.path) == __file__ assert last_frame.lineno == raise_value_error.__code__.co_firstlineno + + +def test_exception_handler_with_single_exception(): + local_app = FastAPI() + + @local_app.exception_handler(CustomException1) + def custom_exception_handler(request, exception): + pass # pragma: no cover + + assert ( + local_app.exception_handlers.get(CustomException1) == custom_exception_handler + ) + + +@pytest.mark.parametrize( + "exceptions", + [ + (CustomException1, CustomException2), # Tuple of exceptions + [CustomException1, CustomException2], # List of exceptions + ], +) +def test_exception_handler_with_multiple_exceptions(exceptions): + local_app = FastAPI() + + @local_app.exception_handler(exceptions) + def custom_exception_handler(request, exception): + pass # pragma: no cover + + assert local_app.exception_handlers.get(exceptions[0]) == custom_exception_handler + + assert local_app.exception_handlers.get(exceptions[1]) == custom_exception_handler + + +def test_exception_handler_with_single_status_code(): + local_app = FastAPI() + + @local_app.exception_handler(409) + def http_409_status_code_handler(request, exception): + pass # pragma: no cover + + assert local_app.exception_handlers.get(409) == http_409_status_code_handler + + +@pytest.mark.parametrize( + "status_codes", + [ + (401, 403), # Tuple of status codes + [401, 403], # List of status codes + ], +) +def test_exception_handler_with_multiple_status_codes(status_codes): + local_app = FastAPI() + + @local_app.exception_handler(status_codes) + def auth_errors_handler(request, exception): + pass # pragma: no cover + + assert local_app.exception_handlers.get(status_codes[0]) == auth_errors_handler + assert local_app.exception_handlers.get(status_codes[1]) == auth_errors_handler