Browse Source

🐛 Use caching logic to determine OpenAPI spec for duplicate dependencies (#417)

pull/465/head
dmontagu 6 years ago
committed by Sebastián Ramírez
parent
commit
483eb73b26
  1. 17
      fastapi/dependencies/utils.py
  2. 4
      fastapi/openapi/utils.py
  3. 103
      tests/test_repeated_dependency_schema.py

17
fastapi/dependencies/utils.py

@ -108,7 +108,16 @@ def get_sub_dependant(
return sub_dependant
def get_flat_dependant(dependant: Dependant) -> Dependant:
CacheKey = Tuple[Optional[Callable], Tuple[str, ...]]
def get_flat_dependant(
dependant: Dependant, *, skip_repeats: bool = False, visited: List[CacheKey] = None
) -> Dependant:
if visited is None:
visited = []
visited.append(dependant.cache_key)
flat_dependant = Dependant(
path_params=dependant.path_params.copy(),
query_params=dependant.query_params.copy(),
@ -120,7 +129,11 @@ def get_flat_dependant(dependant: Dependant) -> Dependant:
path=dependant.path,
)
for sub_dependant in dependant.dependencies:
flat_sub = get_flat_dependant(sub_dependant)
if skip_repeats and sub_dependant.cache_key in visited:
continue
flat_sub = get_flat_dependant(
sub_dependant, skip_repeats=skip_repeats, visited=visited
)
flat_dependant.path_params.extend(flat_sub.path_params)
flat_dependant.query_params.extend(flat_sub.query_params)
flat_dependant.header_params.extend(flat_sub.header_params)

4
fastapi/openapi/utils.py

@ -45,7 +45,7 @@ validation_error_response_definition = {
def get_openapi_params(dependant: Dependant) -> List[Field]:
flat_dependant = get_flat_dependant(dependant)
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
return (
flat_dependant.path_params
+ flat_dependant.query_params
@ -150,7 +150,7 @@ def get_openapi_path(
for method in route.methods:
operation = get_openapi_operation_metadata(route=route, method=method)
parameters: List[Dict] = []
flat_dependant = get_flat_dependant(route.dependant)
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
security_definitions, operation_security = get_openapi_security_definitions(
flat_dependant=flat_dependant
)

103
tests/test_repeated_dependency_schema.py

@ -0,0 +1,103 @@
from fastapi import Depends, FastAPI, Header
from starlette.status import HTTP_200_OK
from starlette.testclient import TestClient
app = FastAPI()
def get_header(*, someheader: str = Header(...)):
return someheader
def get_something_else(*, someheader: str = Depends(get_header)):
return f"{someheader}123"
@app.get("/")
def get_deps(dep1: str = Depends(get_header), dep2: str = Depends(get_something_else)):
return {"dep1": dep1, "dep2": dep2}
client = TestClient(app)
schema = {
"components": {
"schemas": {
"HTTPValidationError": {
"properties": {
"detail": {
"items": {"$ref": "#/components/schemas/ValidationError"},
"title": "Detail",
"type": "array",
}
},
"title": "HTTPValidationError",
"type": "object",
},
"ValidationError": {
"properties": {
"loc": {
"items": {"type": "string"},
"title": "Location",
"type": "array",
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error " "Type", "type": "string"},
},
"required": ["loc", "msg", "type"],
"title": "ValidationError",
"type": "object",
},
}
},
"info": {"title": "Fast API", "version": "0.1.0"},
"openapi": "3.0.2",
"paths": {
"/": {
"get": {
"operationId": "get_deps__get",
"parameters": [
{
"in": "header",
"name": "someheader",
"required": True,
"schema": {"title": "Someheader", "type": "string"},
}
],
"responses": {
"200": {
"content": {"application/json": {"schema": {}}},
"description": "Successful " "Response",
},
"422": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
"description": "Validation " "Error",
},
},
"summary": "Get Deps",
}
}
},
}
def test_schema():
response = client.get("/openapi.json")
assert response.status_code == HTTP_200_OK
actual_schema = response.json()
assert actual_schema == schema
assert (
len(actual_schema["paths"]["/"]["get"]["parameters"]) == 1
) # primary goal of this test
def test_response():
response = client.get("/", headers={"someheader": "hello"})
assert response.status_code == HTTP_200_OK
assert response.json() == {"dep1": "hello", "dep2": "hello123"}
Loading…
Cancel
Save