Browse Source
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/13974/head
committed by
GitHub
14 changed files with 729 additions and 182 deletions
@ -0,0 +1,38 @@ |
|||||
|
import time |
||||
|
from typing import Annotated |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, HTTPException |
||||
|
from fastapi.responses import StreamingResponse |
||||
|
from sqlmodel import Field, Session, SQLModel, create_engine |
||||
|
|
||||
|
engine = create_engine("postgresql+psycopg://postgres:postgres@localhost/db") |
||||
|
|
||||
|
|
||||
|
class User(SQLModel, table=True): |
||||
|
id: int | None = Field(default=None, primary_key=True) |
||||
|
name: str |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
def get_session(): |
||||
|
with Session(engine) as session: |
||||
|
yield session |
||||
|
|
||||
|
|
||||
|
def get_user(user_id: int, session: Annotated[Session, Depends(get_session)]): |
||||
|
user = session.get(User, user_id) |
||||
|
if not user: |
||||
|
raise HTTPException(status_code=403, detail="Not authorized") |
||||
|
|
||||
|
|
||||
|
def generate_stream(query: str): |
||||
|
for ch in query: |
||||
|
yield ch |
||||
|
time.sleep(0.1) |
||||
|
|
||||
|
|
||||
|
@app.get("/generate", dependencies=[Depends(get_user)]) |
||||
|
def generate(query: str): |
||||
|
return StreamingResponse(content=generate_stream(query)) |
@ -0,0 +1,39 @@ |
|||||
|
import time |
||||
|
from typing import Annotated |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, HTTPException |
||||
|
from fastapi.responses import StreamingResponse |
||||
|
from sqlmodel import Field, Session, SQLModel, create_engine |
||||
|
|
||||
|
engine = create_engine("postgresql+psycopg://postgres:postgres@localhost/db") |
||||
|
|
||||
|
|
||||
|
class User(SQLModel, table=True): |
||||
|
id: int | None = Field(default=None, primary_key=True) |
||||
|
name: str |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
def get_session(): |
||||
|
with Session(engine) as session: |
||||
|
yield session |
||||
|
|
||||
|
|
||||
|
def get_user(user_id: int, session: Annotated[Session, Depends(get_session)]): |
||||
|
user = session.get(User, user_id) |
||||
|
if not user: |
||||
|
raise HTTPException(status_code=403, detail="Not authorized") |
||||
|
session.close() |
||||
|
|
||||
|
|
||||
|
def generate_stream(query: str): |
||||
|
for ch in query: |
||||
|
yield ch |
||||
|
time.sleep(0.1) |
||||
|
|
||||
|
|
||||
|
@app.get("/generate", dependencies=[Depends(get_user)]) |
||||
|
def generate(query: str): |
||||
|
return StreamingResponse(content=generate_stream(query)) |
@ -0,0 +1,18 @@ |
|||||
|
from contextlib import AsyncExitStack |
||||
|
|
||||
|
from starlette.types import ASGIApp, Receive, Scope, Send |
||||
|
|
||||
|
|
||||
|
# Used mainly to close files after the request is done, dependencies are closed |
||||
|
# in their own AsyncExitStack |
||||
|
class AsyncExitStackMiddleware: |
||||
|
def __init__( |
||||
|
self, app: ASGIApp, context_name: str = "fastapi_middleware_astack" |
||||
|
) -> None: |
||||
|
self.app = app |
||||
|
self.context_name = context_name |
||||
|
|
||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
||||
|
async with AsyncExitStack() as stack: |
||||
|
scope[self.context_name] = stack |
||||
|
await self.app(scope, receive, send) |
@ -0,0 +1,69 @@ |
|||||
|
from typing import Any |
||||
|
|
||||
|
import pytest |
||||
|
from fastapi import Depends, FastAPI, HTTPException |
||||
|
from fastapi.testclient import TestClient |
||||
|
from typing_extensions import Annotated |
||||
|
|
||||
|
|
||||
|
class CustomError(Exception): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
def catching_dep() -> Any: |
||||
|
try: |
||||
|
yield "s" |
||||
|
except CustomError as err: |
||||
|
raise HTTPException(status_code=418, detail="Session error") from err |
||||
|
|
||||
|
|
||||
|
def broken_dep() -> Any: |
||||
|
yield "s" |
||||
|
raise ValueError("Broken after yield") |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
@app.get("/catching") |
||||
|
def catching(d: Annotated[str, Depends(catching_dep)]) -> Any: |
||||
|
raise CustomError("Simulated error during streaming") |
||||
|
|
||||
|
|
||||
|
@app.get("/broken") |
||||
|
def broken(d: Annotated[str, Depends(broken_dep)]) -> Any: |
||||
|
return {"message": "all good?"} |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def test_catching(): |
||||
|
response = client.get("/catching") |
||||
|
assert response.status_code == 418 |
||||
|
assert response.json() == {"detail": "Session error"} |
||||
|
|
||||
|
|
||||
|
def test_broken_raise(): |
||||
|
with pytest.raises(ValueError, match="Broken after yield"): |
||||
|
client.get("/broken") |
||||
|
|
||||
|
|
||||
|
def test_broken_no_raise(): |
||||
|
""" |
||||
|
When a dependency with yield raises after the yield (not in an except), the |
||||
|
response is already "successfully" sent back to the client, but there's still |
||||
|
an error in the server afterwards, an exception is raised and captured or shown |
||||
|
in the server logs. |
||||
|
""" |
||||
|
with TestClient(app, raise_server_exceptions=False) as client: |
||||
|
response = client.get("/broken") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == {"message": "all good?"} |
||||
|
|
||||
|
|
||||
|
def test_broken_return_finishes(): |
||||
|
client = TestClient(app, raise_server_exceptions=False) |
||||
|
response = client.get("/broken") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == {"message": "all good?"} |
@ -0,0 +1,130 @@ |
|||||
|
from contextlib import contextmanager |
||||
|
from typing import Any, Generator |
||||
|
|
||||
|
import pytest |
||||
|
from fastapi import Depends, FastAPI |
||||
|
from fastapi.responses import StreamingResponse |
||||
|
from fastapi.testclient import TestClient |
||||
|
from typing_extensions import Annotated |
||||
|
|
||||
|
|
||||
|
class Session: |
||||
|
def __init__(self) -> None: |
||||
|
self.data = ["foo", "bar", "baz"] |
||||
|
self.open = True |
||||
|
|
||||
|
def __iter__(self) -> Generator[str, None, None]: |
||||
|
for item in self.data: |
||||
|
if self.open: |
||||
|
yield item |
||||
|
else: |
||||
|
raise ValueError("Session closed") |
||||
|
|
||||
|
|
||||
|
@contextmanager |
||||
|
def acquire_session() -> Generator[Session, None, None]: |
||||
|
session = Session() |
||||
|
try: |
||||
|
yield session |
||||
|
finally: |
||||
|
session.open = False |
||||
|
|
||||
|
|
||||
|
def dep_session() -> Any: |
||||
|
with acquire_session() as s: |
||||
|
yield s |
||||
|
|
||||
|
|
||||
|
def broken_dep_session() -> Any: |
||||
|
with acquire_session() as s: |
||||
|
s.open = False |
||||
|
yield s |
||||
|
|
||||
|
|
||||
|
SessionDep = Annotated[Session, Depends(dep_session)] |
||||
|
BrokenSessionDep = Annotated[Session, Depends(broken_dep_session)] |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
@app.get("/data") |
||||
|
def get_data(session: SessionDep) -> Any: |
||||
|
data = list(session) |
||||
|
return data |
||||
|
|
||||
|
|
||||
|
@app.get("/stream-simple") |
||||
|
def get_stream_simple(session: SessionDep) -> Any: |
||||
|
def iter_data(): |
||||
|
yield from ["x", "y", "z"] |
||||
|
|
||||
|
return StreamingResponse(iter_data()) |
||||
|
|
||||
|
|
||||
|
@app.get("/stream-session") |
||||
|
def get_stream_session(session: SessionDep) -> Any: |
||||
|
def iter_data(): |
||||
|
yield from session |
||||
|
|
||||
|
return StreamingResponse(iter_data()) |
||||
|
|
||||
|
|
||||
|
@app.get("/broken-session-data") |
||||
|
def get_broken_session_data(session: BrokenSessionDep) -> Any: |
||||
|
return list(session) |
||||
|
|
||||
|
|
||||
|
@app.get("/broken-session-stream") |
||||
|
def get_broken_session_stream(session: BrokenSessionDep) -> Any: |
||||
|
def iter_data(): |
||||
|
yield from session |
||||
|
|
||||
|
return StreamingResponse(iter_data()) |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def test_regular_no_stream(): |
||||
|
response = client.get("/data") |
||||
|
assert response.json() == ["foo", "bar", "baz"] |
||||
|
|
||||
|
|
||||
|
def test_stream_simple(): |
||||
|
response = client.get("/stream-simple") |
||||
|
assert response.text == "xyz" |
||||
|
|
||||
|
|
||||
|
def test_stream_session(): |
||||
|
response = client.get("/stream-session") |
||||
|
assert response.text == "foobarbaz" |
||||
|
|
||||
|
|
||||
|
def test_broken_session_data(): |
||||
|
with pytest.raises(ValueError, match="Session closed"): |
||||
|
client.get("/broken-session-data") |
||||
|
|
||||
|
|
||||
|
def test_broken_session_data_no_raise(): |
||||
|
client = TestClient(app, raise_server_exceptions=False) |
||||
|
response = client.get("/broken-session-data") |
||||
|
assert response.status_code == 500 |
||||
|
assert response.text == "Internal Server Error" |
||||
|
|
||||
|
|
||||
|
def test_broken_session_stream_raise(): |
||||
|
# Can raise ValueError on Pydantic v2 and ExceptionGroup on Pydantic v1 |
||||
|
with pytest.raises((ValueError, Exception)): |
||||
|
client.get("/broken-session-stream") |
||||
|
|
||||
|
|
||||
|
def test_broken_session_stream_no_raise(): |
||||
|
""" |
||||
|
When a dependency with yield raises after the streaming response already started |
||||
|
the 200 status code is already sent, but there's still an error in the server |
||||
|
afterwards, an exception is raised and captured or shown in the server logs. |
||||
|
""" |
||||
|
with TestClient(app, raise_server_exceptions=False) as client: |
||||
|
response = client.get("/broken-session-stream") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.text == "" |
@ -0,0 +1,79 @@ |
|||||
|
from contextlib import contextmanager |
||||
|
from typing import Any, Generator |
||||
|
|
||||
|
import pytest |
||||
|
from fastapi import Depends, FastAPI, WebSocket |
||||
|
from fastapi.testclient import TestClient |
||||
|
from typing_extensions import Annotated |
||||
|
|
||||
|
|
||||
|
class Session: |
||||
|
def __init__(self) -> None: |
||||
|
self.data = ["foo", "bar", "baz"] |
||||
|
self.open = True |
||||
|
|
||||
|
def __iter__(self) -> Generator[str, None, None]: |
||||
|
for item in self.data: |
||||
|
if self.open: |
||||
|
yield item |
||||
|
else: |
||||
|
raise ValueError("Session closed") |
||||
|
|
||||
|
|
||||
|
@contextmanager |
||||
|
def acquire_session() -> Generator[Session, None, None]: |
||||
|
session = Session() |
||||
|
try: |
||||
|
yield session |
||||
|
finally: |
||||
|
session.open = False |
||||
|
|
||||
|
|
||||
|
def dep_session() -> Any: |
||||
|
with acquire_session() as s: |
||||
|
yield s |
||||
|
|
||||
|
|
||||
|
def broken_dep_session() -> Any: |
||||
|
with acquire_session() as s: |
||||
|
s.open = False |
||||
|
yield s |
||||
|
|
||||
|
|
||||
|
SessionDep = Annotated[Session, Depends(dep_session)] |
||||
|
BrokenSessionDep = Annotated[Session, Depends(broken_dep_session)] |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
@app.websocket("/ws") |
||||
|
async def websocket_endpoint(websocket: WebSocket, session: SessionDep): |
||||
|
await websocket.accept() |
||||
|
for item in session: |
||||
|
await websocket.send_text(f"{item}") |
||||
|
|
||||
|
|
||||
|
@app.websocket("/ws-broken") |
||||
|
async def websocket_endpoint_broken(websocket: WebSocket, session: BrokenSessionDep): |
||||
|
await websocket.accept() |
||||
|
for item in session: |
||||
|
await websocket.send_text(f"{item}") # pragma no cover |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def test_websocket_dependency_after_yield(): |
||||
|
with client.websocket_connect("/ws") as websocket: |
||||
|
data = websocket.receive_text() |
||||
|
assert data == "foo" |
||||
|
data = websocket.receive_text() |
||||
|
assert data == "bar" |
||||
|
data = websocket.receive_text() |
||||
|
assert data == "baz" |
||||
|
|
||||
|
|
||||
|
def test_websocket_dependency_after_yield_broken(): |
||||
|
with pytest.raises(ValueError, match="Session closed"): |
||||
|
with client.websocket_connect("/ws-broken"): |
||||
|
pass # pragma no cover |
Loading…
Reference in new issue