You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

78 lines
1.6 KiB

from contextlib import contextmanager
from typing import Annotated, Any, Generator
from fastapi import Depends, FastAPI
from fastapi.responses import StreamingResponse
from fastapi.testclient import TestClient
class Session:
def __init__(self) -> None:
self.data = ["foo", "bar", "baz"]
self.open = True
def __iter__(self) -> Generator[str, None, None]:
for item in self.data:
if self.open:
yield item
else:
raise ValueError("Session closed")
@contextmanager
def acquire_session() -> Generator[Session, None, None]:
session = Session()
try:
yield session
finally:
session.open = False
def dep_session() -> Any:
with acquire_session() as s:
yield s
SessionDep = Annotated[Session, Depends(dep_session)]
app = FastAPI()
@app.get("/data")
def get_data(session: SessionDep) -> Any:
data = list(session)
return data
@app.get("/stream-simple")
def get_stream_simple(session: SessionDep) -> Any:
def iter_data():
yield from ["x", "y", "z"]
return StreamingResponse(iter_data())
@app.get("/stream-session")
def get_stream_session(session: SessionDep) -> Any:
def iter_data():
yield from session
return StreamingResponse(iter_data())
client = TestClient(app)
def test_regular_no_stream():
response = client.get("/data")
assert response.json() == ["foo", "bar", "baz"]
def test_stream_simple():
response = client.get("/stream-simple")
assert response.text == "xyz"
def test_stream_session():
response = client.get("/stream-session")
assert response.text == "foobarbaz"