diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py index c8cceb911..c6b961ac0 100644 --- a/fastapi/security/open_id_connect_url.py +++ b/fastapi/security/open_id_connect_url.py @@ -1,11 +1,11 @@ -from typing import Optional +from typing import Literal, Optional from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.security.base import SecurityBase from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.status import HTTP_403_FORBIDDEN -from typing_extensions import Annotated, Doc +from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN +from typing_extensions import Annotated, Doc, deprecated class OpenIdConnect(SecurityBase): @@ -65,20 +65,55 @@ class OpenIdConnect(SecurityBase): """ ), ] = True, + not_authenticated_status_code: Annotated[ + Literal[401, 403], + Doc( + """ + By default, if no HTTP Authorization header provided and `auto_error` + is set to `True`, it will automatically raise an`HTTPException` with + the status code `401`. + + If your client relies on the old (incorrect) behavior and expects the + status code to be `403`, you can set `not_authenticated_status_code` to + `403` to achieve it. + + Keep in mind that this parameter is temporary and will be removed in + the near future. + """ + ), + deprecated( + """ + This parameter is temporary. It was introduced to give users time + to upgrade their clients to follow the new behavior and will eventually + be removed. + + Use it as a short-term workaround, but consider updating your clients + to align with the new behavior. + """ + ), + ] = 401, ): self.model = OpenIdConnectModel( openIdConnectUrl=openIdConnectUrl, description=description ) self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error + self.not_authenticated_status_code = not_authenticated_status_code async def __call__(self, request: Request) -> Optional[str]: authorization = request.headers.get("Authorization") if not authorization: if self.auto_error: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.not_authenticated_status_code == HTTP_403_FORBIDDEN: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) else: return None return authorization diff --git a/tests/test_security_openid_connect.py b/tests/test_security_openid_connect.py index 1e322e640..c9a0a8db7 100644 --- a/tests/test_security_openid_connect.py +++ b/tests/test_security_openid_connect.py @@ -39,8 +39,9 @@ def test_security_oauth2_password_other_header(): def test_security_oauth2_password_bearer_no_header(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == "Bearer" def test_openapi_schema(): diff --git a/tests/test_security_openid_connect_description.py b/tests/test_security_openid_connect_description.py index 44cf57f86..d008cbc63 100644 --- a/tests/test_security_openid_connect_description.py +++ b/tests/test_security_openid_connect_description.py @@ -41,8 +41,9 @@ def test_security_oauth2_password_other_header(): def test_security_oauth2_password_bearer_no_header(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == "Bearer" def test_openapi_schema(): diff --git a/tests/test_security_status_code_403_option.py b/tests/test_security_status_code_403_option.py index bf91f6202..29866bfd4 100644 --- a/tests/test_security_status_code_403_option.py +++ b/tests/test_security_status_code_403_option.py @@ -1,7 +1,10 @@ +from typing import Union + import pytest from fastapi import FastAPI, Security from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyHeader, APIKeyQuery from fastapi.security.http import HTTPBase, HTTPBearer, HTTPDigest +from fastapi.security.open_id_connect_url import OpenIdConnect from fastapi.testclient import TestClient @@ -61,9 +64,10 @@ def test_apikey_status_code_403_on_auth_error_no_auto_error(auth: APIKeyBase): "auth", [ HTTPBearer(not_authenticated_status_code=403), + OpenIdConnect(not_authenticated_status_code=403, openIdConnectUrl="/openid"), ], ) -def test_oauth2_status_code_403_on_auth_error(auth: HTTPBase): +def test_oauth2_status_code_403_on_auth_error(auth: Union[HTTPBase, OpenIdConnect]): """ Test temporary `not_authenticated_status_code` parameter for security classes that follow rfc6750. @@ -86,10 +90,15 @@ def test_oauth2_status_code_403_on_auth_error(auth: HTTPBase): "auth", [ HTTPBearer(not_authenticated_status_code=403, auto_error=False), + OpenIdConnect( + not_authenticated_status_code=403, + openIdConnectUrl="/openid", + auto_error=False, + ), ], ) def test_oauth2_status_code_403_on_auth_error_no_auto_error( - auth: HTTPBase, + auth: Union[HTTPBase, OpenIdConnect], ): """ Test temporary `not_authenticated_status_code` parameter for security classes that