Browse Source

Add parse_sse_events: parse SSE wire-format responses into ParsedSSEEvent objects

pull/15645/head
AshNicolus 5 days ago
parent
commit
a774cc6297
  1. 162
      fastapi/sse.py
  2. 67
      tests/test_sse.py

162
fastapi/sse.py

@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Annotated, Any
from annotated_doc import Doc
@ -227,6 +228,167 @@ def format_sse_event(
return "\n".join(lines).encode("utf-8")
@dataclass(frozen=True)
class ParsedSSEEvent:
"""A Server-Sent Event parsed from the wire format.
Returned by `parse_sse_events()`. This is the *receiver-side* counterpart
to [`ServerSentEvent`](#serversentevent) (used to *send* events): `data`
here is the raw string from the wire (multi-line `data:` lines joined
with `\\n`), not JSON-decoded. Decoding is up to the caller, since the
payload may be JSON, plain text, or any other format depending on the
server.
Each instance reflects only fields explicitly set in its own event block
on the wire `id` and `retry` are not sticky across events here, unlike
a browser `EventSource` client. Stickiness is left to the caller when
needed.
"""
data: Annotated[
str,
Doc(
"""
The event payload multi-line `data:` lines joined with `\\n`,
with a single trailing `\\n` stripped per the SSE spec.
"""
),
]
event: Annotated[
str,
Doc(
"""
The event type. Defaults to `"message"` when no `event:` field
is present, matching what an `EventSource` browser client would
dispatch.
"""
),
] = "message"
id: Annotated[
str | None,
Doc(
"""
The event ID from the `id:` field, or `None` if not set on this
event block. (Not carried over from the previous event.)
"""
),
] = None
retry: Annotated[
int | None,
Doc(
"""
The reconnection time in milliseconds from the `retry:` field,
or `None` if not set on this event block.
"""
),
] = None
def parse_sse_events(
raw: Annotated[
bytes | str,
Doc(
"""
SSE wire-format text or bytes. Typically the full body of a
`text/event-stream` response.
"""
),
],
) -> list[ParsedSSEEvent]:
"""Parse an SSE event stream into a list of `ParsedSSEEvent` objects.
Implements the [WHATWG SSE parsing algorithm](https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation)
for a complete stream. This is the receiver-side counterpart to
`format_sse_event()`.
Useful for **tests**, **clients**, or any code that consumes the response
of an `EventSourceResponse` *path operation*.
Parsing rules followed (per spec):
* Lines may be separated by `\\n`, `\\r`, or `\\r\\n`.
* A leading UTF-8 BOM is stripped.
* Comment lines (those starting with `:`) are skipped.
* Multi-line `data:` fields are joined with `\\n`, with a single trailing
`\\n` stripped.
* Events with an empty data buffer are not emitted.
* Unknown field names are ignored.
* `id` values containing NULL bytes are ignored.
* `retry` values that aren't decimal integers are ignored.
Note: this returns events as they appear on the wire. `id` and `retry`
are **not sticky** across events in the returned list each
`ParsedSSEEvent` reflects only the fields seen in its own block.
"""
if isinstance(raw, bytes):
raw = raw.decode("utf-8")
# Strip a single leading BOM if present (per spec).
if raw.startswith(""):
raw = raw[1:]
# Normalize line endings: \r\n or \r → \n.
text = raw.replace("\r\n", "\n").replace("\r", "\n")
events: list[ParsedSSEEvent] = []
data_buf: list[str] = []
event_type: str | None = None
last_id: str | None = None
retry: int | None = None
def _dispatch() -> None:
nonlocal event_type, last_id, retry
# Per spec: if the data buffer is empty, do not dispatch the event.
if not data_buf:
event_type = None
return
data_str = "\n".join(data_buf)
events.append(
ParsedSSEEvent(
data=data_str,
event=event_type if event_type else "message",
id=last_id,
retry=retry,
)
)
data_buf.clear()
event_type = None
last_id = None
retry = None
for line in text.split("\n"):
if line == "":
_dispatch()
continue
if line.startswith(":"):
# Comment line, ignored per spec.
continue
if ":" in line:
field, _, value = line.partition(":")
# An optional single leading space after the colon is stripped.
if value.startswith(" "):
value = value[1:]
else:
# A line with no colon is treated as a field with empty value.
field = line
value = ""
if field == "data":
data_buf.append(value)
elif field == "event":
event_type = value
elif field == "id":
# Per spec: ignore IDs containing NULL bytes.
if "\0" not in value:
last_id = value
elif field == "retry":
# Per spec: must be a base-10 integer.
if value.isdigit():
retry = int(value)
# Other fields are ignored per spec.
return events
# Keep-alive comment, per the SSE spec recommendation
KEEPALIVE_COMMENT = b": ping\n\n"

67
tests/test_sse.py

@ -6,7 +6,12 @@ import fastapi.routing
import pytest
from fastapi import APIRouter, FastAPI
from fastapi.responses import EventSourceResponse
from fastapi.sse import ServerSentEvent
from fastapi.sse import (
ParsedSSEEvent,
ServerSentEvent,
format_sse_event,
parse_sse_events,
)
from fastapi.testclient import TestClient
from pydantic import BaseModel
@ -325,3 +330,63 @@ def test_no_keepalive_when_fast(client: TestClient):
assert response.status_code == 200
# KEEPALIVE_COMMENT is ": ping\n\n".
assert ": ping\n" not in response.text
def test_parse_sse_events_format_round_trip():
"""parse_sse_events reverses format_sse_event for the supported fields."""
stream = (
format_sse_event(data_str="hello", event="greeting", id="1")
+ format_sse_event(data_str='{"k": 1}', event="json", id="2", retry=5000)
+ format_sse_event(data_str="plain")
)
events = parse_sse_events(stream)
assert events == [
ParsedSSEEvent(data="hello", event="greeting", id="1"),
ParsedSSEEvent(data='{"k": 1}', event="json", id="2", retry=5000),
ParsedSSEEvent(data="plain"),
]
def test_parse_sse_events_multiline_data_joined_with_newline():
"""Multiple `data:` lines in one event are joined with `\\n`."""
events = parse_sse_events("data: line1\ndata: line2\ndata: line3\n\n")
assert events == [ParsedSSEEvent(data="line1\nline2\nline3")]
def test_parse_sse_events_comments_and_unknown_fields_ignored():
"""Comment lines and unrecognized fields are skipped per the spec."""
raw = ": this is a comment\nfoo: bar\ndata: payload\n\n"
assert parse_sse_events(raw) == [ParsedSSEEvent(data="payload")]
@pytest.mark.parametrize(
"raw",
[
b"data: hi\n\n", # bytes input
"data: hi\r\n\r\n", # CRLF line endings
"data: hi\r\r", # CR-only line endings
"data: hi\n\n", # BOM-prefixed
],
)
def test_parse_sse_events_input_variants(raw: bytes | str):
"""Bytes, CRLF, CR-only, and BOM-prefixed inputs are all accepted."""
assert parse_sse_events(raw) == [ParsedSSEEvent(data="hi")]
def test_parse_sse_events_invalid_id_and_retry_dropped():
"""NULL-containing ids and non-decimal retry values are dropped per spec."""
raw = "id: bad\0id\nretry: not-a-number\ndata: ok\n\n"
assert parse_sse_events(raw) == [ParsedSSEEvent(data="ok")]
def test_parse_sse_events_round_trip_through_endpoint(client: TestClient):
"""End-to-end: parse the response from a real EventSourceResponse endpoint."""
response = client.get("/items/stream-sse-event")
assert response.status_code == 200
events = parse_sse_events(response.text)
# The fixture endpoint yields events with greeting/json-data/etc., so we
# should have at least one event with a non-default `event` set.
assert events, "expected at least one parsed event"
assert any(e.event != "message" for e in events)
# `id`s on the wire are strings (per the SSE spec); we don't coerce them.
assert all(e.id is None or isinstance(e.id, str) for e in events)

Loading…
Cancel
Save