diff --git a/fastapi/django/__init__.py b/fastapi/django/__init__.py new file mode 100644 index 000000000..74111998b --- /dev/null +++ b/fastapi/django/__init__.py @@ -0,0 +1,82 @@ +from contextvars import ContextVar +from typing import Annotated + +from django.core.asgi import ASGIHandler +from django.http import HttpRequest, HttpResponse, StreamingHttpResponse +from fastapi import Depends, Request, Response +from fastapi.responses import StreamingResponse +from starlette.middleware.base import BaseHTTPMiddleware + +_django_request = ContextVar[HttpRequest | None]("fastapi_django_request", default=None) + + +def get_django_request(): + django_request = _django_request.get() + + if not django_request: + raise ValueError( + "Django Request not found, did you forget to add the Django Middleware?" + ) + + return django_request + + +DjangoRequestDep = Annotated[HttpRequest, Depends(get_django_request)] + + +class DjangoMiddleware(BaseHTTPMiddleware, ASGIHandler): + """A FastAPI Middleware that runs the Django HTTP Request lifecycle. + + This middleware is responsible for running the Django HTTP Request lifecycle + in the FastAPI application. It is useful when you want to use Django's + authentication system, or any other Django feature that requires the + Django Request object to be available.""" + + def __init__(self, *args, **kwargs): + ASGIHandler.__init__(self) + + super().__init__(*args, **kwargs) + + async def _get_response_async(self, request): + fastapi_response = await self._call_next(request) + + assert isinstance(fastapi_response, StreamingResponse) + + return StreamingHttpResponse( + streaming_content=fastapi_response.body_iterator, + headers=fastapi_response.headers, + status=fastapi_response.status_code, + ) + + async def __call__(self, scope, receive, send): + self._django_request, _ = self.create_request(scope, "") + + _django_request.set(self._django_request) + + await BaseHTTPMiddleware.__call__(self, scope, receive, send) + + async def dispatch(self, request: Request, call_next): + self._call_next = call_next + + django_response = await self.get_response_async(self._django_request) + + if isinstance(django_response, HttpResponse): + return Response( + status_code=django_response.status_code, + content=django_response.content, + headers=django_response.headers, + ) + + if isinstance(django_response, StreamingHttpResponse): + + async def streaming(): + async for chunk in django_response.streaming_content: + yield chunk + + return StreamingResponse( + status_code=django_response.status_code, + content=streaming(), + headers=django_response.headers, + ) + + return Response(status_code=500) diff --git a/requirements-tests.txt b/requirements-tests.txt index bfe70f2f5..5512d600b 100644 --- a/requirements-tests.txt +++ b/requirements-tests.txt @@ -18,3 +18,6 @@ passlib[bcrypt] >=1.7.2,<2.0.0 # types types-ujson ==5.7.0.1 types-orjson ==3.6.2 + +# django +django ==5.0.6 diff --git a/tests/django/__init__.py b/tests/django/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/django/conftest.py b/tests/django/conftest.py new file mode 100644 index 000000000..2cb0ac22e --- /dev/null +++ b/tests/django/conftest.py @@ -0,0 +1,55 @@ +import django +import pytest +from django.conf import settings +from django.core.management.color import no_style +from django.core.management.sql import sql_flush +from django.db import connection + +settings.configure( + SECRET_KEY="not_very", + ROOT_URLCONF="tests.django.proj.urls", + INSTALLED_APPS=[ + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + ], + MIDDLEWARE=[ + "django.contrib.sessions.middleware.SessionMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + ], + DATABASES={ + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": ":memory:", + } + }, +) + +django.setup() + + +@pytest.fixture(scope="session", autouse=True) +def django_db_setup(): + connection.creation.create_test_db(verbosity=0, autoclobber=True) + + yield + + connection.creation.destroy_test_db("default", verbosity=0) + + +@pytest.fixture(autouse=True) +def flush_db(): + sql_list = sql_flush(no_style(), connection, allow_cascade=False) + + connection.ops.execute_sql_flush(sql_list) + + +@pytest.fixture +def authenticated_session_id(): + from django.contrib.auth.models import User + + from tests.django.utils import create_authenticated_session + + user = User.objects.create_user(username="test", password="test") + + return create_authenticated_session(user) diff --git a/tests/django/proj/__init__.py b/tests/django/proj/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/django/proj/urls.py b/tests/django/proj/urls.py new file mode 100644 index 000000000..637600f58 --- /dev/null +++ b/tests/django/proj/urls.py @@ -0,0 +1 @@ +urlpatterns = [] diff --git a/tests/django/test_django_request.py b/tests/django/test_django_request.py new file mode 100644 index 000000000..84bd0c43e --- /dev/null +++ b/tests/django/test_django_request.py @@ -0,0 +1,53 @@ +from django.contrib.auth import ( + aget_user, +) +from fastapi import FastAPI +from fastapi.django import DjangoMiddleware, DjangoRequestDep +from fastapi.testclient import TestClient + +app = FastAPI() +app.add_middleware(DjangoMiddleware) + + +@app.get("/") +async def root(): + return {"message": "Hello World"} + + +@app.get("/current-user") +async def django_user(django_request: DjangoRequestDep): + user = await aget_user(django_request) + + if not user.is_authenticated: + return {"error": "User not authenticated"} + + return {"username": user.username} + + +client = TestClient(app) + + +def test_unauthenticated(): + response = client.get("/current-user") + + assert response.status_code == 200 + + assert response.json() == {"error": "User not authenticated"} + + +def test_authenticated(authenticated_session_id: str): + client.cookies.set("sessionid", authenticated_session_id) + + response = client.get("/current-user") + + assert response.status_code == 200 + + assert response.json() == {"username": "test"} + + +def test_route_with_no_django_request(): + response = client.get("/") + + assert response.status_code == 200 + + assert response.json() == {"message": "Hello World"} diff --git a/tests/django/test_django_request_no_middleware.py b/tests/django/test_django_request_no_middleware.py new file mode 100644 index 000000000..f9e059b3f --- /dev/null +++ b/tests/django/test_django_request_no_middleware.py @@ -0,0 +1,24 @@ +import pytest +from fastapi import FastAPI +from fastapi.django import DjangoRequestDep +from fastapi.testclient import TestClient + +app = FastAPI() + + +@app.get("/") +async def django_user(django_request: DjangoRequestDep): + user = django_request.user + + if not user.is_authenticated: + return {"error": "User not authenticated"} + + return {"username": user.username} + + +client = TestClient(app) + + +def test_returns_an_error(): + with pytest.raises(ValueError, match="Django Request not found"): + client.get("/") diff --git a/tests/django/utils.py b/tests/django/utils.py new file mode 100644 index 000000000..f29a179e9 --- /dev/null +++ b/tests/django/utils.py @@ -0,0 +1,24 @@ +from importlib import import_module + +from django.conf import settings +from django.contrib.auth import BACKEND_SESSION_KEY, HASH_SESSION_KEY, SESSION_KEY + + +def create_authenticated_session(user): + """Creates an authenticated session for the given user.""" + + engine = import_module(settings.SESSION_ENGINE) + session = engine.SessionStore() + session.create() + + session[SESSION_KEY] = str(user.id) + session[BACKEND_SESSION_KEY] = ( + user.backend + if hasattr(user, "backend") + else "django.contrib.auth.backends.ModelBackend" + ) + session[HASH_SESSION_KEY] = user.get_session_auth_hash() + + session.save() + + return session.session_key