Browse Source

Make ServerSentEvent generic with typed data field

Introduce `ServerSentEvent[Data]` so endpoints that yield typed SSE
events get a `contentSchema` in the OpenAPI spec reflecting the data
payload type, while retaining full control over SSE fields (`event`,
`id`, `retry`, `comment`).

- `ServerSentEvent` now inherits from `Generic[Data]` with `data: Data`
- `validate_default=True` ensures `ServerSentEvent[Item]()` raises a
  ValidationError (data is effectively required when Data is concrete)
- `ServerSentEvent[Item | None]` allows optional data; bare
  `ServerSentEvent` is fully backward compatible (`Data=Any`)
- Added `get_sse_data_type()` helper (uses Pydantic's
  `__pydantic_generic_metadata__`) to extract `Data` from a
  parameterized `ServerSentEvent[Data]` annotation
- Routing layer now extracts `Data` from `ServerSentEvent[Data]` and
  uses it as `stream_item_type`, feeding it into the existing OpenAPI
  `contentSchema` pipeline
- Added tutorial006 and corresponding snapshot test
- Extended test_sse.py with generic SSE unit and app-level tests

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
pull/15191/head
Ben Mosher 3 months ago
parent
commit
e9212d8483
  1. 25
      docs_src/server_sent_events/tutorial006_py310.py
  2. 29
      fastapi/routing.py
  3. 59
      fastapi/sse.py
  4. 123
      tests/test_sse.py
  5. 100
      tests/test_tutorial/test_server_sent_events/test_tutorial006.py

25
docs_src/server_sent_events/tutorial006_py310.py

@ -0,0 +1,25 @@
from collections.abc import AsyncIterable
from fastapi import FastAPI
from fastapi.sse import EventSourceResponse, ServerSentEvent
from pydantic import BaseModel
app = FastAPI()
class Item(BaseModel):
name: str
price: float
items = [
Item(name="Plumbus", price=32.99),
Item(name="Portal Gun", price=999.99),
Item(name="Meeseeks Box", price=49.99),
]
@app.get("/items/stream", response_class=EventSourceResponse)
async def stream_items() -> AsyncIterable[ServerSentEvent[Item]]:
for i, item in enumerate(items):
yield ServerSentEvent[Item](data=item, event="item_update", id=str(i + 1))

29
fastapi/routing.py

@ -64,6 +64,7 @@ from fastapi.sse import (
EventSourceResponse,
ServerSentEvent,
format_sse_event,
get_sse_data_type,
)
from fastapi.types import DecoratedCallable, IncEx
from fastapi.utils import (
@ -854,14 +855,26 @@ class APIRoute(routing.Route):
# Extract item type for JSONL or SSE streaming when
# response_class is DefaultPlaceholder (JSONL) or
# EventSourceResponse (SSE).
# ServerSentEvent is excluded: it's a transport
# wrapper, not a data model, so it shouldn't feed
# into validation or OpenAPI schema generation.
if (
isinstance(response_class, DefaultPlaceholder)
or lenient_issubclass(response_class, EventSourceResponse)
) and not lenient_issubclass(stream_item, ServerSentEvent):
self.stream_item_type = stream_item
# Bare ServerSentEvent is excluded: it's a transport
# wrapper with no specific data type, so it doesn't
# feed into validation or OpenAPI schema generation.
# Parameterized ServerSentEvent[Data] is handled by
# extracting Data and using it as the item type.
if isinstance(
response_class, DefaultPlaceholder
) or lenient_issubclass(response_class, EventSourceResponse):
sse_data_type = get_sse_data_type(stream_item)
if sse_data_type is not None:
# ServerSentEvent[Data]: use Data for contentSchema
self.stream_item_type = sse_data_type
elif lenient_issubclass(stream_item, ServerSentEvent):
# Bare ServerSentEvent (no type param): transport
# wrapper with no specific data type, so no
# contentSchema in OpenAPI.
pass
else:
# Plain model (e.g. Item): use as-is
self.stream_item_type = stream_item
response_model = None
else:
response_model = return_annotation

59
fastapi/sse.py

@ -1,9 +1,16 @@
from typing import Annotated, Any
from typing import Annotated, Any, Generic, TypeVar
from annotated_doc import Doc
from pydantic import AfterValidator, BaseModel, Field, model_validator
from pydantic import AfterValidator, BaseModel, ConfigDict, Field, model_validator
from starlette.responses import StreamingResponse
Data = TypeVar("Data")
"""Type variable for the `data` payload of a `ServerSentEvent`.
Use ``ServerSentEvent[MyModel]`` to indicate that every event in the
stream carries a ``MyModel`` instance as its ``data`` field.
"""
# Canonical SSE event schema matching the OpenAPI 3.2 spec
# (Section 4.14.4 "Special Considerations for Server-Sent Events")
_SSE_EVENT_SCHEMA: dict[str, Any] = {
@ -39,7 +46,7 @@ def _check_id_no_null(v: str | None) -> str | None:
return v
class ServerSentEvent(BaseModel):
class ServerSentEvent(BaseModel, Generic[Data]):
"""Represents a single Server-Sent Event.
When `yield`ed from a *path operation function* that uses
@ -56,8 +63,14 @@ class ServerSentEvent(BaseModel):
quotes).
"""
# validate_default=True ensures that when Data is a concrete type (e.g.
# ServerSentEvent[Item]), omitting `data` raises a ValidationError rather
# than silently storing the None default. Without this, Pydantic skips
# default validation and None would be accepted even when Data=Item.
model_config = ConfigDict(validate_default=True)
data: Annotated[
Any,
Data,
Doc(
"""
The event payload.
@ -66,10 +79,19 @@ class ServerSentEvent(BaseModel):
string, number, etc. It is **always** serialized to JSON: strings
are quoted (`"hello"` becomes `data: "hello"` on the wire).
The type of `data` is controlled by the type variable `Data`:
* `ServerSentEvent[Item]` `data` must be an `Item` instance
(non-nullable; omitting `data` will raise a validation error).
* `ServerSentEvent[Item | None]` `data` may be `None`, which is
useful for comment-only or metadata events.
* Bare `ServerSentEvent` (no type parameter) `data` accepts any
value including `None`, preserving backward compatibility.
Mutually exclusive with `raw_data`.
"""
),
] = None
] = None # type: ignore[assignment]
raw_data: Annotated[
str | None,
Doc(
@ -220,3 +242,30 @@ KEEPALIVE_COMMENT = b": ping\n\n"
# Seconds between keep-alive pings when a generator is idle.
# Private but importable so tests can monkeypatch it.
_PING_INTERVAL: float = 15.0
def get_sse_data_type(annotation: Any) -> Any | None:
"""Extract the ``Data`` type from a ``ServerSentEvent[Data]`` annotation.
Returns ``None`` for bare ``ServerSentEvent`` (no type parameter) or for
any annotation that is not a parameterized ``ServerSentEvent``.
Used by the routing layer to build the ``stream_item_field`` for OpenAPI
schema generation when the endpoint yields ``ServerSentEvent[Data]``.
Pydantic's generic BaseModel creates a real subclass (not a
``_GenericAlias``), so ``get_origin`` returns ``None``. Instead, we
inspect ``__pydantic_generic_metadata__`` which Pydantic always attaches
to parameterised models.
"""
if not (isinstance(annotation, type) and issubclass(annotation, ServerSentEvent)):
return None
if annotation is ServerSentEvent:
return None
meta = getattr(annotation, "__pydantic_generic_metadata__", None)
if not meta:
return None
args = meta.get("args", ())
if not args or isinstance(args[0], TypeVar):
return None
return args[0]

123
tests/test_sse.py

@ -6,7 +6,7 @@ import fastapi.routing
import pytest
from fastapi import APIRouter, FastAPI
from fastapi.responses import EventSourceResponse
from fastapi.sse import ServerSentEvent
from fastapi.sse import ServerSentEvent, get_sse_data_type
from fastapi.testclient import TestClient
from pydantic import BaseModel
@ -316,3 +316,124 @@ def test_no_keepalive_when_fast(client: TestClient):
assert response.status_code == 200
# KEEPALIVE_COMMENT is ": ping\n\n".
assert ": ping\n" not in response.text
# ---------------------------------------------------------------------------
# Generic ServerSentEvent[T] tests
# ---------------------------------------------------------------------------
def test_get_sse_data_type_parameterized():
"""get_sse_data_type returns the type argument for ServerSentEvent[T]."""
assert get_sse_data_type(ServerSentEvent[Item]) is Item
def test_get_sse_data_type_bare():
"""get_sse_data_type returns None for bare ServerSentEvent."""
assert get_sse_data_type(ServerSentEvent) is None
def test_get_sse_data_type_non_sse():
"""get_sse_data_type returns None for unrelated types."""
assert get_sse_data_type(Item) is None
assert get_sse_data_type(str) is None
assert get_sse_data_type(None) is None
def test_generic_sse_construction_validates_data():
"""ServerSentEvent[Item] requires data to be an Item."""
item = Item(name="Foo", description=None)
evt = ServerSentEvent[Item](data=item, event="update")
assert evt.data == item
assert evt.event == "update"
def test_generic_sse_rejects_wrong_type():
"""ServerSentEvent[Item] rejects data that is not an Item."""
import pytest
from pydantic import ValidationError
with pytest.raises(ValidationError):
ServerSentEvent[Item](data="not an item")
def test_generic_sse_rejects_none_data():
"""ServerSentEvent[Item] rejects None as data (use Item | None if optional)."""
import pytest
from pydantic import ValidationError
with pytest.raises(ValidationError):
ServerSentEvent[Item]()
def test_generic_sse_optional_data_allows_none():
"""ServerSentEvent[Item | None] accepts None as data."""
evt = ServerSentEvent[Item | None]()
assert evt.data is None
def test_bare_sse_still_accepts_none_data():
"""Bare ServerSentEvent (T=Any) still accepts None (backward compat)."""
evt = ServerSentEvent()
assert evt.data is None
# App-level test for generic SSE streaming and OpenAPI schema
_generic_app = FastAPI()
@_generic_app.get("/stream", response_class=EventSourceResponse)
async def _stream_typed() -> AsyncIterable[ServerSentEvent[Item]]:
for i, item in enumerate(items):
yield ServerSentEvent[Item](data=item, event="item", id=str(i + 1))
def test_generic_sse_streams_correctly():
with TestClient(_generic_app) as c:
response = c.get("/stream")
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
data_lines = [
line for line in response.text.split("\n") if line.startswith("data: ")
]
assert len(data_lines) == 3
import json
first = json.loads(data_lines[0][len("data: ") :])
assert first["name"] == "Plumbus"
def test_generic_sse_openapi_has_content_schema():
with TestClient(_generic_app) as c:
response = c.get("/openapi.json")
assert response.status_code == 200
schema = response.json()
sse_schema = schema["paths"]["/stream"]["get"]["responses"]["200"]["content"][
"text/event-stream"
]["itemSchema"]
assert sse_schema.get("required") == ["data"]
data_prop = sse_schema["properties"]["data"]
assert data_prop.get("contentMediaType") == "application/json"
content_schema = data_prop.get("contentSchema", {})
# Should reference Item (either inline or via $ref)
assert "$ref" in content_schema or content_schema.get("title") == "Item"
def test_bare_sse_openapi_has_no_content_schema():
"""Bare ServerSentEvent return type produces no contentSchema (backward compat)."""
bare_app = FastAPI()
@bare_app.get("/stream", response_class=EventSourceResponse)
async def _bare_stream() -> AsyncIterable[ServerSentEvent]:
yield ServerSentEvent(comment="ping")
with TestClient(bare_app) as c:
response = c.get("/openapi.json")
assert response.status_code == 200
schema = response.json()
sse_schema = schema["paths"]["/stream"]["get"]["responses"]["200"]["content"][
"text/event-stream"
]["itemSchema"]
assert "required" not in sse_schema
assert "contentSchema" not in sse_schema["properties"]["data"]

100
tests/test_tutorial/test_server_sent_events/test_tutorial006.py

@ -0,0 +1,100 @@
import importlib
import json
import pytest
from fastapi.testclient import TestClient
from inline_snapshot import snapshot
@pytest.fixture(
name="client",
params=[
pytest.param("tutorial006_py310"),
],
)
def get_client(request: pytest.FixtureRequest):
mod = importlib.import_module(f"docs_src.server_sent_events.{request.param}")
client = TestClient(mod.app)
return client
def test_stream_items(client: TestClient):
response = client.get("/items/stream")
assert response.status_code == 200, response.text
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
lines = response.text.strip().split("\n")
event_lines = [line for line in lines if line.startswith("event: ")]
assert len(event_lines) == 3
assert all(line == "event: item_update" for line in event_lines)
data_lines = [line for line in lines if line.startswith("data: ")]
assert len(data_lines) == 3
payloads = [json.loads(line[len("data: ") :]) for line in data_lines]
assert payloads[0] == {"name": "Plumbus", "price": 32.99}
assert payloads[1] == {"name": "Portal Gun", "price": 999.99}
assert payloads[2] == {"name": "Meeseeks Box", "price": 49.99}
id_lines = [line for line in lines if line.startswith("id: ")]
assert id_lines == ["id: 1", "id: 2", "id: 3"]
def test_openapi_schema(client: TestClient):
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == snapshot(
{
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/items/stream": {
"get": {
"summary": "Stream Items",
"operationId": "stream_items_items_stream_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"text/event-stream": {
"itemSchema": {
"type": "object",
"properties": {
"data": {
"type": "string",
"contentMediaType": "application/json",
"contentSchema": {
"$ref": "#/components/schemas/Item"
},
},
"event": {"type": "string"},
"id": {"type": "string"},
"retry": {
"type": "integer",
"minimum": 0,
},
},
"required": ["data"],
}
}
},
}
},
}
}
},
"components": {
"schemas": {
"Item": {
"properties": {
"name": {"type": "string", "title": "Name"},
"price": {"type": "number", "title": "Price"},
},
"type": "object",
"required": ["name", "price"],
"title": "Item",
}
}
},
}
)
Loading…
Cancel
Save