Browse Source

Fix optional response encoder import errors

pull/15628/head
Ratish Oberoi 1 week ago
parent
commit
5761366e5c
  1. 24
      fastapi/responses.py
  2. 63
      tests/test_deprecated_responses.py

24
fastapi/responses.py

@ -24,16 +24,22 @@ class _OrjsonModule(Protocol):
def dumps(self, __obj: Any, *, option: int = ...) -> bytes: ...
ujson: _UjsonModule | None
ujson_import_error: ImportError | None = None
try:
ujson = cast(_UjsonModule, importlib.import_module("ujson"))
except ModuleNotFoundError: # pragma: nocover
ujson = None # type: ignore[assignment]
except ImportError as e:
ujson = None
ujson_import_error = e
orjson: _OrjsonModule | None
orjson_import_error: ImportError | None = None
try:
orjson = cast(_OrjsonModule, importlib.import_module("orjson"))
except ModuleNotFoundError: # pragma: nocover
orjson = None # type: ignore[assignment]
except ImportError as e:
orjson = None
orjson_import_error = e
@deprecated(
@ -62,7 +68,10 @@ class UJSONResponse(JSONResponse):
"""
def render(self, content: Any) -> bytes:
assert ujson is not None, "ujson must be installed to use UJSONResponse"
if ujson is None:
raise RuntimeError(
"ujson must be installed to use UJSONResponse"
) from ujson_import_error
return ujson.dumps(content, ensure_ascii=False).encode("utf-8")
@ -92,7 +101,10 @@ class ORJSONResponse(JSONResponse):
"""
def render(self, content: Any) -> bytes:
assert orjson is not None, "orjson must be installed to use ORJSONResponse"
if orjson is None:
raise RuntimeError(
"orjson must be installed to use ORJSONResponse"
) from orjson_import_error
return orjson.dumps(
content, option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY
)

63
tests/test_deprecated_responses.py

@ -1,5 +1,9 @@
import importlib
import sys
import warnings
from types import ModuleType
import fastapi
import pytest
from fastapi import FastAPI
from fastapi.exceptions import FastAPIDeprecationWarning
@ -15,6 +19,33 @@ class Item(BaseModel):
price: float
def _import_responses_with_failed_optional_import(
monkeypatch: pytest.MonkeyPatch, module_name: str
) -> ModuleType:
original_responses = sys.modules.pop("fastapi.responses", None)
had_responses_attr = hasattr(fastapi, "responses")
original_responses_attr = getattr(fastapi, "responses", None)
if had_responses_attr:
delattr(fastapi, "responses")
real_import_module = importlib.import_module
def import_module(name: str, package: str | None = None) -> ModuleType:
if name == module_name:
raise ImportError(f"broken {module_name}")
return real_import_module(name, package)
monkeypatch.setattr(importlib, "import_module", import_module)
try:
return real_import_module("fastapi.responses")
finally:
sys.modules.pop("fastapi.responses", None)
if original_responses is not None:
sys.modules["fastapi.responses"] = original_responses
if had_responses_attr:
fastapi.responses = original_responses_attr
# ORJSON
@ -47,6 +78,22 @@ def test_orjson_response_emits_deprecation_warning():
ORJSONResponse(content={"hello": "world"})
def test_orjson_import_error_does_not_break_responses_import(
monkeypatch: pytest.MonkeyPatch,
):
responses = _import_responses_with_failed_optional_import(monkeypatch, "orjson")
response = responses.JSONResponse(content={"hello": "world"})
assert response.body == b'{"hello":"world"}'
with warnings.catch_warnings():
warnings.simplefilter("ignore", FastAPIDeprecationWarning)
with pytest.raises(RuntimeError, match="orjson must be installed") as exc_info:
responses.ORJSONResponse(content={"hello": "world"})
assert isinstance(exc_info.value.__cause__, ImportError)
assert str(exc_info.value.__cause__) == "broken orjson"
# UJSON
@ -77,3 +124,19 @@ def test_ujson_response_returns_correct_data():
def test_ujson_response_emits_deprecation_warning():
with pytest.warns(FastAPIDeprecationWarning, match="UJSONResponse is deprecated"):
UJSONResponse(content={"hello": "world"})
def test_ujson_import_error_does_not_break_responses_import(
monkeypatch: pytest.MonkeyPatch,
):
responses = _import_responses_with_failed_optional_import(monkeypatch, "ujson")
response = responses.JSONResponse(content={"hello": "world"})
assert response.body == b'{"hello":"world"}'
with warnings.catch_warnings():
warnings.simplefilter("ignore", FastAPIDeprecationWarning)
with pytest.raises(RuntimeError, match="ujson must be installed") as exc_info:
responses.UJSONResponse(content={"hello": "world"})
assert isinstance(exc_info.value.__cause__, ImportError)
assert str(exc_info.value.__cause__) == "broken ujson"

Loading…
Cancel
Save