Browse Source

Add support for injecting HTTPConnection (#1827)

pull/1860/head
Nik 5 years ago
committed by GitHub
parent
commit
b9a0179a03
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      fastapi/dependencies/models.py
  2. 7
      fastapi/dependencies/utils.py
  3. 2
      fastapi/requests.py
  4. 39
      tests/test_http_connection_injection.py

2
fastapi/dependencies/models.py

@ -34,6 +34,7 @@ class Dependant:
call: Optional[Callable] = None, call: Optional[Callable] = None,
request_param_name: Optional[str] = None, request_param_name: Optional[str] = None,
websocket_param_name: Optional[str] = None, websocket_param_name: Optional[str] = None,
http_connection_param_name: Optional[str] = None,
response_param_name: Optional[str] = None, response_param_name: Optional[str] = None,
background_tasks_param_name: Optional[str] = None, background_tasks_param_name: Optional[str] = None,
security_scopes_param_name: Optional[str] = None, security_scopes_param_name: Optional[str] = None,
@ -50,6 +51,7 @@ class Dependant:
self.security_requirements = security_schemes or [] self.security_requirements = security_schemes or []
self.request_param_name = request_param_name self.request_param_name = request_param_name
self.websocket_param_name = websocket_param_name self.websocket_param_name = websocket_param_name
self.http_connection_param_name = http_connection_param_name
self.response_param_name = response_param_name self.response_param_name = response_param_name
self.background_tasks_param_name = background_tasks_param_name self.background_tasks_param_name = background_tasks_param_name
self.security_scopes = security_scopes self.security_scopes = security_scopes

7
fastapi/dependencies/utils.py

@ -41,7 +41,7 @@ from pydantic.utils import lenient_issubclass
from starlette.background import BackgroundTasks from starlette.background import BackgroundTasks
from starlette.concurrency import run_in_threadpool from starlette.concurrency import run_in_threadpool
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
from starlette.requests import Request from starlette.requests import HTTPConnection, Request
from starlette.responses import Response from starlette.responses import Response
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
@ -371,6 +371,9 @@ def add_non_field_param_to_dependency(
elif lenient_issubclass(param.annotation, WebSocket): elif lenient_issubclass(param.annotation, WebSocket):
dependant.websocket_param_name = param.name dependant.websocket_param_name = param.name
return True return True
elif lenient_issubclass(param.annotation, HTTPConnection):
dependant.http_connection_param_name = param.name
return True
elif lenient_issubclass(param.annotation, Response): elif lenient_issubclass(param.annotation, Response):
dependant.response_param_name = param.name dependant.response_param_name = param.name
return True return True
@ -607,6 +610,8 @@ async def solve_dependencies(
) )
values.update(body_values) values.update(body_values)
errors.extend(body_errors) errors.extend(body_errors)
if dependant.http_connection_param_name:
values[dependant.http_connection_param_name] = request
if dependant.request_param_name and isinstance(request, Request): if dependant.request_param_name and isinstance(request, Request):
values[dependant.request_param_name] = request values[dependant.request_param_name] = request
elif dependant.websocket_param_name and isinstance(request, WebSocket): elif dependant.websocket_param_name and isinstance(request, WebSocket):

2
fastapi/requests.py

@ -1 +1 @@
from starlette.requests import Request # noqa from starlette.requests import HTTPConnection, Request # noqa

39
tests/test_http_connection_injection.py

@ -0,0 +1,39 @@
from fastapi import Depends, FastAPI
from fastapi.requests import HTTPConnection
from fastapi.testclient import TestClient
from starlette.websockets import WebSocket
app = FastAPI()
app.state.value = 42
async def extract_value_from_http_connection(conn: HTTPConnection):
return conn.app.state.value
@app.get("/http")
async def get_value_by_http(value: int = Depends(extract_value_from_http_connection)):
return value
@app.websocket("/ws")
async def get_value_by_ws(
websocket: WebSocket, value: int = Depends(extract_value_from_http_connection)
):
await websocket.accept()
await websocket.send_json(value)
await websocket.close()
client = TestClient(app)
def test_value_extracting_by_http():
response = client.get("/http")
assert response.status_code == 200
assert response.json() == 42
def test_value_extracting_by_ws():
with client.websocket_connect("/ws") as websocket:
assert websocket.receive_json() == 42
Loading…
Cancel
Save