From 873aeb70360b073c287b5815c53100850a3e7aa2 Mon Sep 17 00:00:00 2001 From: Anton Ryzhov Date: Tue, 22 Oct 2024 21:20:49 +0200 Subject: [PATCH 1/2] Test context manager dependencies can catch and overwrite exceptions --- tests/test_dependency_contextmanager.py | 98 ++++++++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/tests/test_dependency_contextmanager.py b/tests/test_dependency_contextmanager.py index 039c423b9..3a49dbc67 100644 --- a/tests/test_dependency_contextmanager.py +++ b/tests/test_dependency_contextmanager.py @@ -2,7 +2,7 @@ import json from typing import Dict import pytest -from fastapi import BackgroundTasks, Depends, FastAPI +from fastapi import BackgroundTasks, Depends, FastAPI, Query from fastapi.responses import StreamingResponse from fastapi.testclient import TestClient @@ -12,6 +12,8 @@ state = { "/sync": "generator not started", "/async_raise": "asyncgen raise not started", "/sync_raise": "generator raise not started", + "/async_rewrite_exception": "asyncgen rewrite exception not started", + "/sync_rewrite_exception": "generator rewrite exception not started", "context_a": "not started a", "context_b": "not started b", "bg": "not set", @@ -71,6 +73,28 @@ def generator_state_try(state: Dict[str, str] = Depends(get_state)): state["/sync_raise"] = "generator raise finalized" +async def asyncgen_state_rewrite_exception(state: Dict[str, str] = Depends(get_state)): + state["/async_rewrite_exception"] = "asyncgen rewrite exception started" + try: + yield state["/async_rewrite_exception"] + except Exception as error: + errors.append("/async_rewrite_exception") + raise OtherDependencyError() from error + finally: + state["/async_rewrite_exception"] = "asyncgen rewrite exception finalized" + + +def generator_state_rewrite_exception(state: Dict[str, str] = Depends(get_state)): + state["/sync_rewrite_exception"] = "generator rewrite exception started" + try: + yield state["/sync_rewrite_exception"] + except Exception as error: + errors.append("/sync_rewrite_exception") + raise OtherDependencyError() from error + finally: + state["/sync_rewrite_exception"] = "generator rewrite exception finalized" + + async def context_a(state: dict = Depends(get_state)): state["context_a"] = "started a" try: @@ -121,6 +145,26 @@ async def get_sync_raise_other(state: str = Depends(generator_state_try)): raise OtherDependencyError() +@app.get("/async_rewrite_exception") +async def get_async_rewrite_exception( + state: str = Depends(asyncgen_state_rewrite_exception), + do_raise: bool = Query(), +): + assert state == "asyncgen rewrite exception started" + if do_raise: + raise AsyncDependencyError() + + +@app.get("/sync_rewrite_exception") +def get_sync_rewrite_exception( + state: str = Depends(generator_state_rewrite_exception), + do_raise: bool = Query(), +): + assert state == "generator rewrite exception started" + if do_raise: + raise SyncDependencyError() + + @app.get("/context_b") async def get_context_b(state: dict = Depends(context_b)): return state @@ -263,6 +307,58 @@ def test_async_raise_server_error(): errors.clear() +def test_async_rewrite_exception_no_raise(): + state["/async_rewrite_exception"] = "asyncgen rewrite exception not started" + client.get("/async_rewrite_exception?do_raise=false") + assert state["/async_rewrite_exception"] == "asyncgen rewrite exception finalized" + assert "/async_rewrite_exception" not in errors + errors.clear() + + +def test_async_rewrite_exception_handler_raise(): + state["/async_rewrite_exception"] = "asyncgen rewrite exception not started" + with pytest.raises(OtherDependencyError): + client.get("/async_rewrite_exception?do_raise=true") + assert state["/async_rewrite_exception"] == "asyncgen rewrite exception finalized" + assert "/async_rewrite_exception" in errors + errors.clear() + + +def test_async_rewrite_exception_validator_raise(): + state["/async_rewrite_exception"] = "asyncgen rewrite exception not started" + with pytest.raises(OtherDependencyError): + client.get("/async_rewrite_exception?do_raise=invalid_value") + assert state["/async_rewrite_exception"] == "asyncgen rewrite exception finalized" + assert "/async_rewrite_exception" in errors + errors.clear() + + +def test_sync_rewrite_exception_no_raise(): + state["/sync_rewrite_exception"] = "generator rewrite exception not started" + client.get("/sync_rewrite_exception?do_raise=false") + assert state["/sync_rewrite_exception"] == "generator rewrite exception finalized" + assert "/sync_rewrite_exception" not in errors + errors.clear() + + +def test_sync_rewrite_exception_handler_raise(): + state["/sync_rewrite_exception"] = "generator rewrite exception not started" + with pytest.raises(OtherDependencyError): + client.get("/sync_rewrite_exception?do_raise=true") + assert state["/sync_rewrite_exception"] == "generator rewrite exception finalized" + assert "/sync_rewrite_exception" in errors + errors.clear() + + +def test_sync_rewrite_exception_validator_raise(): + state["/sync_rewrite_exception"] = "generator rewrite exception not started" + with pytest.raises(OtherDependencyError): + client.get("/sync_rewrite_exception?do_raise=invalid_value") + assert state["/sync_rewrite_exception"] == "generator rewrite exception finalized" + assert "/sync_rewrite_exception" in errors + errors.clear() + + def test_context_b(): response = client.get("/context_b") data = response.json() From 01999fa07dc3df194fccaee7460b53ac46fba5e5 Mon Sep 17 00:00:00 2001 From: Anton Ryzhov Date: Tue, 22 Oct 2024 21:20:55 +0200 Subject: [PATCH 2/2] Enable context manager dependencies to also catch validation errors --- fastapi/routing.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 8ea4bb219..f664f0d18 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -297,7 +297,12 @@ def get_request_handler( embed_body_fields=embed_body_fields, ) errors = solved_result.errors - if not errors: + if errors: + validation_error = RequestValidationError( + _normalize_errors(errors), body=body + ) + raise validation_error + else: raw_response = await run_endpoint_function( dependant=dependant, values=solved_result.values, @@ -339,11 +344,6 @@ def get_request_handler( if not is_body_allowed_for_status_code(response.status_code): response.body = b"" response.headers.raw.extend(solved_result.response.headers.raw) - if errors: - validation_error = RequestValidationError( - _normalize_errors(errors), body=body - ) - raise validation_error if response is None: raise FastAPIError( "No response object was returned. There's a high chance that the "