From 1e71e31fd6f49d5473b277fdf3e5ef9d8e82af45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 3 Sep 2025 17:26:06 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=AA=20Add=20test=20for=20dependencies?= =?UTF-8?q?=20with=20yield=20with=20streaming?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_dependency_after_yield_streaming.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 tests/test_dependency_after_yield_streaming.py diff --git a/tests/test_dependency_after_yield_streaming.py b/tests/test_dependency_after_yield_streaming.py new file mode 100644 index 000000000..0010453bf --- /dev/null +++ b/tests/test_dependency_after_yield_streaming.py @@ -0,0 +1,80 @@ +from contextlib import contextmanager +from typing import Annotated, Any, Generator + +from fastapi import Depends, FastAPI +from fastapi.responses import StreamingResponse +from fastapi.testclient import TestClient + + +class Session: + def __init__(self) -> None: + self.data = ["foo", "bar", "baz"] + self.open = True + + def __iter__(self) -> Generator[str, None, None]: + for item in self.data: + if self.open: + yield item + else: + raise ValueError("Session closed") + + +@contextmanager +def acquire_session() -> Generator[Session, None, None]: + session = Session() + try: + yield session + finally: + session.open = False + + +def dep_session() -> Any: + with acquire_session() as s: + yield s + + +SessionDep = Annotated[Session, Depends(dep_session)] + +app = FastAPI() + + +@app.get("/data") +def get_data(session: SessionDep) -> Any: + data = list(session) + return data + + +@app.get("/stream-simple") +def get_stream_simple(session: SessionDep) -> Any: + def iter_data(): + yield from ["x", "y", "z"] + + return StreamingResponse(iter_data()) + + +@app.get("/stream-session") +def get_stream_session(session: SessionDep) -> Any: + def iter_data(): + yield from session + + return StreamingResponse(iter_data()) + + +def session_explode() -> None: + with acquire_session() as s: + iter_s = iter(s) + print(next(iter_s)) + print(next(iter_s)) + + +client = TestClient(app) + + +def test_stream_simple(): + response = client.get("/stream-simple") + assert response.text == "xyz" + + +def test_stream_session(): + response = client.get("/stream-session") + assert response.text == "foobarbaz"