Browse Source

Read cookie from `starlette.requests.HTTPConnection` class

pull/13758/head
Harry Lees 1 month ago
parent
commit
560016f85f
Failed to extract signature
  1. 8
      fastapi/security/api_key.py
  2. 57
      tests/test_security_api_key_cookie_websocket.py

8
fastapi/security/api_key.py

@ -3,7 +3,7 @@ from typing import Optional
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc
@ -107,7 +107,7 @@ class APIKeyQuery(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.query_params.get(self.model.name)
return self.check_api_key(api_key, self.auto_error)
@ -195,7 +195,7 @@ class APIKeyHeader(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.headers.get(self.model.name)
return self.check_api_key(api_key, self.auto_error)
@ -283,6 +283,6 @@ class APIKeyCookie(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.cookies.get(self.model.name)
return self.check_api_key(api_key, self.auto_error)

57
tests/test_security_api_key_cookie_websocket.py

@ -0,0 +1,57 @@
import pytest
from fastapi import Depends, FastAPI, Security, WebSocket
from fastapi.security import APIKeyCookie
from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.testclient import WebSocketDenialResponse
app = FastAPI()
api_key = APIKeyCookie(name="key")
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(api_key)):
return User(username=oauth_header)
@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
current_user: User = Depends(get_current_user),
):
await websocket.accept()
data = await websocket.receive_text()
await websocket.send_text(f"{data}:{current_user.username}")
def test_security_api_key():
client = TestClient(app, cookies={"key": "secret"})
with client.websocket_connect("/ws") as websocket:
message = "test"
websocket.send_text(message)
data = websocket.receive_text()
assert data == f"{message}:{client.cookies['key']}"
def test_security_api_key_no_key():
client = TestClient(app)
with pytest.raises(WebSocketDenialResponse) as exc:
with client.websocket_connect("/ws"):
pass
assert exc.value.status_code == 403, exc.value.text
def test_openapi_schema():
client = TestClient(app)
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == {
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {},
}
Loading…
Cancel
Save