Browse Source

Fix `APIKey**` security schemes status code on "Not authenticated" error

pull/13786/head
Yurii Motov 1 month ago
parent
commit
c653eeff1e
  1. 115
      fastapi/security/api_key.py
  2. 3
      tests/test_security_api_key_cookie.py
  3. 3
      tests/test_security_api_key_cookie_description.py
  4. 3
      tests/test_security_api_key_header.py
  5. 3
      tests/test_security_api_key_header_description.py
  6. 3
      tests/test_security_api_key_query.py
  7. 3
      tests/test_security_api_key_query_description.py
  8. 56
      tests/test_security_status_code_403_option.py

115
fastapi/security/api_key.py

@ -1,11 +1,11 @@
from typing import Optional, Union from typing import Literal, Optional, Union
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc from typing_extensions import Annotated, Doc, deprecated
class APIKeyBase(SecurityBase): class APIKeyBase(SecurityBase):
@ -16,10 +16,13 @@ class APIKeyBase(SecurityBase):
description: Union[str, None], description: Union[str, None],
scheme_name: Union[str, None], scheme_name: Union[str, None],
auto_error: bool, auto_error: bool,
not_authenticated_status_code: Literal[401, 403],
): ):
self.parameter_location = location.value self.parameter_location = location.value
self.parameter_name = name self.parameter_name = name
self.auto_error = auto_error self.auto_error = auto_error
self.not_authenticated_status_code = not_authenticated_status_code
self.model: APIKey = APIKey( self.model: APIKey = APIKey(
**{"in": location}, # type: ignore[arg-type] **{"in": location}, # type: ignore[arg-type]
name=name, name=name,
@ -27,12 +30,31 @@ class APIKeyBase(SecurityBase):
) )
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
def format_www_authenticate_header_value(self) -> str:
"""
The WWW-Authenticate header is not standardized for API Key authentication.
It's considered good practice to include information about the authentication
challange.
This method follows one of the common templates.
If a different format is required, override this method in a subclass.
"""
return f'ApiKey in="{self.parameter_location}", name="{self.parameter_name}"'
def check_api_key(self, api_key: Optional[str]) -> Optional[str]: def check_api_key(self, api_key: Optional[str]) -> Optional[str]:
if not api_key: if not api_key:
if self.auto_error: if self.auto_error:
if self.not_authenticated_status_code == HTTP_403_FORBIDDEN:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
) )
else: # By default use 401
www_authenticate = self.format_www_authenticate_header_value()
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": www_authenticate},
)
return None return None
return api_key return api_key
@ -115,6 +137,34 @@ class APIKeyQuery(APIKeyBase):
""" """
), ),
] = True, ] = True,
not_authenticated_status_code: Annotated[
Literal[401, 403],
Doc(
"""
By default, if the query parameter is not provided and `auto_error` is
set to `True`, `APIKeyQuery` will automatically raise an
`HTTPException` with the status code `401`.
If your client relies on the old (incorrect) behavior and expects the
status code to be `403`, you can set `not_authenticated_status_code` to
`403` to achieve it.
Keep in mind that this parameter is temporary and will be removed in
the near future.
Consider updating your clients to align with the new behavior.
"""
),
deprecated(
"""
This parameter is temporary. It was introduced to give users time
to upgrade their clients to follow the new behavior and will eventually
be removed.
Use it as a short-term workaround, but consider updating your clients
to align with the new behavior.
"""
),
] = 401,
): ):
super().__init__( super().__init__(
location=APIKeyIn.query, location=APIKeyIn.query,
@ -122,6 +172,7 @@ class APIKeyQuery(APIKeyBase):
scheme_name=scheme_name, scheme_name=scheme_name,
description=description, description=description,
auto_error=auto_error, auto_error=auto_error,
not_authenticated_status_code=not_authenticated_status_code,
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: Request) -> Optional[str]:
@ -203,6 +254,34 @@ class APIKeyHeader(APIKeyBase):
""" """
), ),
] = True, ] = True,
not_authenticated_status_code: Annotated[
Literal[401, 403],
Doc(
"""
By default, if the header is not provided and `auto_error` is
set to `True`, `APIKeyHeader` will automatically raise an
`HTTPException` with the status code `401`.
If your client relies on the old (incorrect) behavior and expects the
status code to be `403`, you can set `not_authenticated_status_code` to
`403` to achieve it.
Keep in mind that this parameter is temporary and will be removed in
the near future.
Consider updating your clients to align with the new behavior.
"""
),
deprecated(
"""
This parameter is temporary. It was introduced to give users time
to upgrade their clients to follow the new behavior and will eventually
be removed.
Use it as a short-term workaround, but consider updating your clients
to align with the new behavior.
"""
),
] = 401,
): ):
super().__init__( super().__init__(
location=APIKeyIn.header, location=APIKeyIn.header,
@ -210,6 +289,7 @@ class APIKeyHeader(APIKeyBase):
scheme_name=scheme_name, scheme_name=scheme_name,
description=description, description=description,
auto_error=auto_error, auto_error=auto_error,
not_authenticated_status_code=not_authenticated_status_code,
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: Request) -> Optional[str]:
@ -291,6 +371,34 @@ class APIKeyCookie(APIKeyBase):
""" """
), ),
] = True, ] = True,
not_authenticated_status_code: Annotated[
Literal[401, 403],
Doc(
"""
By default, if the cookie is not provided and `auto_error` is
set to `True`, `APIKeyCookie` will automatically raise an
`HTTPException` with the status code `401`.
If your client relies on the old (incorrect) behavior and expects the
status code to be `403`, you can set `not_authenticated_status_code` to
`403` to achieve it.
Keep in mind that this parameter is temporary and will be removed in
the near future.
Consider updating your clients to align with the new behavior.
"""
),
deprecated(
"""
This parameter is temporary. It was introduced to give users time
to upgrade their clients to follow the new behavior and will eventually
be removed.
Use it as a short-term workaround, but consider updating your clients
to align with the new behavior.
"""
),
] = 401,
): ):
super().__init__( super().__init__(
location=APIKeyIn.cookie, location=APIKeyIn.cookie,
@ -298,6 +406,7 @@ class APIKeyCookie(APIKeyBase):
scheme_name=scheme_name, scheme_name=scheme_name,
description=description, description=description,
auto_error=auto_error, auto_error=auto_error,
not_authenticated_status_code=not_authenticated_status_code,
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: Request) -> Optional[str]:

3
tests/test_security_api_key_cookie.py

@ -32,8 +32,9 @@ def test_security_api_key():
def test_security_api_key_no_key(): def test_security_api_key_no_key():
client = TestClient(app) client = TestClient(app)
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == 'ApiKey in="cookie", name="key"'
def test_openapi_schema(): def test_openapi_schema():

3
tests/test_security_api_key_cookie_description.py

@ -32,8 +32,9 @@ def test_security_api_key():
def test_security_api_key_no_key(): def test_security_api_key_no_key():
client = TestClient(app) client = TestClient(app)
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == 'ApiKey in="cookie", name="key"'
def test_openapi_schema(): def test_openapi_schema():

3
tests/test_security_api_key_header.py

@ -33,8 +33,9 @@ def test_security_api_key():
def test_security_api_key_no_key(): def test_security_api_key_no_key():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == 'ApiKey in="header", name="key"'
def test_openapi_schema(): def test_openapi_schema():

3
tests/test_security_api_key_header_description.py

@ -33,8 +33,9 @@ def test_security_api_key():
def test_security_api_key_no_key(): def test_security_api_key_no_key():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == 'ApiKey in="header", name="key"'
def test_openapi_schema(): def test_openapi_schema():

3
tests/test_security_api_key_query.py

@ -33,8 +33,9 @@ def test_security_api_key():
def test_security_api_key_no_key(): def test_security_api_key_no_key():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == 'ApiKey in="query", name="key"'
def test_openapi_schema(): def test_openapi_schema():

3
tests/test_security_api_key_query_description.py

@ -33,8 +33,9 @@ def test_security_api_key():
def test_security_api_key_no_key(): def test_security_api_key_no_key():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == 'ApiKey in="query", name="key"'
def test_openapi_schema(): def test_openapi_schema():

56
tests/test_security_status_code_403_option.py

@ -0,0 +1,56 @@
import pytest
from fastapi import FastAPI, Security
from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyHeader, APIKeyQuery
from fastapi.testclient import TestClient
@pytest.mark.parametrize(
"auth",
[
APIKeyQuery(name="key", not_authenticated_status_code=403),
APIKeyHeader(name="key", not_authenticated_status_code=403),
APIKeyCookie(name="key", not_authenticated_status_code=403),
],
)
def test_apikey_status_code_403_on_auth_error(auth: APIKeyBase):
"""
Test temporary `not_authenticated_status_code` parameter for APIKey** classes.
"""
app = FastAPI()
@app.get("/")
async def protected(_: str = Security(auth)):
pass # pragma: no cover
client = TestClient(app)
response = client.get("/")
assert response.status_code == 403
assert response.json() == {"detail": "Not authenticated"}
@pytest.mark.parametrize(
"auth",
[
APIKeyQuery(name="key", not_authenticated_status_code=403, auto_error=False),
APIKeyHeader(name="key", not_authenticated_status_code=403, auto_error=False),
APIKeyCookie(name="key", not_authenticated_status_code=403, auto_error=False),
],
)
def test_apikey_status_code_403_on_auth_error_no_auto_error(auth: APIKeyBase):
"""
Test temporary `not_authenticated_status_code` parameter for APIKey** classes with
`auto_error=False`.
"""
app = FastAPI()
@app.get("/")
async def protected(_: str = Security(auth)):
pass # pragma: no cover
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
Loading…
Cancel
Save