From b9a0179a03b8f7fd4ff396fa37bc05b7f1a39af9 Mon Sep 17 00:00:00 2001 From: Nik Date: Sun, 9 Aug 2020 16:56:41 +0300 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20support=20for=20injecting=20H?= =?UTF-8?q?TTPConnection=20(#1827)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/dependencies/models.py | 2 ++ fastapi/dependencies/utils.py | 7 ++++- fastapi/requests.py | 2 +- tests/test_http_connection_injection.py | 39 +++++++++++++++++++++++++ 4 files changed, 48 insertions(+), 2 deletions(-) create mode 100644 tests/test_http_connection_injection.py diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 586852211..8e0c7830a 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -34,6 +34,7 @@ class Dependant: call: Optional[Callable] = None, request_param_name: Optional[str] = None, websocket_param_name: Optional[str] = None, + http_connection_param_name: Optional[str] = None, response_param_name: Optional[str] = None, background_tasks_param_name: Optional[str] = None, security_scopes_param_name: Optional[str] = None, @@ -50,6 +51,7 @@ class Dependant: self.security_requirements = security_schemes or [] self.request_param_name = request_param_name self.websocket_param_name = websocket_param_name + self.http_connection_param_name = http_connection_param_name self.response_param_name = response_param_name self.background_tasks_param_name = background_tasks_param_name self.security_scopes = security_scopes diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index a45a8fe09..6c49941ad 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -41,7 +41,7 @@ from pydantic.utils import lenient_issubclass from starlette.background import BackgroundTasks from starlette.concurrency import run_in_threadpool from starlette.datastructures import FormData, Headers, QueryParams, UploadFile -from starlette.requests import Request +from starlette.requests import HTTPConnection, Request from starlette.responses import Response from starlette.websockets import WebSocket @@ -371,6 +371,9 @@ def add_non_field_param_to_dependency( elif lenient_issubclass(param.annotation, WebSocket): dependant.websocket_param_name = param.name return True + elif lenient_issubclass(param.annotation, HTTPConnection): + dependant.http_connection_param_name = param.name + return True elif lenient_issubclass(param.annotation, Response): dependant.response_param_name = param.name return True @@ -607,6 +610,8 @@ async def solve_dependencies( ) values.update(body_values) errors.extend(body_errors) + if dependant.http_connection_param_name: + values[dependant.http_connection_param_name] = request if dependant.request_param_name and isinstance(request, Request): values[dependant.request_param_name] = request elif dependant.websocket_param_name and isinstance(request, WebSocket): diff --git a/fastapi/requests.py b/fastapi/requests.py index eb13f0380..06d8f01cc 100644 --- a/fastapi/requests.py +++ b/fastapi/requests.py @@ -1 +1 @@ -from starlette.requests import Request # noqa +from starlette.requests import HTTPConnection, Request # noqa diff --git a/tests/test_http_connection_injection.py b/tests/test_http_connection_injection.py new file mode 100644 index 000000000..6e321b53b --- /dev/null +++ b/tests/test_http_connection_injection.py @@ -0,0 +1,39 @@ +from fastapi import Depends, FastAPI +from fastapi.requests import HTTPConnection +from fastapi.testclient import TestClient +from starlette.websockets import WebSocket + +app = FastAPI() +app.state.value = 42 + + +async def extract_value_from_http_connection(conn: HTTPConnection): + return conn.app.state.value + + +@app.get("/http") +async def get_value_by_http(value: int = Depends(extract_value_from_http_connection)): + return value + + +@app.websocket("/ws") +async def get_value_by_ws( + websocket: WebSocket, value: int = Depends(extract_value_from_http_connection) +): + await websocket.accept() + await websocket.send_json(value) + await websocket.close() + + +client = TestClient(app) + + +def test_value_extracting_by_http(): + response = client.get("/http") + assert response.status_code == 200 + assert response.json() == 42 + + +def test_value_extracting_by_ws(): + with client.websocket_connect("/ws") as websocket: + assert websocket.receive_json() == 42