4 changed files with 137 additions and 3 deletions
@ -0,0 +1,78 @@ |
|||
from contextlib import contextmanager |
|||
from typing import Annotated, Any, Generator |
|||
|
|||
import pytest |
|||
from fastapi import Depends, FastAPI, WebSocket |
|||
from fastapi.testclient import TestClient |
|||
|
|||
|
|||
class Session: |
|||
def __init__(self) -> None: |
|||
self.data = ["foo", "bar", "baz"] |
|||
self.open = True |
|||
|
|||
def __iter__(self) -> Generator[str, None, None]: |
|||
for item in self.data: |
|||
if self.open: |
|||
yield item |
|||
else: |
|||
raise ValueError("Session closed") |
|||
|
|||
|
|||
@contextmanager |
|||
def acquire_session() -> Generator[Session, None, None]: |
|||
session = Session() |
|||
try: |
|||
yield session |
|||
finally: |
|||
session.open = False |
|||
|
|||
|
|||
def dep_session() -> Any: |
|||
with acquire_session() as s: |
|||
yield s |
|||
|
|||
|
|||
def broken_dep_session() -> Any: |
|||
with acquire_session() as s: |
|||
s.open = False |
|||
yield s |
|||
|
|||
|
|||
SessionDep = Annotated[Session, Depends(dep_session)] |
|||
BrokenSessionDep = Annotated[Session, Depends(broken_dep_session)] |
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
@app.websocket("/ws") |
|||
async def websocket_endpoint(websocket: WebSocket, session: SessionDep): |
|||
await websocket.accept() |
|||
for item in session: |
|||
await websocket.send_text(f"{item}") |
|||
|
|||
|
|||
@app.websocket("/ws-broken") |
|||
async def websocket_endpoint_broken(websocket: WebSocket, session: BrokenSessionDep): |
|||
await websocket.accept() |
|||
for item in session: |
|||
await websocket.send_text(f"{item}") # pragma no cover |
|||
|
|||
|
|||
client = TestClient(app) |
|||
|
|||
|
|||
def test_websocket_dependency_after_yield(): |
|||
with client.websocket_connect("/ws") as websocket: |
|||
data = websocket.receive_text() |
|||
assert data == "foo" |
|||
data = websocket.receive_text() |
|||
assert data == "bar" |
|||
data = websocket.receive_text() |
|||
assert data == "baz" |
|||
|
|||
|
|||
def test_websocket_dependency_after_yield_broken(): |
|||
with pytest.raises(ValueError, match="Session closed"): |
|||
with client.websocket_connect("/ws-broken"): |
|||
pass # pragma no cover |
Loading…
Reference in new issue