Browse Source
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/13974/head
committed by
GitHub
14 changed files with 729 additions and 182 deletions
@ -0,0 +1,38 @@ |
|||
import time |
|||
from typing import Annotated |
|||
|
|||
from fastapi import Depends, FastAPI, HTTPException |
|||
from fastapi.responses import StreamingResponse |
|||
from sqlmodel import Field, Session, SQLModel, create_engine |
|||
|
|||
engine = create_engine("postgresql+psycopg://postgres:postgres@localhost/db") |
|||
|
|||
|
|||
class User(SQLModel, table=True): |
|||
id: int | None = Field(default=None, primary_key=True) |
|||
name: str |
|||
|
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
def get_session(): |
|||
with Session(engine) as session: |
|||
yield session |
|||
|
|||
|
|||
def get_user(user_id: int, session: Annotated[Session, Depends(get_session)]): |
|||
user = session.get(User, user_id) |
|||
if not user: |
|||
raise HTTPException(status_code=403, detail="Not authorized") |
|||
|
|||
|
|||
def generate_stream(query: str): |
|||
for ch in query: |
|||
yield ch |
|||
time.sleep(0.1) |
|||
|
|||
|
|||
@app.get("/generate", dependencies=[Depends(get_user)]) |
|||
def generate(query: str): |
|||
return StreamingResponse(content=generate_stream(query)) |
@ -0,0 +1,39 @@ |
|||
import time |
|||
from typing import Annotated |
|||
|
|||
from fastapi import Depends, FastAPI, HTTPException |
|||
from fastapi.responses import StreamingResponse |
|||
from sqlmodel import Field, Session, SQLModel, create_engine |
|||
|
|||
engine = create_engine("postgresql+psycopg://postgres:postgres@localhost/db") |
|||
|
|||
|
|||
class User(SQLModel, table=True): |
|||
id: int | None = Field(default=None, primary_key=True) |
|||
name: str |
|||
|
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
def get_session(): |
|||
with Session(engine) as session: |
|||
yield session |
|||
|
|||
|
|||
def get_user(user_id: int, session: Annotated[Session, Depends(get_session)]): |
|||
user = session.get(User, user_id) |
|||
if not user: |
|||
raise HTTPException(status_code=403, detail="Not authorized") |
|||
session.close() |
|||
|
|||
|
|||
def generate_stream(query: str): |
|||
for ch in query: |
|||
yield ch |
|||
time.sleep(0.1) |
|||
|
|||
|
|||
@app.get("/generate", dependencies=[Depends(get_user)]) |
|||
def generate(query: str): |
|||
return StreamingResponse(content=generate_stream(query)) |
@ -0,0 +1,18 @@ |
|||
from contextlib import AsyncExitStack |
|||
|
|||
from starlette.types import ASGIApp, Receive, Scope, Send |
|||
|
|||
|
|||
# Used mainly to close files after the request is done, dependencies are closed |
|||
# in their own AsyncExitStack |
|||
class AsyncExitStackMiddleware: |
|||
def __init__( |
|||
self, app: ASGIApp, context_name: str = "fastapi_middleware_astack" |
|||
) -> None: |
|||
self.app = app |
|||
self.context_name = context_name |
|||
|
|||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
|||
async with AsyncExitStack() as stack: |
|||
scope[self.context_name] = stack |
|||
await self.app(scope, receive, send) |
@ -0,0 +1,69 @@ |
|||
from typing import Any |
|||
|
|||
import pytest |
|||
from fastapi import Depends, FastAPI, HTTPException |
|||
from fastapi.testclient import TestClient |
|||
from typing_extensions import Annotated |
|||
|
|||
|
|||
class CustomError(Exception): |
|||
pass |
|||
|
|||
|
|||
def catching_dep() -> Any: |
|||
try: |
|||
yield "s" |
|||
except CustomError as err: |
|||
raise HTTPException(status_code=418, detail="Session error") from err |
|||
|
|||
|
|||
def broken_dep() -> Any: |
|||
yield "s" |
|||
raise ValueError("Broken after yield") |
|||
|
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
@app.get("/catching") |
|||
def catching(d: Annotated[str, Depends(catching_dep)]) -> Any: |
|||
raise CustomError("Simulated error during streaming") |
|||
|
|||
|
|||
@app.get("/broken") |
|||
def broken(d: Annotated[str, Depends(broken_dep)]) -> Any: |
|||
return {"message": "all good?"} |
|||
|
|||
|
|||
client = TestClient(app) |
|||
|
|||
|
|||
def test_catching(): |
|||
response = client.get("/catching") |
|||
assert response.status_code == 418 |
|||
assert response.json() == {"detail": "Session error"} |
|||
|
|||
|
|||
def test_broken_raise(): |
|||
with pytest.raises(ValueError, match="Broken after yield"): |
|||
client.get("/broken") |
|||
|
|||
|
|||
def test_broken_no_raise(): |
|||
""" |
|||
When a dependency with yield raises after the yield (not in an except), the |
|||
response is already "successfully" sent back to the client, but there's still |
|||
an error in the server afterwards, an exception is raised and captured or shown |
|||
in the server logs. |
|||
""" |
|||
with TestClient(app, raise_server_exceptions=False) as client: |
|||
response = client.get("/broken") |
|||
assert response.status_code == 200 |
|||
assert response.json() == {"message": "all good?"} |
|||
|
|||
|
|||
def test_broken_return_finishes(): |
|||
client = TestClient(app, raise_server_exceptions=False) |
|||
response = client.get("/broken") |
|||
assert response.status_code == 200 |
|||
assert response.json() == {"message": "all good?"} |
@ -0,0 +1,130 @@ |
|||
from contextlib import contextmanager |
|||
from typing import Any, Generator |
|||
|
|||
import pytest |
|||
from fastapi import Depends, FastAPI |
|||
from fastapi.responses import StreamingResponse |
|||
from fastapi.testclient import TestClient |
|||
from typing_extensions import Annotated |
|||
|
|||
|
|||
class Session: |
|||
def __init__(self) -> None: |
|||
self.data = ["foo", "bar", "baz"] |
|||
self.open = True |
|||
|
|||
def __iter__(self) -> Generator[str, None, None]: |
|||
for item in self.data: |
|||
if self.open: |
|||
yield item |
|||
else: |
|||
raise ValueError("Session closed") |
|||
|
|||
|
|||
@contextmanager |
|||
def acquire_session() -> Generator[Session, None, None]: |
|||
session = Session() |
|||
try: |
|||
yield session |
|||
finally: |
|||
session.open = False |
|||
|
|||
|
|||
def dep_session() -> Any: |
|||
with acquire_session() as s: |
|||
yield s |
|||
|
|||
|
|||
def broken_dep_session() -> Any: |
|||
with acquire_session() as s: |
|||
s.open = False |
|||
yield s |
|||
|
|||
|
|||
SessionDep = Annotated[Session, Depends(dep_session)] |
|||
BrokenSessionDep = Annotated[Session, Depends(broken_dep_session)] |
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
@app.get("/data") |
|||
def get_data(session: SessionDep) -> Any: |
|||
data = list(session) |
|||
return data |
|||
|
|||
|
|||
@app.get("/stream-simple") |
|||
def get_stream_simple(session: SessionDep) -> Any: |
|||
def iter_data(): |
|||
yield from ["x", "y", "z"] |
|||
|
|||
return StreamingResponse(iter_data()) |
|||
|
|||
|
|||
@app.get("/stream-session") |
|||
def get_stream_session(session: SessionDep) -> Any: |
|||
def iter_data(): |
|||
yield from session |
|||
|
|||
return StreamingResponse(iter_data()) |
|||
|
|||
|
|||
@app.get("/broken-session-data") |
|||
def get_broken_session_data(session: BrokenSessionDep) -> Any: |
|||
return list(session) |
|||
|
|||
|
|||
@app.get("/broken-session-stream") |
|||
def get_broken_session_stream(session: BrokenSessionDep) -> Any: |
|||
def iter_data(): |
|||
yield from session |
|||
|
|||
return StreamingResponse(iter_data()) |
|||
|
|||
|
|||
client = TestClient(app) |
|||
|
|||
|
|||
def test_regular_no_stream(): |
|||
response = client.get("/data") |
|||
assert response.json() == ["foo", "bar", "baz"] |
|||
|
|||
|
|||
def test_stream_simple(): |
|||
response = client.get("/stream-simple") |
|||
assert response.text == "xyz" |
|||
|
|||
|
|||
def test_stream_session(): |
|||
response = client.get("/stream-session") |
|||
assert response.text == "foobarbaz" |
|||
|
|||
|
|||
def test_broken_session_data(): |
|||
with pytest.raises(ValueError, match="Session closed"): |
|||
client.get("/broken-session-data") |
|||
|
|||
|
|||
def test_broken_session_data_no_raise(): |
|||
client = TestClient(app, raise_server_exceptions=False) |
|||
response = client.get("/broken-session-data") |
|||
assert response.status_code == 500 |
|||
assert response.text == "Internal Server Error" |
|||
|
|||
|
|||
def test_broken_session_stream_raise(): |
|||
# Can raise ValueError on Pydantic v2 and ExceptionGroup on Pydantic v1 |
|||
with pytest.raises((ValueError, Exception)): |
|||
client.get("/broken-session-stream") |
|||
|
|||
|
|||
def test_broken_session_stream_no_raise(): |
|||
""" |
|||
When a dependency with yield raises after the streaming response already started |
|||
the 200 status code is already sent, but there's still an error in the server |
|||
afterwards, an exception is raised and captured or shown in the server logs. |
|||
""" |
|||
with TestClient(app, raise_server_exceptions=False) as client: |
|||
response = client.get("/broken-session-stream") |
|||
assert response.status_code == 200 |
|||
assert response.text == "" |
@ -0,0 +1,79 @@ |
|||
from contextlib import contextmanager |
|||
from typing import Any, Generator |
|||
|
|||
import pytest |
|||
from fastapi import Depends, FastAPI, WebSocket |
|||
from fastapi.testclient import TestClient |
|||
from typing_extensions import Annotated |
|||
|
|||
|
|||
class Session: |
|||
def __init__(self) -> None: |
|||
self.data = ["foo", "bar", "baz"] |
|||
self.open = True |
|||
|
|||
def __iter__(self) -> Generator[str, None, None]: |
|||
for item in self.data: |
|||
if self.open: |
|||
yield item |
|||
else: |
|||
raise ValueError("Session closed") |
|||
|
|||
|
|||
@contextmanager |
|||
def acquire_session() -> Generator[Session, None, None]: |
|||
session = Session() |
|||
try: |
|||
yield session |
|||
finally: |
|||
session.open = False |
|||
|
|||
|
|||
def dep_session() -> Any: |
|||
with acquire_session() as s: |
|||
yield s |
|||
|
|||
|
|||
def broken_dep_session() -> Any: |
|||
with acquire_session() as s: |
|||
s.open = False |
|||
yield s |
|||
|
|||
|
|||
SessionDep = Annotated[Session, Depends(dep_session)] |
|||
BrokenSessionDep = Annotated[Session, Depends(broken_dep_session)] |
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
@app.websocket("/ws") |
|||
async def websocket_endpoint(websocket: WebSocket, session: SessionDep): |
|||
await websocket.accept() |
|||
for item in session: |
|||
await websocket.send_text(f"{item}") |
|||
|
|||
|
|||
@app.websocket("/ws-broken") |
|||
async def websocket_endpoint_broken(websocket: WebSocket, session: BrokenSessionDep): |
|||
await websocket.accept() |
|||
for item in session: |
|||
await websocket.send_text(f"{item}") # pragma no cover |
|||
|
|||
|
|||
client = TestClient(app) |
|||
|
|||
|
|||
def test_websocket_dependency_after_yield(): |
|||
with client.websocket_connect("/ws") as websocket: |
|||
data = websocket.receive_text() |
|||
assert data == "foo" |
|||
data = websocket.receive_text() |
|||
assert data == "bar" |
|||
data = websocket.receive_text() |
|||
assert data == "baz" |
|||
|
|||
|
|||
def test_websocket_dependency_after_yield_broken(): |
|||
with pytest.raises(ValueError, match="Session closed"): |
|||
with client.websocket_connect("/ws-broken"): |
|||
pass # pragma no cover |
Loading…
Reference in new issue