Browse Source

Merge 63367915ed into 8af92a6139

pull/12406/merge
Peter Volf 4 days ago
committed by GitHub
parent
commit
818cf23f51
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 14
      fastapi/dependencies/models.py
  2. 43
      fastapi/dependencies/utils.py
  3. 61
      tests/test_duplicate_special_dependencies.py

14
fastapi/dependencies/models.py

@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Sequence, Tuple from typing import Any, Callable, List, Optional, Sequence, Set, Tuple
from fastapi._compat import ModelField from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
@ -22,12 +22,12 @@ class Dependant:
security_requirements: List[SecurityRequirement] = field(default_factory=list) security_requirements: List[SecurityRequirement] = field(default_factory=list)
name: Optional[str] = None name: Optional[str] = None
call: Optional[Callable[..., Any]] = None call: Optional[Callable[..., Any]] = None
request_param_name: Optional[str] = None request_param_names: Set[str] = field(default_factory=set)
websocket_param_name: Optional[str] = None websocket_param_names: Set[str] = field(default_factory=set)
http_connection_param_name: Optional[str] = None http_connection_param_names: Set[str] = field(default_factory=set)
response_param_name: Optional[str] = None response_param_names: Set[str] = field(default_factory=set)
background_tasks_param_name: Optional[str] = None background_tasks_param_names: Set[str] = field(default_factory=set)
security_scopes_param_name: Optional[str] = None security_scopes_param_names: Set[str] = field(default_factory=set)
security_scopes: Optional[List[str]] = None security_scopes: Optional[List[str]] = None
use_cache: bool = True use_cache: bool = True
path: Optional[str] = None path: Optional[str] = None

43
fastapi/dependencies/utils.py

@ -318,22 +318,22 @@ def add_non_field_param_to_dependency(
*, param_name: str, type_annotation: Any, dependant: Dependant *, param_name: str, type_annotation: Any, dependant: Dependant
) -> Optional[bool]: ) -> Optional[bool]:
if lenient_issubclass(type_annotation, Request): if lenient_issubclass(type_annotation, Request):
dependant.request_param_name = param_name dependant.request_param_names.add(param_name)
return True return True
elif lenient_issubclass(type_annotation, WebSocket): elif lenient_issubclass(type_annotation, WebSocket):
dependant.websocket_param_name = param_name dependant.websocket_param_names.add(param_name)
return True return True
elif lenient_issubclass(type_annotation, HTTPConnection): elif lenient_issubclass(type_annotation, HTTPConnection):
dependant.http_connection_param_name = param_name dependant.http_connection_param_names.add(param_name)
return True return True
elif lenient_issubclass(type_annotation, Response): elif lenient_issubclass(type_annotation, Response):
dependant.response_param_name = param_name dependant.response_param_names.add(param_name)
return True return True
elif lenient_issubclass(type_annotation, StarletteBackgroundTasks): elif lenient_issubclass(type_annotation, StarletteBackgroundTasks):
dependant.background_tasks_param_name = param_name dependant.background_tasks_param_names.add(param_name)
return True return True
elif lenient_issubclass(type_annotation, SecurityScopes): elif lenient_issubclass(type_annotation, SecurityScopes):
dependant.security_scopes_param_name = param_name dependant.security_scopes_param_names.add(param_name)
return True return True
return None return None
@ -670,22 +670,25 @@ async def solve_dependencies(
) )
values.update(body_values) values.update(body_values)
errors.extend(body_errors) errors.extend(body_errors)
if dependant.http_connection_param_name: for name in dependant.http_connection_param_names:
values[dependant.http_connection_param_name] = request values[name] = request
if dependant.request_param_name and isinstance(request, Request): if isinstance(request, Request):
values[dependant.request_param_name] = request for name in dependant.request_param_names:
elif dependant.websocket_param_name and isinstance(request, WebSocket): values[name] = request
values[dependant.websocket_param_name] = request elif isinstance(request, WebSocket):
if dependant.background_tasks_param_name: for name in dependant.websocket_param_names:
values[name] = request
if dependant.background_tasks_param_names:
if background_tasks is None: if background_tasks is None:
background_tasks = BackgroundTasks() background_tasks = BackgroundTasks()
values[dependant.background_tasks_param_name] = background_tasks for name in dependant.background_tasks_param_names:
if dependant.response_param_name: values[name] = background_tasks
values[dependant.response_param_name] = response for name in dependant.response_param_names:
if dependant.security_scopes_param_name: values[name] = response
values[dependant.security_scopes_param_name] = SecurityScopes( if dependant.security_scopes_param_names:
scopes=dependant.security_scopes security_scope = SecurityScopes(scopes=dependant.security_scopes)
) for name in dependant.security_scopes_param_names:
values[name] = security_scope
return SolvedDependency( return SolvedDependency(
values=values, values=values,
errors=errors, errors=errors,

61
tests/test_duplicate_special_dependencies.py

@ -0,0 +1,61 @@
import pytest
from fastapi import BackgroundTasks, FastAPI, Request, Response, WebSocket
from fastapi.security import SecurityScopes
from fastapi.testclient import TestClient
app = FastAPI()
@app.get("/request")
def request(r1: Request, r2: Request) -> str:
assert r1 is not None
assert r1 is r2
return "success"
@app.get("/response")
def response(r1: Response, r2: Response) -> str:
assert r1 is not None
assert r1 is r2
return "success"
@app.get("/background-tasks")
def background_tasks(t1: BackgroundTasks, t2: BackgroundTasks) -> str:
assert t1 is not None
assert t1 is t2
return "success"
@app.get("/security-scopes")
def security_scopes(sc1: SecurityScopes, sc2: SecurityScopes) -> str:
assert sc1 is not None
assert sc1 is sc2
return "success"
@app.websocket("/websocket")
async def websocket(ws1: WebSocket, ws2: WebSocket) -> str:
assert ws1 is ws2
await ws1.accept()
await ws1.send_text("success")
await ws1.close()
@pytest.mark.parametrize(
"url",
(
"/request",
"/response",
"/background-tasks",
"/security-scopes",
),
)
def test_duplicate_special_dependency(url: str) -> None:
assert TestClient(app).get(url).text == '"success"'
def test_duplicate_websocket_dependency() -> None:
with TestClient(app).websocket_connect("/websocket") as ws:
text = ws.receive_text()
assert text == "success"
Loading…
Cancel
Save