Browse Source

Merge 8413c755a2 into 8032e21418

pull/5452/merge
Laurent Mignon (ACSONE) 1 day ago
committed by GitHub
parent
commit
59c24257a0
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 15
      fastapi/_compat.py
  2. 48
      fastapi/dependencies/utils.py
  3. 24
      fastapi/openapi/utils.py
  4. 201
      tests/test_dependency_overrides_openapi.py
  5. 107
      tests/test_dependency_security_overrides_openapi.py

15
fastapi/_compat.py

@ -197,7 +197,20 @@ if PYDANTIC_V2:
None if separate_input_output_schemas else "validation"
)
# This expects that GenerateJsonSchema was already used to generate the definitions
json_schema = field_mapping[(field, override_mode or field.mode)]
try:
json_schema = field_mapping[(field, override_mode or field.mode)]
except KeyError:
inputs = [
(field, override_mode or field.mode, field._type_adapter.core_schema)
]
new_generator = GenerateJsonSchema(
ref_template=schema_generator.ref_template
)
new_field_mapping, definitions = new_generator.generate_definitions(
inputs=inputs
)
field_mapping.update(new_field_mapping)
json_schema = field_mapping[(field, override_mode or field.mode)]
if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207

48
fastapi/dependencies/utils.py

@ -121,6 +121,7 @@ def get_param_sub_dependant(
depends: params.Depends,
path: str,
security_scopes: Optional[List[str]] = None,
dependency_overrides: Optional[Dict[Callable[..., Any], Callable[..., Any]]] = None,
) -> Dependant:
assert depends.dependency
return get_sub_dependant(
@ -129,6 +130,7 @@ def get_param_sub_dependant(
path=path,
name=param_name,
security_scopes=security_scopes,
dependency_overrides=dependency_overrides,
)
@ -146,7 +148,9 @@ def get_sub_dependant(
path: str,
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
dependency_overrides: Optional[Dict[Callable[..., Any], Callable[..., Any]]] = None,
) -> Dependant:
dependency = (dependency_overrides or {}).get(dependency, dependency)
security_requirement = None
security_scopes = security_scopes or []
if isinstance(depends, params.Security):
@ -165,6 +169,7 @@ def get_sub_dependant(
name=name,
security_scopes=security_scopes,
use_cache=depends.use_cache,
dependency_overrides=dependency_overrides,
)
if security_requirement:
sub_dependant.security_requirements.append(security_requirement)
@ -262,6 +267,46 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
return get_typed_annotation(annotation, globalns)
def get_resolved_dependant(
*,
dependant: Dependant,
dependency_overrides: Optional[Dict[Callable[..., Any], Callable[..., Any]]] = None,
) -> Dependant:
new_call = call = dependant.call
if call:
new_call = (dependency_overrides or {}).get(call)
if new_call:
resolved_dependant = get_dependant(
path=dependant.path or "",
call=new_call,
name=dependant.name,
security_scopes=dependant.security_scopes,
use_cache=False,
dependency_overrides=dependency_overrides,
)
else:
resolved_dependant = Dependant(
call=dependant.call,
name=dependant.name,
path=dependant.path,
security_scopes=dependant.security_scopes,
use_cache=False,
path_params=dependant.path_params.copy(),
query_params=dependant.query_params.copy(),
header_params=dependant.header_params.copy(),
cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(),
security_requirements=dependant.security_requirements.copy(),
)
for sub_dependant in dependant.dependencies:
resolved_dependant.dependencies.append(
get_resolved_dependant(
dependant=sub_dependant, dependency_overrides=dependency_overrides
)
)
return resolved_dependant
def get_dependant(
*,
path: str,
@ -269,7 +314,9 @@ def get_dependant(
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
dependency_overrides: Optional[Dict[Callable[..., Any], Callable[..., Any]]] = None,
) -> Dependant:
call = (dependency_overrides or {}).get(call, call)
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters
@ -294,6 +341,7 @@ def get_dependant(
depends=param_details.depends,
path=path,
security_scopes=security_scopes,
dependency_overrides=dependency_overrides,
)
dependant.dependencies.append(sub_dependant)
continue

24
fastapi/openapi/utils.py

@ -20,6 +20,7 @@ from fastapi.dependencies.utils import (
_get_flat_fields_from_params,
get_flat_dependant,
get_flat_params,
get_resolved_dependant,
)
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE
@ -277,8 +278,17 @@ def get_openapi_path(
operation = get_openapi_operation_metadata(
route=route, method=method, operation_ids=operation_ids
)
dependency_overrides = None
if route.dependency_overrides_provider:
dependency_overrides = (
route.dependency_overrides_provider.dependency_overrides
)
dependant = get_resolved_dependant(
dependant=route.dependant,
dependency_overrides=dependency_overrides,
)
parameters: List[Dict[str, Any]] = []
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
security_definitions, operation_security = get_openapi_security_definitions(
flat_dependant=flat_dependant
)
@ -287,7 +297,7 @@ def get_openapi_path(
if security_definitions:
security_schemes.update(security_definitions)
operation_parameters = _get_openapi_operation_parameters(
dependant=route.dependant,
dependant=dependant,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
@ -466,6 +476,16 @@ def get_fields_from_routes(
if route.callbacks:
callback_flat_models.extend(get_fields_from_routes(route.callbacks))
params = get_flat_params(route.dependant)
dependency_overrides = None
if route.dependency_overrides_provider:
dependency_overrides = (
route.dependency_overrides_provider.dependency_overrides
)
dependant = get_resolved_dependant(
dependant=route.dependant,
dependency_overrides=dependency_overrides,
)
params.extend(get_flat_params(dependant))
request_fields_from_routes.extend(params)
flat_models = callback_flat_models + list(

201
tests/test_dependency_overrides_openapi.py

@ -0,0 +1,201 @@
from typing import Optional
from fastapi import APIRouter, Depends, FastAPI
from fastapi._compat import PYDANTIC_V2
from fastapi.testclient import TestClient
app = FastAPI()
router = APIRouter()
async def common_parameters(q: str, skip: int = 0, limit: int = 100):
pass # pragma: no cover
@app.get("/main-depends/")
async def main_depends(commons: dict = Depends(common_parameters)):
pass # pragma: no cover
app.include_router(router)
client = TestClient(app)
async def overrider_dependency_simple(q: Optional[str] = None):
pass # pragma: no cover
async def overrider_sub_dependency(k: str):
pass # pragma: no cover
async def overrider_dependency_with_sub(msg: dict = Depends(overrider_sub_dependency)):
pass # pragma: no cover
override_simple_openapi_schema = {
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/main-depends/": {
"get": {
"summary": "Main Depends",
"operationId": "main_depends_main_depends__get",
"parameters": [
{
"required": False,
"schema": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"title": "Q",
},
"name": "q",
"in": "query",
},
],
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
}
}
},
"components": {
"schemas": {
"HTTPValidationError": {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": "#/components/schemas/ValidationError"},
}
},
},
"ValidationError": {
"title": "ValidationError",
"required": ["loc", "msg", "type"],
"type": "object",
"properties": {
"loc": {
"title": "Location",
"type": "array",
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
},
}
},
}
if not PYDANTIC_V2:
override_simple_openapi_schema["paths"]["/main-depends/"]["get"]["parameters"][0][
"schema"
] = {"title": "Q", "type": "string"}
def test_override_simple_openapi():
app.dependency_overrides[common_parameters] = overrider_dependency_simple
app.openapi_schema = None
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == override_simple_openapi_schema
overrider_dependency_with_sub_schema = {
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/main-depends/": {
"get": {
"summary": "Main Depends",
"operationId": "main_depends_main_depends__get",
"parameters": [
{
"required": True,
"schema": {"title": "K", "type": "string"},
"name": "k",
"in": "query",
},
],
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
}
}
},
"components": {
"schemas": {
"HTTPValidationError": {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": "#/components/schemas/ValidationError"},
}
},
},
"ValidationError": {
"title": "ValidationError",
"required": ["loc", "msg", "type"],
"type": "object",
"properties": {
"loc": {
"title": "Location",
"type": "array",
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
},
}
},
}
def test_overrider_dependency_with_sub():
app.dependency_overrides[common_parameters] = overrider_dependency_with_sub
app.openapi_schema = None
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == overrider_dependency_with_sub_schema
def test_overrider_dependency_with_overriden_sub():
app.dependency_overrides[common_parameters] = overrider_dependency_with_sub
app.dependency_overrides[overrider_sub_dependency] = overrider_dependency_simple
app.openapi_schema = None
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == override_simple_openapi_schema

107
tests/test_dependency_security_overrides_openapi.py

@ -0,0 +1,107 @@
from fastapi import Depends, FastAPI, Header
from fastapi.security import OAuth2PasswordBearer
from fastapi.testclient import TestClient
app = FastAPI()
def get_user_id() -> int:
pass # pragma: no cover
def get_user(user_id=Depends(get_user_id)):
pass # pragma: no cover
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def get_user_id_from_auth_override(token: str = Depends(oauth2_scheme)):
pass # pragma: no cover
def get_user_id_from_header_override(user_id: int = Header()):
pass # pragma: no cover
@app.get("/user")
def read_user(
user: str = Depends(get_user),
):
pass # pragma: no cover
client = TestClient(app)
override_with_security_schema = {
"components": {
"securitySchemes": {
"OAuth2PasswordBearer": {
"flows": {"password": {"scopes": {}, "tokenUrl": "token"}},
"type": "oauth2",
}
}
},
"info": {"title": "FastAPI", "version": "0.1.0"},
"openapi": "3.1.0",
"paths": {
"/user": {
"get": {
"operationId": "read_user_user_get",
"responses": {
"200": {
"content": {"application/json": {"schema": {}}},
"description": "Successful " "Response",
}
},
"security": [{"OAuth2PasswordBearer": []}],
"summary": "Read User",
}
}
},
}
def test_override_with_security():
app.dependency_overrides[get_user_id] = get_user_id_from_auth_override
app.openapi_schema = None
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == override_with_security_schema
override_with_header_schema = {
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/user": {
"get": {
"summary": "Read User",
"operationId": "read_user_user_get",
"parameters": [
{
"name": "user-id",
"in": "header",
"required": True,
"schema": {"type": "integer", "title": "User-Id"},
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
}
}
},
}
def test_override_with_header():
app.dependency_overrides[get_user_id] = get_user_id_from_header_override
app.openapi_schema = None
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == override_with_header_schema
Loading…
Cancel
Save