diff --git a/tests/test_dependency_after_yield_raise.py b/tests/test_dependency_after_yield_raise.py index 31c646e08..b56140277 100644 --- a/tests/test_dependency_after_yield_raise.py +++ b/tests/test_dependency_after_yield_raise.py @@ -43,14 +43,19 @@ def test_catching(): assert response.json() == {"detail": "Session error"} -def test_broken_catch(): +def test_broken_raise(): + with pytest.raises(ValueError, match="Broken after yield"): + client.get("/broken") + + +def test_broken_no_raise(): """ When a dependency with yield raises after the yield (not in an except), the response is already "successfully" sent back to the client, but there's still an error in the server afterwards, an exception is raised and captured or shown in the server logs. """ - with pytest.raises(ValueError, match="Broken after yield"): + with TestClient(app, raise_server_exceptions=False) as client: response = client.get("/broken") assert response.status_code == 200 assert response.json() == {"message": "all good?"} diff --git a/tests/test_dependency_after_yield_streaming.py b/tests/test_dependency_after_yield_streaming.py index 1491c0017..5235d5249 100644 --- a/tests/test_dependency_after_yield_streaming.py +++ b/tests/test_dependency_after_yield_streaming.py @@ -1,6 +1,7 @@ from contextlib import contextmanager from typing import Annotated, Any, Generator +import pytest from fastapi import Depends, FastAPI from fastapi.responses import StreamingResponse from fastapi.testclient import TestClient @@ -33,7 +34,14 @@ def dep_session() -> Any: yield s +def broken_dep_session() -> Any: + with acquire_session() as s: + s.open = False + yield s + + SessionDep = Annotated[Session, Depends(dep_session)] +BrokenSessionDep = Annotated[Session, Depends(broken_dep_session)] app = FastAPI() @@ -60,6 +68,19 @@ def get_stream_session(session: SessionDep) -> Any: return StreamingResponse(iter_data()) +@app.get("/broken-session-data") +def get_broken_session_data(session: BrokenSessionDep) -> Any: + return list(session) + + +@app.get("/broken-session-stream") +def get_broken_session_stream(session: BrokenSessionDep) -> Any: + def iter_data(): + yield from session + + return StreamingResponse(iter_data()) + + client = TestClient(app) @@ -76,3 +97,33 @@ def test_stream_simple(): def test_stream_session(): response = client.get("/stream-session") assert response.text == "foobarbaz" + + +def test_broken_session_data(): + with pytest.raises(ValueError, match="Session closed"): + client.get("/broken-session-data") + + +def test_broken_session_data_no_raise(): + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/broken-session-data") + assert response.status_code == 500 + assert response.text == "Internal Server Error" + + +def test_broken_session_stream_raise(): + # Can raise ValueError on Pydantic v2 and ExceptionGroup on Pydantic v1 + with pytest.raises((ValueError, Exception)): + client.get("/broken-session-stream") + + +def test_broken_session_stream_no_raise(): + """ + When a dependency with yield raises after the streaming response already started + the 200 status code is already sent, but there's still an error in the server + afterwards, an exception is raised and captured or shown in the server logs. + """ + with TestClient(app, raise_server_exceptions=False) as client: + response = client.get("/broken-session-stream") + assert response.status_code == 200 + assert response.text == "" diff --git a/tests/test_dependency_after_yield_websockets.py b/tests/test_dependency_after_yield_websockets.py new file mode 100644 index 000000000..3fa1b051b --- /dev/null +++ b/tests/test_dependency_after_yield_websockets.py @@ -0,0 +1,78 @@ +from contextlib import contextmanager +from typing import Annotated, Any, Generator + +import pytest +from fastapi import Depends, FastAPI, WebSocket +from fastapi.testclient import TestClient + + +class Session: + def __init__(self) -> None: + self.data = ["foo", "bar", "baz"] + self.open = True + + def __iter__(self) -> Generator[str, None, None]: + for item in self.data: + if self.open: + yield item + else: + raise ValueError("Session closed") + + +@contextmanager +def acquire_session() -> Generator[Session, None, None]: + session = Session() + try: + yield session + finally: + session.open = False + + +def dep_session() -> Any: + with acquire_session() as s: + yield s + + +def broken_dep_session() -> Any: + with acquire_session() as s: + s.open = False + yield s + + +SessionDep = Annotated[Session, Depends(dep_session)] +BrokenSessionDep = Annotated[Session, Depends(broken_dep_session)] + +app = FastAPI() + + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket, session: SessionDep): + await websocket.accept() + for item in session: + await websocket.send_text(f"{item}") + + +@app.websocket("/ws-broken") +async def websocket_endpoint_broken(websocket: WebSocket, session: BrokenSessionDep): + await websocket.accept() + for item in session: + await websocket.send_text(f"{item}") # pragma no cover + + +client = TestClient(app) + + +def test_websocket_dependency_after_yield(): + with client.websocket_connect("/ws") as websocket: + data = websocket.receive_text() + assert data == "foo" + data = websocket.receive_text() + assert data == "bar" + data = websocket.receive_text() + assert data == "baz" + + +def test_websocket_dependency_after_yield_broken(): + with pytest.raises(ValueError, match="Session closed"): + with client.websocket_connect("/ws-broken"): + pass # pragma no cover diff --git a/tests/test_route_scope.py b/tests/test_route_scope.py index 2021c828f..792ea66c3 100644 --- a/tests/test_route_scope.py +++ b/tests/test_route_scope.py @@ -47,4 +47,4 @@ def test_websocket(): def test_websocket_invalid_path_doesnt_match(): with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/itemsx/portal-gun"): - pass + pass # pragma: no cover