Browse Source

Add Django middleware

pull/11807/head
Patrick Arminio 9 months ago
parent
commit
828c71ab82
Failed to extract signature
  1. 82
      fastapi/django/__init__.py
  2. 3
      requirements-tests.txt
  3. 0
      tests/django/__init__.py
  4. 55
      tests/django/conftest.py
  5. 0
      tests/django/proj/__init__.py
  6. 1
      tests/django/proj/urls.py
  7. 53
      tests/django/test_django_request.py
  8. 24
      tests/django/test_django_request_no_middleware.py
  9. 24
      tests/django/utils.py

82
fastapi/django/__init__.py

@ -0,0 +1,82 @@
from contextvars import ContextVar
from typing import Annotated
from django.core.asgi import ASGIHandler
from django.http import HttpRequest, HttpResponse, StreamingHttpResponse
from fastapi import Depends, Request, Response
from fastapi.responses import StreamingResponse
from starlette.middleware.base import BaseHTTPMiddleware
_django_request = ContextVar[HttpRequest | None]("fastapi_django_request", default=None)
def get_django_request():
django_request = _django_request.get()
if not django_request:
raise ValueError(
"Django Request not found, did you forget to add the Django Middleware?"
)
return django_request
DjangoRequestDep = Annotated[HttpRequest, Depends(get_django_request)]
class DjangoMiddleware(BaseHTTPMiddleware, ASGIHandler):
"""A FastAPI Middleware that runs the Django HTTP Request lifecycle.
This middleware is responsible for running the Django HTTP Request lifecycle
in the FastAPI application. It is useful when you want to use Django's
authentication system, or any other Django feature that requires the
Django Request object to be available."""
def __init__(self, *args, **kwargs):
ASGIHandler.__init__(self)
super().__init__(*args, **kwargs)
async def _get_response_async(self, request):
fastapi_response = await self._call_next(request)
assert isinstance(fastapi_response, StreamingResponse)
return StreamingHttpResponse(
streaming_content=fastapi_response.body_iterator,
headers=fastapi_response.headers,
status=fastapi_response.status_code,
)
async def __call__(self, scope, receive, send):
self._django_request, _ = self.create_request(scope, "")
_django_request.set(self._django_request)
await BaseHTTPMiddleware.__call__(self, scope, receive, send)
async def dispatch(self, request: Request, call_next):
self._call_next = call_next
django_response = await self.get_response_async(self._django_request)
if isinstance(django_response, HttpResponse):
return Response(
status_code=django_response.status_code,
content=django_response.content,
headers=django_response.headers,
)
if isinstance(django_response, StreamingHttpResponse):
async def streaming():
async for chunk in django_response.streaming_content:
yield chunk
return StreamingResponse(
status_code=django_response.status_code,
content=streaming(),
headers=django_response.headers,
)
return Response(status_code=500)

3
requirements-tests.txt

@ -18,3 +18,6 @@ passlib[bcrypt] >=1.7.2,<2.0.0
# types
types-ujson ==5.7.0.1
types-orjson ==3.6.2
# django
django ==5.0.6

0
tests/django/__init__.py

55
tests/django/conftest.py

@ -0,0 +1,55 @@
import django
import pytest
from django.conf import settings
from django.core.management.color import no_style
from django.core.management.sql import sql_flush
from django.db import connection
settings.configure(
SECRET_KEY="not_very",
ROOT_URLCONF="tests.django.proj.urls",
INSTALLED_APPS=[
"django.contrib.auth",
"django.contrib.contenttypes",
"django.contrib.sessions",
],
MIDDLEWARE=[
"django.contrib.sessions.middleware.SessionMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
],
DATABASES={
"default": {
"ENGINE": "django.db.backends.sqlite3",
"NAME": ":memory:",
}
},
)
django.setup()
@pytest.fixture(scope="session", autouse=True)
def django_db_setup():
connection.creation.create_test_db(verbosity=0, autoclobber=True)
yield
connection.creation.destroy_test_db("default", verbosity=0)
@pytest.fixture(autouse=True)
def flush_db():
sql_list = sql_flush(no_style(), connection, allow_cascade=False)
connection.ops.execute_sql_flush(sql_list)
@pytest.fixture
def authenticated_session_id():
from django.contrib.auth.models import User
from tests.django.utils import create_authenticated_session
user = User.objects.create_user(username="test", password="test")
return create_authenticated_session(user)

0
tests/django/proj/__init__.py

1
tests/django/proj/urls.py

@ -0,0 +1 @@
urlpatterns = []

53
tests/django/test_django_request.py

@ -0,0 +1,53 @@
from django.contrib.auth import (
aget_user,
)
from fastapi import FastAPI
from fastapi.django import DjangoMiddleware, DjangoRequestDep
from fastapi.testclient import TestClient
app = FastAPI()
app.add_middleware(DjangoMiddleware)
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.get("/current-user")
async def django_user(django_request: DjangoRequestDep):
user = await aget_user(django_request)
if not user.is_authenticated:
return {"error": "User not authenticated"}
return {"username": user.username}
client = TestClient(app)
def test_unauthenticated():
response = client.get("/current-user")
assert response.status_code == 200
assert response.json() == {"error": "User not authenticated"}
def test_authenticated(authenticated_session_id: str):
client.cookies.set("sessionid", authenticated_session_id)
response = client.get("/current-user")
assert response.status_code == 200
assert response.json() == {"username": "test"}
def test_route_with_no_django_request():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Hello World"}

24
tests/django/test_django_request_no_middleware.py

@ -0,0 +1,24 @@
import pytest
from fastapi import FastAPI
from fastapi.django import DjangoRequestDep
from fastapi.testclient import TestClient
app = FastAPI()
@app.get("/")
async def django_user(django_request: DjangoRequestDep):
user = django_request.user
if not user.is_authenticated:
return {"error": "User not authenticated"}
return {"username": user.username}
client = TestClient(app)
def test_returns_an_error():
with pytest.raises(ValueError, match="Django Request not found"):
client.get("/")

24
tests/django/utils.py

@ -0,0 +1,24 @@
from importlib import import_module
from django.conf import settings
from django.contrib.auth import BACKEND_SESSION_KEY, HASH_SESSION_KEY, SESSION_KEY
def create_authenticated_session(user):
"""Creates an authenticated session for the given user."""
engine = import_module(settings.SESSION_ENGINE)
session = engine.SessionStore()
session.create()
session[SESSION_KEY] = str(user.id)
session[BACKEND_SESSION_KEY] = (
user.backend
if hasattr(user, "backend")
else "django.contrib.auth.backends.ModelBackend"
)
session[HASH_SESSION_KEY] = user.get_session_auth_hash()
session.save()
return session.session_key
Loading…
Cancel
Save