From c653eeff1eb978acede84e9ccf4fd8b33813f69c Mon Sep 17 00:00:00 2001 From: Yurii Motov Date: Mon, 9 Jun 2025 12:45:33 +0200 Subject: [PATCH] Fix `APIKey**` security schemes status code on "Not authenticated" error --- fastapi/security/api_key.py | 121 +++++++++++++++++- tests/test_security_api_key_cookie.py | 3 +- ...est_security_api_key_cookie_description.py | 3 +- tests/test_security_api_key_header.py | 3 +- ...est_security_api_key_header_description.py | 3 +- tests/test_security_api_key_query.py | 3 +- ...test_security_api_key_query_description.py | 3 +- tests/test_security_status_code_403_option.py | 56 ++++++++ 8 files changed, 183 insertions(+), 12 deletions(-) create mode 100644 tests/test_security_status_code_403_option.py diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py index d0616c73f..b1388f92c 100644 --- a/fastapi/security/api_key.py +++ b/fastapi/security/api_key.py @@ -1,11 +1,11 @@ -from typing import Optional, Union +from typing import Literal, Optional, Union from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.security.base import SecurityBase from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.status import HTTP_403_FORBIDDEN -from typing_extensions import Annotated, Doc +from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN +from typing_extensions import Annotated, Doc, deprecated class APIKeyBase(SecurityBase): @@ -16,10 +16,13 @@ class APIKeyBase(SecurityBase): description: Union[str, None], scheme_name: Union[str, None], auto_error: bool, + not_authenticated_status_code: Literal[401, 403], ): self.parameter_location = location.value self.parameter_name = name self.auto_error = auto_error + self.not_authenticated_status_code = not_authenticated_status_code + self.model: APIKey = APIKey( **{"in": location}, # type: ignore[arg-type] name=name, @@ -27,12 +30,31 @@ class APIKeyBase(SecurityBase): ) self.scheme_name = scheme_name or self.__class__.__name__ + def format_www_authenticate_header_value(self) -> str: + """ + The WWW-Authenticate header is not standardized for API Key authentication. + It's considered good practice to include information about the authentication + challange. + This method follows one of the common templates. + If a different format is required, override this method in a subclass. + """ + + return f'ApiKey in="{self.parameter_location}", name="{self.parameter_name}"' + def check_api_key(self, api_key: Optional[str]) -> Optional[str]: if not api_key: if self.auto_error: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.not_authenticated_status_code == HTTP_403_FORBIDDEN: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: # By default use 401 + www_authenticate = self.format_www_authenticate_header_value() + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": www_authenticate}, + ) return None return api_key @@ -115,6 +137,34 @@ class APIKeyQuery(APIKeyBase): """ ), ] = True, + not_authenticated_status_code: Annotated[ + Literal[401, 403], + Doc( + """ + By default, if the query parameter is not provided and `auto_error` is + set to `True`, `APIKeyQuery` will automatically raise an + `HTTPException` with the status code `401`. + + If your client relies on the old (incorrect) behavior and expects the + status code to be `403`, you can set `not_authenticated_status_code` to + `403` to achieve it. + + Keep in mind that this parameter is temporary and will be removed in + the near future. + Consider updating your clients to align with the new behavior. + """ + ), + deprecated( + """ + This parameter is temporary. It was introduced to give users time + to upgrade their clients to follow the new behavior and will eventually + be removed. + + Use it as a short-term workaround, but consider updating your clients + to align with the new behavior. + """ + ), + ] = 401, ): super().__init__( location=APIKeyIn.query, @@ -122,6 +172,7 @@ class APIKeyQuery(APIKeyBase): scheme_name=scheme_name, description=description, auto_error=auto_error, + not_authenticated_status_code=not_authenticated_status_code, ) async def __call__(self, request: Request) -> Optional[str]: @@ -203,6 +254,34 @@ class APIKeyHeader(APIKeyBase): """ ), ] = True, + not_authenticated_status_code: Annotated[ + Literal[401, 403], + Doc( + """ + By default, if the header is not provided and `auto_error` is + set to `True`, `APIKeyHeader` will automatically raise an + `HTTPException` with the status code `401`. + + If your client relies on the old (incorrect) behavior and expects the + status code to be `403`, you can set `not_authenticated_status_code` to + `403` to achieve it. + + Keep in mind that this parameter is temporary and will be removed in + the near future. + Consider updating your clients to align with the new behavior. + """ + ), + deprecated( + """ + This parameter is temporary. It was introduced to give users time + to upgrade their clients to follow the new behavior and will eventually + be removed. + + Use it as a short-term workaround, but consider updating your clients + to align with the new behavior. + """ + ), + ] = 401, ): super().__init__( location=APIKeyIn.header, @@ -210,6 +289,7 @@ class APIKeyHeader(APIKeyBase): scheme_name=scheme_name, description=description, auto_error=auto_error, + not_authenticated_status_code=not_authenticated_status_code, ) async def __call__(self, request: Request) -> Optional[str]: @@ -291,6 +371,34 @@ class APIKeyCookie(APIKeyBase): """ ), ] = True, + not_authenticated_status_code: Annotated[ + Literal[401, 403], + Doc( + """ + By default, if the cookie is not provided and `auto_error` is + set to `True`, `APIKeyCookie` will automatically raise an + `HTTPException` with the status code `401`. + + If your client relies on the old (incorrect) behavior and expects the + status code to be `403`, you can set `not_authenticated_status_code` to + `403` to achieve it. + + Keep in mind that this parameter is temporary and will be removed in + the near future. + Consider updating your clients to align with the new behavior. + """ + ), + deprecated( + """ + This parameter is temporary. It was introduced to give users time + to upgrade their clients to follow the new behavior and will eventually + be removed. + + Use it as a short-term workaround, but consider updating your clients + to align with the new behavior. + """ + ), + ] = 401, ): super().__init__( location=APIKeyIn.cookie, @@ -298,6 +406,7 @@ class APIKeyCookie(APIKeyBase): scheme_name=scheme_name, description=description, auto_error=auto_error, + not_authenticated_status_code=not_authenticated_status_code, ) async def __call__(self, request: Request) -> Optional[str]: diff --git a/tests/test_security_api_key_cookie.py b/tests/test_security_api_key_cookie.py index 4ddb8e2ee..488503817 100644 --- a/tests/test_security_api_key_cookie.py +++ b/tests/test_security_api_key_cookie.py @@ -32,8 +32,9 @@ def test_security_api_key(): def test_security_api_key_no_key(): client = TestClient(app) response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == 'ApiKey in="cookie", name="key"' def test_openapi_schema(): diff --git a/tests/test_security_api_key_cookie_description.py b/tests/test_security_api_key_cookie_description.py index d99d616e0..e0e448471 100644 --- a/tests/test_security_api_key_cookie_description.py +++ b/tests/test_security_api_key_cookie_description.py @@ -32,8 +32,9 @@ def test_security_api_key(): def test_security_api_key_no_key(): client = TestClient(app) response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == 'ApiKey in="cookie", name="key"' def test_openapi_schema(): diff --git a/tests/test_security_api_key_header.py b/tests/test_security_api_key_header.py index 1ff883703..b72d258c4 100644 --- a/tests/test_security_api_key_header.py +++ b/tests/test_security_api_key_header.py @@ -33,8 +33,9 @@ def test_security_api_key(): def test_security_api_key_no_key(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == 'ApiKey in="header", name="key"' def test_openapi_schema(): diff --git a/tests/test_security_api_key_header_description.py b/tests/test_security_api_key_header_description.py index 27f9d0f29..70b85af03 100644 --- a/tests/test_security_api_key_header_description.py +++ b/tests/test_security_api_key_header_description.py @@ -33,8 +33,9 @@ def test_security_api_key(): def test_security_api_key_no_key(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == 'ApiKey in="header", name="key"' def test_openapi_schema(): diff --git a/tests/test_security_api_key_query.py b/tests/test_security_api_key_query.py index dc7a0a621..7a01101c4 100644 --- a/tests/test_security_api_key_query.py +++ b/tests/test_security_api_key_query.py @@ -33,8 +33,9 @@ def test_security_api_key(): def test_security_api_key_no_key(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == 'ApiKey in="query", name="key"' def test_openapi_schema(): diff --git a/tests/test_security_api_key_query_description.py b/tests/test_security_api_key_query_description.py index 35dc7743a..45102eb17 100644 --- a/tests/test_security_api_key_query_description.py +++ b/tests/test_security_api_key_query_description.py @@ -33,8 +33,9 @@ def test_security_api_key(): def test_security_api_key_no_key(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == 'ApiKey in="query", name="key"' def test_openapi_schema(): diff --git a/tests/test_security_status_code_403_option.py b/tests/test_security_status_code_403_option.py new file mode 100644 index 000000000..ccc4d216a --- /dev/null +++ b/tests/test_security_status_code_403_option.py @@ -0,0 +1,56 @@ +import pytest +from fastapi import FastAPI, Security +from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyHeader, APIKeyQuery +from fastapi.testclient import TestClient + + +@pytest.mark.parametrize( + "auth", + [ + APIKeyQuery(name="key", not_authenticated_status_code=403), + APIKeyHeader(name="key", not_authenticated_status_code=403), + APIKeyCookie(name="key", not_authenticated_status_code=403), + ], +) +def test_apikey_status_code_403_on_auth_error(auth: APIKeyBase): + """ + Test temporary `not_authenticated_status_code` parameter for APIKey** classes. + """ + + app = FastAPI() + + @app.get("/") + async def protected(_: str = Security(auth)): + pass # pragma: no cover + + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 403 + assert response.json() == {"detail": "Not authenticated"} + + +@pytest.mark.parametrize( + "auth", + [ + APIKeyQuery(name="key", not_authenticated_status_code=403, auto_error=False), + APIKeyHeader(name="key", not_authenticated_status_code=403, auto_error=False), + APIKeyCookie(name="key", not_authenticated_status_code=403, auto_error=False), + ], +) +def test_apikey_status_code_403_on_auth_error_no_auto_error(auth: APIKeyBase): + """ + Test temporary `not_authenticated_status_code` parameter for APIKey** classes with + `auto_error=False`. + """ + + app = FastAPI() + + @app.get("/") + async def protected(_: str = Security(auth)): + pass # pragma: no cover + + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 200