1 changed files with 217 additions and 0 deletions
@ -0,0 +1,217 @@ |
|||||
|
from collections.abc import AsyncGenerator, Generator |
||||
|
from contextlib import asynccontextmanager, contextmanager |
||||
|
from typing import Annotated, Any |
||||
|
|
||||
|
import pytest |
||||
|
from fastapi import Depends, FastAPI |
||||
|
from fastapi.responses import EventSourceResponse |
||||
|
from fastapi.testclient import TestClient |
||||
|
from pydantic import BaseModel |
||||
|
|
||||
|
|
||||
|
class Item(BaseModel): |
||||
|
name: str |
||||
|
|
||||
|
|
||||
|
class Session: |
||||
|
def __init__(self) -> None: |
||||
|
self.items = [Item(name="foo"), Item(name="bar"), Item(name="baz")] |
||||
|
self.open = True |
||||
|
|
||||
|
def __iter__(self) -> Generator[Item, None, None]: |
||||
|
for item in self.items: |
||||
|
if self.open: |
||||
|
yield item |
||||
|
else: |
||||
|
raise ValueError("Session closed") |
||||
|
|
||||
|
def __aiter__(self) -> AsyncGenerator[Item, None]: |
||||
|
return self._async_iter() |
||||
|
|
||||
|
async def _async_iter(self) -> AsyncGenerator[Item, None]: |
||||
|
for item in self.items: |
||||
|
if self.open: |
||||
|
yield item |
||||
|
else: |
||||
|
raise ValueError("Session closed") |
||||
|
|
||||
|
|
||||
|
@contextmanager |
||||
|
def acquire_session() -> Generator[Session, None, None]: |
||||
|
session = Session() |
||||
|
try: |
||||
|
yield session |
||||
|
finally: |
||||
|
session.open = False |
||||
|
|
||||
|
|
||||
|
@asynccontextmanager |
||||
|
async def acquire_async_session() -> AsyncGenerator[Session, None]: |
||||
|
session = Session() |
||||
|
try: |
||||
|
yield session |
||||
|
finally: |
||||
|
session.open = False |
||||
|
|
||||
|
|
||||
|
def dep_session() -> Any: |
||||
|
with acquire_session() as s: |
||||
|
yield s |
||||
|
|
||||
|
|
||||
|
async def async_dep_session() -> Any: |
||||
|
async with acquire_async_session() as s: |
||||
|
yield s |
||||
|
|
||||
|
|
||||
|
def broken_dep_session() -> Any: |
||||
|
with acquire_session() as s: |
||||
|
s.open = False |
||||
|
yield s |
||||
|
|
||||
|
|
||||
|
async def async_broken_dep_session() -> Any: |
||||
|
async with acquire_async_session() as s: |
||||
|
s.open = False |
||||
|
yield s |
||||
|
|
||||
|
|
||||
|
SessionDep = Annotated[Session, Depends(dep_session)] |
||||
|
AsyncSessionDep = Annotated[Session, Depends(async_dep_session)] |
||||
|
BrokenSessionDep = Annotated[Session, Depends(broken_dep_session)] |
||||
|
AsyncBrokenSessionDep = Annotated[Session, Depends(async_broken_dep_session)] |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
@app.get("/sse-sync", response_class=EventSourceResponse) |
||||
|
def sse_sync(session: SessionDep) -> Any: |
||||
|
def gen() -> Generator[Item, None, None]: |
||||
|
yield from session |
||||
|
|
||||
|
return gen() |
||||
|
|
||||
|
|
||||
|
@app.get("/sse-async", response_class=EventSourceResponse) |
||||
|
async def sse_async(session: AsyncSessionDep) -> AsyncGenerator[Item, None]: |
||||
|
async for item in session: |
||||
|
yield item |
||||
|
|
||||
|
|
||||
|
@app.get("/sse-broken-sync", response_class=EventSourceResponse) |
||||
|
def sse_broken_sync(session: BrokenSessionDep) -> Any: |
||||
|
def gen() -> Generator[Item, None, None]: |
||||
|
yield from session |
||||
|
|
||||
|
return gen() |
||||
|
|
||||
|
|
||||
|
@app.get("/sse-broken-async", response_class=EventSourceResponse) |
||||
|
async def sse_broken_async( |
||||
|
session: AsyncBrokenSessionDep, |
||||
|
) -> AsyncGenerator[Item, None]: |
||||
|
async for item in session: |
||||
|
yield item |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def _parse_sse_data_lines(text: str) -> list[str]: |
||||
|
return [ |
||||
|
line[len("data: ") :] |
||||
|
for line in text.strip().splitlines() |
||||
|
if line.startswith("data: ") |
||||
|
] |
||||
|
|
||||
|
|
||||
|
def test_sse_sync_streams_items(): |
||||
|
response = client.get("/sse-sync") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8" |
||||
|
data_lines = _parse_sse_data_lines(response.text) |
||||
|
assert len(data_lines) == 3 |
||||
|
|
||||
|
|
||||
|
def test_sse_sync_dependency_cleaned_up(): |
||||
|
"""Yield dependency cleanup runs after the SSE stream completes.""" |
||||
|
sessions: list[Session] = [] |
||||
|
|
||||
|
def tracking_dep() -> Any: |
||||
|
with acquire_session() as s: |
||||
|
sessions.append(s) |
||||
|
yield s |
||||
|
|
||||
|
app.dependency_overrides[dep_session] = tracking_dep |
||||
|
try: |
||||
|
response = client.get("/sse-sync") |
||||
|
assert response.status_code == 200 |
||||
|
finally: |
||||
|
app.dependency_overrides.clear() |
||||
|
|
||||
|
assert len(sessions) == 1 |
||||
|
# The session's open flag must be False after stream ends - |
||||
|
# meaning the finally block in acquire_session() ran. |
||||
|
assert sessions[0].open is False |
||||
|
|
||||
|
|
||||
|
def test_sse_async_streams_items(): |
||||
|
response = client.get("/sse-async") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8" |
||||
|
data_lines = _parse_sse_data_lines(response.text) |
||||
|
assert len(data_lines) == 3 |
||||
|
|
||||
|
|
||||
|
def test_sse_async_dependency_cleaned_up(): |
||||
|
"""Async yield dependency cleanup runs after the SSE stream completes.""" |
||||
|
sessions: list[Session] = [] |
||||
|
|
||||
|
async def tracking_dep() -> Any: |
||||
|
async with acquire_async_session() as s: |
||||
|
sessions.append(s) |
||||
|
yield s |
||||
|
|
||||
|
app.dependency_overrides[async_dep_session] = tracking_dep |
||||
|
try: |
||||
|
response = client.get("/sse-async") |
||||
|
assert response.status_code == 200 |
||||
|
finally: |
||||
|
app.dependency_overrides.clear() |
||||
|
|
||||
|
assert len(sessions) == 1 |
||||
|
assert sessions[0].open is False |
||||
|
|
||||
|
|
||||
|
def test_sse_broken_sync_raises(): |
||||
|
"""When a sync yield dependency is broken the stream fails.""" |
||||
|
with pytest.raises((ValueError, Exception)): |
||||
|
client.get("/sse-broken-sync") |
||||
|
|
||||
|
|
||||
|
def test_sse_broken_sync_no_raise(): |
||||
|
""" |
||||
|
When a sync yield dependency raises after streaming has started, |
||||
|
the 200 status code is already sent but the body is empty. |
||||
|
""" |
||||
|
with TestClient(app, raise_server_exceptions=False) as c: |
||||
|
response = c.get("/sse-broken-sync") |
||||
|
assert response.status_code == 200 |
||||
|
assert _parse_sse_data_lines(response.text) == [] |
||||
|
|
||||
|
|
||||
|
def test_sse_broken_async_raises(): |
||||
|
"""When an async yield dependency is broken the stream fails.""" |
||||
|
with pytest.raises((ValueError, Exception)): |
||||
|
client.get("/sse-broken-async") |
||||
|
|
||||
|
|
||||
|
def test_sse_broken_async_no_raise(): |
||||
|
""" |
||||
|
When an async yield dependency raises after streaming has started, |
||||
|
the 200 status code is already sent but the body is empty. |
||||
|
""" |
||||
|
with TestClient(app, raise_server_exceptions=False) as c: |
||||
|
response = c.get("/sse-broken-async") |
||||
|
assert response.status_code == 200 |
||||
|
assert _parse_sse_data_lines(response.text) == [] |
||||
Loading…
Reference in new issue