From 4f8157588e47f909276e0474f6926740a2e55b9c Mon Sep 17 00:00:00 2001 From: Abdullah Hashim Date: Wed, 4 Dec 2024 01:37:12 +0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Preserve=20traceback=20when=20ex?= =?UTF-8?q?ception=20is=20raised=20in=20sync=20dependency=20with=20`yield`?= =?UTF-8?q?=20(#5823)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Marcelo Trylesinski --- fastapi/concurrency.py | 4 ++-- tests/test_exception_handlers.py | 23 ++++++++++++++++++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/fastapi/concurrency.py b/fastapi/concurrency.py index 894bd3ed1..3202c7078 100644 --- a/fastapi/concurrency.py +++ b/fastapi/concurrency.py @@ -1,7 +1,7 @@ from contextlib import asynccontextmanager as asynccontextmanager from typing import AsyncGenerator, ContextManager, TypeVar -import anyio +import anyio.to_thread from anyio import CapacityLimiter from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa @@ -28,7 +28,7 @@ async def contextmanager_in_threadpool( except Exception as e: ok = bool( await anyio.to_thread.run_sync( - cm.__exit__, type(e), e, None, limiter=exit_limiter + cm.__exit__, type(e), e, e.__traceback__, limiter=exit_limiter ) ) if not ok: diff --git a/tests/test_exception_handlers.py b/tests/test_exception_handlers.py index 67a4becec..6a3cbd830 100644 --- a/tests/test_exception_handlers.py +++ b/tests/test_exception_handlers.py @@ -1,5 +1,5 @@ import pytest -from fastapi import FastAPI, HTTPException +from fastapi import Depends, FastAPI, HTTPException from fastapi.exceptions import RequestValidationError from fastapi.testclient import TestClient from starlette.responses import JSONResponse @@ -28,6 +28,18 @@ app = FastAPI( client = TestClient(app) +def raise_value_error(): + raise ValueError() + + +def dependency_with_yield(): + yield raise_value_error() + + +@app.get("/dependency-with-yield", dependencies=[Depends(dependency_with_yield)]) +def with_yield(): ... + + @app.get("/http-exception") def route_with_http_exception(): raise HTTPException(status_code=400) @@ -65,3 +77,12 @@ def test_override_server_error_exception_response(): response = client.get("/server-error") assert response.status_code == 500 assert response.json() == {"exception": "server-error"} + + +def test_traceback_for_dependency_with_yield(): + client = TestClient(app, raise_server_exceptions=True) + with pytest.raises(ValueError) as exc_info: + client.get("/dependency-with-yield") + last_frame = exc_info.traceback[-1] + assert str(last_frame.path) == __file__ + assert last_frame.lineno == raise_value_error.__code__.co_firstlineno