Browse Source

Fix optional response encoder import errors

pull/15628/head
Ratish Oberoi 2 weeks ago
parent
commit
5761366e5c
  1. 24
      fastapi/responses.py
  2. 63
      tests/test_deprecated_responses.py

24
fastapi/responses.py

@ -24,16 +24,22 @@ class _OrjsonModule(Protocol):
def dumps(self, __obj: Any, *, option: int = ...) -> bytes: ... def dumps(self, __obj: Any, *, option: int = ...) -> bytes: ...
ujson: _UjsonModule | None
ujson_import_error: ImportError | None = None
try: try:
ujson = cast(_UjsonModule, importlib.import_module("ujson")) ujson = cast(_UjsonModule, importlib.import_module("ujson"))
except ModuleNotFoundError: # pragma: nocover except ImportError as e:
ujson = None # type: ignore[assignment] ujson = None
ujson_import_error = e
orjson: _OrjsonModule | None
orjson_import_error: ImportError | None = None
try: try:
orjson = cast(_OrjsonModule, importlib.import_module("orjson")) orjson = cast(_OrjsonModule, importlib.import_module("orjson"))
except ModuleNotFoundError: # pragma: nocover except ImportError as e:
orjson = None # type: ignore[assignment] orjson = None
orjson_import_error = e
@deprecated( @deprecated(
@ -62,7 +68,10 @@ class UJSONResponse(JSONResponse):
""" """
def render(self, content: Any) -> bytes: def render(self, content: Any) -> bytes:
assert ujson is not None, "ujson must be installed to use UJSONResponse" if ujson is None:
raise RuntimeError(
"ujson must be installed to use UJSONResponse"
) from ujson_import_error
return ujson.dumps(content, ensure_ascii=False).encode("utf-8") return ujson.dumps(content, ensure_ascii=False).encode("utf-8")
@ -92,7 +101,10 @@ class ORJSONResponse(JSONResponse):
""" """
def render(self, content: Any) -> bytes: def render(self, content: Any) -> bytes:
assert orjson is not None, "orjson must be installed to use ORJSONResponse" if orjson is None:
raise RuntimeError(
"orjson must be installed to use ORJSONResponse"
) from orjson_import_error
return orjson.dumps( return orjson.dumps(
content, option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY content, option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY
) )

63
tests/test_deprecated_responses.py

@ -1,5 +1,9 @@
import importlib
import sys
import warnings import warnings
from types import ModuleType
import fastapi
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.exceptions import FastAPIDeprecationWarning from fastapi.exceptions import FastAPIDeprecationWarning
@ -15,6 +19,33 @@ class Item(BaseModel):
price: float price: float
def _import_responses_with_failed_optional_import(
monkeypatch: pytest.MonkeyPatch, module_name: str
) -> ModuleType:
original_responses = sys.modules.pop("fastapi.responses", None)
had_responses_attr = hasattr(fastapi, "responses")
original_responses_attr = getattr(fastapi, "responses", None)
if had_responses_attr:
delattr(fastapi, "responses")
real_import_module = importlib.import_module
def import_module(name: str, package: str | None = None) -> ModuleType:
if name == module_name:
raise ImportError(f"broken {module_name}")
return real_import_module(name, package)
monkeypatch.setattr(importlib, "import_module", import_module)
try:
return real_import_module("fastapi.responses")
finally:
sys.modules.pop("fastapi.responses", None)
if original_responses is not None:
sys.modules["fastapi.responses"] = original_responses
if had_responses_attr:
fastapi.responses = original_responses_attr
# ORJSON # ORJSON
@ -47,6 +78,22 @@ def test_orjson_response_emits_deprecation_warning():
ORJSONResponse(content={"hello": "world"}) ORJSONResponse(content={"hello": "world"})
def test_orjson_import_error_does_not_break_responses_import(
monkeypatch: pytest.MonkeyPatch,
):
responses = _import_responses_with_failed_optional_import(monkeypatch, "orjson")
response = responses.JSONResponse(content={"hello": "world"})
assert response.body == b'{"hello":"world"}'
with warnings.catch_warnings():
warnings.simplefilter("ignore", FastAPIDeprecationWarning)
with pytest.raises(RuntimeError, match="orjson must be installed") as exc_info:
responses.ORJSONResponse(content={"hello": "world"})
assert isinstance(exc_info.value.__cause__, ImportError)
assert str(exc_info.value.__cause__) == "broken orjson"
# UJSON # UJSON
@ -77,3 +124,19 @@ def test_ujson_response_returns_correct_data():
def test_ujson_response_emits_deprecation_warning(): def test_ujson_response_emits_deprecation_warning():
with pytest.warns(FastAPIDeprecationWarning, match="UJSONResponse is deprecated"): with pytest.warns(FastAPIDeprecationWarning, match="UJSONResponse is deprecated"):
UJSONResponse(content={"hello": "world"}) UJSONResponse(content={"hello": "world"})
def test_ujson_import_error_does_not_break_responses_import(
monkeypatch: pytest.MonkeyPatch,
):
responses = _import_responses_with_failed_optional_import(monkeypatch, "ujson")
response = responses.JSONResponse(content={"hello": "world"})
assert response.body == b'{"hello":"world"}'
with warnings.catch_warnings():
warnings.simplefilter("ignore", FastAPIDeprecationWarning)
with pytest.raises(RuntimeError, match="ujson must be installed") as exc_info:
responses.UJSONResponse(content={"hello": "world"})
assert isinstance(exc_info.value.__cause__, ImportError)
assert str(exc_info.value.__cause__) == "broken ujson"

Loading…
Cancel
Save