diff --git a/tests/test_sse.py b/tests/test_sse.py index 8ca90fe720..97a68214f4 100644 --- a/tests/test_sse.py +++ b/tests/test_sse.py @@ -64,7 +64,8 @@ async def sse_items_event(): @app.get("/items/stream-mixed", response_class=EventSourceResponse) async def sse_items_mixed() -> AsyncIterable[Item]: - yield items[0] + for item in items: + yield item yield ServerSentEvent(data="custom-event", event="special") yield items[1] @@ -102,12 +103,6 @@ async def stream_events_typed() -> AsyncIterable[Item]: yield item -@router.get("/events-jsonl") -async def stream_events_jsonl() -> AsyncIterable[Item]: - for item in items: - yield item - - app.include_router(router, prefix="/api") @@ -287,14 +282,6 @@ def test_sse_router_typed_stream(client: TestClient): assert len(data_lines) == 3 -def test_jsonl_router_typed_stream(client: TestClient): - response = client.get("/api/events-jsonl") - assert response.status_code == 200 - assert response.headers["content-type"] == "application/jsonl" - lines = response.text.strip().split("\n") - assert len(lines) == 3 - - def test_sse_router_typed_openapi_schema(client: TestClient): """Typed SSE endpoint on a router should preserve itemSchema with contentSchema.""" response = client.get("/openapi.json") @@ -324,20 +311,6 @@ def test_sse_router_typed_openapi_schema(client: TestClient): } -def test_jsonl_router_typed_openapi_schema(client: TestClient): - """Typed JSONL endpoint on a router should preserve itemSchema.""" - response = client.get("/openapi.json") - assert response.status_code == 200 - paths = response.json()["paths"] - jsonl_response = paths["/api/events-jsonl"]["get"]["responses"]["200"] - assert jsonl_response == { - "description": "Successful Response", - "content": { - "application/jsonl": {"itemSchema": {"$ref": "#/components/schemas/Item"}} - }, - } - - # Keepalive ping tests @@ -389,3 +362,97 @@ def test_no_keepalive_when_fast(client: TestClient): assert response.status_code == 200 # KEEPALIVE_COMMENT is ": ping\n\n". assert ": ping\n" not in response.text + + +# default_response_class tests + + +sse_schema_response = { + "description": "Successful Response", + "content": { + "text/event-stream": { + "itemSchema": { + "type": "object", + "properties": { + "data": { + "type": "string", + "contentMediaType": "application/json", + "contentSchema": {"$ref": "#/components/schemas/Item"}, + }, + "event": {"type": "string"}, + "id": {"type": "string"}, + "retry": {"type": "integer", "minimum": 0}, + }, + "required": ["data"], + } + } + }, +} + + +# default_response_class on app + +default_app_app = FastAPI(default_response_class=EventSourceResponse) +default_app_router = APIRouter() + + +@default_app_router.get("/stream") +async def default_app_stream() -> AsyncIterable[Item]: + for item in items: + yield item + + +default_app_app.include_router(default_app_router, prefix="/api") + + +def test_default_response_class_on_app_stream(): + with TestClient(default_app_app) as client: + response = client.get("/api/stream") + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + data_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("data: ") + ] + assert len(data_lines) == 3 + + +def test_default_response_class_on_app_openapi_schema(): + assert ( + default_app_app.openapi()["paths"]["/api/stream"]["get"]["responses"]["200"] + == sse_schema_response + ) + + +# default_response_class on parent router + +default_parent_app = FastAPI() +parent_router = APIRouter(default_response_class=EventSourceResponse) +child_router = APIRouter() + + +@child_router.get("/stream") +async def default_parent_stream() -> AsyncIterable[Item]: + for item in items: + yield item + + +parent_router.include_router(child_router) +default_parent_app.include_router(parent_router, prefix="/api") + + +def test_default_response_class_on_parent_router_stream(): + with TestClient(default_parent_app) as client: + response = client.get("/api/stream") + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + data_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("data: ") + ] + assert len(data_lines) == 3 + + +def test_default_response_class_on_parent_router_openapi_schema(): + assert ( + default_parent_app.openapi()["paths"]["/api/stream"]["get"]["responses"]["200"] + == sse_schema_response + ) diff --git a/tests/test_stream_bare_type.py b/tests/test_stream_bare_type.py index 68bd31df6b..88c0ccfb25 100644 --- a/tests/test_stream_bare_type.py +++ b/tests/test_stream_bare_type.py @@ -1,7 +1,7 @@ import json from typing import AsyncIterable, Iterable # noqa: UP035 to test coverage -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI from fastapi.testclient import TestClient from pydantic import BaseModel @@ -23,6 +23,16 @@ def stream_bare_sync() -> Iterable: yield {"name": "bar"} +router = APIRouter() + + +@router.get("/events-jsonl") +async def stream_events_jsonl() -> AsyncIterable[Item]: + yield Item(name="foo") + + +app.include_router(router, prefix="/api") + client = TestClient(app) @@ -40,3 +50,24 @@ def test_stream_bare_sync_iterable(): assert response.headers["content-type"] == "application/jsonl" lines = [json.loads(line) for line in response.text.strip().splitlines()] assert lines == [{"name": "bar"}] + + +def test_jsonl_router_typed_stream(): + response = client.get("/api/events-jsonl") + assert response.status_code == 200 + assert response.headers["content-type"] == "application/jsonl" + lines = [json.loads(line) for line in response.text.strip().splitlines()] + assert lines == [{"name": "foo"}] + + +def test_jsonl_router_typed_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + paths = response.json()["paths"] + jsonl_response = paths["/api/events-jsonl"]["get"]["responses"]["200"] + assert jsonl_response == { + "description": "Successful Response", + "content": { + "application/jsonl": {"itemSchema": {"$ref": "#/components/schemas/Item"}} + }, + }