diff --git a/tests/test_sse.py b/tests/test_sse.py index 6dfec61838..4317a8f322 100644 --- a/tests/test_sse.py +++ b/tests/test_sse.py @@ -6,6 +6,7 @@ import fastapi.routing import pytest from fastapi import APIRouter, FastAPI from fastapi.responses import EventSourceResponse +from fastapi.routing import APIRoute from fastapi.sse import ServerSentEvent from fastapi.testclient import TestClient from pydantic import BaseModel @@ -99,6 +100,20 @@ async def stream_events(): app.include_router(router, prefix="/api") +# Router with a typed SSE route to test stream_item_type propagation +typed_router = APIRouter() + + +@typed_router.get("/typed-events", response_class=EventSourceResponse) +async def stream_typed_events() -> AsyncIterable[Item]: + for item in items: + yield item + + +typed_app = FastAPI() +typed_app.include_router(typed_router, prefix="/api") + + @pytest.fixture(name="client") def client_fixture(): with TestClient(app) as c: @@ -237,9 +252,7 @@ def test_raw_data_sent_without_json_encoding(client: TestClient): assert response.status_code == 200 text = response.text - # raw_data should appear without JSON quotes assert "data: plain text without quotes\n" in text - # Not JSON-quoted assert 'data: "plain text without quotes"' not in text assert "event: html\n" in text @@ -265,6 +278,23 @@ def test_sse_on_router_included_in_app(client: TestClient): assert len(data_lines) == 2 +def test_stream_item_type_propagated_through_include_router(): + included_route = next( + r for r in typed_app.routes if getattr(r, "path", None) == "/api/typed-events" + ) + assert isinstance(included_route, APIRoute) + assert included_route.stream_item_type is Item + + with TestClient(typed_app) as client: + response = client.get("/api/typed-events") + assert response.status_code == 200 + data_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("data: ") + ] + assert len(data_lines) == 3 + assert '"name"' in data_lines[0] + + # Keepalive ping tests @@ -274,8 +304,6 @@ keepalive_app = FastAPI() @keepalive_app.get("/slow-async", response_class=EventSourceResponse) async def slow_async_stream(): yield {"n": 1} - # Sleep longer than the (monkeypatched) ping interval so a keepalive - # comment is emitted before the next item. await asyncio.sleep(0.3) yield {"n": 2} @@ -293,7 +321,6 @@ def test_keepalive_ping_async(monkeypatch: pytest.MonkeyPatch): response = c.get("/slow-async") assert response.status_code == 200 text = response.text - # The keepalive comment ": ping" should appear between the two data events assert ": ping\n" in text data_lines = [line for line in text.split("\n") if line.startswith("data: ")] assert len(data_lines) == 2 @@ -314,5 +341,4 @@ def test_no_keepalive_when_fast(client: TestClient): """No keepalive comment when items arrive quickly.""" response = client.get("/items/stream") assert response.status_code == 200 - # KEEPALIVE_COMMENT is ": ping\n\n". assert ": ping\n" not in response.text