diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py index 70c2dca8a..bed9d96c3 100644 --- a/fastapi/security/api_key.py +++ b/fastapi/security/api_key.py @@ -3,7 +3,7 @@ from typing import Optional from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.security.base import SecurityBase from starlette.exceptions import HTTPException -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.status import HTTP_403_FORBIDDEN from typing_extensions import Annotated, Doc @@ -107,7 +107,7 @@ class APIKeyQuery(APIKeyBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: api_key = request.query_params.get(self.model.name) return self.check_api_key(api_key, self.auto_error) @@ -195,7 +195,7 @@ class APIKeyHeader(APIKeyBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: api_key = request.headers.get(self.model.name) return self.check_api_key(api_key, self.auto_error) @@ -283,6 +283,6 @@ class APIKeyCookie(APIKeyBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: api_key = request.cookies.get(self.model.name) return self.check_api_key(api_key, self.auto_error) diff --git a/tests/test_security_api_key_cookie_websocket.py b/tests/test_security_api_key_cookie_websocket.py new file mode 100644 index 000000000..cce404e66 --- /dev/null +++ b/tests/test_security_api_key_cookie_websocket.py @@ -0,0 +1,57 @@ +import pytest +from fastapi import Depends, FastAPI, Security, WebSocket +from fastapi.security import APIKeyCookie +from fastapi.testclient import TestClient +from pydantic import BaseModel +from starlette.testclient import WebSocketDenialResponse + +app = FastAPI() + +api_key = APIKeyCookie(name="key") + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: str = Security(api_key)): + return User(username=oauth_header) + + +@app.websocket("/ws") +async def websocket_endpoint( + websocket: WebSocket, + current_user: User = Depends(get_current_user), +): + await websocket.accept() + data = await websocket.receive_text() + await websocket.send_text(f"{data}:{current_user.username}") + + +def test_security_api_key(): + client = TestClient(app, cookies={"key": "secret"}) + with client.websocket_connect("/ws") as websocket: + message = "test" + websocket.send_text(message) + data = websocket.receive_text() + assert data == f"{message}:{client.cookies['key']}" + + +def test_security_api_key_no_key(): + client = TestClient(app) + + with pytest.raises(WebSocketDenialResponse) as exc: + with client.websocket_connect("/ws"): + pass + assert exc.value.status_code == 403, exc.value.text + + +def test_openapi_schema(): + client = TestClient(app) + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": {}, + }