Browse Source

Merge d6c58fe43c into 460f8d2cc8

pull/15040/merge
AlberLC 13 hours ago
committed by GitHub
parent
commit
317138a788
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 145
      fastapi/openapi/utils.py
  2. 65
      fastapi/routing.py
  3. 1053
      tests/test_response_class_as_return_annotation.py
  4. 29
      tests/test_response_model_as_return_annotation.py
  5. 1
      tests/test_tutorial/test_response_model/test_tutorial003_02.py
  6. 7
      tests/test_tutorial/test_response_model/test_tutorial003_03.py
  7. 11
      tests/test_tutorial/test_response_model/test_tutorial003_04.py

145
fastapi/openapi/utils.py

@ -337,76 +337,93 @@ def get_openapi_path(
)
callbacks[callback.name] = {callback.path: cb_path}
operation["callbacks"] = callbacks
if route.status_code is not None:
status_code = str(route.status_code)
if route.return_response_classes:
if route.return_response_models:
response_classes = (
current_response_class,
*route.return_response_classes,
)
else:
response_classes = tuple(route.return_response_classes)
else:
# It would probably make more sense for all response classes to have an
# explicit default status_code, and to extract it from them, instead of
# doing this inspection tricks, that would probably be in the future
# TODO: probably make status_code a default class attribute for all
# responses in Starlette
response_signature = inspect.signature(current_response_class.__init__)
status_code_param = response_signature.parameters.get("status_code")
if status_code_param is not None:
if isinstance(status_code_param.default, int):
status_code = str(status_code_param.default)
operation.setdefault("responses", {}).setdefault(status_code, {})[
"description"
] = route.response_description
if is_body_allowed_for_status_code(route.status_code):
# Check for JSONL streaming (generator endpoints)
if route.is_json_stream:
jsonl_content: dict[str, Any] = {}
if route.stream_item_field:
item_schema = get_schema_from_model_field(
field=route.stream_item_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
jsonl_content["itemSchema"] = item_schema
else:
jsonl_content["itemSchema"] = {}
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {})["application/jsonl"] = jsonl_content
elif route.is_sse_stream:
sse_content: dict[str, Any] = {}
item_schema = copy.deepcopy(_SSE_EVENT_SCHEMA)
if route.stream_item_field:
content_schema = get_schema_from_model_field(
field=route.stream_item_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
item_schema["required"] = ["data"]
item_schema["properties"]["data"] = {
"type": "string",
"contentMediaType": "application/json",
"contentSchema": content_schema,
}
sse_content["itemSchema"] = item_schema
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {})["text/event-stream"] = sse_content
elif route_response_media_type:
response_schema = {"type": "string"}
if lenient_issubclass(current_response_class, JSONResponse):
if route.response_field:
response_schema = get_schema_from_model_field(
field=route.response_field,
response_classes = (current_response_class,)
for response_class in response_classes:
if route.status_code is not None:
status_code = str(route.status_code)
else:
# It would probably make more sense for all response classes to have an
# explicit default status_code, and to extract it from them, instead of
# doing this inspection tricks, that would probably be in the future
# TODO: probably make status_code a default class attribute for all
# responses in Starlette
response_signature = inspect.signature(response_class.__init__)
status_code_param = response_signature.parameters.get("status_code")
if status_code_param is not None:
if isinstance(status_code_param.default, int):
status_code = str(status_code_param.default)
operation.setdefault("responses", {}).setdefault(status_code, {})[
"description"
] = route.response_description
if is_body_allowed_for_status_code(route.status_code):
return_response_media_type: str | None = response_class.media_type
# Check for JSONL streaming (generator endpoints)
if route.is_json_stream:
jsonl_content: dict[str, Any] = {}
if route.stream_item_field:
item_schema = get_schema_from_model_field(
field=route.stream_item_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
jsonl_content["itemSchema"] = item_schema
else:
response_schema = {}
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {}).setdefault(
route_response_media_type, {}
)["schema"] = response_schema
jsonl_content["itemSchema"] = {}
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {})["application/jsonl"] = jsonl_content
elif route.is_sse_stream:
sse_content: dict[str, Any] = {}
item_schema = copy.deepcopy(_SSE_EVENT_SCHEMA)
if route.stream_item_field:
content_schema = get_schema_from_model_field(
field=route.stream_item_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
item_schema["required"] = ["data"]
item_schema["properties"]["data"] = {
"type": "string",
"contentMediaType": "application/json",
"contentSchema": content_schema,
}
sse_content["itemSchema"] = item_schema
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {})["text/event-stream"] = sse_content
elif return_response_media_type:
response_schema = {"type": "string"}
if lenient_issubclass(response_class, JSONResponse):
if route.response_field:
response_schema = get_schema_from_model_field(
field=route.response_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
else:
response_schema = {}
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {}).setdefault(
return_response_media_type, {}
)["schema"] = response_schema
if route.responses:
operation_responses = operation.setdefault("responses", {})
for (

65
fastapi/routing.py

@ -3,6 +3,7 @@ import email.message
import functools
import inspect
import json
import operator
import types
from collections.abc import (
AsyncIterator,
@ -26,7 +27,10 @@ from typing import (
Annotated,
Any,
TypeVar,
Union,
cast,
get_args,
get_origin,
)
import anyio
@ -65,7 +69,7 @@ from fastapi.sse import (
ServerSentEvent,
format_sse_event,
)
from fastapi.types import DecoratedCallable, IncEx
from fastapi.types import DecoratedCallable, IncEx, UnionType
from fastapi.utils import (
create_model_field,
generate_unique_id,
@ -844,27 +848,54 @@ class APIRoute(routing.Route):
self.path = path
self.endpoint = endpoint
self.stream_item_type: Any | None = None
self.return_response_models = []
self.return_response_classes = []
if isinstance(response_model, DefaultPlaceholder):
return_annotation = get_typed_return_annotation(endpoint)
if lenient_issubclass(return_annotation, Response):
stream_item = get_stream_item_type(return_annotation)
if stream_item is not None:
# Extract item type for JSONL or SSE streaming when
# response_class is DefaultPlaceholder (JSONL) or
# EventSourceResponse (SSE).
# ServerSentEvent is excluded: it's a transport
# wrapper, not a data model, so it shouldn't feed
# into validation or OpenAPI schema generation.
if (
isinstance(response_class, DefaultPlaceholder)
or lenient_issubclass(response_class, EventSourceResponse)
) and not lenient_issubclass(stream_item, ServerSentEvent):
self.stream_item_type = stream_item
response_model = None
else:
stream_item = get_stream_item_type(return_annotation)
if stream_item is not None:
# Extract item type for JSONL or SSE streaming when
# response_class is DefaultPlaceholder (JSONL) or
# EventSourceResponse (SSE).
# ServerSentEvent is excluded: it's a transport
# wrapper, not a data model, so it shouldn't feed
# into validation or OpenAPI schema generation.
if (
isinstance(response_class, DefaultPlaceholder)
or lenient_issubclass(response_class, EventSourceResponse)
) and not lenient_issubclass(stream_item, ServerSentEvent):
self.stream_item_type = stream_item
response_model = None
origin = get_origin(return_annotation)
if origin is Union or origin is UnionType:
for arg in get_args(return_annotation):
if arg is type(None):
continue
if lenient_issubclass(arg, Response):
self.return_response_classes.append(arg)
else:
self.return_response_models.append(arg)
elif lenient_issubclass(return_annotation, Response):
self.return_response_classes.append(return_annotation)
else:
response_model = return_annotation
self.return_response_models.append(return_annotation)
if self.return_response_models:
if len(self.return_response_models) == 1:
response_model = self.return_response_models[0]
else:
response_model = functools.reduce(
operator.or_, self.return_response_models
)
else:
response_model = None
self.response_model = response_model
self.summary = summary
self.response_description = response_description

1053
tests/test_response_class_as_return_annotation.py

File diff suppressed because it is too large

29
tests/test_response_model_as_return_annotation.py

@ -1,6 +1,6 @@
import pytest
from fastapi import FastAPI
from fastapi.exceptions import FastAPIError, ResponseValidationError
from fastapi.exceptions import ResponseValidationError
from fastapi.responses import JSONResponse, Response
from fastapi.testclient import TestClient
from inline_snapshot import snapshot
@ -248,6 +248,11 @@ def no_response_model_annotation_json_response_class() -> JSONResponse:
return JSONResponse(content={"foo": "bar"})
@app.get("/no_response_model-annotation_response_or_none")
def no_response_model_annotation_response_or_none() -> Response | None:
return Response(content="Foo")
client = TestClient(app)
@ -496,16 +501,10 @@ def test_no_response_model_annotation_json_response_class():
assert response.json() == {"foo": "bar"}
def test_invalid_response_model_field():
app = FastAPI()
with pytest.raises(FastAPIError) as e:
@app.get("/")
def read_root() -> Response | None:
return Response(content="Foo") # pragma: no cover
assert "valid Pydantic field type" in e.value.args[0]
assert "parameter response_model=None" in e.value.args[0]
def test_no_response_model_annotation_response_or_none():
response = client.get("/no_response_model-annotation_response_or_none")
assert response.status_code == 200
assert response.text == "Foo"
def test_openapi_schema():
@ -1077,7 +1076,6 @@ def test_openapi_schema():
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
}
@ -1094,6 +1092,13 @@ def test_openapi_schema():
},
}
},
"/no_response_model-annotation_response_or_none": {
"get": {
"summary": "No Response Model Annotation Response Or None",
"operationId": "no_response_model_annotation_response_or_none_no_response_model_annotation_response_or_none_get",
"responses": {"200": {"description": "Successful Response"}},
}
},
},
"components": {
"schemas": {

1
tests/test_tutorial/test_response_model/test_tutorial003_02.py

@ -45,7 +45,6 @@ def test_openapi_schema():
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"422": {
"description": "Validation Error",

7
tests/test_tutorial/test_response_model/test_tutorial003_03.py

@ -24,12 +24,7 @@ def test_openapi_schema():
"get": {
"summary": "Get Teleport",
"operationId": "get_teleport_teleport_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"responses": {"307": {"description": "Successful Response"}},
}
}
},

11
tests/test_tutorial/test_response_model/test_tutorial003_04.py

@ -1,17 +1,20 @@
import importlib
import pytest
from fastapi.exceptions import FastAPIError
from ...utils import needs_py310
# Previously, unions including `Response` in the return annotation were
# considered invalid and raised FastAPIError at import time.
# They are now supported as part of the enhanced return annotation handling.
# Importing the module should not raise FastAPIError anymore.
@pytest.mark.parametrize(
"module_name",
[
pytest.param("tutorial003_04_py310", marks=needs_py310),
],
)
def test_invalid_response_model(module_name: str) -> None:
with pytest.raises(FastAPIError):
importlib.import_module(f"docs_src.response_model.{module_name}")
def test_response_union_with_response_is_valid(module_name: str) -> None:
module = importlib.import_module(f"docs_src.response_model.{module_name}")
assert module is not None

Loading…
Cancel
Save