You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

79 lines
2.0 KiB

from contextlib import contextmanager
from typing import Any, Generator
import pytest
from fastapi import Depends, FastAPI, WebSocket
from fastapi.testclient import TestClient
from typing_extensions import Annotated
class Session:
def __init__(self) -> None:
self.data = ["foo", "bar", "baz"]
self.open = True
def __iter__(self) -> Generator[str, None, None]:
for item in self.data:
if self.open:
yield item
else:
raise ValueError("Session closed")
@contextmanager
def acquire_session() -> Generator[Session, None, None]:
session = Session()
try:
yield session
finally:
session.open = False
def dep_session() -> Any:
with acquire_session() as s:
yield s
def broken_dep_session() -> Any:
with acquire_session() as s:
s.open = False
yield s
SessionDep = Annotated[Session, Depends(dep_session)]
BrokenSessionDep = Annotated[Session, Depends(broken_dep_session)]
app = FastAPI()
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, session: SessionDep):
await websocket.accept()
for item in session:
await websocket.send_text(f"{item}")
@app.websocket("/ws-broken")
async def websocket_endpoint_broken(websocket: WebSocket, session: BrokenSessionDep):
await websocket.accept()
for item in session:
await websocket.send_text(f"{item}") # pragma no cover
client = TestClient(app)
def test_websocket_dependency_after_yield():
with client.websocket_connect("/ws") as websocket:
data = websocket.receive_text()
assert data == "foo"
data = websocket.receive_text()
assert data == "bar"
data = websocket.receive_text()
assert data == "baz"
def test_websocket_dependency_after_yield_broken():
with pytest.raises(ValueError, match="Session closed"):
with client.websocket_connect("/ws-broken"):
pass # pragma no cover