diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index c898ab7db..f9e42d0a8 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -108,7 +108,16 @@ def get_sub_dependant( return sub_dependant -def get_flat_dependant(dependant: Dependant) -> Dependant: +CacheKey = Tuple[Optional[Callable], Tuple[str, ...]] + + +def get_flat_dependant( + dependant: Dependant, *, skip_repeats: bool = False, visited: List[CacheKey] = None +) -> Dependant: + if visited is None: + visited = [] + visited.append(dependant.cache_key) + flat_dependant = Dependant( path_params=dependant.path_params.copy(), query_params=dependant.query_params.copy(), @@ -120,7 +129,11 @@ def get_flat_dependant(dependant: Dependant) -> Dependant: path=dependant.path, ) for sub_dependant in dependant.dependencies: - flat_sub = get_flat_dependant(sub_dependant) + if skip_repeats and sub_dependant.cache_key in visited: + continue + flat_sub = get_flat_dependant( + sub_dependant, skip_repeats=skip_repeats, visited=visited + ) flat_dependant.path_params.extend(flat_sub.path_params) flat_dependant.query_params.extend(flat_sub.query_params) flat_dependant.header_params.extend(flat_sub.header_params) diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 9c043103d..718c7de27 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -45,7 +45,7 @@ validation_error_response_definition = { def get_openapi_params(dependant: Dependant) -> List[Field]: - flat_dependant = get_flat_dependant(dependant) + flat_dependant = get_flat_dependant(dependant, skip_repeats=True) return ( flat_dependant.path_params + flat_dependant.query_params @@ -150,7 +150,7 @@ def get_openapi_path( for method in route.methods: operation = get_openapi_operation_metadata(route=route, method=method) parameters: List[Dict] = [] - flat_dependant = get_flat_dependant(route.dependant) + flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True) security_definitions, operation_security = get_openapi_security_definitions( flat_dependant=flat_dependant ) diff --git a/tests/test_repeated_dependency_schema.py b/tests/test_repeated_dependency_schema.py new file mode 100644 index 000000000..5b8ba82c0 --- /dev/null +++ b/tests/test_repeated_dependency_schema.py @@ -0,0 +1,103 @@ +from fastapi import Depends, FastAPI, Header +from starlette.status import HTTP_200_OK +from starlette.testclient import TestClient + +app = FastAPI() + + +def get_header(*, someheader: str = Header(...)): + return someheader + + +def get_something_else(*, someheader: str = Depends(get_header)): + return f"{someheader}123" + + +@app.get("/") +def get_deps(dep1: str = Depends(get_header), dep2: str = Depends(get_something_else)): + return {"dep1": dep1, "dep2": dep2} + + +client = TestClient(app) + +schema = { + "components": { + "schemas": { + "HTTPValidationError": { + "properties": { + "detail": { + "items": {"$ref": "#/components/schemas/ValidationError"}, + "title": "Detail", + "type": "array", + } + }, + "title": "HTTPValidationError", + "type": "object", + }, + "ValidationError": { + "properties": { + "loc": { + "items": {"type": "string"}, + "title": "Location", + "type": "array", + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error " "Type", "type": "string"}, + }, + "required": ["loc", "msg", "type"], + "title": "ValidationError", + "type": "object", + }, + } + }, + "info": {"title": "Fast API", "version": "0.1.0"}, + "openapi": "3.0.2", + "paths": { + "/": { + "get": { + "operationId": "get_deps__get", + "parameters": [ + { + "in": "header", + "name": "someheader", + "required": True, + "schema": {"title": "Someheader", "type": "string"}, + } + ], + "responses": { + "200": { + "content": {"application/json": {"schema": {}}}, + "description": "Successful " "Response", + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation " "Error", + }, + }, + "summary": "Get Deps", + } + } + }, +} + + +def test_schema(): + response = client.get("/openapi.json") + assert response.status_code == HTTP_200_OK + actual_schema = response.json() + assert actual_schema == schema + assert ( + len(actual_schema["paths"]["/"]["get"]["parameters"]) == 1 + ) # primary goal of this test + + +def test_response(): + response = client.get("/", headers={"someheader": "hello"}) + assert response.status_code == HTTP_200_OK + assert response.json() == {"dep1": "hello", "dep2": "hello123"}