diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 5cebbf00f..d64c47c70 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -113,6 +113,7 @@ def get_param_sub_dependant( depends: params.Depends, path: str, security_scopes: Optional[List[str]] = None, + dependency_overrides: Optional[Dict[Callable[..., Any], Callable[..., Any]]] = None, ) -> Dependant: assert depends.dependency return get_sub_dependant( @@ -121,6 +122,7 @@ def get_param_sub_dependant( path=path, name=param_name, security_scopes=security_scopes, + dependency_overrides=dependency_overrides, ) @@ -138,7 +140,9 @@ def get_sub_dependant( path: str, name: Optional[str] = None, security_scopes: Optional[List[str]] = None, + dependency_overrides: Optional[Dict[Callable[..., Any], Callable[..., Any]]] = None, ) -> Dependant: + dependency = (dependency_overrides or {}).get(dependency, dependency) security_requirement = None security_scopes = security_scopes or [] if isinstance(depends, params.Security): @@ -157,6 +161,7 @@ def get_sub_dependant( name=name, security_scopes=security_scopes, use_cache=depends.use_cache, + dependency_overrides=dependency_overrides, ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) @@ -254,6 +259,46 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any: return get_typed_annotation(annotation, globalns) +def get_resolved_dependant( + *, + dependant: Dependant, + dependency_overrides: Optional[Dict[Callable[..., Any], Callable[..., Any]]] = None, +) -> Dependant: + new_call = call = dependant.call + if call: + new_call = (dependency_overrides or {}).get(call) + if new_call: + resolved_dependant = get_dependant( + path=dependant.path or "", + call=new_call, + name=dependant.name, + security_scopes=dependant.security_scopes, + use_cache=False, + dependency_overrides=dependency_overrides, + ) + else: + resolved_dependant = Dependant( + call=dependant.call, + name=dependant.name, + path=dependant.path, + security_scopes=dependant.security_scopes, + use_cache=False, + path_params=dependant.path_params.copy(), + query_params=dependant.query_params.copy(), + header_params=dependant.header_params.copy(), + cookie_params=dependant.cookie_params.copy(), + body_params=dependant.body_params.copy(), + security_schemes=dependant.security_requirements.copy(), + ) + for sub_dependant in dependant.dependencies: + resolved_dependant.dependencies.append( + get_resolved_dependant( + dependant=sub_dependant, dependency_overrides=dependency_overrides + ) + ) + return resolved_dependant + + def get_dependant( *, path: str, @@ -261,7 +306,9 @@ def get_dependant( name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, + dependency_overrides: Optional[Dict[Callable[..., Any], Callable[..., Any]]] = None, ) -> Dependant: + call = (dependency_overrides or {}).get(call, call) path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters @@ -286,6 +333,7 @@ def get_dependant( depends=param_details.depends, path=path, security_scopes=security_scopes, + dependency_overrides=dependency_overrides, ) dependant.dependencies.append(sub_dependant) continue diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 947eca948..a27304fb6 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -20,6 +20,7 @@ from fastapi.dependencies.utils import ( _get_flat_fields_from_params, get_flat_dependant, get_flat_params, + get_resolved_dependant, ) from fastapi.encoders import jsonable_encoder from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE @@ -256,8 +257,17 @@ def get_openapi_path( operation = get_openapi_operation_metadata( route=route, method=method, operation_ids=operation_ids ) + dependency_overrides = None + if route.dependency_overrides_provider: + dependency_overrides = ( + route.dependency_overrides_provider.dependency_overrides + ) + dependant = get_resolved_dependant( + dependant=route.dependant, + dependency_overrides=dependency_overrides, + ) parameters: List[Dict[str, Any]] = [] - flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True) + flat_dependant = get_flat_dependant(dependant, skip_repeats=True) security_definitions, operation_security = get_openapi_security_definitions( flat_dependant=flat_dependant ) @@ -265,8 +275,14 @@ def get_openapi_path( operation.setdefault("security", []).extend(operation_security) if security_definitions: security_schemes.update(security_definitions) +<<<<<<< HEAD operation_parameters = _get_openapi_operation_parameters( dependant=route.dependant, +======= + all_route_params = get_flat_params(dependant) + operation_parameters = get_openapi_operation_parameters( + all_route_params=all_route_params, +>>>>>>> 022f1e79 (Fix openapi document with dependencies override (#5451)) schema_generator=schema_generator, model_name_map=model_name_map, field_mapping=field_mapping, diff --git a/tests/test_dependency_overrides_openapi.py b/tests/test_dependency_overrides_openapi.py new file mode 100644 index 000000000..82bada04c --- /dev/null +++ b/tests/test_dependency_overrides_openapi.py @@ -0,0 +1,193 @@ +from typing import Optional + +from fastapi import APIRouter, Depends, FastAPI +from fastapi.testclient import TestClient + +app = FastAPI() + +router = APIRouter() + + +async def common_parameters(q: str, skip: int = 0, limit: int = 100): + pass # pragma: no cover + + +@app.get("/main-depends/") +async def main_depends(commons: dict = Depends(common_parameters)): + pass # pragma: no cover + + +app.include_router(router) + +client = TestClient(app) + + +async def overrider_dependency_simple(q: Optional[str] = None): + pass # pragma: no cover + + +async def overrider_sub_dependency(k: str): + pass # pragma: no cover + + +async def overrider_dependency_with_sub(msg: dict = Depends(overrider_sub_dependency)): + pass # pragma: no cover + + +override_simple_openapi_schema = { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/main-depends/": { + "get": { + "summary": "Main Depends", + "operationId": "main_depends_main_depends__get", + "parameters": [ + { + "required": False, + "schema": {"title": "Q", "type": "string"}, + "name": "q", + "in": "query", + }, + ], + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + } + }, + "components": { + "schemas": { + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + } + }, +} + + +def test_override_simple_openapi(): + app.dependency_overrides[common_parameters] = overrider_dependency_simple + app.openapi_schema = None + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == override_simple_openapi_schema + + +overrider_dependency_with_sub_schema = { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/main-depends/": { + "get": { + "summary": "Main Depends", + "operationId": "main_depends_main_depends__get", + "parameters": [ + { + "required": True, + "schema": {"title": "K", "type": "string"}, + "name": "k", + "in": "query", + }, + ], + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + } + }, + "components": { + "schemas": { + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + } + }, +} + + +def test_overrider_dependency_with_sub(): + app.dependency_overrides[common_parameters] = overrider_dependency_with_sub + app.openapi_schema = None + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == overrider_dependency_with_sub_schema + + +def test_overrider_dependency_with_overriden_sub(): + app.dependency_overrides[common_parameters] = overrider_dependency_with_sub + app.dependency_overrides[overrider_sub_dependency] = overrider_dependency_simple + app.openapi_schema = None + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == override_simple_openapi_schema diff --git a/tests/test_dependency_security_overrides_openapi.py b/tests/test_dependency_security_overrides_openapi.py new file mode 100644 index 000000000..f6b55cf15 --- /dev/null +++ b/tests/test_dependency_security_overrides_openapi.py @@ -0,0 +1,146 @@ +from fastapi import Depends, FastAPI, Header +from fastapi.security import OAuth2PasswordBearer +from fastapi.testclient import TestClient + +app = FastAPI() + + +def get_user_id() -> int: + pass # pragma: no cover + + +def get_user(user_id=Depends(get_user_id)): + pass # pragma: no cover + + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +def get_user_id_from_auth_override(token: str = Depends(oauth2_scheme)): + pass # pragma: no cover + + +def get_user_id_from_header_override(user_id: int = Header()): + pass # pragma: no cover + + +@app.get("/user") +def read_user( + user: str = Depends(get_user), +): + pass # pragma: no cover + + +client = TestClient(app) + + +override_with_security_schema = { + "components": { + "securitySchemes": { + "OAuth2PasswordBearer": { + "flows": {"password": {"scopes": {}, "tokenUrl": "token"}}, + "type": "oauth2", + } + } + }, + "info": {"title": "FastAPI", "version": "0.1.0"}, + "openapi": "3.1.0", + "paths": { + "/user": { + "get": { + "operationId": "read_user_user_get", + "responses": { + "200": { + "content": {"application/json": {"schema": {}}}, + "description": "Successful " "Response", + } + }, + "security": [{"OAuth2PasswordBearer": []}], + "summary": "Read User", + } + } + }, +} + + +def test_override_with_security(): + app.dependency_overrides[get_user_id] = get_user_id_from_auth_override + app.openapi_schema = None + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == override_with_security_schema + + +override_with_header_schema = { + "components": { + "schemas": { + "HTTPValidationError": { + "properties": { + "detail": { + "items": {"$ref": "#/components/schemas/ValidationError"}, + "title": "Detail", + "type": "array", + } + }, + "title": "HTTPValidationError", + "type": "object", + }, + "ValidationError": { + "properties": { + "loc": { + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "title": "Location", + "type": "array", + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error " "Type", "type": "string"}, + }, + "required": ["loc", "msg", "type"], + "title": "ValidationError", + "type": "object", + }, + } + }, + "info": {"title": "FastAPI", "version": "0.1.0"}, + "openapi": "3.1.0", + "paths": { + "/user": { + "get": { + "operationId": "read_user_user_get", + "parameters": [ + { + "in": "header", + "name": "user-id", + "required": True, + "schema": {"title": "User-Id", "type": "integer"}, + } + ], + "responses": { + "200": { + "content": {"application/json": {"schema": {}}}, + "description": "Successful " "Response", + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation " "Error", + }, + }, + "summary": "Read User", + } + } + }, +} + + +def test_override_with_header(): + app.dependency_overrides[get_user_id] = get_user_id_from_header_override + app.openapi_schema = None + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == override_with_header_schema