diff --git a/fastapi/applications.py b/fastapi/applications.py index f3ce08e9f..24a242a4e 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -66,6 +66,10 @@ class FastAPI(Starlette): self.exception_handlers = ( {} if exception_handlers is None else dict(exception_handlers) ) + self.exception_handlers.setdefault(HTTPException, http_exception_handler) + self.exception_handlers.setdefault( + RequestValidationError, request_validation_exception_handler + ) self.user_middleware = [] if middleware is None else list(middleware) self.middleware_stack = self.build_middleware_stack() @@ -165,10 +169,6 @@ class FastAPI(Starlette): ) self.add_route(self.redoc_url, redoc_html, include_in_schema=False) - self.add_exception_handler(HTTPException, http_exception_handler) - self.add_exception_handler( - RequestValidationError, request_validation_exception_handler - ) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.root_path: diff --git a/tests/test_exception_handlers.py b/tests/test_exception_handlers.py new file mode 100644 index 000000000..6153f7ab9 --- /dev/null +++ b/tests/test_exception_handlers.py @@ -0,0 +1,44 @@ +from fastapi import FastAPI, HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.testclient import TestClient +from starlette.responses import JSONResponse + + +def http_exception_handler(request, exception): + return JSONResponse({"exception": "http-exception"}) + + +def request_validation_exception_handler(request, exception): + return JSONResponse({"exception": "request-validation"}) + + +app = FastAPI( + exception_handlers={ + HTTPException: http_exception_handler, + RequestValidationError: request_validation_exception_handler, + } +) + +client = TestClient(app) + + +@app.get("/http-exception") +def route_with_http_exception(): + raise HTTPException(status_code=400) + + +@app.get("/request-validation/{param}/") +def route_with_request_validation_exception(param: int): + pass # pragma: no cover + + +def test_override_http_exception(): + response = client.get("/http-exception") + assert response.status_code == 200 + assert response.json() == {"exception": "http-exception"} + + +def test_override_request_validation_exception(): + response = client.get("/request-validation/invalid") + assert response.status_code == 200 + assert response.json() == {"exception": "request-validation"}