4 changed files with 137 additions and 3 deletions
@ -0,0 +1,78 @@ |
|||||
|
from contextlib import contextmanager |
||||
|
from typing import Annotated, Any, Generator |
||||
|
|
||||
|
import pytest |
||||
|
from fastapi import Depends, FastAPI, WebSocket |
||||
|
from fastapi.testclient import TestClient |
||||
|
|
||||
|
|
||||
|
class Session: |
||||
|
def __init__(self) -> None: |
||||
|
self.data = ["foo", "bar", "baz"] |
||||
|
self.open = True |
||||
|
|
||||
|
def __iter__(self) -> Generator[str, None, None]: |
||||
|
for item in self.data: |
||||
|
if self.open: |
||||
|
yield item |
||||
|
else: |
||||
|
raise ValueError("Session closed") |
||||
|
|
||||
|
|
||||
|
@contextmanager |
||||
|
def acquire_session() -> Generator[Session, None, None]: |
||||
|
session = Session() |
||||
|
try: |
||||
|
yield session |
||||
|
finally: |
||||
|
session.open = False |
||||
|
|
||||
|
|
||||
|
def dep_session() -> Any: |
||||
|
with acquire_session() as s: |
||||
|
yield s |
||||
|
|
||||
|
|
||||
|
def broken_dep_session() -> Any: |
||||
|
with acquire_session() as s: |
||||
|
s.open = False |
||||
|
yield s |
||||
|
|
||||
|
|
||||
|
SessionDep = Annotated[Session, Depends(dep_session)] |
||||
|
BrokenSessionDep = Annotated[Session, Depends(broken_dep_session)] |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
@app.websocket("/ws") |
||||
|
async def websocket_endpoint(websocket: WebSocket, session: SessionDep): |
||||
|
await websocket.accept() |
||||
|
for item in session: |
||||
|
await websocket.send_text(f"{item}") |
||||
|
|
||||
|
|
||||
|
@app.websocket("/ws-broken") |
||||
|
async def websocket_endpoint_broken(websocket: WebSocket, session: BrokenSessionDep): |
||||
|
await websocket.accept() |
||||
|
for item in session: |
||||
|
await websocket.send_text(f"{item}") # pragma no cover |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def test_websocket_dependency_after_yield(): |
||||
|
with client.websocket_connect("/ws") as websocket: |
||||
|
data = websocket.receive_text() |
||||
|
assert data == "foo" |
||||
|
data = websocket.receive_text() |
||||
|
assert data == "bar" |
||||
|
data = websocket.receive_text() |
||||
|
assert data == "baz" |
||||
|
|
||||
|
|
||||
|
def test_websocket_dependency_after_yield_broken(): |
||||
|
with pytest.raises(ValueError, match="Session closed"): |
||||
|
with client.websocket_connect("/ws-broken"): |
||||
|
pass # pragma no cover |
Loading…
Reference in new issue