diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py index 70c2dca8a..a6f4e85d0 100644 --- a/fastapi/security/api_key.py +++ b/fastapi/security/api_key.py @@ -1,21 +1,60 @@ -from typing import Optional +from typing import Optional, Union from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.security.base import SecurityBase from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.status import HTTP_403_FORBIDDEN -from typing_extensions import Annotated, Doc +from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN +from typing_extensions import Annotated, Doc, Literal, deprecated class APIKeyBase(SecurityBase): - @staticmethod - def check_api_key(api_key: Optional[str], auto_error: bool) -> Optional[str]: + def __init__( + self, + location: APIKeyIn, + name: str, + description: Union[str, None], + scheme_name: Union[str, None], + auto_error: bool, + not_authenticated_status_code: Literal[401, 403], + ): + self.parameter_location = location.value + self.parameter_name = name + self.auto_error = auto_error + self.not_authenticated_status_code = not_authenticated_status_code + + self.model: APIKey = APIKey( + **{"in": location}, # type: ignore[arg-type] + name=name, + description=description, + ) + self.scheme_name = scheme_name or self.__class__.__name__ + + def format_www_authenticate_header_value(self) -> str: + """ + The WWW-Authenticate header is not standardized for API Key authentication. + It's considered good practice to include information about the authentication + challange. + This method follows one of the common templates. + If a different format is required, override this method in a subclass. + """ + + return f'ApiKey in="{self.parameter_location}", name="{self.parameter_name}"' + + def check_api_key(self, api_key: Optional[str]) -> Optional[str]: if not api_key: - if auto_error: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.auto_error: + if self.not_authenticated_status_code == HTTP_403_FORBIDDEN: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: # By default use 401 + www_authenticate = self.format_www_authenticate_header_value() + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": www_authenticate}, + ) return None return api_key @@ -98,18 +137,47 @@ class APIKeyQuery(APIKeyBase): """ ), ] = True, + not_authenticated_status_code: Annotated[ + Literal[401, 403], + Doc( + """ + By default, if the query parameter is not provided and `auto_error` is + set to `True`, `APIKeyQuery` will automatically raise an + `HTTPException` with the status code `401`. + + If your client relies on the old (incorrect) behavior and expects the + status code to be `403`, you can set `not_authenticated_status_code` to + `403` to achieve it. + + Keep in mind that this parameter is temporary and will be removed in + the near future. + Consider updating your clients to align with the new behavior. + """ + ), + deprecated( + """ + This parameter is temporary. It was introduced to give users time + to upgrade their clients to follow the new behavior and will eventually + be removed. + + Use it as a short-term workaround, but consider updating your clients + to align with the new behavior. + """ + ), + ] = 401, ): - self.model: APIKey = APIKey( - **{"in": APIKeyIn.query}, # type: ignore[arg-type] + super().__init__( + location=APIKeyIn.query, name=name, + scheme_name=scheme_name, description=description, + auto_error=auto_error, + not_authenticated_status_code=not_authenticated_status_code, ) - self.scheme_name = scheme_name or self.__class__.__name__ - self.auto_error = auto_error async def __call__(self, request: Request) -> Optional[str]: api_key = request.query_params.get(self.model.name) - return self.check_api_key(api_key, self.auto_error) + return self.check_api_key(api_key) class APIKeyHeader(APIKeyBase): @@ -186,18 +254,47 @@ class APIKeyHeader(APIKeyBase): """ ), ] = True, + not_authenticated_status_code: Annotated[ + Literal[401, 403], + Doc( + """ + By default, if the header is not provided and `auto_error` is + set to `True`, `APIKeyHeader` will automatically raise an + `HTTPException` with the status code `401`. + + If your client relies on the old (incorrect) behavior and expects the + status code to be `403`, you can set `not_authenticated_status_code` to + `403` to achieve it. + + Keep in mind that this parameter is temporary and will be removed in + the near future. + Consider updating your clients to align with the new behavior. + """ + ), + deprecated( + """ + This parameter is temporary. It was introduced to give users time + to upgrade their clients to follow the new behavior and will eventually + be removed. + + Use it as a short-term workaround, but consider updating your clients + to align with the new behavior. + """ + ), + ] = 401, ): - self.model: APIKey = APIKey( - **{"in": APIKeyIn.header}, # type: ignore[arg-type] + super().__init__( + location=APIKeyIn.header, name=name, + scheme_name=scheme_name, description=description, + auto_error=auto_error, + not_authenticated_status_code=not_authenticated_status_code, ) - self.scheme_name = scheme_name or self.__class__.__name__ - self.auto_error = auto_error async def __call__(self, request: Request) -> Optional[str]: api_key = request.headers.get(self.model.name) - return self.check_api_key(api_key, self.auto_error) + return self.check_api_key(api_key) class APIKeyCookie(APIKeyBase): @@ -274,15 +371,44 @@ class APIKeyCookie(APIKeyBase): """ ), ] = True, + not_authenticated_status_code: Annotated[ + Literal[401, 403], + Doc( + """ + By default, if the cookie is not provided and `auto_error` is + set to `True`, `APIKeyCookie` will automatically raise an + `HTTPException` with the status code `401`. + + If your client relies on the old (incorrect) behavior and expects the + status code to be `403`, you can set `not_authenticated_status_code` to + `403` to achieve it. + + Keep in mind that this parameter is temporary and will be removed in + the near future. + Consider updating your clients to align with the new behavior. + """ + ), + deprecated( + """ + This parameter is temporary. It was introduced to give users time + to upgrade their clients to follow the new behavior and will eventually + be removed. + + Use it as a short-term workaround, but consider updating your clients + to align with the new behavior. + """ + ), + ] = 401, ): - self.model: APIKey = APIKey( - **{"in": APIKeyIn.cookie}, # type: ignore[arg-type] + super().__init__( + location=APIKeyIn.cookie, name=name, + scheme_name=scheme_name, description=description, + auto_error=auto_error, + not_authenticated_status_code=not_authenticated_status_code, ) - self.scheme_name = scheme_name or self.__class__.__name__ - self.auto_error = auto_error async def __call__(self, request: Request) -> Optional[str]: api_key = request.cookies.get(self.model.name) - return self.check_api_key(api_key, self.auto_error) + return self.check_api_key(api_key) diff --git a/fastapi/security/http.py b/fastapi/security/http.py index 9ab2df3c9..7d106fbe1 100644 --- a/fastapi/security/http.py +++ b/fastapi/security/http.py @@ -10,7 +10,7 @@ from fastapi.security.utils import get_authorization_scheme_param from pydantic import BaseModel from starlette.requests import Request from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN -from typing_extensions import Annotated, Doc +from typing_extensions import Annotated, Doc, Literal, deprecated class HTTPBasicCredentials(BaseModel): @@ -74,10 +74,13 @@ class HTTPBase(SecurityBase): scheme_name: Optional[str] = None, description: Optional[str] = None, auto_error: bool = True, + not_authenticated_status_code: Literal[401, 403] = 401, ): self.model = HTTPBaseModel(scheme=scheme, description=description) + self.model_scheme = scheme self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error + self.not_authenticated_status_code = not_authenticated_status_code async def __call__( self, request: Request @@ -86,9 +89,16 @@ class HTTPBase(SecurityBase): scheme, credentials = get_authorization_scheme_param(authorization) if not (authorization and scheme and credentials): if self.auto_error: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.not_authenticated_status_code == HTTP_403_FORBIDDEN: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": self.model_scheme}, + ) else: return None return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) @@ -293,10 +303,38 @@ class HTTPBearer(HTTPBase): """ ), ] = True, + not_authenticated_status_code: Annotated[ + Literal[401, 403], + Doc( + """ + By default, if the HTTP Bearer token is not provided and `auto_error` + is set to `True`, `HTTPBearer` will automatically raise an + `HTTPException` with the status code `401`. + + If your client relies on the old (incorrect) behavior and expects the + status code to be `403`, you can set `not_authenticated_status_code` to + `403` to achieve it. + + Keep in mind that this parameter is temporary and will be removed in + the near future. + """ + ), + deprecated( + """ + This parameter is temporary. It was introduced to give users time + to upgrade their clients to follow the new behavior and will eventually + be removed. + + Use it as a short-term workaround, but consider updating your clients + to align with the new behavior. + """ + ), + ] = 401, ): self.model = HTTPBearerModel(bearerFormat=bearerFormat, description=description) self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error + self.not_authenticated_status_code = not_authenticated_status_code async def __call__( self, request: Request @@ -305,21 +343,28 @@ class HTTPBearer(HTTPBase): scheme, credentials = get_authorization_scheme_param(authorization) if not (authorization and scheme and credentials): if self.auto_error: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + self._raise_not_authenticated_error(error_message="Not authenticated") else: return None if scheme.lower() != "bearer": if self.auto_error: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail="Invalid authentication credentials", + self._raise_not_authenticated_error( + error_message="Invalid authentication credentials" ) else: return None return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) + def _raise_not_authenticated_error(self, error_message: str) -> None: + if self.not_authenticated_status_code == HTTP_403_FORBIDDEN: + raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=error_message) + else: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail=error_message, + headers={"WWW-Authenticate": "Bearer"}, + ) + class HTTPDigest(HTTPBase): """ @@ -395,10 +440,38 @@ class HTTPDigest(HTTPBase): """ ), ] = True, + not_authenticated_status_code: Annotated[ + Literal[401, 403], + Doc( + """ + By default, if the HTTP Digest is not provided and `auto_error` + is set to `True`, `HTTPDigest` will automatically raise an + `HTTPException` with the status code `401`. + + If your client relies on the old (incorrect) behavior and expects the + status code to be `403`, you can set `not_authenticated_status_code` to + `403` to achieve it. + + Keep in mind that this parameter is temporary and will be removed in + the near future. + """ + ), + deprecated( + """ + This parameter is temporary. It was introduced to give users time + to upgrade their clients to follow the new behavior and will eventually + be removed. + + Use it as a short-term workaround, but consider updating your clients + to align with the new behavior. + """ + ), + ] = 401, ): self.model = HTTPBaseModel(scheme="digest", description=description) self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error + self.not_authenticated_status_code = not_authenticated_status_code async def __call__( self, request: Request @@ -407,17 +480,24 @@ class HTTPDigest(HTTPBase): scheme, credentials = get_authorization_scheme_param(authorization) if not (authorization and scheme and credentials): if self.auto_error: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + self._raise_not_authenticated_error(error_message="Not authenticated") else: return None if scheme.lower() != "digest": if self.auto_error: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail="Invalid authentication credentials", + self._raise_not_authenticated_error( + error_message="Invalid authentication credentials", ) else: return None return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) + + def _raise_not_authenticated_error(self, error_message: str) -> None: + if self.not_authenticated_status_code == HTTP_403_FORBIDDEN: + raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=error_message) + else: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail=error_message, + headers={"WWW-Authenticate": "Digest"}, + ) diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py index 88e394db1..85dc75bf5 100644 --- a/fastapi/security/oauth2.py +++ b/fastapi/security/oauth2.py @@ -10,7 +10,7 @@ from starlette.requests import Request from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN # TODO: import from typing when deprecating Python 3.9 -from typing_extensions import Annotated, Doc +from typing_extensions import Annotated, Doc, Literal, deprecated class OAuth2PasswordRequestForm: @@ -369,20 +369,56 @@ class OAuth2(SecurityBase): """ ), ] = True, + not_authenticated_status_code: Annotated[ + Literal[401, 403], + Doc( + """ + By default, if no HTTP Authorization header provided and `auto_error` + is set to `True`, it will automatically raise an`HTTPException` with + the status code `401`. + + If your client relies on the old (incorrect) behavior and expects the + status code to be `403`, you can set `not_authenticated_status_code` to + `403` to achieve it. + + Keep in mind that this parameter is temporary and will be removed in + the near future. + """ + ), + deprecated( + """ + This parameter is temporary. It was introduced to give users time + to upgrade their clients to follow the new behavior and will eventually + be removed. + + Use it as a short-term workaround, but consider updating your clients + to align with the new behavior. + """ + ), + ] = 401, ): self.model = OAuth2Model( flows=cast(OAuthFlowsModel, flows), description=description ) self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error + self.not_authenticated_status_code = not_authenticated_status_code async def __call__(self, request: Request) -> Optional[str]: authorization = request.headers.get("Authorization") if not authorization: if self.auto_error: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.not_authenticated_status_code == HTTP_403_FORBIDDEN: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + else: return None return authorization diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py index c8cceb911..76d91941e 100644 --- a/fastapi/security/open_id_connect_url.py +++ b/fastapi/security/open_id_connect_url.py @@ -4,8 +4,8 @@ from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.security.base import SecurityBase from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.status import HTTP_403_FORBIDDEN -from typing_extensions import Annotated, Doc +from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN +from typing_extensions import Annotated, Doc, Literal, deprecated class OpenIdConnect(SecurityBase): @@ -65,20 +65,55 @@ class OpenIdConnect(SecurityBase): """ ), ] = True, + not_authenticated_status_code: Annotated[ + Literal[401, 403], + Doc( + """ + By default, if no HTTP Authorization header provided and `auto_error` + is set to `True`, it will automatically raise an`HTTPException` with + the status code `401`. + + If your client relies on the old (incorrect) behavior and expects the + status code to be `403`, you can set `not_authenticated_status_code` to + `403` to achieve it. + + Keep in mind that this parameter is temporary and will be removed in + the near future. + """ + ), + deprecated( + """ + This parameter is temporary. It was introduced to give users time + to upgrade their clients to follow the new behavior and will eventually + be removed. + + Use it as a short-term workaround, but consider updating your clients + to align with the new behavior. + """ + ), + ] = 401, ): self.model = OpenIdConnectModel( openIdConnectUrl=openIdConnectUrl, description=description ) self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error + self.not_authenticated_status_code = not_authenticated_status_code async def __call__(self, request: Request) -> Optional[str]: authorization = request.headers.get("Authorization") if not authorization: if self.auto_error: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.not_authenticated_status_code == HTTP_403_FORBIDDEN: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) else: return None return authorization diff --git a/tests/test_security_api_key_cookie.py b/tests/test_security_api_key_cookie.py index 4ddb8e2ee..488503817 100644 --- a/tests/test_security_api_key_cookie.py +++ b/tests/test_security_api_key_cookie.py @@ -32,8 +32,9 @@ def test_security_api_key(): def test_security_api_key_no_key(): client = TestClient(app) response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == 'ApiKey in="cookie", name="key"' def test_openapi_schema(): diff --git a/tests/test_security_api_key_cookie_description.py b/tests/test_security_api_key_cookie_description.py index d99d616e0..e0e448471 100644 --- a/tests/test_security_api_key_cookie_description.py +++ b/tests/test_security_api_key_cookie_description.py @@ -32,8 +32,9 @@ def test_security_api_key(): def test_security_api_key_no_key(): client = TestClient(app) response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == 'ApiKey in="cookie", name="key"' def test_openapi_schema(): diff --git a/tests/test_security_api_key_header.py b/tests/test_security_api_key_header.py index 1ff883703..b72d258c4 100644 --- a/tests/test_security_api_key_header.py +++ b/tests/test_security_api_key_header.py @@ -33,8 +33,9 @@ def test_security_api_key(): def test_security_api_key_no_key(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == 'ApiKey in="header", name="key"' def test_openapi_schema(): diff --git a/tests/test_security_api_key_header_description.py b/tests/test_security_api_key_header_description.py index 27f9d0f29..70b85af03 100644 --- a/tests/test_security_api_key_header_description.py +++ b/tests/test_security_api_key_header_description.py @@ -33,8 +33,9 @@ def test_security_api_key(): def test_security_api_key_no_key(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == 'ApiKey in="header", name="key"' def test_openapi_schema(): diff --git a/tests/test_security_api_key_query.py b/tests/test_security_api_key_query.py index dc7a0a621..7a01101c4 100644 --- a/tests/test_security_api_key_query.py +++ b/tests/test_security_api_key_query.py @@ -33,8 +33,9 @@ def test_security_api_key(): def test_security_api_key_no_key(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == 'ApiKey in="query", name="key"' def test_openapi_schema(): diff --git a/tests/test_security_api_key_query_description.py b/tests/test_security_api_key_query_description.py index 35dc7743a..45102eb17 100644 --- a/tests/test_security_api_key_query_description.py +++ b/tests/test_security_api_key_query_description.py @@ -33,8 +33,9 @@ def test_security_api_key(): def test_security_api_key_no_key(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == 'ApiKey in="query", name="key"' def test_openapi_schema(): diff --git a/tests/test_security_http_base.py b/tests/test_security_http_base.py index 51928bafd..8cf259a75 100644 --- a/tests/test_security_http_base.py +++ b/tests/test_security_http_base.py @@ -23,8 +23,9 @@ def test_security_http_base(): def test_security_http_base_no_credentials(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == "Other" def test_openapi_schema(): diff --git a/tests/test_security_http_base_description.py b/tests/test_security_http_base_description.py index bc79f3242..791ea59f4 100644 --- a/tests/test_security_http_base_description.py +++ b/tests/test_security_http_base_description.py @@ -23,8 +23,9 @@ def test_security_http_base(): def test_security_http_base_no_credentials(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == "Other" def test_openapi_schema(): diff --git a/tests/test_security_http_bearer.py b/tests/test_security_http_bearer.py index 5b9e2d691..de4e0427a 100644 --- a/tests/test_security_http_bearer.py +++ b/tests/test_security_http_bearer.py @@ -23,14 +23,16 @@ def test_security_http_bearer(): def test_security_http_bearer_no_credentials(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == "Bearer" def test_security_http_bearer_incorrect_scheme_credentials(): response = client.get("/users/me", headers={"Authorization": "Basic notreally"}) - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Invalid authentication credentials"} + assert response.headers["WWW-Authenticate"] == "Bearer" def test_openapi_schema(): diff --git a/tests/test_security_http_bearer_description.py b/tests/test_security_http_bearer_description.py index 2f11c3a14..f87df5434 100644 --- a/tests/test_security_http_bearer_description.py +++ b/tests/test_security_http_bearer_description.py @@ -23,14 +23,16 @@ def test_security_http_bearer(): def test_security_http_bearer_no_credentials(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == "Bearer" def test_security_http_bearer_incorrect_scheme_credentials(): response = client.get("/users/me", headers={"Authorization": "Basic notreally"}) - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Invalid authentication credentials"} + assert response.headers["WWW-Authenticate"] == "Bearer" def test_openapi_schema(): diff --git a/tests/test_security_http_digest.py b/tests/test_security_http_digest.py index 133d35763..a195430d2 100644 --- a/tests/test_security_http_digest.py +++ b/tests/test_security_http_digest.py @@ -23,16 +23,18 @@ def test_security_http_digest(): def test_security_http_digest_no_credentials(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == "Digest" def test_security_http_digest_incorrect_scheme_credentials(): response = client.get( "/users/me", headers={"Authorization": "Other invalidauthorization"} ) - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Invalid authentication credentials"} + assert response.headers["WWW-Authenticate"] == "Digest" def test_openapi_schema(): diff --git a/tests/test_security_http_digest_description.py b/tests/test_security_http_digest_description.py index 4e31a0c00..0ced8494e 100644 --- a/tests/test_security_http_digest_description.py +++ b/tests/test_security_http_digest_description.py @@ -23,16 +23,18 @@ def test_security_http_digest(): def test_security_http_digest_no_credentials(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == "Digest" def test_security_http_digest_incorrect_scheme_credentials(): response = client.get( "/users/me", headers={"Authorization": "Other invalidauthorization"} ) - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Invalid authentication credentials"} + assert response.headers["WWW-Authenticate"] == "Digest" def test_openapi_schema(): diff --git a/tests/test_security_oauth2.py b/tests/test_security_oauth2.py index 2b7e3457a..804e4152d 100644 --- a/tests/test_security_oauth2.py +++ b/tests/test_security_oauth2.py @@ -56,8 +56,9 @@ def test_security_oauth2_password_other_header(): def test_security_oauth2_password_bearer_no_header(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == "Bearer" def test_strict_login_no_data(): diff --git a/tests/test_security_openid_connect.py b/tests/test_security_openid_connect.py index 1e322e640..c9a0a8db7 100644 --- a/tests/test_security_openid_connect.py +++ b/tests/test_security_openid_connect.py @@ -39,8 +39,9 @@ def test_security_oauth2_password_other_header(): def test_security_oauth2_password_bearer_no_header(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == "Bearer" def test_openapi_schema(): diff --git a/tests/test_security_openid_connect_description.py b/tests/test_security_openid_connect_description.py index 44cf57f86..d008cbc63 100644 --- a/tests/test_security_openid_connect_description.py +++ b/tests/test_security_openid_connect_description.py @@ -41,8 +41,9 @@ def test_security_oauth2_password_other_header(): def test_security_oauth2_password_bearer_no_header(): response = client.get("/users/me") - assert response.status_code == 403, response.text + assert response.status_code == 401, response.text assert response.json() == {"detail": "Not authenticated"} + assert response.headers["WWW-Authenticate"] == "Bearer" def test_openapi_schema(): diff --git a/tests/test_security_status_code_403_option.py b/tests/test_security_status_code_403_option.py new file mode 100644 index 000000000..95d385ae4 --- /dev/null +++ b/tests/test_security_status_code_403_option.py @@ -0,0 +1,208 @@ +from typing import Union + +import pytest +from fastapi import FastAPI, Security +from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyHeader, APIKeyQuery +from fastapi.security.http import HTTPBase, HTTPBearer, HTTPDigest +from fastapi.security.oauth2 import OAuth2 +from fastapi.security.open_id_connect_url import OpenIdConnect +from fastapi.testclient import TestClient + + +@pytest.mark.parametrize( + "auth", + [ + APIKeyQuery(name="key", not_authenticated_status_code=403), + APIKeyHeader(name="key", not_authenticated_status_code=403), + APIKeyCookie(name="key", not_authenticated_status_code=403), + ], +) +def test_apikey_status_code_403_on_auth_error(auth: APIKeyBase): + """ + Test temporary `not_authenticated_status_code` parameter for APIKey** classes. + """ + + app = FastAPI() + + @app.get("/") + async def protected(_: str = Security(auth)): + pass # pragma: no cover + + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 403 + assert response.json() == {"detail": "Not authenticated"} + + +@pytest.mark.parametrize( + "auth", + [ + APIKeyQuery(name="key", not_authenticated_status_code=403, auto_error=False), + APIKeyHeader(name="key", not_authenticated_status_code=403, auto_error=False), + APIKeyCookie(name="key", not_authenticated_status_code=403, auto_error=False), + ], +) +def test_apikey_status_code_403_on_auth_error_no_auto_error(auth: APIKeyBase): + """ + Test temporary `not_authenticated_status_code` parameter for APIKey** classes with + `auto_error=False`. + """ + + app = FastAPI() + + @app.get("/") + async def protected(_: str = Security(auth)): + pass # pragma: no cover + + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 200 + + +@pytest.mark.parametrize( + "auth", + [ + HTTPBearer(not_authenticated_status_code=403), + OpenIdConnect(not_authenticated_status_code=403, openIdConnectUrl="/openid"), + OAuth2( + not_authenticated_status_code=403, + flows={"password": {"tokenUrl": "token", "scopes": {}}}, + ), + ], +) +def test_oauth2_status_code_403_on_auth_error(auth: Union[HTTPBase, OpenIdConnect]): + """ + Test temporary `not_authenticated_status_code` parameter for security classes that + follow rfc6750. + """ + + app = FastAPI() + + @app.get("/") + async def protected(_: str = Security(auth)): + pass # pragma: no cover + + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 403 + assert response.json() == {"detail": "Not authenticated"} + + +@pytest.mark.parametrize( + "auth", + [ + HTTPBearer(not_authenticated_status_code=403, auto_error=False), + OpenIdConnect( + not_authenticated_status_code=403, + openIdConnectUrl="/openid", + auto_error=False, + ), + OAuth2( + not_authenticated_status_code=403, + flows={"password": {"tokenUrl": "token", "scopes": {}}}, + auto_error=False, + ), + ], +) +def test_oauth2_status_code_403_on_auth_error_no_auto_error( + auth: Union[HTTPBase, OpenIdConnect], +): + """ + Test temporary `not_authenticated_status_code` parameter for security classes that + follow rfc6750. + With `auto_error=False`. Response code should be 200 + """ + + app = FastAPI() + + @app.get("/") + async def protected(_: str = Security(auth)): + pass # pragma: no cover + + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 200 + + +def test_digest_status_code_403_on_auth_error(): + """ + Test temporary `not_authenticated_status_code` parameter for `Digest` scheme. + """ + + app = FastAPI() + + auth = HTTPDigest(not_authenticated_status_code=403) + + @app.get("/") + async def protected(_: str = Security(auth)): + pass # pragma: no cover + + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 403 + assert response.json() == {"detail": "Not authenticated"} + + +def test_digest_status_code_403_on_auth_error_no_auto_error(): + """ + Test temporary `not_authenticated_status_code` parameter for `Digest` scheme with + `auto_error=False`. + """ + + app = FastAPI() + + auth = HTTPDigest(not_authenticated_status_code=403, auto_error=False) + + @app.get("/") + async def protected(_: str = Security(auth)): + pass # pragma: no cover + + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 200 + + +def test_httpbase_status_code_403_on_auth_error(): + """ + Test temporary `not_authenticated_status_code` parameter for `HTTPBase` class. + """ + + app = FastAPI() + + auth = HTTPBase(scheme="Other", not_authenticated_status_code=403) + + @app.get("/") + async def protected(_: str = Security(auth)): + pass # pragma: no cover + + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 403 + assert response.json() == {"detail": "Not authenticated"} + + +def test_httpbase_status_code_403_on_auth_error_no_auto_error(): + """ + Test temporary `not_authenticated_status_code` parameter for `HTTPBase` class with + `auto_error=False`. + """ + + app = FastAPI() + + auth = HTTPBase(scheme="Other", not_authenticated_status_code=403, auto_error=False) + + @app.get("/") + async def protected(_: str = Security(auth)): + pass # pragma: no cover + + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 200