Browse Source

Fix `HTTPBearer` security scheme status code on "Not authenticated" error

pull/13786/head
Yurii Motov 1 month ago
parent
commit
6c53a67422
  1. 51
      fastapi/security/http.py
  2. 6
      tests/test_security_http_bearer.py
  3. 6
      tests/test_security_http_bearer_description.py
  4. 54
      tests/test_security_status_code_403_option.py

51
fastapi/security/http.py

@ -1,6 +1,6 @@
import binascii
from base64 import b64decode
from typing import Optional
from typing import Literal, Optional
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import HTTPBase as HTTPBaseModel
@ -10,7 +10,7 @@ from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc
from typing_extensions import Annotated, Doc, deprecated
class HTTPBasicCredentials(BaseModel):
@ -293,10 +293,38 @@ class HTTPBearer(HTTPBase):
"""
),
] = True,
not_authenticated_status_code: Annotated[
Literal[401, 403],
Doc(
"""
By default, if the HTTP Bearer token is not provided and `auto_error`
is set to `True`, `HTTPBearer` will automatically raise an
`HTTPException` with the status code `401`.
If your client relies on the old (incorrect) behavior and expects the
status code to be `403`, you can set `not_authenticated_status_code` to
`403` to achieve it.
Keep in mind that this parameter is temporary and will be removed in
the near future.
"""
),
deprecated(
"""
This parameter is temporary. It was introduced to give users time
to upgrade their clients to follow the new behavior and will eventually
be removed.
Use it as a short-term workaround, but consider updating your clients
to align with the new behavior.
"""
),
] = 401,
):
self.model = HTTPBearerModel(bearerFormat=bearerFormat, description=description)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
self.not_authenticated_status_code = not_authenticated_status_code
async def __call__(
self, request: Request
@ -305,21 +333,28 @@ class HTTPBearer(HTTPBase):
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
self._raise_not_authenticated_error(error_message="Not authenticated")
else:
return None
if scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid authentication credentials",
self._raise_not_authenticated_error(
error_message="Invalid authentication credentials"
)
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
def _raise_not_authenticated_error(self, error_message: str):
if self.not_authenticated_status_code == HTTP_403_FORBIDDEN:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=error_message)
else:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=error_message,
headers={"WWW-Authenticate": "Bearer"},
)
class HTTPDigest(HTTPBase):
"""

6
tests/test_security_http_bearer.py

@ -23,14 +23,16 @@ def test_security_http_bearer():
def test_security_http_bearer_no_credentials():
response = client.get("/users/me")
assert response.status_code == 403, response.text
assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_security_http_bearer_incorrect_scheme_credentials():
response = client.get("/users/me", headers={"Authorization": "Basic notreally"})
assert response.status_code == 403, response.text
assert response.status_code == 401, response.text
assert response.json() == {"detail": "Invalid authentication credentials"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_openapi_schema():

6
tests/test_security_http_bearer_description.py

@ -23,14 +23,16 @@ def test_security_http_bearer():
def test_security_http_bearer_no_credentials():
response = client.get("/users/me")
assert response.status_code == 403, response.text
assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_security_http_bearer_incorrect_scheme_credentials():
response = client.get("/users/me", headers={"Authorization": "Basic notreally"})
assert response.status_code == 403, response.text
assert response.status_code == 401, response.text
assert response.json() == {"detail": "Invalid authentication credentials"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_openapi_schema():

54
tests/test_security_status_code_403_option.py

@ -1,6 +1,8 @@
import pytest
from fastapi import FastAPI, Security
from fastapi.openapi.models import HTTPBase
from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyHeader, APIKeyQuery
from fastapi.security.http import HTTPBearer, HTTPDigest
from fastapi.testclient import TestClient
@ -54,3 +56,55 @@ def test_apikey_status_code_403_on_auth_error_no_auto_error(auth: APIKeyBase):
response = client.get("/")
assert response.status_code == 200
@pytest.mark.parametrize(
"auth",
[
HTTPBearer(not_authenticated_status_code=403),
],
)
def test_oauth2_status_code_403_on_auth_error(auth: HTTPBase):
"""
Test temporary `not_authenticated_status_code` parameter for security classes that
follow rfc6750.
"""
app = FastAPI()
@app.get("/")
async def protected(_: str = Security(auth)):
pass # pragma: no cover
client = TestClient(app)
response = client.get("/")
assert response.status_code == 403
assert response.json() == {"detail": "Not authenticated"}
@pytest.mark.parametrize(
"auth",
[
HTTPBearer(not_authenticated_status_code=403, auto_error=False),
],
)
def test_oauth2_status_code_403_on_auth_error_no_auto_error(
auth: HTTPBase,
):
"""
Test temporary `not_authenticated_status_code` parameter for security classes that
follow rfc6750.
With `auto_error=False`. Response code should be 200
"""
app = FastAPI()
@app.get("/")
async def protected(_: str = Security(auth)):
pass # pragma: no cover
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200

Loading…
Cancel
Save