diff --git a/tests/test_dependency_contextmanager.py b/tests/test_dependency_contextmanager.py index 039c423b9..3a49dbc67 100644 --- a/tests/test_dependency_contextmanager.py +++ b/tests/test_dependency_contextmanager.py @@ -2,7 +2,7 @@ import json from typing import Dict import pytest -from fastapi import BackgroundTasks, Depends, FastAPI +from fastapi import BackgroundTasks, Depends, FastAPI, Query from fastapi.responses import StreamingResponse from fastapi.testclient import TestClient @@ -12,6 +12,8 @@ state = { "/sync": "generator not started", "/async_raise": "asyncgen raise not started", "/sync_raise": "generator raise not started", + "/async_rewrite_exception": "asyncgen rewrite exception not started", + "/sync_rewrite_exception": "generator rewrite exception not started", "context_a": "not started a", "context_b": "not started b", "bg": "not set", @@ -71,6 +73,28 @@ def generator_state_try(state: Dict[str, str] = Depends(get_state)): state["/sync_raise"] = "generator raise finalized" +async def asyncgen_state_rewrite_exception(state: Dict[str, str] = Depends(get_state)): + state["/async_rewrite_exception"] = "asyncgen rewrite exception started" + try: + yield state["/async_rewrite_exception"] + except Exception as error: + errors.append("/async_rewrite_exception") + raise OtherDependencyError() from error + finally: + state["/async_rewrite_exception"] = "asyncgen rewrite exception finalized" + + +def generator_state_rewrite_exception(state: Dict[str, str] = Depends(get_state)): + state["/sync_rewrite_exception"] = "generator rewrite exception started" + try: + yield state["/sync_rewrite_exception"] + except Exception as error: + errors.append("/sync_rewrite_exception") + raise OtherDependencyError() from error + finally: + state["/sync_rewrite_exception"] = "generator rewrite exception finalized" + + async def context_a(state: dict = Depends(get_state)): state["context_a"] = "started a" try: @@ -121,6 +145,26 @@ async def get_sync_raise_other(state: str = Depends(generator_state_try)): raise OtherDependencyError() +@app.get("/async_rewrite_exception") +async def get_async_rewrite_exception( + state: str = Depends(asyncgen_state_rewrite_exception), + do_raise: bool = Query(), +): + assert state == "asyncgen rewrite exception started" + if do_raise: + raise AsyncDependencyError() + + +@app.get("/sync_rewrite_exception") +def get_sync_rewrite_exception( + state: str = Depends(generator_state_rewrite_exception), + do_raise: bool = Query(), +): + assert state == "generator rewrite exception started" + if do_raise: + raise SyncDependencyError() + + @app.get("/context_b") async def get_context_b(state: dict = Depends(context_b)): return state @@ -263,6 +307,58 @@ def test_async_raise_server_error(): errors.clear() +def test_async_rewrite_exception_no_raise(): + state["/async_rewrite_exception"] = "asyncgen rewrite exception not started" + client.get("/async_rewrite_exception?do_raise=false") + assert state["/async_rewrite_exception"] == "asyncgen rewrite exception finalized" + assert "/async_rewrite_exception" not in errors + errors.clear() + + +def test_async_rewrite_exception_handler_raise(): + state["/async_rewrite_exception"] = "asyncgen rewrite exception not started" + with pytest.raises(OtherDependencyError): + client.get("/async_rewrite_exception?do_raise=true") + assert state["/async_rewrite_exception"] == "asyncgen rewrite exception finalized" + assert "/async_rewrite_exception" in errors + errors.clear() + + +def test_async_rewrite_exception_validator_raise(): + state["/async_rewrite_exception"] = "asyncgen rewrite exception not started" + with pytest.raises(OtherDependencyError): + client.get("/async_rewrite_exception?do_raise=invalid_value") + assert state["/async_rewrite_exception"] == "asyncgen rewrite exception finalized" + assert "/async_rewrite_exception" in errors + errors.clear() + + +def test_sync_rewrite_exception_no_raise(): + state["/sync_rewrite_exception"] = "generator rewrite exception not started" + client.get("/sync_rewrite_exception?do_raise=false") + assert state["/sync_rewrite_exception"] == "generator rewrite exception finalized" + assert "/sync_rewrite_exception" not in errors + errors.clear() + + +def test_sync_rewrite_exception_handler_raise(): + state["/sync_rewrite_exception"] = "generator rewrite exception not started" + with pytest.raises(OtherDependencyError): + client.get("/sync_rewrite_exception?do_raise=true") + assert state["/sync_rewrite_exception"] == "generator rewrite exception finalized" + assert "/sync_rewrite_exception" in errors + errors.clear() + + +def test_sync_rewrite_exception_validator_raise(): + state["/sync_rewrite_exception"] = "generator rewrite exception not started" + with pytest.raises(OtherDependencyError): + client.get("/sync_rewrite_exception?do_raise=invalid_value") + assert state["/sync_rewrite_exception"] == "generator rewrite exception finalized" + assert "/sync_rewrite_exception" in errors + errors.clear() + + def test_context_b(): response = client.get("/context_b") data = response.json()