From 054139d15cb03ee2e8ec1031fa23dd5a6750adf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 21 Sep 2025 15:47:13 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=85=20Add=20tests=20for=20dependency=20wi?= =?UTF-8?q?th=20yield=20catching=20exceptions=20and=20re-raising?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_dependency_after_yield_raise.py | 63 ++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 tests/test_dependency_after_yield_raise.py diff --git a/tests/test_dependency_after_yield_raise.py b/tests/test_dependency_after_yield_raise.py new file mode 100644 index 000000000..31c646e08 --- /dev/null +++ b/tests/test_dependency_after_yield_raise.py @@ -0,0 +1,63 @@ +from typing import Annotated, Any + +import pytest +from fastapi import Depends, FastAPI, HTTPException +from fastapi.testclient import TestClient + + +class CustomError(Exception): + pass + + +def catching_dep() -> Any: + try: + yield "s" + except CustomError as err: + raise HTTPException(status_code=418, detail="Session error") from err + + +def broken_dep() -> Any: + yield "s" + raise ValueError("Broken after yield") + + +app = FastAPI() + + +@app.get("/catching") +def catching(d: Annotated[str, Depends(catching_dep)]) -> Any: + raise CustomError("Simulated error during streaming") + + +@app.get("/broken") +def broken(d: Annotated[str, Depends(broken_dep)]) -> Any: + return {"message": "all good?"} + + +client = TestClient(app) + + +def test_catching(): + response = client.get("/catching") + assert response.status_code == 418 + assert response.json() == {"detail": "Session error"} + + +def test_broken_catch(): + """ + When a dependency with yield raises after the yield (not in an except), the + response is already "successfully" sent back to the client, but there's still + an error in the server afterwards, an exception is raised and captured or shown + in the server logs. + """ + with pytest.raises(ValueError, match="Broken after yield"): + response = client.get("/broken") + assert response.status_code == 200 + assert response.json() == {"message": "all good?"} + + +def test_broken_return_finishes(): + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/broken") + assert response.status_code == 200 + assert response.json() == {"message": "all good?"}