Browse Source

Fix `OAuth2` security scheme status code on "Not authenticated" error

pull/13786/head
Yurii Motov 1 month ago
parent
commit
ade9d830f6
  1. 38
      fastapi/security/oauth2.py
  2. 3
      tests/test_security_oauth2.py
  3. 11
      tests/test_security_status_code_403_option.py

38
fastapi/security/oauth2.py

@ -10,7 +10,7 @@ from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
# TODO: import from typing when deprecating Python 3.9 # TODO: import from typing when deprecating Python 3.9
from typing_extensions import Annotated, Doc from typing_extensions import Annotated, Doc, Literal, deprecated
class OAuth2PasswordRequestForm: class OAuth2PasswordRequestForm:
@ -369,20 +369,56 @@ class OAuth2(SecurityBase):
""" """
), ),
] = True, ] = True,
not_authenticated_status_code: Annotated[
Literal[401, 403],
Doc(
"""
By default, if no HTTP Authorization header provided and `auto_error`
is set to `True`, it will automatically raise an`HTTPException` with
the status code `401`.
If your client relies on the old (incorrect) behavior and expects the
status code to be `403`, you can set `not_authenticated_status_code` to
`403` to achieve it.
Keep in mind that this parameter is temporary and will be removed in
the near future.
"""
),
deprecated(
"""
This parameter is temporary. It was introduced to give users time
to upgrade their clients to follow the new behavior and will eventually
be removed.
Use it as a short-term workaround, but consider updating your clients
to align with the new behavior.
"""
),
] = 401,
): ):
self.model = OAuth2Model( self.model = OAuth2Model(
flows=cast(OAuthFlowsModel, flows), description=description flows=cast(OAuthFlowsModel, flows), description=description
) )
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
self.not_authenticated_status_code = not_authenticated_status_code
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: Request) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:
if self.not_authenticated_status_code == HTTP_403_FORBIDDEN:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
) )
else:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
else: else:
return None return None
return authorization return authorization

3
tests/test_security_oauth2.py

@ -56,8 +56,9 @@ def test_security_oauth2_password_other_header():
def test_security_oauth2_password_bearer_no_header(): def test_security_oauth2_password_bearer_no_header():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_strict_login_no_data(): def test_strict_login_no_data():

11
tests/test_security_status_code_403_option.py

@ -4,6 +4,7 @@ import pytest
from fastapi import FastAPI, Security from fastapi import FastAPI, Security
from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyHeader, APIKeyQuery from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyHeader, APIKeyQuery
from fastapi.security.http import HTTPBase, HTTPBearer, HTTPDigest from fastapi.security.http import HTTPBase, HTTPBearer, HTTPDigest
from fastapi.security.oauth2 import OAuth2
from fastapi.security.open_id_connect_url import OpenIdConnect from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -65,6 +66,10 @@ def test_apikey_status_code_403_on_auth_error_no_auto_error(auth: APIKeyBase):
[ [
HTTPBearer(not_authenticated_status_code=403), HTTPBearer(not_authenticated_status_code=403),
OpenIdConnect(not_authenticated_status_code=403, openIdConnectUrl="/openid"), OpenIdConnect(not_authenticated_status_code=403, openIdConnectUrl="/openid"),
OAuth2(
not_authenticated_status_code=403,
flows={"password": {"tokenUrl": "token", "scopes": {}}},
),
], ],
) )
def test_oauth2_status_code_403_on_auth_error(auth: Union[HTTPBase, OpenIdConnect]): def test_oauth2_status_code_403_on_auth_error(auth: Union[HTTPBase, OpenIdConnect]):
@ -95,6 +100,12 @@ def test_oauth2_status_code_403_on_auth_error(auth: Union[HTTPBase, OpenIdConnec
openIdConnectUrl="/openid", openIdConnectUrl="/openid",
auto_error=False, auto_error=False,
), ),
OAuth2(
not_authenticated_status_code=403,
flows={"password": {"tokenUrl": "token", "scopes": {}}},
auto_error=False,
),
], ],
) )
def test_oauth2_status_code_403_on_auth_error_no_auto_error( def test_oauth2_status_code_403_on_auth_error_no_auto_error(

Loading…
Cancel
Save