AlberLC 4 weeks ago
committed by GitHub
parent
commit
317138a788
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 145
      fastapi/openapi/utils.py
  2. 65
      fastapi/routing.py
  3. 1053
      tests/test_response_class_as_return_annotation.py
  4. 29
      tests/test_response_model_as_return_annotation.py
  5. 1
      tests/test_tutorial/test_response_model/test_tutorial003_02.py
  6. 7
      tests/test_tutorial/test_response_model/test_tutorial003_03.py
  7. 11
      tests/test_tutorial/test_response_model/test_tutorial003_04.py

145
fastapi/openapi/utils.py

@ -337,76 +337,93 @@ def get_openapi_path(
) )
callbacks[callback.name] = {callback.path: cb_path} callbacks[callback.name] = {callback.path: cb_path}
operation["callbacks"] = callbacks operation["callbacks"] = callbacks
if route.status_code is not None:
status_code = str(route.status_code) if route.return_response_classes:
if route.return_response_models:
response_classes = (
current_response_class,
*route.return_response_classes,
)
else:
response_classes = tuple(route.return_response_classes)
else: else:
# It would probably make more sense for all response classes to have an response_classes = (current_response_class,)
# explicit default status_code, and to extract it from them, instead of
# doing this inspection tricks, that would probably be in the future for response_class in response_classes:
# TODO: probably make status_code a default class attribute for all if route.status_code is not None:
# responses in Starlette status_code = str(route.status_code)
response_signature = inspect.signature(current_response_class.__init__) else:
status_code_param = response_signature.parameters.get("status_code") # It would probably make more sense for all response classes to have an
if status_code_param is not None: # explicit default status_code, and to extract it from them, instead of
if isinstance(status_code_param.default, int): # doing this inspection tricks, that would probably be in the future
status_code = str(status_code_param.default) # TODO: probably make status_code a default class attribute for all
operation.setdefault("responses", {}).setdefault(status_code, {})[ # responses in Starlette
"description" response_signature = inspect.signature(response_class.__init__)
] = route.response_description status_code_param = response_signature.parameters.get("status_code")
if is_body_allowed_for_status_code(route.status_code): if status_code_param is not None:
# Check for JSONL streaming (generator endpoints) if isinstance(status_code_param.default, int):
if route.is_json_stream: status_code = str(status_code_param.default)
jsonl_content: dict[str, Any] = {} operation.setdefault("responses", {}).setdefault(status_code, {})[
if route.stream_item_field: "description"
item_schema = get_schema_from_model_field( ] = route.response_description
field=route.stream_item_field,
model_name_map=model_name_map, if is_body_allowed_for_status_code(route.status_code):
field_mapping=field_mapping, return_response_media_type: str | None = response_class.media_type
separate_input_output_schemas=separate_input_output_schemas,
) # Check for JSONL streaming (generator endpoints)
jsonl_content["itemSchema"] = item_schema if route.is_json_stream:
else: jsonl_content: dict[str, Any] = {}
jsonl_content["itemSchema"] = {} if route.stream_item_field:
operation.setdefault("responses", {}).setdefault( item_schema = get_schema_from_model_field(
status_code, {} field=route.stream_item_field,
).setdefault("content", {})["application/jsonl"] = jsonl_content
elif route.is_sse_stream:
sse_content: dict[str, Any] = {}
item_schema = copy.deepcopy(_SSE_EVENT_SCHEMA)
if route.stream_item_field:
content_schema = get_schema_from_model_field(
field=route.stream_item_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
item_schema["required"] = ["data"]
item_schema["properties"]["data"] = {
"type": "string",
"contentMediaType": "application/json",
"contentSchema": content_schema,
}
sse_content["itemSchema"] = item_schema
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {})["text/event-stream"] = sse_content
elif route_response_media_type:
response_schema = {"type": "string"}
if lenient_issubclass(current_response_class, JSONResponse):
if route.response_field:
response_schema = get_schema_from_model_field(
field=route.response_field,
model_name_map=model_name_map, model_name_map=model_name_map,
field_mapping=field_mapping, field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas, separate_input_output_schemas=separate_input_output_schemas,
) )
jsonl_content["itemSchema"] = item_schema
else: else:
response_schema = {} jsonl_content["itemSchema"] = {}
operation.setdefault("responses", {}).setdefault( operation.setdefault("responses", {}).setdefault(
status_code, {} status_code, {}
).setdefault("content", {}).setdefault( ).setdefault("content", {})["application/jsonl"] = jsonl_content
route_response_media_type, {} elif route.is_sse_stream:
)["schema"] = response_schema sse_content: dict[str, Any] = {}
item_schema = copy.deepcopy(_SSE_EVENT_SCHEMA)
if route.stream_item_field:
content_schema = get_schema_from_model_field(
field=route.stream_item_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
item_schema["required"] = ["data"]
item_schema["properties"]["data"] = {
"type": "string",
"contentMediaType": "application/json",
"contentSchema": content_schema,
}
sse_content["itemSchema"] = item_schema
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {})["text/event-stream"] = sse_content
elif return_response_media_type:
response_schema = {"type": "string"}
if lenient_issubclass(response_class, JSONResponse):
if route.response_field:
response_schema = get_schema_from_model_field(
field=route.response_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
else:
response_schema = {}
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {}).setdefault(
return_response_media_type, {}
)["schema"] = response_schema
if route.responses: if route.responses:
operation_responses = operation.setdefault("responses", {}) operation_responses = operation.setdefault("responses", {})
for ( for (

65
fastapi/routing.py

@ -3,6 +3,7 @@ import email.message
import functools import functools
import inspect import inspect
import json import json
import operator
import types import types
from collections.abc import ( from collections.abc import (
AsyncIterator, AsyncIterator,
@ -26,7 +27,10 @@ from typing import (
Annotated, Annotated,
Any, Any,
TypeVar, TypeVar,
Union,
cast, cast,
get_args,
get_origin,
) )
import anyio import anyio
@ -65,7 +69,7 @@ from fastapi.sse import (
ServerSentEvent, ServerSentEvent,
format_sse_event, format_sse_event,
) )
from fastapi.types import DecoratedCallable, IncEx from fastapi.types import DecoratedCallable, IncEx, UnionType
from fastapi.utils import ( from fastapi.utils import (
create_model_field, create_model_field,
generate_unique_id, generate_unique_id,
@ -844,27 +848,54 @@ class APIRoute(routing.Route):
self.path = path self.path = path
self.endpoint = endpoint self.endpoint = endpoint
self.stream_item_type: Any | None = None self.stream_item_type: Any | None = None
self.return_response_models = []
self.return_response_classes = []
if isinstance(response_model, DefaultPlaceholder): if isinstance(response_model, DefaultPlaceholder):
return_annotation = get_typed_return_annotation(endpoint) return_annotation = get_typed_return_annotation(endpoint)
if lenient_issubclass(return_annotation, Response): stream_item = get_stream_item_type(return_annotation)
if stream_item is not None:
# Extract item type for JSONL or SSE streaming when
# response_class is DefaultPlaceholder (JSONL) or
# EventSourceResponse (SSE).
# ServerSentEvent is excluded: it's a transport
# wrapper, not a data model, so it shouldn't feed
# into validation or OpenAPI schema generation.
if (
isinstance(response_class, DefaultPlaceholder)
or lenient_issubclass(response_class, EventSourceResponse)
) and not lenient_issubclass(stream_item, ServerSentEvent):
self.stream_item_type = stream_item
response_model = None response_model = None
else: else:
stream_item = get_stream_item_type(return_annotation) origin = get_origin(return_annotation)
if stream_item is not None:
# Extract item type for JSONL or SSE streaming when if origin is Union or origin is UnionType:
# response_class is DefaultPlaceholder (JSONL) or for arg in get_args(return_annotation):
# EventSourceResponse (SSE). if arg is type(None):
# ServerSentEvent is excluded: it's a transport continue
# wrapper, not a data model, so it shouldn't feed
# into validation or OpenAPI schema generation. if lenient_issubclass(arg, Response):
if ( self.return_response_classes.append(arg)
isinstance(response_class, DefaultPlaceholder) else:
or lenient_issubclass(response_class, EventSourceResponse) self.return_response_models.append(arg)
) and not lenient_issubclass(stream_item, ServerSentEvent): elif lenient_issubclass(return_annotation, Response):
self.stream_item_type = stream_item self.return_response_classes.append(return_annotation)
response_model = None
else: else:
response_model = return_annotation self.return_response_models.append(return_annotation)
if self.return_response_models:
if len(self.return_response_models) == 1:
response_model = self.return_response_models[0]
else:
response_model = functools.reduce(
operator.or_, self.return_response_models
)
else:
response_model = None
self.response_model = response_model self.response_model = response_model
self.summary = summary self.summary = summary
self.response_description = response_description self.response_description = response_description

1053
tests/test_response_class_as_return_annotation.py

File diff suppressed because it is too large

29
tests/test_response_model_as_return_annotation.py

@ -1,6 +1,6 @@
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.exceptions import FastAPIError, ResponseValidationError from fastapi.exceptions import ResponseValidationError
from fastapi.responses import JSONResponse, Response from fastapi.responses import JSONResponse, Response
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from inline_snapshot import snapshot from inline_snapshot import snapshot
@ -248,6 +248,11 @@ def no_response_model_annotation_json_response_class() -> JSONResponse:
return JSONResponse(content={"foo": "bar"}) return JSONResponse(content={"foo": "bar"})
@app.get("/no_response_model-annotation_response_or_none")
def no_response_model_annotation_response_or_none() -> Response | None:
return Response(content="Foo")
client = TestClient(app) client = TestClient(app)
@ -496,16 +501,10 @@ def test_no_response_model_annotation_json_response_class():
assert response.json() == {"foo": "bar"} assert response.json() == {"foo": "bar"}
def test_invalid_response_model_field(): def test_no_response_model_annotation_response_or_none():
app = FastAPI() response = client.get("/no_response_model-annotation_response_or_none")
with pytest.raises(FastAPIError) as e: assert response.status_code == 200
assert response.text == "Foo"
@app.get("/")
def read_root() -> Response | None:
return Response(content="Foo") # pragma: no cover
assert "valid Pydantic field type" in e.value.args[0]
assert "parameter response_model=None" in e.value.args[0]
def test_openapi_schema(): def test_openapi_schema():
@ -1077,7 +1076,6 @@ def test_openapi_schema():
"responses": { "responses": {
"200": { "200": {
"description": "Successful Response", "description": "Successful Response",
"content": {"application/json": {"schema": {}}},
} }
}, },
} }
@ -1094,6 +1092,13 @@ def test_openapi_schema():
}, },
} }
}, },
"/no_response_model-annotation_response_or_none": {
"get": {
"summary": "No Response Model Annotation Response Or None",
"operationId": "no_response_model_annotation_response_or_none_no_response_model_annotation_response_or_none_get",
"responses": {"200": {"description": "Successful Response"}},
}
},
}, },
"components": { "components": {
"schemas": { "schemas": {

1
tests/test_tutorial/test_response_model/test_tutorial003_02.py

@ -45,7 +45,6 @@ def test_openapi_schema():
"responses": { "responses": {
"200": { "200": {
"description": "Successful Response", "description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}, },
"422": { "422": {
"description": "Validation Error", "description": "Validation Error",

7
tests/test_tutorial/test_response_model/test_tutorial003_03.py

@ -24,12 +24,7 @@ def test_openapi_schema():
"get": { "get": {
"summary": "Get Teleport", "summary": "Get Teleport",
"operationId": "get_teleport_teleport_get", "operationId": "get_teleport_teleport_get",
"responses": { "responses": {"307": {"description": "Successful Response"}},
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
} }
} }
}, },

11
tests/test_tutorial/test_response_model/test_tutorial003_04.py

@ -1,17 +1,20 @@
import importlib import importlib
import pytest import pytest
from fastapi.exceptions import FastAPIError
from ...utils import needs_py310 from ...utils import needs_py310
# Previously, unions including `Response` in the return annotation were
# considered invalid and raised FastAPIError at import time.
# They are now supported as part of the enhanced return annotation handling.
# Importing the module should not raise FastAPIError anymore.
@pytest.mark.parametrize( @pytest.mark.parametrize(
"module_name", "module_name",
[ [
pytest.param("tutorial003_04_py310", marks=needs_py310), pytest.param("tutorial003_04_py310", marks=needs_py310),
], ],
) )
def test_invalid_response_model(module_name: str) -> None: def test_response_union_with_response_is_valid(module_name: str) -> None:
with pytest.raises(FastAPIError): module = importlib.import_module(f"docs_src.response_model.{module_name}")
importlib.import_module(f"docs_src.response_model.{module_name}") assert module is not None

Loading…
Cancel
Save