pythonasyncioapiasyncfastapiframeworkjsonjson-schemaopenapiopenapi3pydanticpython-typespython3redocreststarletteswaggerswagger-uiuvicornweb
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
439 lines
14 KiB
439 lines
14 KiB
import asyncio
|
|
import time
|
|
from collections.abc import AsyncIterable, Iterable
|
|
|
|
import fastapi.routing
|
|
import pytest
|
|
from fastapi import APIRouter, FastAPI
|
|
from fastapi.responses import EventSourceResponse
|
|
from fastapi.sse import ServerSentEvent, get_sse_data_type
|
|
from fastapi.testclient import TestClient
|
|
from pydantic import BaseModel
|
|
|
|
|
|
class Item(BaseModel):
|
|
name: str
|
|
description: str | None = None
|
|
|
|
|
|
items = [
|
|
Item(name="Plumbus", description="A multi-purpose household device."),
|
|
Item(name="Portal Gun", description="A portal opening device."),
|
|
Item(name="Meeseeks Box", description="A box that summons a Meeseeks."),
|
|
]
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
@app.get("/items/stream", response_class=EventSourceResponse)
|
|
async def sse_items() -> AsyncIterable[Item]:
|
|
for item in items:
|
|
yield item
|
|
|
|
|
|
@app.get("/items/stream-sync", response_class=EventSourceResponse)
|
|
def sse_items_sync() -> Iterable[Item]:
|
|
yield from items
|
|
|
|
|
|
@app.get("/items/stream-no-annotation", response_class=EventSourceResponse)
|
|
async def sse_items_no_annotation():
|
|
for item in items:
|
|
yield item
|
|
|
|
|
|
@app.get("/items/stream-sync-no-annotation", response_class=EventSourceResponse)
|
|
def sse_items_sync_no_annotation():
|
|
yield from items
|
|
|
|
|
|
@app.get("/items/stream-dict", response_class=EventSourceResponse)
|
|
async def sse_items_dict():
|
|
for item in items:
|
|
yield {"name": item.name, "description": item.description}
|
|
|
|
|
|
@app.get("/items/stream-sse-event", response_class=EventSourceResponse)
|
|
async def sse_items_event():
|
|
yield ServerSentEvent(data="hello", event="greeting", id="1")
|
|
yield ServerSentEvent(data={"key": "value"}, event="json-data", id="2")
|
|
yield ServerSentEvent(comment="just a comment")
|
|
yield ServerSentEvent(data="retry-test", retry=5000)
|
|
|
|
|
|
@app.get("/items/stream-mixed", response_class=EventSourceResponse)
|
|
async def sse_items_mixed() -> AsyncIterable[Item]:
|
|
yield items[0]
|
|
yield ServerSentEvent(data="custom-event", event="special")
|
|
yield items[1]
|
|
|
|
|
|
@app.get("/items/stream-string", response_class=EventSourceResponse)
|
|
async def sse_items_string():
|
|
yield ServerSentEvent(data="plain text data")
|
|
|
|
|
|
@app.post("/items/stream-post", response_class=EventSourceResponse)
|
|
async def sse_items_post() -> AsyncIterable[Item]:
|
|
for item in items:
|
|
yield item
|
|
|
|
|
|
@app.get("/items/stream-raw", response_class=EventSourceResponse)
|
|
async def sse_items_raw():
|
|
yield ServerSentEvent(raw_data="plain text without quotes")
|
|
yield ServerSentEvent(raw_data="<div>html fragment</div>", event="html")
|
|
yield ServerSentEvent(raw_data="cpu,87.3,1709145600", event="csv")
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("/events", response_class=EventSourceResponse)
|
|
async def stream_events():
|
|
yield {"msg": "hello"}
|
|
yield {"msg": "world"}
|
|
|
|
|
|
app.include_router(router, prefix="/api")
|
|
|
|
|
|
@pytest.fixture(name="client")
|
|
def client_fixture():
|
|
with TestClient(app) as c:
|
|
yield c
|
|
|
|
|
|
def test_async_generator_with_model(client: TestClient):
|
|
response = client.get("/items/stream")
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
|
assert response.headers["cache-control"] == "no-cache"
|
|
assert response.headers["x-accel-buffering"] == "no"
|
|
|
|
lines = response.text.strip().split("\n")
|
|
data_lines = [line for line in lines if line.startswith("data: ")]
|
|
assert len(data_lines) == 3
|
|
assert '"name":"Plumbus"' in data_lines[0] or '"name": "Plumbus"' in data_lines[0]
|
|
assert (
|
|
'"name":"Portal Gun"' in data_lines[1]
|
|
or '"name": "Portal Gun"' in data_lines[1]
|
|
)
|
|
assert (
|
|
'"name":"Meeseeks Box"' in data_lines[2]
|
|
or '"name": "Meeseeks Box"' in data_lines[2]
|
|
)
|
|
|
|
|
|
def test_sync_generator_with_model(client: TestClient):
|
|
response = client.get("/items/stream-sync")
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
|
|
|
data_lines = [
|
|
line for line in response.text.strip().split("\n") if line.startswith("data: ")
|
|
]
|
|
assert len(data_lines) == 3
|
|
|
|
|
|
def test_async_generator_no_annotation(client: TestClient):
|
|
response = client.get("/items/stream-no-annotation")
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
|
|
|
data_lines = [
|
|
line for line in response.text.strip().split("\n") if line.startswith("data: ")
|
|
]
|
|
assert len(data_lines) == 3
|
|
|
|
|
|
def test_sync_generator_no_annotation(client: TestClient):
|
|
response = client.get("/items/stream-sync-no-annotation")
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
|
|
|
data_lines = [
|
|
line for line in response.text.strip().split("\n") if line.startswith("data: ")
|
|
]
|
|
assert len(data_lines) == 3
|
|
|
|
|
|
def test_dict_items(client: TestClient):
|
|
response = client.get("/items/stream-dict")
|
|
assert response.status_code == 200
|
|
data_lines = [
|
|
line for line in response.text.strip().split("\n") if line.startswith("data: ")
|
|
]
|
|
assert len(data_lines) == 3
|
|
assert '"name"' in data_lines[0]
|
|
|
|
|
|
def test_post_method_sse(client: TestClient):
|
|
"""SSE should work with POST (needed for MCP compatibility)."""
|
|
response = client.post("/items/stream-post")
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
|
data_lines = [
|
|
line for line in response.text.strip().split("\n") if line.startswith("data: ")
|
|
]
|
|
assert len(data_lines) == 3
|
|
|
|
|
|
def test_sse_events_with_fields(client: TestClient):
|
|
response = client.get("/items/stream-sse-event")
|
|
assert response.status_code == 200
|
|
text = response.text
|
|
|
|
assert "event: greeting\n" in text
|
|
assert 'data: "hello"\n' in text
|
|
assert "id: 1\n" in text
|
|
|
|
assert "event: json-data\n" in text
|
|
assert "id: 2\n" in text
|
|
assert 'data: {"key": "value"}\n' in text
|
|
|
|
assert ": just a comment\n" in text
|
|
|
|
assert "retry: 5000\n" in text
|
|
assert 'data: "retry-test"\n' in text
|
|
|
|
|
|
def test_mixed_plain_and_sse_events(client: TestClient):
|
|
response = client.get("/items/stream-mixed")
|
|
assert response.status_code == 200
|
|
text = response.text
|
|
|
|
assert "event: special\n" in text
|
|
assert 'data: "custom-event"\n' in text
|
|
assert '"name"' in text
|
|
|
|
|
|
def test_string_data_json_encoded(client: TestClient):
|
|
"""Strings are always JSON-encoded (quoted)."""
|
|
response = client.get("/items/stream-string")
|
|
assert response.status_code == 200
|
|
assert 'data: "plain text data"\n' in response.text
|
|
|
|
|
|
def test_server_sent_event_null_id_rejected():
|
|
with pytest.raises(ValueError, match="null"):
|
|
ServerSentEvent(data="test", id="has\0null")
|
|
|
|
|
|
def test_server_sent_event_negative_retry_rejected():
|
|
with pytest.raises(ValueError):
|
|
ServerSentEvent(data="test", retry=-1)
|
|
|
|
|
|
def test_server_sent_event_float_retry_rejected():
|
|
with pytest.raises(ValueError):
|
|
ServerSentEvent(data="test", retry=1.5) # type: ignore[arg-type]
|
|
|
|
|
|
def test_raw_data_sent_without_json_encoding(client: TestClient):
|
|
"""raw_data is sent as-is, not JSON-encoded."""
|
|
response = client.get("/items/stream-raw")
|
|
assert response.status_code == 200
|
|
text = response.text
|
|
|
|
# raw_data should appear without JSON quotes
|
|
assert "data: plain text without quotes\n" in text
|
|
# Not JSON-quoted
|
|
assert 'data: "plain text without quotes"' not in text
|
|
|
|
assert "event: html\n" in text
|
|
assert "data: <div>html fragment</div>\n" in text
|
|
|
|
assert "event: csv\n" in text
|
|
assert "data: cpu,87.3,1709145600\n" in text
|
|
|
|
|
|
def test_data_and_raw_data_mutually_exclusive():
|
|
"""Cannot set both data and raw_data."""
|
|
with pytest.raises(ValueError, match="Cannot set both"):
|
|
ServerSentEvent(data="json", raw_data="raw")
|
|
|
|
|
|
def test_sse_on_router_included_in_app(client: TestClient):
|
|
response = client.get("/api/events")
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
|
data_lines = [
|
|
line for line in response.text.strip().split("\n") if line.startswith("data: ")
|
|
]
|
|
assert len(data_lines) == 2
|
|
|
|
|
|
# Keepalive ping tests
|
|
|
|
|
|
keepalive_app = FastAPI()
|
|
|
|
|
|
@keepalive_app.get("/slow-async", response_class=EventSourceResponse)
|
|
async def slow_async_stream():
|
|
yield {"n": 1}
|
|
# Sleep longer than the (monkeypatched) ping interval so a keepalive
|
|
# comment is emitted before the next item.
|
|
await asyncio.sleep(0.3)
|
|
yield {"n": 2}
|
|
|
|
|
|
@keepalive_app.get("/slow-sync", response_class=EventSourceResponse)
|
|
def slow_sync_stream():
|
|
yield {"n": 1}
|
|
time.sleep(0.3)
|
|
yield {"n": 2}
|
|
|
|
|
|
def test_keepalive_ping_async(monkeypatch: pytest.MonkeyPatch):
|
|
monkeypatch.setattr(fastapi.routing, "_PING_INTERVAL", 0.05)
|
|
with TestClient(keepalive_app) as c:
|
|
response = c.get("/slow-async")
|
|
assert response.status_code == 200
|
|
text = response.text
|
|
# The keepalive comment ": ping" should appear between the two data events
|
|
assert ": ping\n" in text
|
|
data_lines = [line for line in text.split("\n") if line.startswith("data: ")]
|
|
assert len(data_lines) == 2
|
|
|
|
|
|
def test_keepalive_ping_sync(monkeypatch: pytest.MonkeyPatch):
|
|
monkeypatch.setattr(fastapi.routing, "_PING_INTERVAL", 0.05)
|
|
with TestClient(keepalive_app) as c:
|
|
response = c.get("/slow-sync")
|
|
assert response.status_code == 200
|
|
text = response.text
|
|
assert ": ping\n" in text
|
|
data_lines = [line for line in text.split("\n") if line.startswith("data: ")]
|
|
assert len(data_lines) == 2
|
|
|
|
|
|
def test_no_keepalive_when_fast(client: TestClient):
|
|
"""No keepalive comment when items arrive quickly."""
|
|
response = client.get("/items/stream")
|
|
assert response.status_code == 200
|
|
# KEEPALIVE_COMMENT is ": ping\n\n".
|
|
assert ": ping\n" not in response.text
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Generic ServerSentEvent[T] tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_get_sse_data_type_parameterized():
|
|
"""get_sse_data_type returns the type argument for ServerSentEvent[T]."""
|
|
assert get_sse_data_type(ServerSentEvent[Item]) is Item
|
|
|
|
|
|
def test_get_sse_data_type_bare():
|
|
"""get_sse_data_type returns None for bare ServerSentEvent."""
|
|
assert get_sse_data_type(ServerSentEvent) is None
|
|
|
|
|
|
def test_get_sse_data_type_non_sse():
|
|
"""get_sse_data_type returns None for unrelated types."""
|
|
assert get_sse_data_type(Item) is None
|
|
assert get_sse_data_type(str) is None
|
|
assert get_sse_data_type(None) is None
|
|
|
|
|
|
def test_generic_sse_construction_validates_data():
|
|
"""ServerSentEvent[Item] requires data to be an Item."""
|
|
item = Item(name="Foo", description=None)
|
|
evt = ServerSentEvent[Item](data=item, event="update")
|
|
assert evt.data == item
|
|
assert evt.event == "update"
|
|
|
|
|
|
def test_generic_sse_rejects_wrong_type():
|
|
"""ServerSentEvent[Item] rejects data that is not an Item."""
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
|
|
with pytest.raises(ValidationError):
|
|
ServerSentEvent[Item](data="not an item")
|
|
|
|
|
|
def test_generic_sse_rejects_none_data():
|
|
"""ServerSentEvent[Item] rejects None as data (use Item | None if optional)."""
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
|
|
with pytest.raises(ValidationError):
|
|
ServerSentEvent[Item]()
|
|
|
|
|
|
def test_generic_sse_optional_data_allows_none():
|
|
"""ServerSentEvent[Item | None] accepts None as data."""
|
|
evt = ServerSentEvent[Item | None]()
|
|
assert evt.data is None
|
|
|
|
|
|
def test_bare_sse_still_accepts_none_data():
|
|
"""Bare ServerSentEvent (T=Any) still accepts None (backward compat)."""
|
|
evt = ServerSentEvent()
|
|
assert evt.data is None
|
|
|
|
|
|
# App-level test for generic SSE streaming and OpenAPI schema
|
|
|
|
_generic_app = FastAPI()
|
|
|
|
|
|
@_generic_app.get("/stream", response_class=EventSourceResponse)
|
|
async def _stream_typed() -> AsyncIterable[ServerSentEvent[Item]]:
|
|
for i, item in enumerate(items):
|
|
yield ServerSentEvent[Item](data=item, event="item", id=str(i + 1))
|
|
|
|
|
|
def test_generic_sse_streams_correctly():
|
|
with TestClient(_generic_app) as c:
|
|
response = c.get("/stream")
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
|
data_lines = [
|
|
line for line in response.text.split("\n") if line.startswith("data: ")
|
|
]
|
|
assert len(data_lines) == 3
|
|
import json
|
|
|
|
first = json.loads(data_lines[0][len("data: ") :])
|
|
assert first["name"] == "Plumbus"
|
|
|
|
|
|
def test_generic_sse_openapi_has_content_schema():
|
|
with TestClient(_generic_app) as c:
|
|
response = c.get("/openapi.json")
|
|
assert response.status_code == 200
|
|
schema = response.json()
|
|
sse_schema = schema["paths"]["/stream"]["get"]["responses"]["200"]["content"][
|
|
"text/event-stream"
|
|
]["itemSchema"]
|
|
assert sse_schema.get("required") == ["data"]
|
|
data_prop = sse_schema["properties"]["data"]
|
|
assert data_prop.get("contentMediaType") == "application/json"
|
|
content_schema = data_prop.get("contentSchema", {})
|
|
# Should reference Item (either inline or via $ref)
|
|
assert "$ref" in content_schema or content_schema.get("title") == "Item"
|
|
|
|
|
|
def test_bare_sse_openapi_has_no_content_schema():
|
|
"""Bare ServerSentEvent return type produces no contentSchema (backward compat)."""
|
|
bare_app = FastAPI()
|
|
|
|
@bare_app.get("/stream", response_class=EventSourceResponse)
|
|
async def _bare_stream() -> AsyncIterable[ServerSentEvent]:
|
|
yield ServerSentEvent(comment="ping")
|
|
|
|
with TestClient(bare_app) as c:
|
|
response = c.get("/openapi.json")
|
|
assert response.status_code == 200
|
|
schema = response.json()
|
|
sse_schema = schema["paths"]["/stream"]["get"]["responses"]["200"]["content"][
|
|
"text/event-stream"
|
|
]["itemSchema"]
|
|
assert "required" not in sse_schema
|
|
assert "contentSchema" not in sse_schema["properties"]["data"]
|
|
|