Browse Source

fix test and apply pydantic2.x case

pull/4573/head
jujumilk3 1 year ago
parent
commit
28d91bb3cd
  1. 187
      fastapi/dependencies/utils.py
  2. 186
      tests/test_dependency_schema_query.py

187
fastapi/dependencies/utils.py

@ -119,9 +119,7 @@ def get_param_sub_dependant(
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
assert callable( assert callable(depends.dependency), "A parameter-less dependency must have a callable dependency"
depends.dependency
), "A parameter-less dependency must have a callable dependency"
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
@ -142,9 +140,7 @@ def get_sub_dependant(
use_scopes: List[str] = [] use_scopes: List[str] = []
if isinstance(dependency, (OAuth2, OpenIdConnect)): if isinstance(dependency, (OAuth2, OpenIdConnect)):
use_scopes = security_scopes use_scopes = security_scopes
security_requirement = SecurityRequirement( security_requirement = SecurityRequirement(security_scheme=dependency, scopes=use_scopes)
security_scheme=dependency, scopes=use_scopes
)
sub_dependant = get_dependant( sub_dependant = get_dependant(
path=path, path=path,
call=dependency, call=dependency,
@ -183,9 +179,7 @@ def get_flat_dependant(
for sub_dependant in dependant.dependencies: for sub_dependant in dependant.dependencies:
if skip_repeats and sub_dependant.cache_key in visited: if skip_repeats and sub_dependant.cache_key in visited:
continue continue
flat_sub = get_flat_dependant( flat_sub = get_flat_dependant(sub_dependant, skip_repeats=skip_repeats, visited=visited)
sub_dependant, skip_repeats=skip_repeats, visited=visited
)
flat_dependant.path_params.extend(flat_sub.path_params) flat_dependant.path_params.extend(flat_sub.path_params)
flat_dependant.query_params.extend(flat_sub.query_params) flat_dependant.query_params.extend(flat_sub.query_params)
flat_dependant.header_params.extend(flat_sub.header_params) flat_dependant.header_params.extend(flat_sub.header_params)
@ -208,29 +202,38 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = inspect.signature(call) signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {}) globalns = getattr(call, "__globals__", {})
fields = getattr(call, "__fields__", {}) fields = getattr(call, "model_fields", {})
if len(fields): if len(fields):
alias_dict = {}
query_extra_info = {} query_extra_info = {}
for param in fields: for param in fields:
query_extra_info[param] = dict(fields[param].field_info.__repr_args__()) query_extra_info[param] = dict(fields[param].__repr_args__())
if "alias" in query_extra_info[param]: if "alias" in query_extra_info[param]:
query_extra_info[query_extra_info[param]["alias"]] = dict( query_extra_info[query_extra_info[param]["alias"]] = dict(fields[param].__repr_args__())
fields[param].field_info.__repr_args__() alias_dict[query_extra_info[param]["alias"]] = param
)
query_extra_info[param]["default"] = ( query_extra_info[param]["default"] = (
Required Required if getattr(fields[param], "required", False) else fields[param].default
if getattr(fields[param], "required", False)
else fields[param].default
)
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=params.Param(**query_extra_info[param.name]),
annotation=get_typed_annotation(param.annotation, globalns),
) )
for param in signature.parameters.values() typed_params = []
]
for param in signature.parameters.values():
if param.name in alias_dict:
original_param_name = alias_dict[param.name]
created_param = inspect.Parameter(
name=original_param_name,
kind=param.kind,
default=params.Param(**query_extra_info[original_param_name]),
annotation=get_typed_annotation(param.annotation, globalns),
)
else:
created_param = inspect.Parameter(
name=param.name,
kind=param.kind,
default=params.Param(**query_extra_info[param.name]),
annotation=get_typed_annotation(param.annotation, globalns),
)
typed_params.append(created_param)
else: else:
typed_params = [ typed_params = [
inspect.Parameter( inspect.Parameter(
@ -303,9 +306,7 @@ def get_dependant(
type_annotation=type_annotation, type_annotation=type_annotation,
dependant=dependant, dependant=dependant,
): ):
assert ( assert param_field is None, f"Cannot specify multiple FastAPI annotations for {param_name!r}"
param_field is None
), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
continue continue
assert param_field is not None assert param_field is not None
if is_body_param(param_field=param_field, is_path_param=is_path_param): if is_body_param(param_field=param_field, is_path_param=is_path_param):
@ -315,9 +316,7 @@ def get_dependant(
return dependant return dependant
def add_non_field_param_to_dependency( def add_non_field_param_to_dependency(*, param_name: str, type_annotation: Any, dependant: Dependant) -> Optional[bool]:
*, param_name: str, type_annotation: Any, dependant: Dependant
) -> Optional[bool]:
if lenient_issubclass(type_annotation, Request): if lenient_issubclass(type_annotation, Request):
dependant.request_param_name = param_name dependant.request_param_name = param_name
return True return True
@ -349,26 +348,17 @@ def analyze_param(
field_info = None field_info = None
depends = None depends = None
type_annotation: Any = Any type_annotation: Any = Any
if ( if annotation is not inspect.Signature.empty and get_origin(annotation) is Annotated:
annotation is not inspect.Signature.empty
and get_origin(annotation) is Annotated
):
annotated_args = get_args(annotation) annotated_args = get_args(annotation)
type_annotation = annotated_args[0] type_annotation = annotated_args[0]
fastapi_annotations = [ fastapi_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, (FieldInfo, params.Depends))]
arg
for arg in annotated_args[1:]
if isinstance(arg, (FieldInfo, params.Depends))
]
assert ( assert (
len(fastapi_annotations) <= 1 len(fastapi_annotations) <= 1
), f"Cannot specify multiple `Annotated` FastAPI arguments for {param_name!r}" ), f"Cannot specify multiple `Annotated` FastAPI arguments for {param_name!r}"
fastapi_annotation = next(iter(fastapi_annotations), None) fastapi_annotation = next(iter(fastapi_annotations), None)
if isinstance(fastapi_annotation, FieldInfo): if isinstance(fastapi_annotation, FieldInfo):
# Copy `field_info` because we mutate `field_info.default` below. # Copy `field_info` because we mutate `field_info.default` below.
field_info = copy_field_info( field_info = copy_field_info(field_info=fastapi_annotation, annotation=annotation)
field_info=fastapi_annotation, annotation=annotation
)
assert field_info.default is Undefined or field_info.default is Required, ( assert field_info.default is Undefined or field_info.default is Required, (
f"`{field_info.__class__.__name__}` default value cannot be set in" f"`{field_info.__class__.__name__}` default value cannot be set in"
f" `Annotated` for {param_name!r}. Set the default value with `=` instead." f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
@ -385,8 +375,7 @@ def analyze_param(
if isinstance(value, params.Depends): if isinstance(value, params.Depends):
assert depends is None, ( assert depends is None, (
"Cannot specify `Depends` in `Annotated` and default value" "Cannot specify `Depends` in `Annotated` and default value" f" together for {param_name!r}"
f" together for {param_name!r}"
) )
assert field_info is None, ( assert field_info is None, (
"Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a" "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a"
@ -395,8 +384,7 @@ def analyze_param(
depends = value depends = value
elif isinstance(value, FieldInfo): elif isinstance(value, FieldInfo):
assert field_info is None, ( assert field_info is None, (
"Cannot specify FastAPI annotations in `Annotated` and default value" "Cannot specify FastAPI annotations in `Annotated` and default value" f" together for {param_name!r}"
f" together for {param_name!r}"
) )
field_info = value field_info = value
if PYDANTIC_V2: if PYDANTIC_V2:
@ -417,9 +405,7 @@ def analyze_param(
), ),
): ):
assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}" assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}"
assert ( assert field_info is None, f"Cannot specify FastAPI annotation for type {type_annotation!r}"
field_info is None
), f"Cannot specify FastAPI annotation for type {type_annotation!r}"
elif field_info is None and depends is None: elif field_info is None and depends is None:
default_value = value if value is not inspect.Signature.empty else Required default_value = value if value is not inspect.Signature.empty else Required
if is_path_param: if is_path_param:
@ -427,9 +413,9 @@ def analyze_param(
# parameter might sometimes be a path parameter and sometimes not. See # parameter might sometimes be a path parameter and sometimes not. See
# `tests/test_infer_param_optionality.py` for an example. # `tests/test_infer_param_optionality.py` for an example.
field_info = params.Path(annotation=type_annotation) field_info = params.Path(annotation=type_annotation)
elif is_uploadfile_or_nonable_uploadfile_annotation( elif is_uploadfile_or_nonable_uploadfile_annotation(type_annotation) or is_uploadfile_sequence_annotation(
type_annotation type_annotation
) or is_uploadfile_sequence_annotation(type_annotation): ):
field_info = params.File(annotation=type_annotation, default=default_value) field_info = params.File(annotation=type_annotation, default=default_value)
elif not field_annotation_is_scalar(annotation=type_annotation): elif not field_annotation_is_scalar(annotation=type_annotation):
field_info = params.Body(annotation=type_annotation, default=default_value) field_info = params.Body(annotation=type_annotation, default=default_value)
@ -440,13 +426,9 @@ def analyze_param(
if field_info is not None: if field_info is not None:
if is_path_param: if is_path_param:
assert isinstance(field_info, params.Path), ( assert isinstance(field_info, params.Path), (
f"Cannot use `{field_info.__class__.__name__}` for path param" f"Cannot use `{field_info.__class__.__name__}` for path param" f" {param_name!r}"
f" {param_name!r}"
) )
elif ( elif isinstance(field_info, params.Param) and getattr(field_info, "in_", None) is None:
isinstance(field_info, params.Param)
and getattr(field_info, "in_", None) is None
):
field_info.in_ = params.ParamTypes.query field_info.in_ = params.ParamTypes.query
use_annotation = get_annotation_from_field_info( use_annotation = get_annotation_from_field_info(
type_annotation, type_annotation,
@ -472,15 +454,11 @@ def analyze_param(
def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
if is_path_param: if is_path_param:
assert is_scalar_field( assert is_scalar_field(field=param_field), "Path params must be of one of the supported types"
field=param_field
), "Path params must be of one of the supported types"
return False return False
elif is_scalar_field(field=param_field): elif is_scalar_field(field=param_field):
return False return False
elif isinstance( elif isinstance(param_field.field_info, (params.Query, params.Header)) and is_scalar_sequence_field(param_field):
param_field.field_info, (params.Query, params.Header)
) and is_scalar_sequence_field(param_field):
return False return False
else: else:
assert isinstance( assert isinstance(
@ -527,9 +505,7 @@ def is_gen_callable(call: Callable[..., Any]) -> bool:
return inspect.isgeneratorfunction(dunder_call) return inspect.isgeneratorfunction(dunder_call)
async def solve_generator( async def solve_generator(*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]) -> Any:
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
) -> Any:
if is_gen_callable(call): if is_gen_callable(call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
elif is_async_gen_callable(call): elif is_async_gen_callable(call):
@ -563,19 +539,12 @@ async def solve_dependencies(
sub_dependant: Dependant sub_dependant: Dependant
for sub_dependant in dependant.dependencies: for sub_dependant in dependant.dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast( sub_dependant.cache_key = cast(Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key)
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
)
call = sub_dependant.call call = sub_dependant.call
use_sub_dependant = sub_dependant use_sub_dependant = sub_dependant
if ( if dependency_overrides_provider and dependency_overrides_provider.dependency_overrides:
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
):
original_call = sub_dependant.call original_call = sub_dependant.call
call = getattr( call = getattr(dependency_overrides_provider, "dependency_overrides", {}).get(original_call, original_call)
dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call)
use_path: str = sub_dependant.path # type: ignore use_path: str = sub_dependant.path # type: ignore
use_sub_dependant = get_dependant( use_sub_dependant = get_dependant(
path=use_path, path=use_path,
@ -609,9 +578,7 @@ async def solve_dependencies(
elif is_gen_callable(call) or is_async_gen_callable(call): elif is_gen_callable(call) or is_async_gen_callable(call):
stack = request.scope.get("fastapi_astack") stack = request.scope.get("fastapi_astack")
assert isinstance(stack, AsyncExitStack) assert isinstance(stack, AsyncExitStack)
solved = await solve_generator( solved = await solve_generator(call=call, stack=stack, sub_values=sub_values)
call=call, stack=stack, sub_values=sub_values
)
elif is_coroutine_callable(call): elif is_coroutine_callable(call):
solved = await call(**sub_values) solved = await call(**sub_values)
else: else:
@ -620,18 +587,10 @@ async def solve_dependencies(
values[sub_dependant.name] = solved values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache: if sub_dependant.cache_key not in dependency_cache:
dependency_cache[sub_dependant.cache_key] = solved dependency_cache[sub_dependant.cache_key] = solved
path_values, path_errors = request_params_to_args( path_values, path_errors = request_params_to_args(dependant.path_params, request.path_params)
dependant.path_params, request.path_params query_values, query_errors = request_params_to_args(dependant.query_params, request.query_params)
) header_values, header_errors = request_params_to_args(dependant.header_params, request.headers)
query_values, query_errors = request_params_to_args( cookie_values, cookie_errors = request_params_to_args(dependant.cookie_params, request.cookies)
dependant.query_params, request.query_params
)
header_values, header_errors = request_params_to_args(
dependant.header_params, request.headers
)
cookie_values, cookie_errors = request_params_to_args(
dependant.cookie_params, request.cookies
)
values.update(path_values) values.update(path_values)
values.update(query_values) values.update(query_values)
values.update(header_values) values.update(header_values)
@ -659,9 +618,7 @@ async def solve_dependencies(
if dependant.response_param_name: if dependant.response_param_name:
values[dependant.response_param_name] = response values[dependant.response_param_name] = response
if dependant.security_scopes_param_name: if dependant.security_scopes_param_name:
values[dependant.security_scopes_param_name] = SecurityScopes( values[dependant.security_scopes_param_name] = SecurityScopes(scopes=dependant.security_scopes)
scopes=dependant.security_scopes
)
return values, errors, background_tasks, response, dependency_cache return values, errors, background_tasks, response, dependency_cache
@ -672,16 +629,12 @@ def request_params_to_args(
values = {} values = {}
errors = [] errors = []
for field in required_params: for field in required_params:
if is_scalar_sequence_field(field) and isinstance( if is_scalar_sequence_field(field) and isinstance(received_params, (QueryParams, Headers)):
received_params, (QueryParams, Headers)
):
value = received_params.getlist(field.alias) or field.default value = received_params.getlist(field.alias) or field.default
else: else:
value = received_params.get(field.alias) value = received_params.get(field.alias)
field_info = field.field_info field_info = field.field_info
assert isinstance( assert isinstance(field_info, params.Param), "Params must be subclasses of Param"
field_info, params.Param
), "Params must be subclasses of Param"
loc = (field_info.in_.value, field.alias) loc = (field_info.in_.value, field.alias)
if value is None: if value is None:
if field.required: if field.required:
@ -734,35 +687,21 @@ async def request_body_to_args(
if ( if (
value is None value is None
or (isinstance(field_info, params.Form) and value == "") or (isinstance(field_info, params.Form) and value == "")
or ( or (isinstance(field_info, params.Form) and is_sequence_field(field) and len(value) == 0)
isinstance(field_info, params.Form)
and is_sequence_field(field)
and len(value) == 0
)
): ):
if field.required: if field.required:
errors.append(get_missing_field_error(loc)) errors.append(get_missing_field_error(loc))
else: else:
values[field.name] = deepcopy(field.default) values[field.name] = deepcopy(field.default)
continue continue
if ( if isinstance(field_info, params.File) and is_bytes_field(field) and isinstance(value, UploadFile):
isinstance(field_info, params.File)
and is_bytes_field(field)
and isinstance(value, UploadFile)
):
value = await value.read() value = await value.read()
elif ( elif is_bytes_sequence_field(field) and isinstance(field_info, params.File) and value_is_sequence(value):
is_bytes_sequence_field(field)
and isinstance(field_info, params.File)
and value_is_sequence(value)
):
# For types # For types
assert isinstance(value, sequence_types) # type: ignore[arg-type] assert isinstance(value, sequence_types) # type: ignore[arg-type]
results: List[Union[bytes, str]] = [] results: List[Union[bytes, str]] = []
async def process_fn( async def process_fn(fn: Callable[[], Coroutine[Any, Any, Any]]) -> None:
fn: Callable[[], Coroutine[Any, Any, Any]]
) -> None:
result = await fn() result = await fn()
results.append(result) # noqa: B023 results.append(result) # noqa: B023
@ -799,9 +738,7 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
for param in flat_dependant.body_params: for param in flat_dependant.body_params:
setattr(param.field_info, "embed", True) # noqa: B010 setattr(param.field_info, "embed", True) # noqa: B010
model_name = "Body_" + name model_name = "Body_" + name
BodyModel = create_body_model( BodyModel = create_body_model(fields=flat_dependant.body_params, model_name=model_name)
fields=flat_dependant.body_params, model_name=model_name
)
required = any(True for f in flat_dependant.body_params if f.required) required = any(True for f in flat_dependant.body_params if f.required)
BodyFieldInfo_kwargs: Dict[str, Any] = { BodyFieldInfo_kwargs: Dict[str, Any] = {
"annotation": BodyModel, "annotation": BodyModel,
@ -817,9 +754,7 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
BodyFieldInfo = params.Body BodyFieldInfo = params.Body
body_param_media_types = [ body_param_media_types = [
f.field_info.media_type f.field_info.media_type for f in flat_dependant.body_params if isinstance(f.field_info, params.Body)
for f in flat_dependant.body_params
if isinstance(f.field_info, params.Body)
] ]
if len(set(body_param_media_types)) == 1: if len(set(body_param_media_types)) == 1:
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]

186
tests/test_dependency_schema_query.py

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from fastapi import Depends, FastAPI, Query from fastapi import Depends, FastAPI, Query, status
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from pydantic import BaseModel from pydantic import BaseModel
@ -14,7 +14,9 @@ class Item(BaseModel):
name_required_without_default: str = Query( name_required_without_default: str = Query(
description="This is a name_required_without_default field." description="This is a name_required_without_default field."
) )
optional_int: Optional[int] = Query(description="This is a optional_int field") optional_int: Optional[int] = Query(
default=None, description="This is a optional_int field"
)
optional_str: Optional[str] = Query( optional_str: Optional[str] = Query(
"default_exists", description="This is a optional_str field" "default_exists", description="This is a optional_str field"
) )
@ -38,7 +40,7 @@ async def item_with_query_dependency(item: Item = Depends()):
client = TestClient(app) client = TestClient(app)
openapi_schema_with_not_omitted_description = { openapi_schema_with_not_omitted_description = {
"openapi": "3.0.2", "openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"}, "info": {"title": "FastAPI", "version": "0.1.0"},
"paths": { "paths": {
"/item": { "/item": {
@ -47,95 +49,103 @@ openapi_schema_with_not_omitted_description = {
"operationId": "item_with_query_dependency_item_get", "operationId": "item_with_query_dependency_item_get",
"parameters": [ "parameters": [
{ {
"description": "This is a name_required_with_default field.", "name": "name_required_with_default",
"in": "query",
"required": False, "required": False,
"schema": { "schema": {
"title": "Name Required With Default",
"type": "string", "type": "string",
"description": "This is a name_required_with_default field.", "description": "This is a name_required_with_default field.",
"required": False,
"default": "name default", "default": "name default",
"extra": {}, "title": "Name Required With Default",
}, },
"name": "name_required_with_default", "description": "This is a name_required_with_default field.",
"in": "query",
}, },
{ {
"description": "This is a name_required_without_default field.", "name": "name_required_without_default",
"in": "query",
"required": True, "required": True,
"schema": { "schema": {
"title": "Name Required Without Default",
"type": "string", "type": "string",
"description": "This is a name_required_without_default field.", "description": "This is a name_required_without_default field.",
"extra": {}, "required": True,
"title": "Name Required Without Default",
}, },
"name": "name_required_without_default", "description": "This is a name_required_without_default field.",
"in": "query",
}, },
{ {
"description": "This is a optional_int field", "name": "optional_int",
"in": "query",
"required": False, "required": False,
"schema": { "schema": {
"title": "Optional Int", "anyOf": [{"type": "integer"}, {"type": "null"}],
"type": "integer",
"description": "This is a optional_int field", "description": "This is a optional_int field",
"extra": {}, "required": False,
"title": "Optional Int",
}, },
"name": "optional_int", "description": "This is a optional_int field",
"in": "query",
}, },
{ {
"description": "This is a optional_str field", "name": "optional_str",
"in": "query",
"required": False, "required": False,
"schema": { "schema": {
"title": "Optional Str", "anyOf": [{"type": "string"}, {"type": "null"}],
"type": "string",
"description": "This is a optional_str field", "description": "This is a optional_str field",
"required": False,
"default": "default_exists", "default": "default_exists",
"extra": {}, "title": "Optional Str",
}, },
"name": "optional_str", "description": "This is a optional_str field",
"in": "query",
}, },
{ {
"required": True,
"schema": {"title": "Model", "type": "string", "extra": {}},
"name": "model", "name": "model",
"in": "query", "in": "query",
},
{
"required": True, "required": True,
"schema": { "schema": {
"title": "Manufacturer",
"type": "string", "type": "string",
"extra": {}, "required": True,
"title": "Model",
}, },
},
{
"name": "manufacturer", "name": "manufacturer",
"in": "query", "in": "query",
"required": True,
"schema": {
"type": "string",
"required": True,
"title": "Manufacturer",
},
}, },
{ {
"required": True,
"schema": {"title": "Price", "type": "number", "extra": {}},
"name": "price", "name": "price",
"in": "query", "in": "query",
"required": True,
"schema": {
"type": "number",
"required": True,
"title": "Price",
},
}, },
{ {
"required": True,
"schema": {"title": "Tax", "type": "number", "extra": {}},
"name": "tax", "name": "tax",
"in": "query", "in": "query",
"required": True,
"schema": {"type": "number", "required": True, "title": "Tax"},
}, },
{ {
"description": "This is a extra_optional_attributes field", "name": "extra_optional_attributes_alias",
"required": True, "in": "query",
"required": False,
"schema": { "schema": {
"title": "Extra Optional Attributes Alias",
"maxLength": 30,
"type": "string", "type": "string",
"description": "This is a extra_optional_attributes field", "description": "This is a extra_optional_attributes field",
"extra": {}, "required": False,
"metadata": [{"max_length": 30}],
"title": "Extra Optional Attributes Alias",
}, },
"name": "extra_optional_attributes_alias", "description": "This is a extra_optional_attributes field",
"in": "query",
}, },
], ],
"responses": { "responses": {
@ -160,29 +170,29 @@ openapi_schema_with_not_omitted_description = {
"components": { "components": {
"schemas": { "schemas": {
"HTTPValidationError": { "HTTPValidationError": {
"title": "HTTPValidationError",
"type": "object",
"properties": { "properties": {
"detail": { "detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": "#/components/schemas/ValidationError"}, "items": {"$ref": "#/components/schemas/ValidationError"},
"type": "array",
"title": "Detail",
} }
}, },
"type": "object",
"title": "HTTPValidationError",
}, },
"ValidationError": { "ValidationError": {
"title": "ValidationError",
"required": ["loc", "msg", "type"],
"type": "object",
"properties": { "properties": {
"loc": { "loc": {
"title": "Location",
"type": "array",
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
"type": "array",
"title": "Location",
}, },
"msg": {"title": "Message", "type": "string"}, "msg": {"type": "string", "title": "Message"},
"type": {"title": "Error Type", "type": "string"}, "type": {"type": "string", "title": "Error Type"},
}, },
"type": "object",
"required": ["loc", "msg", "type"],
"title": "ValidationError",
}, },
} }
}, },
@ -192,6 +202,7 @@ openapi_schema_with_not_omitted_description = {
def test_openapi_schema_with_query_dependency(): def test_openapi_schema_with_query_dependency():
response = client.get("/openapi.json") response = client.get("/openapi.json")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
print(response.json())
assert response.json() == openapi_schema_with_not_omitted_description assert response.json() == openapi_schema_with_not_omitted_description
@ -205,10 +216,73 @@ def test_response():
"manufacturer": "manufacturer", "manufacturer": "manufacturer",
"price": 100.0, "price": 100.0,
"tax": 9.0, "tax": 9.0,
"extra_optional_attributes_alias": "alias_query", "extra_optional_attributes": "alias_query",
} }
response = client.get( response = client.get(
"/item?name_required_with_default=name%20default&name_required_without_default=default&optional_str=default_exists&model=model&manufacturer=manufacturer&price=100&tax=9&extra_optional_attributes_alias=alias_query" "/item",
params={
"name_required_with_default": "name default",
"name_required_without_default": "default",
"optional_str": "default_exists",
"model": "model",
"manufacturer": "manufacturer",
"price": 100,
"tax": 9,
"extra_optional_attributes_alias": "alias_query",
},
) )
assert response.status_code == 200, response.text assert response.status_code == status.HTTP_200_OK, response.text
assert response.json() == expected_response
expected_response = {
"name_required_with_default": "name default",
"name_required_without_default": "default",
"optional_int": None,
"optional_str": "",
"model": "model",
"manufacturer": "manufacturer",
"price": 100.0,
"tax": 9.0,
"extra_optional_attributes": "alias_query",
}
response = client.get(
"/item",
params={
"name_required_with_default": "name default",
"name_required_without_default": "default",
"optional_str": None,
"model": "model",
"manufacturer": "manufacturer",
"price": 100,
"tax": 9,
"extra_optional_attributes_alias": "alias_query",
},
)
assert response.status_code == status.HTTP_200_OK, response.text
assert response.json() == expected_response
expected_response = {
"name_required_with_default": "name default",
"name_required_without_default": "default",
"optional_int": None,
"optional_str": "default_exists",
"model": "model",
"manufacturer": "manufacturer",
"price": 100.0,
"tax": 9.0,
"extra_optional_attributes": "alias_query",
}
response = client.get(
"/item",
params={
"name_required_with_default": "name default",
"name_required_without_default": "default",
"model": "model",
"manufacturer": "manufacturer",
"price": 100,
"tax": 9,
"extra_optional_attributes_alias": "alias_query",
},
)
assert response.status_code == status.HTTP_200_OK, response.text
assert response.json() == expected_response assert response.json() == expected_response

Loading…
Cancel
Save