diff --git a/fastapi/_compat.py b/fastapi/_compat.py index c07e4a3b0..f019c961d 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -197,7 +197,20 @@ if PYDANTIC_V2: None if separate_input_output_schemas else "validation" ) # This expects that GenerateJsonSchema was already used to generate the definitions - json_schema = field_mapping[(field, override_mode or field.mode)] + try: + json_schema = field_mapping[(field, override_mode or field.mode)] + except KeyError: + inputs = [ + (field, override_mode or field.mode, field._type_adapter.core_schema) + ] + new_generator = GenerateJsonSchema( + ref_template=schema_generator.ref_template + ) + new_field_mapping, definitions = new_generator.generate_definitions( + inputs=inputs + ) + field_mapping.update(new_field_mapping) + json_schema = field_mapping[(field, override_mode or field.mode)] if "$ref" not in json_schema: # TODO remove when deprecating Pydantic v1 # Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207 diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 84dfa4d03..55a9e9d67 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -121,6 +121,7 @@ def get_param_sub_dependant( depends: params.Depends, path: str, security_scopes: Optional[List[str]] = None, + dependency_overrides: Optional[Dict[Callable[..., Any], Callable[..., Any]]] = None, ) -> Dependant: assert depends.dependency return get_sub_dependant( @@ -129,6 +130,7 @@ def get_param_sub_dependant( path=path, name=param_name, security_scopes=security_scopes, + dependency_overrides=dependency_overrides, ) @@ -146,7 +148,9 @@ def get_sub_dependant( path: str, name: Optional[str] = None, security_scopes: Optional[List[str]] = None, + dependency_overrides: Optional[Dict[Callable[..., Any], Callable[..., Any]]] = None, ) -> Dependant: + dependency = (dependency_overrides or {}).get(dependency, dependency) security_requirement = None security_scopes = security_scopes or [] if isinstance(depends, params.Security): @@ -165,6 +169,7 @@ def get_sub_dependant( name=name, security_scopes=security_scopes, use_cache=depends.use_cache, + dependency_overrides=dependency_overrides, ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) @@ -262,6 +267,46 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any: return get_typed_annotation(annotation, globalns) +def get_resolved_dependant( + *, + dependant: Dependant, + dependency_overrides: Optional[Dict[Callable[..., Any], Callable[..., Any]]] = None, +) -> Dependant: + new_call = call = dependant.call + if call: + new_call = (dependency_overrides or {}).get(call) + if new_call: + resolved_dependant = get_dependant( + path=dependant.path or "", + call=new_call, + name=dependant.name, + security_scopes=dependant.security_scopes, + use_cache=False, + dependency_overrides=dependency_overrides, + ) + else: + resolved_dependant = Dependant( + call=dependant.call, + name=dependant.name, + path=dependant.path, + security_scopes=dependant.security_scopes, + use_cache=False, + path_params=dependant.path_params.copy(), + query_params=dependant.query_params.copy(), + header_params=dependant.header_params.copy(), + cookie_params=dependant.cookie_params.copy(), + body_params=dependant.body_params.copy(), + security_requirements=dependant.security_requirements.copy(), + ) + for sub_dependant in dependant.dependencies: + resolved_dependant.dependencies.append( + get_resolved_dependant( + dependant=sub_dependant, dependency_overrides=dependency_overrides + ) + ) + return resolved_dependant + + def get_dependant( *, path: str, @@ -269,7 +314,9 @@ def get_dependant( name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, + dependency_overrides: Optional[Dict[Callable[..., Any], Callable[..., Any]]] = None, ) -> Dependant: + call = (dependency_overrides or {}).get(call, call) path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters @@ -294,6 +341,7 @@ def get_dependant( depends=param_details.depends, path=path, security_scopes=security_scopes, + dependency_overrides=dependency_overrides, ) dependant.dependencies.append(sub_dependant) continue diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 808646cc2..47cba8bdf 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -20,6 +20,7 @@ from fastapi.dependencies.utils import ( _get_flat_fields_from_params, get_flat_dependant, get_flat_params, + get_resolved_dependant, ) from fastapi.encoders import jsonable_encoder from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE @@ -277,8 +278,17 @@ def get_openapi_path( operation = get_openapi_operation_metadata( route=route, method=method, operation_ids=operation_ids ) + dependency_overrides = None + if route.dependency_overrides_provider: + dependency_overrides = ( + route.dependency_overrides_provider.dependency_overrides + ) + dependant = get_resolved_dependant( + dependant=route.dependant, + dependency_overrides=dependency_overrides, + ) parameters: List[Dict[str, Any]] = [] - flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True) + flat_dependant = get_flat_dependant(dependant, skip_repeats=True) security_definitions, operation_security = get_openapi_security_definitions( flat_dependant=flat_dependant ) @@ -287,7 +297,7 @@ def get_openapi_path( if security_definitions: security_schemes.update(security_definitions) operation_parameters = _get_openapi_operation_parameters( - dependant=route.dependant, + dependant=dependant, schema_generator=schema_generator, model_name_map=model_name_map, field_mapping=field_mapping, @@ -466,6 +476,16 @@ def get_fields_from_routes( if route.callbacks: callback_flat_models.extend(get_fields_from_routes(route.callbacks)) params = get_flat_params(route.dependant) + dependency_overrides = None + if route.dependency_overrides_provider: + dependency_overrides = ( + route.dependency_overrides_provider.dependency_overrides + ) + dependant = get_resolved_dependant( + dependant=route.dependant, + dependency_overrides=dependency_overrides, + ) + params.extend(get_flat_params(dependant)) request_fields_from_routes.extend(params) flat_models = callback_flat_models + list( diff --git a/tests/test_dependency_overrides_openapi.py b/tests/test_dependency_overrides_openapi.py new file mode 100644 index 000000000..4fee7366b --- /dev/null +++ b/tests/test_dependency_overrides_openapi.py @@ -0,0 +1,201 @@ +from typing import Optional + +from fastapi import APIRouter, Depends, FastAPI +from fastapi._compat import PYDANTIC_V2 +from fastapi.testclient import TestClient + +app = FastAPI() + +router = APIRouter() + + +async def common_parameters(q: str, skip: int = 0, limit: int = 100): + pass # pragma: no cover + + +@app.get("/main-depends/") +async def main_depends(commons: dict = Depends(common_parameters)): + pass # pragma: no cover + + +app.include_router(router) + +client = TestClient(app) + + +async def overrider_dependency_simple(q: Optional[str] = None): + pass # pragma: no cover + + +async def overrider_sub_dependency(k: str): + pass # pragma: no cover + + +async def overrider_dependency_with_sub(msg: dict = Depends(overrider_sub_dependency)): + pass # pragma: no cover + + +override_simple_openapi_schema = { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/main-depends/": { + "get": { + "summary": "Main Depends", + "operationId": "main_depends_main_depends__get", + "parameters": [ + { + "required": False, + "schema": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "title": "Q", + }, + "name": "q", + "in": "query", + }, + ], + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + } + }, + "components": { + "schemas": { + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + } + }, +} +if not PYDANTIC_V2: + override_simple_openapi_schema["paths"]["/main-depends/"]["get"]["parameters"][0][ + "schema" + ] = {"title": "Q", "type": "string"} + + +def test_override_simple_openapi(): + app.dependency_overrides[common_parameters] = overrider_dependency_simple + app.openapi_schema = None + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == override_simple_openapi_schema + + +overrider_dependency_with_sub_schema = { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/main-depends/": { + "get": { + "summary": "Main Depends", + "operationId": "main_depends_main_depends__get", + "parameters": [ + { + "required": True, + "schema": {"title": "K", "type": "string"}, + "name": "k", + "in": "query", + }, + ], + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + } + }, + "components": { + "schemas": { + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + } + }, +} + + +def test_overrider_dependency_with_sub(): + app.dependency_overrides[common_parameters] = overrider_dependency_with_sub + app.openapi_schema = None + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == overrider_dependency_with_sub_schema + + +def test_overrider_dependency_with_overriden_sub(): + app.dependency_overrides[common_parameters] = overrider_dependency_with_sub + app.dependency_overrides[overrider_sub_dependency] = overrider_dependency_simple + app.openapi_schema = None + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == override_simple_openapi_schema diff --git a/tests/test_dependency_security_overrides_openapi.py b/tests/test_dependency_security_overrides_openapi.py new file mode 100644 index 000000000..70aa5b3ea --- /dev/null +++ b/tests/test_dependency_security_overrides_openapi.py @@ -0,0 +1,107 @@ +from fastapi import Depends, FastAPI, Header +from fastapi.security import OAuth2PasswordBearer +from fastapi.testclient import TestClient + +app = FastAPI() + + +def get_user_id() -> int: + pass # pragma: no cover + + +def get_user(user_id=Depends(get_user_id)): + pass # pragma: no cover + + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +def get_user_id_from_auth_override(token: str = Depends(oauth2_scheme)): + pass # pragma: no cover + + +def get_user_id_from_header_override(user_id: int = Header()): + pass # pragma: no cover + + +@app.get("/user") +def read_user( + user: str = Depends(get_user), +): + pass # pragma: no cover + + +client = TestClient(app) + + +override_with_security_schema = { + "components": { + "securitySchemes": { + "OAuth2PasswordBearer": { + "flows": {"password": {"scopes": {}, "tokenUrl": "token"}}, + "type": "oauth2", + } + } + }, + "info": {"title": "FastAPI", "version": "0.1.0"}, + "openapi": "3.1.0", + "paths": { + "/user": { + "get": { + "operationId": "read_user_user_get", + "responses": { + "200": { + "content": {"application/json": {"schema": {}}}, + "description": "Successful " "Response", + } + }, + "security": [{"OAuth2PasswordBearer": []}], + "summary": "Read User", + } + } + }, +} + + +def test_override_with_security(): + app.dependency_overrides[get_user_id] = get_user_id_from_auth_override + app.openapi_schema = None + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == override_with_security_schema + + +override_with_header_schema = { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/user": { + "get": { + "summary": "Read User", + "operationId": "read_user_user_get", + "parameters": [ + { + "name": "user-id", + "in": "header", + "required": True, + "schema": {"type": "integer", "title": "User-Id"}, + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + } + } + }, +} + + +def test_override_with_header(): + app.dependency_overrides[get_user_id] = get_user_id_from_header_override + app.openapi_schema = None + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == override_with_header_schema