Browse Source

Read cookie from `starlette.requests.HTTPConnection` class

pull/13758/head
Harry Lees 1 month ago
parent
commit
560016f85f
Failed to extract signature
  1. 8
      fastapi/security/api_key.py
  2. 57
      tests/test_security_api_key_cookie_websocket.py

8
fastapi/security/api_key.py

@ -3,7 +3,7 @@ from typing import Optional
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc from typing_extensions import Annotated, Doc
@ -107,7 +107,7 @@ class APIKeyQuery(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.query_params.get(self.model.name) api_key = request.query_params.get(self.model.name)
return self.check_api_key(api_key, self.auto_error) return self.check_api_key(api_key, self.auto_error)
@ -195,7 +195,7 @@ class APIKeyHeader(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.headers.get(self.model.name) api_key = request.headers.get(self.model.name)
return self.check_api_key(api_key, self.auto_error) return self.check_api_key(api_key, self.auto_error)
@ -283,6 +283,6 @@ class APIKeyCookie(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.cookies.get(self.model.name) api_key = request.cookies.get(self.model.name)
return self.check_api_key(api_key, self.auto_error) return self.check_api_key(api_key, self.auto_error)

57
tests/test_security_api_key_cookie_websocket.py

@ -0,0 +1,57 @@
import pytest
from fastapi import Depends, FastAPI, Security, WebSocket
from fastapi.security import APIKeyCookie
from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.testclient import WebSocketDenialResponse
app = FastAPI()
api_key = APIKeyCookie(name="key")
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(api_key)):
return User(username=oauth_header)
@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
current_user: User = Depends(get_current_user),
):
await websocket.accept()
data = await websocket.receive_text()
await websocket.send_text(f"{data}:{current_user.username}")
def test_security_api_key():
client = TestClient(app, cookies={"key": "secret"})
with client.websocket_connect("/ws") as websocket:
message = "test"
websocket.send_text(message)
data = websocket.receive_text()
assert data == f"{message}:{client.cookies['key']}"
def test_security_api_key_no_key():
client = TestClient(app)
with pytest.raises(WebSocketDenialResponse) as exc:
with client.websocket_connect("/ws"):
pass
assert exc.value.status_code == 403, exc.value.text
def test_openapi_schema():
client = TestClient(app)
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == {
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {},
}
Loading…
Cancel
Save