Browse Source

Fix `HTTPBase` class status code on "Not authenticated" error

pull/13786/head
Yurii Motov 1 month ago
parent
commit
54d60db33d
  1. 10
      fastapi/security/http.py
  2. 3
      tests/test_security_http_base.py
  3. 3
      tests/test_security_http_base_description.py
  4. 40
      tests/test_security_status_code_403_option.py

10
fastapi/security/http.py

@ -74,10 +74,13 @@ class HTTPBase(SecurityBase):
scheme_name: Optional[str] = None, scheme_name: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,
auto_error: bool = True, auto_error: bool = True,
not_authenticated_status_code: Literal[401, 403] = 401,
): ):
self.model = HTTPBaseModel(scheme=scheme, description=description) self.model = HTTPBaseModel(scheme=scheme, description=description)
self.model_scheme = scheme
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
self.not_authenticated_status_code = not_authenticated_status_code
async def __call__( async def __call__(
self, request: Request self, request: Request
@ -86,9 +89,16 @@ class HTTPBase(SecurityBase):
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
if self.auto_error: if self.auto_error:
if self.not_authenticated_status_code == HTTP_403_FORBIDDEN:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
) )
else:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": self.model_scheme},
)
else: else:
return None return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)

3
tests/test_security_http_base.py

@ -23,8 +23,9 @@ def test_security_http_base():
def test_security_http_base_no_credentials(): def test_security_http_base_no_credentials():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Other"
def test_openapi_schema(): def test_openapi_schema():

3
tests/test_security_http_base_description.py

@ -23,8 +23,9 @@ def test_security_http_base():
def test_security_http_base_no_credentials(): def test_security_http_base_no_credentials():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Other"
def test_openapi_schema(): def test_openapi_schema():

40
tests/test_security_status_code_403_option.py

@ -147,3 +147,43 @@ def test_digest_status_code_403_on_auth_error_no_auto_error():
response = client.get("/") response = client.get("/")
assert response.status_code == 200 assert response.status_code == 200
def test_httpbase_status_code_403_on_auth_error():
"""
Test temporary `not_authenticated_status_code` parameter for `HTTPBase` class.
"""
app = FastAPI()
auth = HTTPBase(scheme="Other", not_authenticated_status_code=403)
@app.get("/")
async def protected(_: str = Security(auth)):
pass # pragma: no cover
client = TestClient(app)
response = client.get("/")
assert response.status_code == 403
assert response.json() == {"detail": "Not authenticated"}
def test_httpbase_status_code_403_on_auth_error_no_auto_error():
"""
Test temporary `not_authenticated_status_code` parameter for `HTTPBase` class with
`auto_error=False`.
"""
app = FastAPI()
auth = HTTPBase(scheme="Other", not_authenticated_status_code=403, auto_error=False)
@app.get("/")
async def protected(_: str = Security(auth)):
pass # pragma: no cover
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200

Loading…
Cancel
Save