committed by
GitHub
9 changed files with 240 additions and 0 deletions
@ -0,0 +1,82 @@ |
|||
from contextvars import ContextVar |
|||
from typing import Annotated |
|||
|
|||
from django.core.asgi import ASGIHandler |
|||
from django.http import HttpRequest, HttpResponse, StreamingHttpResponse |
|||
from fastapi import Depends, Request, Response |
|||
from fastapi.responses import StreamingResponse |
|||
from starlette.middleware.base import BaseHTTPMiddleware |
|||
|
|||
_django_request = ContextVar[HttpRequest | None]("fastapi_django_request", default=None) |
|||
|
|||
|
|||
def get_django_request(): |
|||
django_request = _django_request.get() |
|||
|
|||
if not django_request: |
|||
raise ValueError( |
|||
"Django Request not found, did you forget to add the Django Middleware?" |
|||
) |
|||
|
|||
return django_request |
|||
|
|||
|
|||
DjangoRequestDep = Annotated[HttpRequest, Depends(get_django_request)] |
|||
|
|||
|
|||
class DjangoMiddleware(BaseHTTPMiddleware, ASGIHandler): |
|||
"""A FastAPI Middleware that runs the Django HTTP Request lifecycle. |
|||
|
|||
This middleware is responsible for running the Django HTTP Request lifecycle |
|||
in the FastAPI application. It is useful when you want to use Django's |
|||
authentication system, or any other Django feature that requires the |
|||
Django Request object to be available.""" |
|||
|
|||
def __init__(self, *args, **kwargs): |
|||
ASGIHandler.__init__(self) |
|||
|
|||
super().__init__(*args, **kwargs) |
|||
|
|||
async def _get_response_async(self, request): |
|||
fastapi_response = await self._call_next(request) |
|||
|
|||
assert isinstance(fastapi_response, StreamingResponse) |
|||
|
|||
return StreamingHttpResponse( |
|||
streaming_content=fastapi_response.body_iterator, |
|||
headers=fastapi_response.headers, |
|||
status=fastapi_response.status_code, |
|||
) |
|||
|
|||
async def __call__(self, scope, receive, send): |
|||
self._django_request, _ = self.create_request(scope, "") |
|||
|
|||
_django_request.set(self._django_request) |
|||
|
|||
await BaseHTTPMiddleware.__call__(self, scope, receive, send) |
|||
|
|||
async def dispatch(self, request: Request, call_next): |
|||
self._call_next = call_next |
|||
|
|||
django_response = await self.get_response_async(self._django_request) |
|||
|
|||
if isinstance(django_response, HttpResponse): |
|||
return Response( |
|||
status_code=django_response.status_code, |
|||
content=django_response.content, |
|||
headers=django_response.headers, |
|||
) |
|||
|
|||
if isinstance(django_response, StreamingHttpResponse): |
|||
|
|||
async def streaming(): |
|||
async for chunk in django_response.streaming_content: |
|||
yield chunk |
|||
|
|||
return StreamingResponse( |
|||
status_code=django_response.status_code, |
|||
content=streaming(), |
|||
headers=django_response.headers, |
|||
) |
|||
|
|||
return Response(status_code=500) |
@ -0,0 +1,55 @@ |
|||
import django |
|||
import pytest |
|||
from django.conf import settings |
|||
from django.core.management.color import no_style |
|||
from django.core.management.sql import sql_flush |
|||
from django.db import connection |
|||
|
|||
settings.configure( |
|||
SECRET_KEY="not_very", |
|||
ROOT_URLCONF="tests.django.proj.urls", |
|||
INSTALLED_APPS=[ |
|||
"django.contrib.auth", |
|||
"django.contrib.contenttypes", |
|||
"django.contrib.sessions", |
|||
], |
|||
MIDDLEWARE=[ |
|||
"django.contrib.sessions.middleware.SessionMiddleware", |
|||
"django.contrib.auth.middleware.AuthenticationMiddleware", |
|||
], |
|||
DATABASES={ |
|||
"default": { |
|||
"ENGINE": "django.db.backends.sqlite3", |
|||
"NAME": ":memory:", |
|||
} |
|||
}, |
|||
) |
|||
|
|||
django.setup() |
|||
|
|||
|
|||
@pytest.fixture(scope="session", autouse=True) |
|||
def django_db_setup(): |
|||
connection.creation.create_test_db(verbosity=0, autoclobber=True) |
|||
|
|||
yield |
|||
|
|||
connection.creation.destroy_test_db("default", verbosity=0) |
|||
|
|||
|
|||
@pytest.fixture(autouse=True) |
|||
def flush_db(): |
|||
sql_list = sql_flush(no_style(), connection, allow_cascade=False) |
|||
|
|||
connection.ops.execute_sql_flush(sql_list) |
|||
|
|||
|
|||
@pytest.fixture |
|||
def authenticated_session_id(): |
|||
from django.contrib.auth.models import User |
|||
|
|||
from tests.django.utils import create_authenticated_session |
|||
|
|||
user = User.objects.create_user(username="test", password="test") |
|||
|
|||
return create_authenticated_session(user) |
@ -0,0 +1 @@ |
|||
urlpatterns = [] |
@ -0,0 +1,51 @@ |
|||
from django.contrib.auth import aget_user |
|||
from fastapi import FastAPI |
|||
from fastapi.django import DjangoMiddleware, DjangoRequestDep |
|||
from fastapi.testclient import TestClient |
|||
|
|||
app = FastAPI() |
|||
app.add_middleware(DjangoMiddleware) |
|||
|
|||
|
|||
@app.get("/") |
|||
async def root(): |
|||
return {"message": "Hello World"} |
|||
|
|||
|
|||
@app.get("/current-user") |
|||
async def django_user(django_request: DjangoRequestDep): |
|||
user = await aget_user(django_request) |
|||
|
|||
if not user.is_authenticated: |
|||
return {"error": "User not authenticated"} |
|||
|
|||
return {"username": user.username} |
|||
|
|||
|
|||
client = TestClient(app) |
|||
|
|||
|
|||
def test_unauthenticated(): |
|||
response = client.get("/current-user") |
|||
|
|||
assert response.status_code == 200 |
|||
|
|||
assert response.json() == {"error": "User not authenticated"} |
|||
|
|||
|
|||
def test_authenticated(authenticated_session_id: str): |
|||
client.cookies.set("sessionid", authenticated_session_id) |
|||
|
|||
response = client.get("/current-user") |
|||
|
|||
assert response.status_code == 200 |
|||
|
|||
assert response.json() == {"username": "test"} |
|||
|
|||
|
|||
def test_route_with_no_django_request(): |
|||
response = client.get("/") |
|||
|
|||
assert response.status_code == 200 |
|||
|
|||
assert response.json() == {"message": "Hello World"} |
@ -0,0 +1,24 @@ |
|||
import pytest |
|||
from fastapi import FastAPI |
|||
from fastapi.django import DjangoRequestDep |
|||
from fastapi.testclient import TestClient |
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
@app.get("/") |
|||
async def django_user(django_request: DjangoRequestDep): |
|||
user = django_request.user |
|||
|
|||
if not user.is_authenticated: |
|||
return {"error": "User not authenticated"} |
|||
|
|||
return {"username": user.username} |
|||
|
|||
|
|||
client = TestClient(app) |
|||
|
|||
|
|||
def test_returns_an_error(): |
|||
with pytest.raises(ValueError, match="Django Request not found"): |
|||
client.get("/") |
@ -0,0 +1,24 @@ |
|||
from importlib import import_module |
|||
|
|||
from django.conf import settings |
|||
from django.contrib.auth import BACKEND_SESSION_KEY, HASH_SESSION_KEY, SESSION_KEY |
|||
|
|||
|
|||
def create_authenticated_session(user): |
|||
"""Creates an authenticated session for the given user.""" |
|||
|
|||
engine = import_module(settings.SESSION_ENGINE) |
|||
session = engine.SessionStore() |
|||
session.create() |
|||
|
|||
session[SESSION_KEY] = str(user.id) |
|||
session[BACKEND_SESSION_KEY] = ( |
|||
user.backend |
|||
if hasattr(user, "backend") |
|||
else "django.contrib.auth.backends.ModelBackend" |
|||
) |
|||
session[HASH_SESSION_KEY] = user.get_session_auth_hash() |
|||
|
|||
session.save() |
|||
|
|||
return session.session_key |
Loading…
Reference in new issue