9 changed files with 242 additions and 0 deletions
@ -0,0 +1,82 @@ |
|||||
|
from contextvars import ContextVar |
||||
|
from typing import Annotated |
||||
|
|
||||
|
from django.core.asgi import ASGIHandler |
||||
|
from django.http import HttpRequest, HttpResponse, StreamingHttpResponse |
||||
|
from fastapi import Depends, Request, Response |
||||
|
from fastapi.responses import StreamingResponse |
||||
|
from starlette.middleware.base import BaseHTTPMiddleware |
||||
|
|
||||
|
_django_request = ContextVar[HttpRequest | None]("fastapi_django_request", default=None) |
||||
|
|
||||
|
|
||||
|
def get_django_request(): |
||||
|
django_request = _django_request.get() |
||||
|
|
||||
|
if not django_request: |
||||
|
raise ValueError( |
||||
|
"Django Request not found, did you forget to add the Django Middleware?" |
||||
|
) |
||||
|
|
||||
|
return django_request |
||||
|
|
||||
|
|
||||
|
DjangoRequestDep = Annotated[HttpRequest, Depends(get_django_request)] |
||||
|
|
||||
|
|
||||
|
class DjangoMiddleware(BaseHTTPMiddleware, ASGIHandler): |
||||
|
"""A FastAPI Middleware that runs the Django HTTP Request lifecycle. |
||||
|
|
||||
|
This middleware is responsible for running the Django HTTP Request lifecycle |
||||
|
in the FastAPI application. It is useful when you want to use Django's |
||||
|
authentication system, or any other Django feature that requires the |
||||
|
Django Request object to be available.""" |
||||
|
|
||||
|
def __init__(self, *args, **kwargs): |
||||
|
ASGIHandler.__init__(self) |
||||
|
|
||||
|
super().__init__(*args, **kwargs) |
||||
|
|
||||
|
async def _get_response_async(self, request): |
||||
|
fastapi_response = await self._call_next(request) |
||||
|
|
||||
|
assert isinstance(fastapi_response, StreamingResponse) |
||||
|
|
||||
|
return StreamingHttpResponse( |
||||
|
streaming_content=fastapi_response.body_iterator, |
||||
|
headers=fastapi_response.headers, |
||||
|
status=fastapi_response.status_code, |
||||
|
) |
||||
|
|
||||
|
async def __call__(self, scope, receive, send): |
||||
|
self._django_request, _ = self.create_request(scope, "") |
||||
|
|
||||
|
_django_request.set(self._django_request) |
||||
|
|
||||
|
await BaseHTTPMiddleware.__call__(self, scope, receive, send) |
||||
|
|
||||
|
async def dispatch(self, request: Request, call_next): |
||||
|
self._call_next = call_next |
||||
|
|
||||
|
django_response = await self.get_response_async(self._django_request) |
||||
|
|
||||
|
if isinstance(django_response, HttpResponse): |
||||
|
return Response( |
||||
|
status_code=django_response.status_code, |
||||
|
content=django_response.content, |
||||
|
headers=django_response.headers, |
||||
|
) |
||||
|
|
||||
|
if isinstance(django_response, StreamingHttpResponse): |
||||
|
|
||||
|
async def streaming(): |
||||
|
async for chunk in django_response.streaming_content: |
||||
|
yield chunk |
||||
|
|
||||
|
return StreamingResponse( |
||||
|
status_code=django_response.status_code, |
||||
|
content=streaming(), |
||||
|
headers=django_response.headers, |
||||
|
) |
||||
|
|
||||
|
return Response(status_code=500) |
@ -0,0 +1,55 @@ |
|||||
|
import django |
||||
|
import pytest |
||||
|
from django.conf import settings |
||||
|
from django.core.management.color import no_style |
||||
|
from django.core.management.sql import sql_flush |
||||
|
from django.db import connection |
||||
|
|
||||
|
settings.configure( |
||||
|
SECRET_KEY="not_very", |
||||
|
ROOT_URLCONF="tests.django.proj.urls", |
||||
|
INSTALLED_APPS=[ |
||||
|
"django.contrib.auth", |
||||
|
"django.contrib.contenttypes", |
||||
|
"django.contrib.sessions", |
||||
|
], |
||||
|
MIDDLEWARE=[ |
||||
|
"django.contrib.sessions.middleware.SessionMiddleware", |
||||
|
"django.contrib.auth.middleware.AuthenticationMiddleware", |
||||
|
], |
||||
|
DATABASES={ |
||||
|
"default": { |
||||
|
"ENGINE": "django.db.backends.sqlite3", |
||||
|
"NAME": ":memory:", |
||||
|
} |
||||
|
}, |
||||
|
) |
||||
|
|
||||
|
django.setup() |
||||
|
|
||||
|
|
||||
|
@pytest.fixture(scope="session", autouse=True) |
||||
|
def django_db_setup(): |
||||
|
connection.creation.create_test_db(verbosity=0, autoclobber=True) |
||||
|
|
||||
|
yield |
||||
|
|
||||
|
connection.creation.destroy_test_db("default", verbosity=0) |
||||
|
|
||||
|
|
||||
|
@pytest.fixture(autouse=True) |
||||
|
def flush_db(): |
||||
|
sql_list = sql_flush(no_style(), connection, allow_cascade=False) |
||||
|
|
||||
|
connection.ops.execute_sql_flush(sql_list) |
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def authenticated_session_id(): |
||||
|
from django.contrib.auth.models import User |
||||
|
|
||||
|
from tests.django.utils import create_authenticated_session |
||||
|
|
||||
|
user = User.objects.create_user(username="test", password="test") |
||||
|
|
||||
|
return create_authenticated_session(user) |
@ -0,0 +1 @@ |
|||||
|
urlpatterns = [] |
@ -0,0 +1,53 @@ |
|||||
|
from django.contrib.auth import ( |
||||
|
aget_user, |
||||
|
) |
||||
|
from fastapi import FastAPI |
||||
|
from fastapi.django import DjangoMiddleware, DjangoRequestDep |
||||
|
from fastapi.testclient import TestClient |
||||
|
|
||||
|
app = FastAPI() |
||||
|
app.add_middleware(DjangoMiddleware) |
||||
|
|
||||
|
|
||||
|
@app.get("/") |
||||
|
async def root(): |
||||
|
return {"message": "Hello World"} |
||||
|
|
||||
|
|
||||
|
@app.get("/current-user") |
||||
|
async def django_user(django_request: DjangoRequestDep): |
||||
|
user = await aget_user(django_request) |
||||
|
|
||||
|
if not user.is_authenticated: |
||||
|
return {"error": "User not authenticated"} |
||||
|
|
||||
|
return {"username": user.username} |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def test_unauthenticated(): |
||||
|
response = client.get("/current-user") |
||||
|
|
||||
|
assert response.status_code == 200 |
||||
|
|
||||
|
assert response.json() == {"error": "User not authenticated"} |
||||
|
|
||||
|
|
||||
|
def test_authenticated(authenticated_session_id: str): |
||||
|
client.cookies.set("sessionid", authenticated_session_id) |
||||
|
|
||||
|
response = client.get("/current-user") |
||||
|
|
||||
|
assert response.status_code == 200 |
||||
|
|
||||
|
assert response.json() == {"username": "test"} |
||||
|
|
||||
|
|
||||
|
def test_route_with_no_django_request(): |
||||
|
response = client.get("/") |
||||
|
|
||||
|
assert response.status_code == 200 |
||||
|
|
||||
|
assert response.json() == {"message": "Hello World"} |
@ -0,0 +1,24 @@ |
|||||
|
import pytest |
||||
|
from fastapi import FastAPI |
||||
|
from fastapi.django import DjangoRequestDep |
||||
|
from fastapi.testclient import TestClient |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
@app.get("/") |
||||
|
async def django_user(django_request: DjangoRequestDep): |
||||
|
user = django_request.user |
||||
|
|
||||
|
if not user.is_authenticated: |
||||
|
return {"error": "User not authenticated"} |
||||
|
|
||||
|
return {"username": user.username} |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def test_returns_an_error(): |
||||
|
with pytest.raises(ValueError, match="Django Request not found"): |
||||
|
client.get("/") |
@ -0,0 +1,24 @@ |
|||||
|
from importlib import import_module |
||||
|
|
||||
|
from django.conf import settings |
||||
|
from django.contrib.auth import BACKEND_SESSION_KEY, HASH_SESSION_KEY, SESSION_KEY |
||||
|
|
||||
|
|
||||
|
def create_authenticated_session(user): |
||||
|
"""Creates an authenticated session for the given user.""" |
||||
|
|
||||
|
engine = import_module(settings.SESSION_ENGINE) |
||||
|
session = engine.SessionStore() |
||||
|
session.create() |
||||
|
|
||||
|
session[SESSION_KEY] = str(user.id) |
||||
|
session[BACKEND_SESSION_KEY] = ( |
||||
|
user.backend |
||||
|
if hasattr(user, "backend") |
||||
|
else "django.contrib.auth.backends.ModelBackend" |
||||
|
) |
||||
|
session[HASH_SESSION_KEY] = user.get_session_auth_hash() |
||||
|
|
||||
|
session.save() |
||||
|
|
||||
|
return session.session_key |
Loading…
Reference in new issue