From 9a23243cc64d09dc750bc2fe341d0dab5e6da5a5 Mon Sep 17 00:00:00 2001 From: gyudoza Date: Mon, 14 Feb 2022 23:46:01 +0900 Subject: [PATCH 01/17] =?UTF-8?q?=F0=9F=94=A7=20Keep=20description=20when?= =?UTF-8?q?=20endpoint=20depends=20schema?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/dependencies/utils.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index d4028d067..d3fb6aea7 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -245,15 +245,31 @@ def is_scalar_sequence_field(field: ModelField) -> bool: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) - typed_params = [ - inspect.Parameter( - name=param.name, - kind=param.kind, - default=param.default, - annotation=get_typed_annotation(param, globalns), - ) - for param in signature.parameters.values() - ] + if inspect.isclass(call): + from fastapi import Query + parameters = {} + fields = getattr(call, '__fields__', {}) + for param in fields: + parameters[param] = dict((fields[param].field_info.__repr_args__())) + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=Query(parameters[param.name].get("default"), description=parameters[param.name].get("description")), + annotation=get_typed_annotation(param, globalns), + ) + for param in signature.parameters.values() + ] + else: + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=get_typed_annotation(param, globalns), + ) + for param in signature.parameters.values() + ] typed_signature = inspect.Signature(typed_params) return typed_signature From 318834ea0da7f1f21be6d485741d86879b7ab4a6 Mon Sep 17 00:00:00 2001 From: jujumilk3 Date: Sun, 4 Sep 2022 22:48:29 +0900 Subject: [PATCH 02/17] add test, arrange codes, fix bug --- fastapi/dependencies/utils.py | 16 ++- tests/test_dependency_schema_query.py | 178 ++++++++++++++++++++++++++ 2 files changed, 188 insertions(+), 6 deletions(-) create mode 100644 tests/test_dependency_schema_query.py diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index e47c15758..9a5b9e1aa 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -247,17 +247,21 @@ def is_scalar_sequence_field(field: ModelField) -> bool: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) - if inspect.isclass(call): - from fastapi import Query - parameters = {} - fields = getattr(call, '__fields__', {}) + fields = getattr(call, "__fields__", {}) + if len(fields): + query_extra_info = dict() for param in fields: - parameters[param] = dict((fields[param].field_info.__repr_args__())) + query_extra_info[param] = dict((fields[param].field_info.__repr_args__())) + query_extra_info[param]["default"] = ( + Required + if getattr(fields[param], "required", False) + else fields[param].default + ) typed_params = [ inspect.Parameter( name=param.name, kind=param.kind, - default=Query(parameters[param.name].get("default"), description=parameters[param.name].get("description")), + default=params.Param(**query_extra_info[param.name]), annotation=get_typed_annotation(param, globalns), ) for param in signature.parameters.values() diff --git a/tests/test_dependency_schema_query.py b/tests/test_dependency_schema_query.py new file mode 100644 index 000000000..72535741f --- /dev/null +++ b/tests/test_dependency_schema_query.py @@ -0,0 +1,178 @@ +from typing import Optional + +from fastapi import FastAPI, Depends, Query +from fastapi.testclient import TestClient +from pydantic import BaseModel + +app = FastAPI() + + +class Item(BaseModel): + name_required_with_default: str = Query( + "name default", description="This is a name_required_with_default field." + ) + name_required_without_default: str = Query( + None, description="This is a name_required_without_default field." + ) + optional_int: Optional[int] = Query( + None, description="This is a optional_int field" + ) + optional_str: Optional[str] = Query( + "default_exists", description="This is a optional_str field" + ) + model: str + manufacturer: str + price: float + tax: float + + +@app.get("/item") +async def item_with_query_dependency(item: Item = Depends()): + return item + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/item": { + "get": { + "summary": "Item With Query Dependency", + "operationId": "item_with_query_dependency_item_get", + "parameters": [ + { + "description": "This is a name_required_with_default field.", + "required": False, + "schema": { + "title": "Name Required With Default", + "type": "string", + "description": "This is a name_required_with_default field.", + "default": "name default", + "extra": {}, + }, + "name": "name_required_with_default", + "in": "query", + }, + { + "description": "This is a name_required_without_default field.", + "required": False, + "schema": { + "title": "Name Required Without Default", + "type": "string", + "description": "This is a name_required_without_default field.", + "extra": {}, + }, + "name": "name_required_without_default", + "in": "query", + }, + { + "description": "This is a optional_int field", + "required": False, + "schema": { + "title": "Optional Int", + "type": "integer", + "description": "This is a optional_int field", + "extra": {}, + }, + "name": "optional_int", + "in": "query", + }, + { + "description": "This is a optional_str field", + "required": False, + "schema": { + "title": "Optional Str", + "type": "string", + "description": "This is a optional_str field", + "default": "default_exists", + "extra": {}, + }, + "name": "optional_str", + "in": "query", + }, + { + "required": True, + "schema": {"title": "Model", "type": "string", "extra": {}}, + "name": "model", + "in": "query", + }, + { + "required": True, + "schema": { + "title": "Manufacturer", + "type": "string", + "extra": {}, + }, + "name": "manufacturer", + "in": "query", + }, + { + "required": True, + "schema": {"title": "Price", "type": "number", "extra": {}}, + "name": "price", + "in": "query", + }, + { + "required": True, + "schema": {"title": "Tax", "type": "number", "extra": {}}, + "name": "tax", + "in": "query", + }, + ], + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + } + }, + "components": { + "schemas": { + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + } + }, +} + + +def test_openapi_schema_with_query_dependency(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == openapi_schema From 76c5c9b51d046a34cf9fd21370e5a8bfde376ab0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 4 Sep 2022 13:49:27 +0000 Subject: [PATCH 03/17] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/dependencies/utils.py | 2 +- tests/test_dependency_schema_query.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 9a5b9e1aa..c80737fbc 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -251,7 +251,7 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: if len(fields): query_extra_info = dict() for param in fields: - query_extra_info[param] = dict((fields[param].field_info.__repr_args__())) + query_extra_info[param] = dict(fields[param].field_info.__repr_args__()) query_extra_info[param]["default"] = ( Required if getattr(fields[param], "required", False) diff --git a/tests/test_dependency_schema_query.py b/tests/test_dependency_schema_query.py index 72535741f..60d93c884 100644 --- a/tests/test_dependency_schema_query.py +++ b/tests/test_dependency_schema_query.py @@ -1,6 +1,6 @@ from typing import Optional -from fastapi import FastAPI, Depends, Query +from fastapi import Depends, FastAPI, Query from fastapi.testclient import TestClient from pydantic import BaseModel From 28afb1de0c27c652f74f76e71933a8d92765d59c Mon Sep 17 00:00:00 2001 From: jujumilk3 Date: Sun, 4 Sep 2022 23:07:12 +0900 Subject: [PATCH 04/17] fill test coverage of this PR --- tests/test_dependency_schema_query.py | 30 ++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/tests/test_dependency_schema_query.py b/tests/test_dependency_schema_query.py index 72535741f..59bf24a15 100644 --- a/tests/test_dependency_schema_query.py +++ b/tests/test_dependency_schema_query.py @@ -12,11 +12,9 @@ class Item(BaseModel): "name default", description="This is a name_required_with_default field." ) name_required_without_default: str = Query( - None, description="This is a name_required_without_default field." - ) - optional_int: Optional[int] = Query( - None, description="This is a optional_int field" + description="This is a name_required_without_default field." ) + optional_int: Optional[int] = Query(description="This is a optional_int field") optional_str: Optional[str] = Query( "default_exists", description="This is a optional_str field" ) @@ -33,7 +31,7 @@ async def item_with_query_dependency(item: Item = Depends()): client = TestClient(app) -openapi_schema = { +openapi_schema_with_not_omitted_description = { "openapi": "3.0.2", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -57,7 +55,7 @@ openapi_schema = { }, { "description": "This is a name_required_without_default field.", - "required": False, + "required": True, "schema": { "title": "Name Required Without Default", "type": "string", @@ -175,4 +173,22 @@ openapi_schema = { def test_openapi_schema_with_query_dependency(): response = client.get("/openapi.json") assert response.status_code == 200, response.text - assert response.json() == openapi_schema + assert response.json() == openapi_schema_with_not_omitted_description + + +def test_response(): + expected_response = { + "name_required_with_default": "name default", + "name_required_without_default": "default", + "optional_int": None, + "optional_str": "default_exists", + "model": "model", + "manufacturer": "manufacturer", + "price": 100.0, + "tax": 9.0, + } + response = client.get( + "/item?name_required_with_default=name%20default&name_required_without_default=default&optional_str=default_exists&model=model&manufacturer=manufacturer&price=100&tax=9" + ) + assert response.status_code == 200, response.text + assert response.json() == expected_response From 2172e11b25d28f629b278f7893b05875777b14b5 Mon Sep 17 00:00:00 2001 From: jujumilk3 Date: Sat, 29 Oct 2022 18:21:59 +0900 Subject: [PATCH 05/17] add alias case --- fastapi/dependencies/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index c80737fbc..fa7a1c87c 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -252,6 +252,10 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: query_extra_info = dict() for param in fields: query_extra_info[param] = dict(fields[param].field_info.__repr_args__()) + if "alias" in query_extra_info[param]: + query_extra_info[query_extra_info[param]["alias"]] = dict( + fields[param].field_info.__repr_args__() + ) query_extra_info[param]["default"] = ( Required if getattr(fields[param], "required", False) From 88240590c06106c0bcabd58d824290271ae38f70 Mon Sep 17 00:00:00 2001 From: jujumilk3 Date: Sat, 29 Oct 2022 18:52:23 +0900 Subject: [PATCH 06/17] fill testcode cov and add alias case --- tests/test_dependency_schema_query.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/test_dependency_schema_query.py b/tests/test_dependency_schema_query.py index 3b6cf811d..7b56956e3 100644 --- a/tests/test_dependency_schema_query.py +++ b/tests/test_dependency_schema_query.py @@ -22,6 +22,12 @@ class Item(BaseModel): manufacturer: str price: float tax: float + extra_optional_attributes: str = Query( + None, + description="This is a extra_optional_attributes field", + alias="extra_optional_attributes_alias", + max_length=30, + ) @app.get("/item") @@ -118,6 +124,19 @@ openapi_schema_with_not_omitted_description = { "name": "tax", "in": "query", }, + { + "description": "This is a extra_optional_attributes field", + "required": True, + "schema": { + "title": "Extra Optional Attributes Alias", + "maxLength": 30, + "type": "string", + "description": "This is a extra_optional_attributes field", + "extra": {}, + }, + "name": "extra_optional_attributes_alias", + "in": "query", + }, ], "responses": { "200": { @@ -186,9 +205,10 @@ def test_response(): "manufacturer": "manufacturer", "price": 100.0, "tax": 9.0, + "extra_optional_attributes_alias": "alias_query", } response = client.get( - "/item?name_required_with_default=name%20default&name_required_without_default=default&optional_str=default_exists&model=model&manufacturer=manufacturer&price=100&tax=9" + "/item?name_required_with_default=name%20default&name_required_without_default=default&optional_str=default_exists&model=model&manufacturer=manufacturer&price=100&tax=9&extra_optional_attributes_alias=alias_query" ) assert response.status_code == 200, response.text assert response.json() == expected_response From a22b2b794782274b2ef9a7a97f983a2d78cd16a6 Mon Sep 17 00:00:00 2001 From: jujumilk3 Date: Sat, 29 Oct 2022 19:06:50 +0900 Subject: [PATCH 07/17] fix test error relevant trio (#5547) --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 755723224..c76538c17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,5 +137,8 @@ filterwarnings = [ 'ignore:The loop argument is deprecated since Python 3\.8, and scheduled for removal in Python 3\.10:DeprecationWarning:asyncio', 'ignore:starlette.middleware.wsgi is deprecated and will be removed in a future release\..*:DeprecationWarning:starlette', # see https://trio.readthedocs.io/en/stable/history.html#trio-0-22-0-2022-09-28 - 'ignore::trio.TrioDeprecationWarning', + "ignore:You seem to already have a custom.*:RuntimeWarning:trio", + "ignore::trio.TrioDeprecationWarning", + # TODO remove pytest-cov + 'ignore::pytest.PytestDeprecationWarning:pytest_cov', ] From e91aa530c9691cf74aab28a64fb106007a7b3dda Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Dec 2022 19:18:53 +0000 Subject: [PATCH 08/17] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/dependencies/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 115723662..2387826fd 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -250,7 +250,7 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: globalns = getattr(call, "__globals__", {}) fields = getattr(call, "__fields__", {}) if len(fields): - query_extra_info = dict() + query_extra_info = {} for param in fields: query_extra_info[param] = dict(fields[param].field_info.__repr_args__()) if "alias" in query_extra_info[param]: From a97064624397faedc1b05f2ba2f28cf9422b4eae Mon Sep 17 00:00:00 2001 From: jujumilk3 Date: Thu, 12 Jan 2023 21:59:40 +0900 Subject: [PATCH 09/17] fix signature error --- fastapi/dependencies/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 1d4d8d23d..16947ceca 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -267,7 +267,7 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: name=param.name, kind=param.kind, default=params.Param(**query_extra_info[param.name]), - annotation=get_typed_annotation(param, globalns), + annotation=get_typed_annotation(param.annotation, globalns), ) for param in signature.parameters.values() ] @@ -277,7 +277,7 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: name=param.name, kind=param.kind, default=param.default, - annotation=get_typed_annotation(param, globalns), + annotation=get_typed_annotation(param.annotation, globalns), ) for param in signature.parameters.values() ] From 28d91bb3cd3dd4431fb4526887ea3f630db0780a Mon Sep 17 00:00:00 2001 From: jujumilk3 Date: Sat, 4 Nov 2023 02:55:57 +0900 Subject: [PATCH 10/17] fix test and apply pydantic2.x case --- fastapi/dependencies/utils.py | 187 +++++++++----------------- tests/test_dependency_schema_query.py | 186 +++++++++++++++++-------- 2 files changed, 191 insertions(+), 182 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index dc9db91a4..95fd5b033 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -119,9 +119,7 @@ def get_param_sub_dependant( def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: - assert callable( - depends.dependency - ), "A parameter-less dependency must have a callable dependency" + assert callable(depends.dependency), "A parameter-less dependency must have a callable dependency" return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) @@ -142,9 +140,7 @@ def get_sub_dependant( use_scopes: List[str] = [] if isinstance(dependency, (OAuth2, OpenIdConnect)): use_scopes = security_scopes - security_requirement = SecurityRequirement( - security_scheme=dependency, scopes=use_scopes - ) + security_requirement = SecurityRequirement(security_scheme=dependency, scopes=use_scopes) sub_dependant = get_dependant( path=path, call=dependency, @@ -183,9 +179,7 @@ def get_flat_dependant( for sub_dependant in dependant.dependencies: if skip_repeats and sub_dependant.cache_key in visited: continue - flat_sub = get_flat_dependant( - sub_dependant, skip_repeats=skip_repeats, visited=visited - ) + flat_sub = get_flat_dependant(sub_dependant, skip_repeats=skip_repeats, visited=visited) flat_dependant.path_params.extend(flat_sub.path_params) flat_dependant.query_params.extend(flat_sub.query_params) flat_dependant.header_params.extend(flat_sub.header_params) @@ -208,29 +202,38 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) - fields = getattr(call, "__fields__", {}) + fields = getattr(call, "model_fields", {}) if len(fields): + alias_dict = {} query_extra_info = {} for param in fields: - query_extra_info[param] = dict(fields[param].field_info.__repr_args__()) + query_extra_info[param] = dict(fields[param].__repr_args__()) if "alias" in query_extra_info[param]: - query_extra_info[query_extra_info[param]["alias"]] = dict( - fields[param].field_info.__repr_args__() - ) + query_extra_info[query_extra_info[param]["alias"]] = dict(fields[param].__repr_args__()) + alias_dict[query_extra_info[param]["alias"]] = param query_extra_info[param]["default"] = ( - Required - if getattr(fields[param], "required", False) - else fields[param].default - ) - typed_params = [ - inspect.Parameter( - name=param.name, - kind=param.kind, - default=params.Param(**query_extra_info[param.name]), - annotation=get_typed_annotation(param.annotation, globalns), + Required if getattr(fields[param], "required", False) else fields[param].default ) - for param in signature.parameters.values() - ] + typed_params = [] + + for param in signature.parameters.values(): + if param.name in alias_dict: + original_param_name = alias_dict[param.name] + created_param = inspect.Parameter( + name=original_param_name, + kind=param.kind, + default=params.Param(**query_extra_info[original_param_name]), + annotation=get_typed_annotation(param.annotation, globalns), + ) + else: + created_param = inspect.Parameter( + name=param.name, + kind=param.kind, + default=params.Param(**query_extra_info[param.name]), + annotation=get_typed_annotation(param.annotation, globalns), + ) + + typed_params.append(created_param) else: typed_params = [ inspect.Parameter( @@ -303,9 +306,7 @@ def get_dependant( type_annotation=type_annotation, dependant=dependant, ): - assert ( - param_field is None - ), f"Cannot specify multiple FastAPI annotations for {param_name!r}" + assert param_field is None, f"Cannot specify multiple FastAPI annotations for {param_name!r}" continue assert param_field is not None if is_body_param(param_field=param_field, is_path_param=is_path_param): @@ -315,9 +316,7 @@ def get_dependant( return dependant -def add_non_field_param_to_dependency( - *, param_name: str, type_annotation: Any, dependant: Dependant -) -> Optional[bool]: +def add_non_field_param_to_dependency(*, param_name: str, type_annotation: Any, dependant: Dependant) -> Optional[bool]: if lenient_issubclass(type_annotation, Request): dependant.request_param_name = param_name return True @@ -349,26 +348,17 @@ def analyze_param( field_info = None depends = None type_annotation: Any = Any - if ( - annotation is not inspect.Signature.empty - and get_origin(annotation) is Annotated - ): + if annotation is not inspect.Signature.empty and get_origin(annotation) is Annotated: annotated_args = get_args(annotation) type_annotation = annotated_args[0] - fastapi_annotations = [ - arg - for arg in annotated_args[1:] - if isinstance(arg, (FieldInfo, params.Depends)) - ] + fastapi_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, (FieldInfo, params.Depends))] assert ( len(fastapi_annotations) <= 1 ), f"Cannot specify multiple `Annotated` FastAPI arguments for {param_name!r}" fastapi_annotation = next(iter(fastapi_annotations), None) if isinstance(fastapi_annotation, FieldInfo): # Copy `field_info` because we mutate `field_info.default` below. - field_info = copy_field_info( - field_info=fastapi_annotation, annotation=annotation - ) + field_info = copy_field_info(field_info=fastapi_annotation, annotation=annotation) assert field_info.default is Undefined or field_info.default is Required, ( f"`{field_info.__class__.__name__}` default value cannot be set in" f" `Annotated` for {param_name!r}. Set the default value with `=` instead." @@ -385,8 +375,7 @@ def analyze_param( if isinstance(value, params.Depends): assert depends is None, ( - "Cannot specify `Depends` in `Annotated` and default value" - f" together for {param_name!r}" + "Cannot specify `Depends` in `Annotated` and default value" f" together for {param_name!r}" ) assert field_info is None, ( "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a" @@ -395,8 +384,7 @@ def analyze_param( depends = value elif isinstance(value, FieldInfo): assert field_info is None, ( - "Cannot specify FastAPI annotations in `Annotated` and default value" - f" together for {param_name!r}" + "Cannot specify FastAPI annotations in `Annotated` and default value" f" together for {param_name!r}" ) field_info = value if PYDANTIC_V2: @@ -417,9 +405,7 @@ def analyze_param( ), ): assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}" - assert ( - field_info is None - ), f"Cannot specify FastAPI annotation for type {type_annotation!r}" + assert field_info is None, f"Cannot specify FastAPI annotation for type {type_annotation!r}" elif field_info is None and depends is None: default_value = value if value is not inspect.Signature.empty else Required if is_path_param: @@ -427,9 +413,9 @@ def analyze_param( # parameter might sometimes be a path parameter and sometimes not. See # `tests/test_infer_param_optionality.py` for an example. field_info = params.Path(annotation=type_annotation) - elif is_uploadfile_or_nonable_uploadfile_annotation( + elif is_uploadfile_or_nonable_uploadfile_annotation(type_annotation) or is_uploadfile_sequence_annotation( type_annotation - ) or is_uploadfile_sequence_annotation(type_annotation): + ): field_info = params.File(annotation=type_annotation, default=default_value) elif not field_annotation_is_scalar(annotation=type_annotation): field_info = params.Body(annotation=type_annotation, default=default_value) @@ -440,13 +426,9 @@ def analyze_param( if field_info is not None: if is_path_param: assert isinstance(field_info, params.Path), ( - f"Cannot use `{field_info.__class__.__name__}` for path param" - f" {param_name!r}" + f"Cannot use `{field_info.__class__.__name__}` for path param" f" {param_name!r}" ) - elif ( - isinstance(field_info, params.Param) - and getattr(field_info, "in_", None) is None - ): + elif isinstance(field_info, params.Param) and getattr(field_info, "in_", None) is None: field_info.in_ = params.ParamTypes.query use_annotation = get_annotation_from_field_info( type_annotation, @@ -472,15 +454,11 @@ def analyze_param( def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: if is_path_param: - assert is_scalar_field( - field=param_field - ), "Path params must be of one of the supported types" + assert is_scalar_field(field=param_field), "Path params must be of one of the supported types" return False elif is_scalar_field(field=param_field): return False - elif isinstance( - param_field.field_info, (params.Query, params.Header) - ) and is_scalar_sequence_field(param_field): + elif isinstance(param_field.field_info, (params.Query, params.Header)) and is_scalar_sequence_field(param_field): return False else: assert isinstance( @@ -527,9 +505,7 @@ def is_gen_callable(call: Callable[..., Any]) -> bool: return inspect.isgeneratorfunction(dunder_call) -async def solve_generator( - *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] -) -> Any: +async def solve_generator(*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]) -> Any: if is_gen_callable(call): cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) elif is_async_gen_callable(call): @@ -563,19 +539,12 @@ async def solve_dependencies( sub_dependant: Dependant for sub_dependant in dependant.dependencies: sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) - sub_dependant.cache_key = cast( - Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key - ) + sub_dependant.cache_key = cast(Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key) call = sub_dependant.call use_sub_dependant = sub_dependant - if ( - dependency_overrides_provider - and dependency_overrides_provider.dependency_overrides - ): + if dependency_overrides_provider and dependency_overrides_provider.dependency_overrides: original_call = sub_dependant.call - call = getattr( - dependency_overrides_provider, "dependency_overrides", {} - ).get(original_call, original_call) + call = getattr(dependency_overrides_provider, "dependency_overrides", {}).get(original_call, original_call) use_path: str = sub_dependant.path # type: ignore use_sub_dependant = get_dependant( path=use_path, @@ -609,9 +578,7 @@ async def solve_dependencies( elif is_gen_callable(call) or is_async_gen_callable(call): stack = request.scope.get("fastapi_astack") assert isinstance(stack, AsyncExitStack) - solved = await solve_generator( - call=call, stack=stack, sub_values=sub_values - ) + solved = await solve_generator(call=call, stack=stack, sub_values=sub_values) elif is_coroutine_callable(call): solved = await call(**sub_values) else: @@ -620,18 +587,10 @@ async def solve_dependencies( values[sub_dependant.name] = solved if sub_dependant.cache_key not in dependency_cache: dependency_cache[sub_dependant.cache_key] = solved - path_values, path_errors = request_params_to_args( - dependant.path_params, request.path_params - ) - query_values, query_errors = request_params_to_args( - dependant.query_params, request.query_params - ) - header_values, header_errors = request_params_to_args( - dependant.header_params, request.headers - ) - cookie_values, cookie_errors = request_params_to_args( - dependant.cookie_params, request.cookies - ) + path_values, path_errors = request_params_to_args(dependant.path_params, request.path_params) + query_values, query_errors = request_params_to_args(dependant.query_params, request.query_params) + header_values, header_errors = request_params_to_args(dependant.header_params, request.headers) + cookie_values, cookie_errors = request_params_to_args(dependant.cookie_params, request.cookies) values.update(path_values) values.update(query_values) values.update(header_values) @@ -659,9 +618,7 @@ async def solve_dependencies( if dependant.response_param_name: values[dependant.response_param_name] = response if dependant.security_scopes_param_name: - values[dependant.security_scopes_param_name] = SecurityScopes( - scopes=dependant.security_scopes - ) + values[dependant.security_scopes_param_name] = SecurityScopes(scopes=dependant.security_scopes) return values, errors, background_tasks, response, dependency_cache @@ -672,16 +629,12 @@ def request_params_to_args( values = {} errors = [] for field in required_params: - if is_scalar_sequence_field(field) and isinstance( - received_params, (QueryParams, Headers) - ): + if is_scalar_sequence_field(field) and isinstance(received_params, (QueryParams, Headers)): value = received_params.getlist(field.alias) or field.default else: value = received_params.get(field.alias) field_info = field.field_info - assert isinstance( - field_info, params.Param - ), "Params must be subclasses of Param" + assert isinstance(field_info, params.Param), "Params must be subclasses of Param" loc = (field_info.in_.value, field.alias) if value is None: if field.required: @@ -734,35 +687,21 @@ async def request_body_to_args( if ( value is None or (isinstance(field_info, params.Form) and value == "") - or ( - isinstance(field_info, params.Form) - and is_sequence_field(field) - and len(value) == 0 - ) + or (isinstance(field_info, params.Form) and is_sequence_field(field) and len(value) == 0) ): if field.required: errors.append(get_missing_field_error(loc)) else: values[field.name] = deepcopy(field.default) continue - if ( - isinstance(field_info, params.File) - and is_bytes_field(field) - and isinstance(value, UploadFile) - ): + if isinstance(field_info, params.File) and is_bytes_field(field) and isinstance(value, UploadFile): value = await value.read() - elif ( - is_bytes_sequence_field(field) - and isinstance(field_info, params.File) - and value_is_sequence(value) - ): + elif is_bytes_sequence_field(field) and isinstance(field_info, params.File) and value_is_sequence(value): # For types assert isinstance(value, sequence_types) # type: ignore[arg-type] results: List[Union[bytes, str]] = [] - async def process_fn( - fn: Callable[[], Coroutine[Any, Any, Any]] - ) -> None: + async def process_fn(fn: Callable[[], Coroutine[Any, Any, Any]]) -> None: result = await fn() results.append(result) # noqa: B023 @@ -799,9 +738,7 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: for param in flat_dependant.body_params: setattr(param.field_info, "embed", True) # noqa: B010 model_name = "Body_" + name - BodyModel = create_body_model( - fields=flat_dependant.body_params, model_name=model_name - ) + BodyModel = create_body_model(fields=flat_dependant.body_params, model_name=model_name) required = any(True for f in flat_dependant.body_params if f.required) BodyFieldInfo_kwargs: Dict[str, Any] = { "annotation": BodyModel, @@ -817,9 +754,7 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: BodyFieldInfo = params.Body body_param_media_types = [ - f.field_info.media_type - for f in flat_dependant.body_params - if isinstance(f.field_info, params.Body) + f.field_info.media_type for f in flat_dependant.body_params if isinstance(f.field_info, params.Body) ] if len(set(body_param_media_types)) == 1: BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] diff --git a/tests/test_dependency_schema_query.py b/tests/test_dependency_schema_query.py index 7b56956e3..950ff6acd 100644 --- a/tests/test_dependency_schema_query.py +++ b/tests/test_dependency_schema_query.py @@ -1,6 +1,6 @@ from typing import Optional -from fastapi import Depends, FastAPI, Query +from fastapi import Depends, FastAPI, Query, status from fastapi.testclient import TestClient from pydantic import BaseModel @@ -14,7 +14,9 @@ class Item(BaseModel): name_required_without_default: str = Query( description="This is a name_required_without_default field." ) - optional_int: Optional[int] = Query(description="This is a optional_int field") + optional_int: Optional[int] = Query( + default=None, description="This is a optional_int field" + ) optional_str: Optional[str] = Query( "default_exists", description="This is a optional_str field" ) @@ -38,7 +40,7 @@ async def item_with_query_dependency(item: Item = Depends()): client = TestClient(app) openapi_schema_with_not_omitted_description = { - "openapi": "3.0.2", + "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { "/item": { @@ -47,95 +49,103 @@ openapi_schema_with_not_omitted_description = { "operationId": "item_with_query_dependency_item_get", "parameters": [ { - "description": "This is a name_required_with_default field.", + "name": "name_required_with_default", + "in": "query", "required": False, "schema": { - "title": "Name Required With Default", "type": "string", "description": "This is a name_required_with_default field.", + "required": False, "default": "name default", - "extra": {}, + "title": "Name Required With Default", }, - "name": "name_required_with_default", - "in": "query", + "description": "This is a name_required_with_default field.", }, { - "description": "This is a name_required_without_default field.", + "name": "name_required_without_default", + "in": "query", "required": True, "schema": { - "title": "Name Required Without Default", "type": "string", "description": "This is a name_required_without_default field.", - "extra": {}, + "required": True, + "title": "Name Required Without Default", }, - "name": "name_required_without_default", - "in": "query", + "description": "This is a name_required_without_default field.", }, { - "description": "This is a optional_int field", + "name": "optional_int", + "in": "query", "required": False, "schema": { - "title": "Optional Int", - "type": "integer", + "anyOf": [{"type": "integer"}, {"type": "null"}], "description": "This is a optional_int field", - "extra": {}, + "required": False, + "title": "Optional Int", }, - "name": "optional_int", - "in": "query", + "description": "This is a optional_int field", }, { - "description": "This is a optional_str field", + "name": "optional_str", + "in": "query", "required": False, "schema": { - "title": "Optional Str", - "type": "string", + "anyOf": [{"type": "string"}, {"type": "null"}], "description": "This is a optional_str field", + "required": False, "default": "default_exists", - "extra": {}, + "title": "Optional Str", }, - "name": "optional_str", - "in": "query", + "description": "This is a optional_str field", }, { - "required": True, - "schema": {"title": "Model", "type": "string", "extra": {}}, "name": "model", "in": "query", - }, - { "required": True, "schema": { - "title": "Manufacturer", "type": "string", - "extra": {}, + "required": True, + "title": "Model", }, + }, + { "name": "manufacturer", "in": "query", + "required": True, + "schema": { + "type": "string", + "required": True, + "title": "Manufacturer", + }, }, { - "required": True, - "schema": {"title": "Price", "type": "number", "extra": {}}, "name": "price", "in": "query", + "required": True, + "schema": { + "type": "number", + "required": True, + "title": "Price", + }, }, { - "required": True, - "schema": {"title": "Tax", "type": "number", "extra": {}}, "name": "tax", "in": "query", + "required": True, + "schema": {"type": "number", "required": True, "title": "Tax"}, }, { - "description": "This is a extra_optional_attributes field", - "required": True, + "name": "extra_optional_attributes_alias", + "in": "query", + "required": False, "schema": { - "title": "Extra Optional Attributes Alias", - "maxLength": 30, "type": "string", "description": "This is a extra_optional_attributes field", - "extra": {}, + "required": False, + "metadata": [{"max_length": 30}], + "title": "Extra Optional Attributes Alias", }, - "name": "extra_optional_attributes_alias", - "in": "query", + "description": "This is a extra_optional_attributes field", }, ], "responses": { @@ -160,29 +170,29 @@ openapi_schema_with_not_omitted_description = { "components": { "schemas": { "HTTPValidationError": { - "title": "HTTPValidationError", - "type": "object", "properties": { "detail": { - "title": "Detail", - "type": "array", "items": {"$ref": "#/components/schemas/ValidationError"}, + "type": "array", + "title": "Detail", } }, + "type": "object", + "title": "HTTPValidationError", }, "ValidationError": { - "title": "ValidationError", - "required": ["loc", "msg", "type"], - "type": "object", "properties": { "loc": { - "title": "Location", - "type": "array", "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "type": "array", + "title": "Location", }, - "msg": {"title": "Message", "type": "string"}, - "type": {"title": "Error Type", "type": "string"}, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, }, + "type": "object", + "required": ["loc", "msg", "type"], + "title": "ValidationError", }, } }, @@ -192,6 +202,7 @@ openapi_schema_with_not_omitted_description = { def test_openapi_schema_with_query_dependency(): response = client.get("/openapi.json") assert response.status_code == 200, response.text + print(response.json()) assert response.json() == openapi_schema_with_not_omitted_description @@ -205,10 +216,73 @@ def test_response(): "manufacturer": "manufacturer", "price": 100.0, "tax": 9.0, - "extra_optional_attributes_alias": "alias_query", + "extra_optional_attributes": "alias_query", } response = client.get( - "/item?name_required_with_default=name%20default&name_required_without_default=default&optional_str=default_exists&model=model&manufacturer=manufacturer&price=100&tax=9&extra_optional_attributes_alias=alias_query" + "/item", + params={ + "name_required_with_default": "name default", + "name_required_without_default": "default", + "optional_str": "default_exists", + "model": "model", + "manufacturer": "manufacturer", + "price": 100, + "tax": 9, + "extra_optional_attributes_alias": "alias_query", + }, ) - assert response.status_code == 200, response.text + assert response.status_code == status.HTTP_200_OK, response.text + assert response.json() == expected_response + + expected_response = { + "name_required_with_default": "name default", + "name_required_without_default": "default", + "optional_int": None, + "optional_str": "", + "model": "model", + "manufacturer": "manufacturer", + "price": 100.0, + "tax": 9.0, + "extra_optional_attributes": "alias_query", + } + response = client.get( + "/item", + params={ + "name_required_with_default": "name default", + "name_required_without_default": "default", + "optional_str": None, + "model": "model", + "manufacturer": "manufacturer", + "price": 100, + "tax": 9, + "extra_optional_attributes_alias": "alias_query", + }, + ) + assert response.status_code == status.HTTP_200_OK, response.text + assert response.json() == expected_response + + expected_response = { + "name_required_with_default": "name default", + "name_required_without_default": "default", + "optional_int": None, + "optional_str": "default_exists", + "model": "model", + "manufacturer": "manufacturer", + "price": 100.0, + "tax": 9.0, + "extra_optional_attributes": "alias_query", + } + response = client.get( + "/item", + params={ + "name_required_with_default": "name default", + "name_required_without_default": "default", + "model": "model", + "manufacturer": "manufacturer", + "price": 100, + "tax": 9, + "extra_optional_attributes_alias": "alias_query", + }, + ) + assert response.status_code == status.HTTP_200_OK, response.text assert response.json() == expected_response From 21532f792836f65f6ffe0d752e7dfc5b62aaa78d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Nov 2023 17:56:26 +0000 Subject: [PATCH 11/17] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/dependencies/utils.py | 152 +++++++++++++++++++++++++--------- 1 file changed, 115 insertions(+), 37 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 95fd5b033..caa18e118 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -119,7 +119,9 @@ def get_param_sub_dependant( def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: - assert callable(depends.dependency), "A parameter-less dependency must have a callable dependency" + assert callable( + depends.dependency + ), "A parameter-less dependency must have a callable dependency" return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) @@ -140,7 +142,9 @@ def get_sub_dependant( use_scopes: List[str] = [] if isinstance(dependency, (OAuth2, OpenIdConnect)): use_scopes = security_scopes - security_requirement = SecurityRequirement(security_scheme=dependency, scopes=use_scopes) + security_requirement = SecurityRequirement( + security_scheme=dependency, scopes=use_scopes + ) sub_dependant = get_dependant( path=path, call=dependency, @@ -179,7 +183,9 @@ def get_flat_dependant( for sub_dependant in dependant.dependencies: if skip_repeats and sub_dependant.cache_key in visited: continue - flat_sub = get_flat_dependant(sub_dependant, skip_repeats=skip_repeats, visited=visited) + flat_sub = get_flat_dependant( + sub_dependant, skip_repeats=skip_repeats, visited=visited + ) flat_dependant.path_params.extend(flat_sub.path_params) flat_dependant.query_params.extend(flat_sub.query_params) flat_dependant.header_params.extend(flat_sub.header_params) @@ -209,10 +215,14 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: for param in fields: query_extra_info[param] = dict(fields[param].__repr_args__()) if "alias" in query_extra_info[param]: - query_extra_info[query_extra_info[param]["alias"]] = dict(fields[param].__repr_args__()) + query_extra_info[query_extra_info[param]["alias"]] = dict( + fields[param].__repr_args__() + ) alias_dict[query_extra_info[param]["alias"]] = param query_extra_info[param]["default"] = ( - Required if getattr(fields[param], "required", False) else fields[param].default + Required + if getattr(fields[param], "required", False) + else fields[param].default ) typed_params = [] @@ -306,7 +316,9 @@ def get_dependant( type_annotation=type_annotation, dependant=dependant, ): - assert param_field is None, f"Cannot specify multiple FastAPI annotations for {param_name!r}" + assert ( + param_field is None + ), f"Cannot specify multiple FastAPI annotations for {param_name!r}" continue assert param_field is not None if is_body_param(param_field=param_field, is_path_param=is_path_param): @@ -316,7 +328,9 @@ def get_dependant( return dependant -def add_non_field_param_to_dependency(*, param_name: str, type_annotation: Any, dependant: Dependant) -> Optional[bool]: +def add_non_field_param_to_dependency( + *, param_name: str, type_annotation: Any, dependant: Dependant +) -> Optional[bool]: if lenient_issubclass(type_annotation, Request): dependant.request_param_name = param_name return True @@ -348,17 +362,26 @@ def analyze_param( field_info = None depends = None type_annotation: Any = Any - if annotation is not inspect.Signature.empty and get_origin(annotation) is Annotated: + if ( + annotation is not inspect.Signature.empty + and get_origin(annotation) is Annotated + ): annotated_args = get_args(annotation) type_annotation = annotated_args[0] - fastapi_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, (FieldInfo, params.Depends))] + fastapi_annotations = [ + arg + for arg in annotated_args[1:] + if isinstance(arg, (FieldInfo, params.Depends)) + ] assert ( len(fastapi_annotations) <= 1 ), f"Cannot specify multiple `Annotated` FastAPI arguments for {param_name!r}" fastapi_annotation = next(iter(fastapi_annotations), None) if isinstance(fastapi_annotation, FieldInfo): # Copy `field_info` because we mutate `field_info.default` below. - field_info = copy_field_info(field_info=fastapi_annotation, annotation=annotation) + field_info = copy_field_info( + field_info=fastapi_annotation, annotation=annotation + ) assert field_info.default is Undefined or field_info.default is Required, ( f"`{field_info.__class__.__name__}` default value cannot be set in" f" `Annotated` for {param_name!r}. Set the default value with `=` instead." @@ -375,7 +398,8 @@ def analyze_param( if isinstance(value, params.Depends): assert depends is None, ( - "Cannot specify `Depends` in `Annotated` and default value" f" together for {param_name!r}" + "Cannot specify `Depends` in `Annotated` and default value" + f" together for {param_name!r}" ) assert field_info is None, ( "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a" @@ -384,7 +408,8 @@ def analyze_param( depends = value elif isinstance(value, FieldInfo): assert field_info is None, ( - "Cannot specify FastAPI annotations in `Annotated` and default value" f" together for {param_name!r}" + "Cannot specify FastAPI annotations in `Annotated` and default value" + f" together for {param_name!r}" ) field_info = value if PYDANTIC_V2: @@ -405,7 +430,9 @@ def analyze_param( ), ): assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}" - assert field_info is None, f"Cannot specify FastAPI annotation for type {type_annotation!r}" + assert ( + field_info is None + ), f"Cannot specify FastAPI annotation for type {type_annotation!r}" elif field_info is None and depends is None: default_value = value if value is not inspect.Signature.empty else Required if is_path_param: @@ -413,9 +440,9 @@ def analyze_param( # parameter might sometimes be a path parameter and sometimes not. See # `tests/test_infer_param_optionality.py` for an example. field_info = params.Path(annotation=type_annotation) - elif is_uploadfile_or_nonable_uploadfile_annotation(type_annotation) or is_uploadfile_sequence_annotation( + elif is_uploadfile_or_nonable_uploadfile_annotation( type_annotation - ): + ) or is_uploadfile_sequence_annotation(type_annotation): field_info = params.File(annotation=type_annotation, default=default_value) elif not field_annotation_is_scalar(annotation=type_annotation): field_info = params.Body(annotation=type_annotation, default=default_value) @@ -426,9 +453,13 @@ def analyze_param( if field_info is not None: if is_path_param: assert isinstance(field_info, params.Path), ( - f"Cannot use `{field_info.__class__.__name__}` for path param" f" {param_name!r}" + f"Cannot use `{field_info.__class__.__name__}` for path param" + f" {param_name!r}" ) - elif isinstance(field_info, params.Param) and getattr(field_info, "in_", None) is None: + elif ( + isinstance(field_info, params.Param) + and getattr(field_info, "in_", None) is None + ): field_info.in_ = params.ParamTypes.query use_annotation = get_annotation_from_field_info( type_annotation, @@ -454,11 +485,15 @@ def analyze_param( def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: if is_path_param: - assert is_scalar_field(field=param_field), "Path params must be of one of the supported types" + assert is_scalar_field( + field=param_field + ), "Path params must be of one of the supported types" return False elif is_scalar_field(field=param_field): return False - elif isinstance(param_field.field_info, (params.Query, params.Header)) and is_scalar_sequence_field(param_field): + elif isinstance( + param_field.field_info, (params.Query, params.Header) + ) and is_scalar_sequence_field(param_field): return False else: assert isinstance( @@ -505,7 +540,9 @@ def is_gen_callable(call: Callable[..., Any]) -> bool: return inspect.isgeneratorfunction(dunder_call) -async def solve_generator(*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]) -> Any: +async def solve_generator( + *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] +) -> Any: if is_gen_callable(call): cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) elif is_async_gen_callable(call): @@ -539,12 +576,19 @@ async def solve_dependencies( sub_dependant: Dependant for sub_dependant in dependant.dependencies: sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) - sub_dependant.cache_key = cast(Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key) + sub_dependant.cache_key = cast( + Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key + ) call = sub_dependant.call use_sub_dependant = sub_dependant - if dependency_overrides_provider and dependency_overrides_provider.dependency_overrides: + if ( + dependency_overrides_provider + and dependency_overrides_provider.dependency_overrides + ): original_call = sub_dependant.call - call = getattr(dependency_overrides_provider, "dependency_overrides", {}).get(original_call, original_call) + call = getattr( + dependency_overrides_provider, "dependency_overrides", {} + ).get(original_call, original_call) use_path: str = sub_dependant.path # type: ignore use_sub_dependant = get_dependant( path=use_path, @@ -578,7 +622,9 @@ async def solve_dependencies( elif is_gen_callable(call) or is_async_gen_callable(call): stack = request.scope.get("fastapi_astack") assert isinstance(stack, AsyncExitStack) - solved = await solve_generator(call=call, stack=stack, sub_values=sub_values) + solved = await solve_generator( + call=call, stack=stack, sub_values=sub_values + ) elif is_coroutine_callable(call): solved = await call(**sub_values) else: @@ -587,10 +633,18 @@ async def solve_dependencies( values[sub_dependant.name] = solved if sub_dependant.cache_key not in dependency_cache: dependency_cache[sub_dependant.cache_key] = solved - path_values, path_errors = request_params_to_args(dependant.path_params, request.path_params) - query_values, query_errors = request_params_to_args(dependant.query_params, request.query_params) - header_values, header_errors = request_params_to_args(dependant.header_params, request.headers) - cookie_values, cookie_errors = request_params_to_args(dependant.cookie_params, request.cookies) + path_values, path_errors = request_params_to_args( + dependant.path_params, request.path_params + ) + query_values, query_errors = request_params_to_args( + dependant.query_params, request.query_params + ) + header_values, header_errors = request_params_to_args( + dependant.header_params, request.headers + ) + cookie_values, cookie_errors = request_params_to_args( + dependant.cookie_params, request.cookies + ) values.update(path_values) values.update(query_values) values.update(header_values) @@ -618,7 +672,9 @@ async def solve_dependencies( if dependant.response_param_name: values[dependant.response_param_name] = response if dependant.security_scopes_param_name: - values[dependant.security_scopes_param_name] = SecurityScopes(scopes=dependant.security_scopes) + values[dependant.security_scopes_param_name] = SecurityScopes( + scopes=dependant.security_scopes + ) return values, errors, background_tasks, response, dependency_cache @@ -629,12 +685,16 @@ def request_params_to_args( values = {} errors = [] for field in required_params: - if is_scalar_sequence_field(field) and isinstance(received_params, (QueryParams, Headers)): + if is_scalar_sequence_field(field) and isinstance( + received_params, (QueryParams, Headers) + ): value = received_params.getlist(field.alias) or field.default else: value = received_params.get(field.alias) field_info = field.field_info - assert isinstance(field_info, params.Param), "Params must be subclasses of Param" + assert isinstance( + field_info, params.Param + ), "Params must be subclasses of Param" loc = (field_info.in_.value, field.alias) if value is None: if field.required: @@ -687,21 +747,35 @@ async def request_body_to_args( if ( value is None or (isinstance(field_info, params.Form) and value == "") - or (isinstance(field_info, params.Form) and is_sequence_field(field) and len(value) == 0) + or ( + isinstance(field_info, params.Form) + and is_sequence_field(field) + and len(value) == 0 + ) ): if field.required: errors.append(get_missing_field_error(loc)) else: values[field.name] = deepcopy(field.default) continue - if isinstance(field_info, params.File) and is_bytes_field(field) and isinstance(value, UploadFile): + if ( + isinstance(field_info, params.File) + and is_bytes_field(field) + and isinstance(value, UploadFile) + ): value = await value.read() - elif is_bytes_sequence_field(field) and isinstance(field_info, params.File) and value_is_sequence(value): + elif ( + is_bytes_sequence_field(field) + and isinstance(field_info, params.File) + and value_is_sequence(value) + ): # For types assert isinstance(value, sequence_types) # type: ignore[arg-type] results: List[Union[bytes, str]] = [] - async def process_fn(fn: Callable[[], Coroutine[Any, Any, Any]]) -> None: + async def process_fn( + fn: Callable[[], Coroutine[Any, Any, Any]] + ) -> None: result = await fn() results.append(result) # noqa: B023 @@ -738,7 +812,9 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: for param in flat_dependant.body_params: setattr(param.field_info, "embed", True) # noqa: B010 model_name = "Body_" + name - BodyModel = create_body_model(fields=flat_dependant.body_params, model_name=model_name) + BodyModel = create_body_model( + fields=flat_dependant.body_params, model_name=model_name + ) required = any(True for f in flat_dependant.body_params if f.required) BodyFieldInfo_kwargs: Dict[str, Any] = { "annotation": BodyModel, @@ -754,7 +830,9 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: BodyFieldInfo = params.Body body_param_media_types = [ - f.field_info.media_type for f in flat_dependant.body_params if isinstance(f.field_info, params.Body) + f.field_info.media_type + for f in flat_dependant.body_params + if isinstance(f.field_info, params.Body) ] if len(set(body_param_media_types)) == 1: BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] From ce8de15a29feac9f5154b15be5c25e2c5e1745c5 Mon Sep 17 00:00:00 2001 From: jujumilk3 Date: Sat, 4 Nov 2023 03:21:40 +0900 Subject: [PATCH 12/17] separate pydantic v1 v2 cases --- fastapi/dependencies/utils.py | 22 +++- tests/test_dependency_schema_query.py | 163 +++++++++++++++++++++++++- 2 files changed, 177 insertions(+), 8 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index caa18e118..ffa798b1c 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -208,16 +208,28 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) - fields = getattr(call, "model_fields", {}) + if PYDANTIC_V2: + fields = getattr(call, "model_fields", {}) + else: + fields = getattr(call, "__fields__", {}) if len(fields): alias_dict = {} query_extra_info = {} for param in fields: - query_extra_info[param] = dict(fields[param].__repr_args__()) + if PYDANTIC_V2: + query_extra_info[param] = dict(fields[param].__repr_args__()) + else: + query_extra_info[param] = dict(fields[param].field_info.__repr_args__()) + if "alias" in query_extra_info[param]: - query_extra_info[query_extra_info[param]["alias"]] = dict( - fields[param].__repr_args__() - ) + if PYDANTIC_V2: + query_extra_info[query_extra_info[param]["alias"]] = dict( + fields[param].__repr_args__() + ) + else: + query_extra_info[query_extra_info[param]["alias"]] = dict( + fields[param].field_info.__repr_args__() + ) alias_dict[query_extra_info[param]["alias"]] = param query_extra_info[param]["default"] = ( Required diff --git a/tests/test_dependency_schema_query.py b/tests/test_dependency_schema_query.py index 950ff6acd..d1d809c90 100644 --- a/tests/test_dependency_schema_query.py +++ b/tests/test_dependency_schema_query.py @@ -3,6 +3,10 @@ from typing import Optional from fastapi import Depends, FastAPI, Query, status from fastapi.testclient import TestClient from pydantic import BaseModel +from pydantic.version import VERSION as PYDANTIC_VERSION + + +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") app = FastAPI() @@ -39,7 +43,158 @@ async def item_with_query_dependency(item: Item = Depends()): client = TestClient(app) -openapi_schema_with_not_omitted_description = { +openapi_schema_with_not_omitted_description_pydantic_v1 = { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/item": { + "get": { + "summary": "Item With Query Dependency", + "operationId": "item_with_query_dependency_item_get", + "parameters": [ + { + "description": "This is a name_required_with_default field.", + "required": False, + "schema": { + "type": "string", + "title": "Name Required With Default", + "description": "This is a name_required_with_default field.", + "default": "name default", + "extra": {}, + }, + "name": "name_required_with_default", + "in": "query", + }, + { + "description": "This is a name_required_without_default field.", + "required": True, + "schema": { + "type": "string", + "title": "Name Required Without Default", + "description": "This is a name_required_without_default field.", + "extra": {}, + }, + "name": "name_required_without_default", + "in": "query", + }, + { + "description": "This is a optional_int field", + "required": False, + "schema": { + "type": "integer", + "title": "Optional Int", + "description": "This is a optional_int field", + "extra": {}, + }, + "name": "optional_int", + "in": "query", + }, + { + "description": "This is a optional_str field", + "required": False, + "schema": { + "type": "string", + "title": "Optional Str", + "description": "This is a optional_str field", + "default": "default_exists", + "extra": {}, + }, + "name": "optional_str", + "in": "query", + }, + { + "required": True, + "schema": {"type": "string", "title": "Model", "extra": {}}, + "name": "model", + "in": "query", + }, + { + "required": True, + "schema": { + "type": "string", + "title": "Manufacturer", + "extra": {}, + }, + "name": "manufacturer", + "in": "query", + }, + { + "required": True, + "schema": {"type": "number", "title": "Price", "extra": {}}, + "name": "price", + "in": "query", + }, + { + "required": True, + "schema": {"type": "number", "title": "Tax", "extra": {}}, + "name": "tax", + "in": "query", + }, + { + "description": "This is a extra_optional_attributes field", + "required": False, + "schema": { + "type": "string", + "maxLength": 30, + "title": "Extra Optional Attributes Alias", + "description": "This is a extra_optional_attributes field", + "extra": {}, + }, + "name": "extra_optional_attributes_alias", + "in": "query", + }, + ], + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + } + }, + "components": { + "schemas": { + "HTTPValidationError": { + "properties": { + "detail": { + "items": {"$ref": "#/components/schemas/ValidationError"}, + "type": "array", + "title": "Detail", + } + }, + "type": "object", + "title": "HTTPValidationError", + }, + "ValidationError": { + "properties": { + "loc": { + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "type": "array", + "title": "Location", + }, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, + }, + "type": "object", + "required": ["loc", "msg", "type"], + "title": "ValidationError", + }, + } + }, +} + +openapi_schema_with_not_omitted_description_pydantic_v2 = { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -202,8 +357,10 @@ openapi_schema_with_not_omitted_description = { def test_openapi_schema_with_query_dependency(): response = client.get("/openapi.json") assert response.status_code == 200, response.text - print(response.json()) - assert response.json() == openapi_schema_with_not_omitted_description + if PYDANTIC_V2: + assert response.json() == openapi_schema_with_not_omitted_description_pydantic_v2 + else: + assert response.json() == openapi_schema_with_not_omitted_description_pydantic_v1 def test_response(): From 279e17689bf4754674bc429b7995471615f9dcd5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Nov 2023 18:21:59 +0000 Subject: [PATCH 13/17] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/dependencies/utils.py | 2 +- tests/test_dependency_schema_query.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index ffa798b1c..f5bc230ce 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -220,7 +220,7 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: query_extra_info[param] = dict(fields[param].__repr_args__()) else: query_extra_info[param] = dict(fields[param].field_info.__repr_args__()) - + if "alias" in query_extra_info[param]: if PYDANTIC_V2: query_extra_info[query_extra_info[param]["alias"]] = dict( diff --git a/tests/test_dependency_schema_query.py b/tests/test_dependency_schema_query.py index d1d809c90..79f172eac 100644 --- a/tests/test_dependency_schema_query.py +++ b/tests/test_dependency_schema_query.py @@ -5,7 +5,6 @@ from fastapi.testclient import TestClient from pydantic import BaseModel from pydantic.version import VERSION as PYDANTIC_VERSION - PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") app = FastAPI() @@ -358,9 +357,13 @@ def test_openapi_schema_with_query_dependency(): response = client.get("/openapi.json") assert response.status_code == 200, response.text if PYDANTIC_V2: - assert response.json() == openapi_schema_with_not_omitted_description_pydantic_v2 + assert ( + response.json() == openapi_schema_with_not_omitted_description_pydantic_v2 + ) else: - assert response.json() == openapi_schema_with_not_omitted_description_pydantic_v1 + assert ( + response.json() == openapi_schema_with_not_omitted_description_pydantic_v1 + ) def test_response(): From 5bf577e563df35129aba5a5e9cec289a11fba669 Mon Sep 17 00:00:00 2001 From: jujumilk3 Date: Sat, 4 Nov 2023 12:38:02 +0900 Subject: [PATCH 14/17] include an inherent bug in the test --- tests/test_dependency_schema_query.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_dependency_schema_query.py b/tests/test_dependency_schema_query.py index 79f172eac..b7307519a 100644 --- a/tests/test_dependency_schema_query.py +++ b/tests/test_dependency_schema_query.py @@ -378,6 +378,9 @@ def test_response(): "tax": 9.0, "extra_optional_attributes": "alias_query", } + if not PYDANTIC_V2: + expected_response.pop("extra_optional_attributes") + expected_response["extra_optional_attributes_alias"] = None response = client.get( "/item", params={ @@ -405,6 +408,9 @@ def test_response(): "tax": 9.0, "extra_optional_attributes": "alias_query", } + if not PYDANTIC_V2: + expected_response.pop("extra_optional_attributes") + expected_response["extra_optional_attributes_alias"] = None response = client.get( "/item", params={ @@ -432,6 +438,9 @@ def test_response(): "tax": 9.0, "extra_optional_attributes": "alias_query", } + if not PYDANTIC_V2: + expected_response.pop("extra_optional_attributes") + expected_response["extra_optional_attributes_alias"] = None response = client.get( "/item", params={ From 2438c17823a90d5d5ffc97053680ba81cd97d106 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Nov 2023 03:38:22 +0000 Subject: [PATCH 15/17] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_dependency_schema_query.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_dependency_schema_query.py b/tests/test_dependency_schema_query.py index b7307519a..30204a557 100644 --- a/tests/test_dependency_schema_query.py +++ b/tests/test_dependency_schema_query.py @@ -410,7 +410,7 @@ def test_response(): } if not PYDANTIC_V2: expected_response.pop("extra_optional_attributes") - expected_response["extra_optional_attributes_alias"] = None + expected_response["extra_optional_attributes_alias"] = None response = client.get( "/item", params={ @@ -440,7 +440,7 @@ def test_response(): } if not PYDANTIC_V2: expected_response.pop("extra_optional_attributes") - expected_response["extra_optional_attributes_alias"] = None + expected_response["extra_optional_attributes_alias"] = None response = client.get( "/item", params={ From 121fb20f26572cea26a8b60e7f65a8673d3562e8 Mon Sep 17 00:00:00 2001 From: jujumilk3 Date: Sat, 4 Nov 2023 14:36:44 +0900 Subject: [PATCH 16/17] simple condition --- fastapi/dependencies/utils.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index f5bc230ce..e82e210dd 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -239,22 +239,13 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: typed_params = [] for param in signature.parameters.values(): - if param.name in alias_dict: - original_param_name = alias_dict[param.name] - created_param = inspect.Parameter( - name=original_param_name, - kind=param.kind, - default=params.Param(**query_extra_info[original_param_name]), - annotation=get_typed_annotation(param.annotation, globalns), - ) - else: - created_param = inspect.Parameter( - name=param.name, - kind=param.kind, - default=params.Param(**query_extra_info[param.name]), - annotation=get_typed_annotation(param.annotation, globalns), - ) - + param_name = param.name if param.name not in alias_dict else alias_dict[param.name] + created_param = inspect.Parameter( + name=param_name, + kind=param.kind, + default=params.Param(**query_extra_info[param_name]), + annotation=get_typed_annotation(param.annotation, globalns), + ) typed_params.append(created_param) else: typed_params = [ From 53b08f50d4d8e1696c9ee97926cc9c314c8937a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Nov 2023 05:37:44 +0000 Subject: [PATCH 17/17] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/dependencies/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index e82e210dd..9d99e9e69 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -239,7 +239,9 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: typed_params = [] for param in signature.parameters.values(): - param_name = param.name if param.name not in alias_dict else alias_dict[param.name] + param_name = ( + param.name if param.name not in alias_dict else alias_dict[param.name] + ) created_param = inspect.Parameter( name=param_name, kind=param.kind,