You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

397 lines
12 KiB

from dataclasses import dataclass
from typing import Annotated, Any
from annotated_doc import Doc
from pydantic import AfterValidator, BaseModel, Field, model_validator
from starlette.responses import StreamingResponse
# Canonical SSE event schema matching the OpenAPI 3.2 spec
# (Section 4.14.4 "Special Considerations for Server-Sent Events")
_SSE_EVENT_SCHEMA: dict[str, Any] = {
"type": "object",
"properties": {
"data": {"type": "string"},
"event": {"type": "string"},
"id": {"type": "string"},
"retry": {"type": "integer", "minimum": 0},
},
}
class EventSourceResponse(StreamingResponse):
"""Streaming response with `text/event-stream` media type.
Use as `response_class=EventSourceResponse` on a *path operation* that uses `yield`
to enable Server Sent Events (SSE) responses.
Works with **any HTTP method** (`GET`, `POST`, etc.), which makes it compatible
with protocols like MCP that stream SSE over `POST`.
The actual encoding logic lives in the FastAPI routing layer. This class
serves mainly as a marker and sets the correct `Content-Type`.
"""
media_type = "text/event-stream"
def _check_single_line(v: str | None, field_name: str) -> str | None:
if v is not None and ("\r" in v or "\n" in v):
raise ValueError(f"SSE '{field_name}' must be a single line")
return v
def _check_event_single_line(v: str | None) -> str | None:
return _check_single_line(v, "event")
def _check_id_valid(v: str | None) -> str | None:
if v is not None and "\0" in v:
raise ValueError("SSE 'id' must not contain null characters")
return _check_single_line(v, "id")
class ServerSentEvent(BaseModel):
"""Represents a single Server-Sent Event.
When `yield`ed from a *path operation function* that uses
`response_class=EventSourceResponse`, each `ServerSentEvent` is encoded
into the [SSE wire format](https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream)
(`text/event-stream`).
If you yield a plain object (dict, Pydantic model, etc.) instead, it is
automatically JSON-encoded and sent as the `data:` field.
All `data` values **including plain strings** are JSON-serialized.
For example, `data="hello"` produces `data: "hello"` on the wire (with
quotes).
"""
data: Annotated[
Any,
Doc(
"""
The event payload.
Can be any JSON-serializable value: a Pydantic model, dict, list,
string, number, etc. It is **always** serialized to JSON: strings
are quoted (`"hello"` becomes `data: "hello"` on the wire).
Mutually exclusive with `raw_data`.
"""
),
] = None
raw_data: Annotated[
str | None,
Doc(
"""
Raw string to send as the `data:` field **without** JSON encoding.
Use this when you need to send pre-formatted text, HTML fragments,
CSV lines, or any non-JSON payload. The string is placed directly
into the `data:` field as-is.
Mutually exclusive with `data`.
"""
),
] = None
event: Annotated[
str | None,
AfterValidator(_check_event_single_line),
Doc(
"""
Optional event type name.
Maps to `addEventListener(event, ...)` on the browser. When omitted,
the browser dispatches on the generic `message` event. Must be a
single line.
"""
),
] = None
id: Annotated[
str | None,
AfterValidator(_check_id_valid),
Doc(
"""
Optional event ID.
The browser sends this value back as the `Last-Event-ID` header on
automatic reconnection. **Must be a single line** and must not contain
null (`\\0`) characters.
"""
),
] = None
retry: Annotated[
int | None,
Field(ge=0),
Doc(
"""
Optional reconnection time in **milliseconds**.
Tells the browser how long to wait before reconnecting after the
connection is lost. Must be a non-negative integer.
"""
),
] = None
comment: Annotated[
str | None,
Doc(
"""
Optional comment line(s).
Comment lines start with `:` in the SSE wire format and are ignored by
`EventSource` clients. Useful for keep-alive pings to prevent
proxy/load-balancer timeouts.
"""
),
] = None
@model_validator(mode="after")
def _check_data_exclusive(self) -> "ServerSentEvent":
if self.data is not None and self.raw_data is not None:
raise ValueError(
"Cannot set both 'data' and 'raw_data' on the same "
"ServerSentEvent. Use 'data' for JSON-serialized payloads "
"or 'raw_data' for pre-formatted strings."
)
return self
def format_sse_event(
*,
data_str: Annotated[
str | None,
Doc(
"""
Pre-serialized data string to use as the `data:` field.
"""
),
] = None,
event: Annotated[
str | None,
Doc(
"""
Optional event type name (`event:` field).
"""
),
] = None,
id: Annotated[
str | None,
Doc(
"""
Optional event ID (`id:` field).
"""
),
] = None,
retry: Annotated[
int | None,
Doc(
"""
Optional reconnection time in milliseconds (`retry:` field).
"""
),
] = None,
comment: Annotated[
str | None,
Doc(
"""
Optional comment line(s) (`:` prefix).
"""
),
] = None,
) -> bytes:
"""Build SSE wire-format bytes from **pre-serialized** data.
The result always ends with `\n\n` (the event terminator).
"""
lines: list[str] = []
if comment is not None:
for line in comment.splitlines():
lines.append(f": {line}")
if event is not None:
lines.append(f"event: {event}")
if data_str is not None:
for line in data_str.splitlines():
lines.append(f"data: {line}")
if id is not None:
lines.append(f"id: {id}")
if retry is not None:
lines.append(f"retry: {retry}")
lines.append("")
lines.append("")
return "\n".join(lines).encode("utf-8")
@dataclass(frozen=True)
class ParsedSSEEvent:
"""A Server-Sent Event parsed from the wire format.
Returned by `parse_sse_events()`. This is the *receiver-side* counterpart
to [`ServerSentEvent`](#serversentevent) (used to *send* events): `data`
here is the raw string from the wire (multi-line `data:` lines joined
with `\\n`), not JSON-decoded. Decoding is up to the caller, since the
payload may be JSON, plain text, or any other format depending on the
server.
Each instance reflects only fields explicitly set in its own event block
on the wire — `id` and `retry` are not sticky across events here, unlike
a browser `EventSource` client. Stickiness is left to the caller when
needed.
"""
data: Annotated[
str,
Doc(
"""
The event payload — multi-line `data:` lines joined with `\\n`,
with a single trailing `\\n` stripped per the SSE spec.
"""
),
]
event: Annotated[
str,
Doc(
"""
The event type. Defaults to `"message"` when no `event:` field
is present, matching what an `EventSource` browser client would
dispatch.
"""
),
] = "message"
id: Annotated[
str | None,
Doc(
"""
The event ID from the `id:` field, or `None` if not set on this
event block. (Not carried over from the previous event.)
"""
),
] = None
retry: Annotated[
int | None,
Doc(
"""
The reconnection time in milliseconds from the `retry:` field,
or `None` if not set on this event block.
"""
),
] = None
def parse_sse_events(
raw: Annotated[
bytes | str,
Doc(
"""
SSE wire-format text or bytes. Typically the full body of a
`text/event-stream` response.
"""
),
],
) -> list[ParsedSSEEvent]:
"""Parse an SSE event stream into a list of `ParsedSSEEvent` objects.
Implements the [WHATWG SSE parsing algorithm](https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation)
for a complete stream. This is the receiver-side counterpart to
`format_sse_event()`.
Useful for **tests**, **clients**, or any code that consumes the response
of an `EventSourceResponse` *path operation*.
Parsing rules followed (per spec):
* Lines may be separated by `\\n`, `\\r`, or `\\r\\n`.
* A leading UTF-8 BOM is stripped.
* Comment lines (those starting with `:`) are skipped.
* Multi-line `data:` fields are joined with `\\n`, with a single trailing
`\\n` stripped.
* Events with an empty data buffer are not emitted.
* Unknown field names are ignored.
* `id` values containing NULL bytes are ignored.
* `retry` values that aren't decimal integers are ignored.
Note: this returns events as they appear on the wire. `id` and `retry`
are **not sticky** across events in the returned list — each
`ParsedSSEEvent` reflects only the fields seen in its own block.
"""
if isinstance(raw, bytes):
raw = raw.decode("utf-8")
# Strip a single leading BOM if present (per spec).
if raw.startswith(""):
raw = raw[1:]
# Normalize line endings: \r\n or \r → \n.
text = raw.replace("\r\n", "\n").replace("\r", "\n")
events: list[ParsedSSEEvent] = []
data_buf: list[str] = []
event_type: str | None = None
last_id: str | None = None
retry: int | None = None
def _dispatch() -> None:
nonlocal event_type, last_id, retry
# Per spec: if the data buffer is empty, do not dispatch the event.
if not data_buf:
event_type = None
return
data_str = "\n".join(data_buf)
events.append(
ParsedSSEEvent(
data=data_str,
event=event_type if event_type else "message",
id=last_id,
retry=retry,
)
)
data_buf.clear()
event_type = None
last_id = None
retry = None
for line in text.split("\n"):
if line == "":
_dispatch()
continue
if line.startswith(":"):
# Comment line, ignored per spec.
continue
if ":" in line:
field, _, value = line.partition(":")
# An optional single leading space after the colon is stripped.
if value.startswith(" "):
value = value[1:]
else:
# A line with no colon is treated as a field with empty value.
field = line
value = ""
if field == "data":
data_buf.append(value)
elif field == "event":
event_type = value
elif field == "id":
# Per spec: ignore IDs containing NULL bytes.
if "\0" not in value:
last_id = value
elif field == "retry":
# Per spec: must be a base-10 integer.
if value.isdigit():
retry = int(value)
# Other fields are ignored per spec.
return events
# Keep-alive comment, per the SSE spec recommendation
KEEPALIVE_COMMENT = b": ping\n\n"
# Seconds between keep-alive pings when a generator is idle.
# Private but importable so tests can monkeypatch it.
_PING_INTERVAL: float = 15.0