Browse Source

Fix `OpenIdConnect` security scheme status code on "Not authenticated" error

pull/13786/head
Yurii Motov 1 month ago
parent
commit
b9f29b8ee7
  1. 41
      fastapi/security/open_id_connect_url.py
  2. 3
      tests/test_security_openid_connect.py
  3. 3
      tests/test_security_openid_connect_description.py
  4. 13
      tests/test_security_status_code_403_option.py

41
fastapi/security/open_id_connect_url.py

@ -1,11 +1,11 @@
from typing import Optional from typing import Literal, Optional
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc from typing_extensions import Annotated, Doc, deprecated
class OpenIdConnect(SecurityBase): class OpenIdConnect(SecurityBase):
@ -65,20 +65,55 @@ class OpenIdConnect(SecurityBase):
""" """
), ),
] = True, ] = True,
not_authenticated_status_code: Annotated[
Literal[401, 403],
Doc(
"""
By default, if no HTTP Authorization header provided and `auto_error`
is set to `True`, it will automatically raise an`HTTPException` with
the status code `401`.
If your client relies on the old (incorrect) behavior and expects the
status code to be `403`, you can set `not_authenticated_status_code` to
`403` to achieve it.
Keep in mind that this parameter is temporary and will be removed in
the near future.
"""
),
deprecated(
"""
This parameter is temporary. It was introduced to give users time
to upgrade their clients to follow the new behavior and will eventually
be removed.
Use it as a short-term workaround, but consider updating your clients
to align with the new behavior.
"""
),
] = 401,
): ):
self.model = OpenIdConnectModel( self.model = OpenIdConnectModel(
openIdConnectUrl=openIdConnectUrl, description=description openIdConnectUrl=openIdConnectUrl, description=description
) )
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
self.not_authenticated_status_code = not_authenticated_status_code
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: Request) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:
if self.not_authenticated_status_code == HTTP_403_FORBIDDEN:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
) )
else:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
else: else:
return None return None
return authorization return authorization

3
tests/test_security_openid_connect.py

@ -39,8 +39,9 @@ def test_security_oauth2_password_other_header():
def test_security_oauth2_password_bearer_no_header(): def test_security_oauth2_password_bearer_no_header():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_openapi_schema(): def test_openapi_schema():

3
tests/test_security_openid_connect_description.py

@ -41,8 +41,9 @@ def test_security_oauth2_password_other_header():
def test_security_oauth2_password_bearer_no_header(): def test_security_oauth2_password_bearer_no_header():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_openapi_schema(): def test_openapi_schema():

13
tests/test_security_status_code_403_option.py

@ -1,7 +1,10 @@
from typing import Union
import pytest import pytest
from fastapi import FastAPI, Security from fastapi import FastAPI, Security
from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyHeader, APIKeyQuery from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyHeader, APIKeyQuery
from fastapi.security.http import HTTPBase, HTTPBearer, HTTPDigest from fastapi.security.http import HTTPBase, HTTPBearer, HTTPDigest
from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -61,9 +64,10 @@ def test_apikey_status_code_403_on_auth_error_no_auto_error(auth: APIKeyBase):
"auth", "auth",
[ [
HTTPBearer(not_authenticated_status_code=403), HTTPBearer(not_authenticated_status_code=403),
OpenIdConnect(not_authenticated_status_code=403, openIdConnectUrl="/openid"),
], ],
) )
def test_oauth2_status_code_403_on_auth_error(auth: HTTPBase): def test_oauth2_status_code_403_on_auth_error(auth: Union[HTTPBase, OpenIdConnect]):
""" """
Test temporary `not_authenticated_status_code` parameter for security classes that Test temporary `not_authenticated_status_code` parameter for security classes that
follow rfc6750. follow rfc6750.
@ -86,10 +90,15 @@ def test_oauth2_status_code_403_on_auth_error(auth: HTTPBase):
"auth", "auth",
[ [
HTTPBearer(not_authenticated_status_code=403, auto_error=False), HTTPBearer(not_authenticated_status_code=403, auto_error=False),
OpenIdConnect(
not_authenticated_status_code=403,
openIdConnectUrl="/openid",
auto_error=False,
),
], ],
) )
def test_oauth2_status_code_403_on_auth_error_no_auto_error( def test_oauth2_status_code_403_on_auth_error_no_auto_error(
auth: HTTPBase, auth: Union[HTTPBase, OpenIdConnect],
): ):
""" """
Test temporary `not_authenticated_status_code` parameter for security classes that Test temporary `not_authenticated_status_code` parameter for security classes that

Loading…
Cancel
Save