diff --git a/fastapi/responses.py b/fastapi/responses.py index 29df4b7a61..36a4a95607 100644 --- a/fastapi/responses.py +++ b/fastapi/responses.py @@ -24,16 +24,22 @@ class _OrjsonModule(Protocol): def dumps(self, __obj: Any, *, option: int = ...) -> bytes: ... +ujson: _UjsonModule | None +ujson_import_error: ImportError | None = None try: ujson = cast(_UjsonModule, importlib.import_module("ujson")) -except ModuleNotFoundError: # pragma: nocover - ujson = None # type: ignore[assignment] +except ImportError as e: + ujson = None + ujson_import_error = e +orjson: _OrjsonModule | None +orjson_import_error: ImportError | None = None try: orjson = cast(_OrjsonModule, importlib.import_module("orjson")) -except ModuleNotFoundError: # pragma: nocover - orjson = None # type: ignore[assignment] +except ImportError as e: + orjson = None + orjson_import_error = e @deprecated( @@ -62,7 +68,10 @@ class UJSONResponse(JSONResponse): """ def render(self, content: Any) -> bytes: - assert ujson is not None, "ujson must be installed to use UJSONResponse" + if ujson is None: + raise RuntimeError( + "ujson must be installed to use UJSONResponse" + ) from ujson_import_error return ujson.dumps(content, ensure_ascii=False).encode("utf-8") @@ -92,7 +101,10 @@ class ORJSONResponse(JSONResponse): """ def render(self, content: Any) -> bytes: - assert orjson is not None, "orjson must be installed to use ORJSONResponse" + if orjson is None: + raise RuntimeError( + "orjson must be installed to use ORJSONResponse" + ) from orjson_import_error return orjson.dumps( content, option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY ) diff --git a/tests/test_deprecated_responses.py b/tests/test_deprecated_responses.py index 8cbd9c11fe..8b891c0fbe 100644 --- a/tests/test_deprecated_responses.py +++ b/tests/test_deprecated_responses.py @@ -1,5 +1,9 @@ +import importlib +import sys import warnings +from types import ModuleType +import fastapi import pytest from fastapi import FastAPI from fastapi.exceptions import FastAPIDeprecationWarning @@ -15,6 +19,33 @@ class Item(BaseModel): price: float +def _import_responses_with_failed_optional_import( + monkeypatch: pytest.MonkeyPatch, module_name: str +) -> ModuleType: + original_responses = sys.modules.pop("fastapi.responses", None) + had_responses_attr = hasattr(fastapi, "responses") + original_responses_attr = getattr(fastapi, "responses", None) + if had_responses_attr: + delattr(fastapi, "responses") + + real_import_module = importlib.import_module + + def import_module(name: str, package: str | None = None) -> ModuleType: + if name == module_name: + raise ImportError(f"broken {module_name}") + return real_import_module(name, package) + + monkeypatch.setattr(importlib, "import_module", import_module) + try: + return real_import_module("fastapi.responses") + finally: + sys.modules.pop("fastapi.responses", None) + if original_responses is not None: + sys.modules["fastapi.responses"] = original_responses + if had_responses_attr: + fastapi.responses = original_responses_attr + + # ORJSON @@ -47,6 +78,22 @@ def test_orjson_response_emits_deprecation_warning(): ORJSONResponse(content={"hello": "world"}) +def test_orjson_import_error_does_not_break_responses_import( + monkeypatch: pytest.MonkeyPatch, +): + responses = _import_responses_with_failed_optional_import(monkeypatch, "orjson") + + response = responses.JSONResponse(content={"hello": "world"}) + assert response.body == b'{"hello":"world"}' + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FastAPIDeprecationWarning) + with pytest.raises(RuntimeError, match="orjson must be installed") as exc_info: + responses.ORJSONResponse(content={"hello": "world"}) + assert isinstance(exc_info.value.__cause__, ImportError) + assert str(exc_info.value.__cause__) == "broken orjson" + + # UJSON @@ -77,3 +124,19 @@ def test_ujson_response_returns_correct_data(): def test_ujson_response_emits_deprecation_warning(): with pytest.warns(FastAPIDeprecationWarning, match="UJSONResponse is deprecated"): UJSONResponse(content={"hello": "world"}) + + +def test_ujson_import_error_does_not_break_responses_import( + monkeypatch: pytest.MonkeyPatch, +): + responses = _import_responses_with_failed_optional_import(monkeypatch, "ujson") + + response = responses.JSONResponse(content={"hello": "world"}) + assert response.body == b'{"hello":"world"}' + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FastAPIDeprecationWarning) + with pytest.raises(RuntimeError, match="ujson must be installed") as exc_info: + responses.UJSONResponse(content={"hello": "world"}) + assert isinstance(exc_info.value.__cause__, ImportError) + assert str(exc_info.value.__cause__) == "broken ujson"