Manish 3 days ago
committed by GitHub
parent
commit
6d89a273b0
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 23
      fastapi/security/api_key.py
  2. 26
      fastapi/security/http.py
  3. 6
      fastapi/security/oauth2.py
  4. 8
      fastapi/security/open_id_connect_url.py
  5. 3
      tests/test_security_api_key_cookie.py
  6. 3
      tests/test_security_api_key_cookie_description.py
  7. 3
      tests/test_security_api_key_header.py
  8. 3
      tests/test_security_api_key_header_description.py
  9. 3
      tests/test_security_api_key_query.py
  10. 3
      tests/test_security_api_key_query_description.py
  11. 3
      tests/test_security_http_base.py
  12. 3
      tests/test_security_http_base_description.py
  13. 6
      tests/test_security_http_bearer.py
  14. 6
      tests/test_security_http_bearer_description.py
  15. 6
      tests/test_security_http_digest.py
  16. 6
      tests/test_security_http_digest_description.py
  17. 3
      tests/test_security_oauth2.py
  18. 2
      tests/test_security_oauth2_authorization_code_bearer.py
  19. 2
      tests/test_security_oauth2_authorization_code_bearer_description.py
  20. 3
      tests/test_security_openid_connect.py
  21. 3
      tests/test_security_openid_connect_description.py

23
fastapi/security/api_key.py

@ -4,17 +4,22 @@ from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated, Doc from typing_extensions import Annotated, Doc
class APIKeyBase(SecurityBase): class APIKeyBase(SecurityBase):
@staticmethod @staticmethod
def check_api_key(api_key: Optional[str], auto_error: bool) -> Optional[str]: def check_api_key(
api_key: Optional[str], auto_error: bool, key_name: str, key_in: APIKeyIn
) -> Optional[str]:
if not api_key: if not api_key:
if auto_error: if auto_error:
auth_header = f'ApiKey name="{key_name}", in="{key_in.value}"'
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": auth_header},
) )
return None return None
return api_key return api_key
@ -109,7 +114,9 @@ class APIKeyQuery(APIKeyBase):
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: Request) -> Optional[str]:
api_key = request.query_params.get(self.model.name) api_key = request.query_params.get(self.model.name)
return self.check_api_key(api_key, self.auto_error) return self.check_api_key(
api_key, self.auto_error, self.model.name, APIKeyIn.query
)
class APIKeyHeader(APIKeyBase): class APIKeyHeader(APIKeyBase):
@ -197,7 +204,9 @@ class APIKeyHeader(APIKeyBase):
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: Request) -> Optional[str]:
api_key = request.headers.get(self.model.name) api_key = request.headers.get(self.model.name)
return self.check_api_key(api_key, self.auto_error) return self.check_api_key(
api_key, self.auto_error, self.model.name, APIKeyIn.header
)
class APIKeyCookie(APIKeyBase): class APIKeyCookie(APIKeyBase):
@ -285,4 +294,6 @@ class APIKeyCookie(APIKeyBase):
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: Request) -> Optional[str]:
api_key = request.cookies.get(self.model.name) api_key = request.cookies.get(self.model.name)
return self.check_api_key(api_key, self.auto_error) return self.check_api_key(
api_key, self.auto_error, self.model.name, APIKeyIn.cookie
)

26
fastapi/security/http.py

@ -9,7 +9,7 @@ from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel from pydantic import BaseModel
from starlette.requests import Request from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated, Doc from typing_extensions import Annotated, Doc
@ -75,7 +75,9 @@ class HTTPBase(SecurityBase):
description: Optional[str] = None, description: Optional[str] = None,
auto_error: bool = True, auto_error: bool = True,
): ):
self.model = HTTPBaseModel(scheme=scheme, description=description) self.model: HTTPBaseModel = HTTPBaseModel(
scheme=scheme, description=description
)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@ -87,8 +89,11 @@ class HTTPBase(SecurityBase):
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
if self.auto_error: if self.auto_error:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": self.model.scheme},
) )
else: else:
return None return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
@ -306,15 +311,18 @@ class HTTPBearer(HTTPBase):
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
if self.auto_error: if self.auto_error:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
) )
else: else:
return None return None
if scheme.lower() != "bearer": if scheme.lower() != "bearer":
if self.auto_error: if self.auto_error:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials", detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
) )
else: else:
return None return None
@ -408,16 +416,20 @@ class HTTPDigest(HTTPBase):
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
if self.auto_error: if self.auto_error:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Digest"},
) )
else: else:
return None return None
if scheme.lower() != "digest": if scheme.lower() != "digest":
if self.auto_error: if self.auto_error:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials", detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Digest"},
) )
else: else:
return None return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)

6
fastapi/security/oauth2.py

@ -7,7 +7,7 @@ from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED
# TODO: import from typing when deprecating Python 3.9 # TODO: import from typing when deprecating Python 3.9
from typing_extensions import Annotated, Doc from typing_extensions import Annotated, Doc
@ -381,7 +381,9 @@ class OAuth2(SecurityBase):
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
) )
else: else:
return None return None

8
fastapi/security/open_id_connect_url.py

@ -4,7 +4,7 @@ from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated, Doc from typing_extensions import Annotated, Doc
@ -77,7 +77,11 @@ class OpenIdConnect(SecurityBase):
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={
"WWW-Authenticate": "Bearer",
},
) )
else: else:
return None return None

3
tests/test_security_api_key_cookie.py

@ -32,8 +32,9 @@ def test_security_api_key():
def test_security_api_key_no_key(): def test_security_api_key_no_key():
client = TestClient(app) client = TestClient(app)
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == 'ApiKey name="key", in="cookie"'
def test_openapi_schema(): def test_openapi_schema():

3
tests/test_security_api_key_cookie_description.py

@ -32,8 +32,9 @@ def test_security_api_key():
def test_security_api_key_no_key(): def test_security_api_key_no_key():
client = TestClient(app) client = TestClient(app)
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == 'ApiKey name="key", in="cookie"'
def test_openapi_schema(): def test_openapi_schema():

3
tests/test_security_api_key_header.py

@ -33,8 +33,9 @@ def test_security_api_key():
def test_security_api_key_no_key(): def test_security_api_key_no_key():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == 'ApiKey name="key", in="header"'
def test_openapi_schema(): def test_openapi_schema():

3
tests/test_security_api_key_header_description.py

@ -33,8 +33,9 @@ def test_security_api_key():
def test_security_api_key_no_key(): def test_security_api_key_no_key():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == 'ApiKey name="key", in="header"'
def test_openapi_schema(): def test_openapi_schema():

3
tests/test_security_api_key_query.py

@ -33,8 +33,9 @@ def test_security_api_key():
def test_security_api_key_no_key(): def test_security_api_key_no_key():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == 'ApiKey name="key", in="query"'
def test_openapi_schema(): def test_openapi_schema():

3
tests/test_security_api_key_query_description.py

@ -33,8 +33,9 @@ def test_security_api_key():
def test_security_api_key_no_key(): def test_security_api_key_no_key():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == 'ApiKey name="key", in="query"'
def test_openapi_schema(): def test_openapi_schema():

3
tests/test_security_http_base.py

@ -23,8 +23,9 @@ def test_security_http_base():
def test_security_http_base_no_credentials(): def test_security_http_base_no_credentials():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Other"
def test_openapi_schema(): def test_openapi_schema():

3
tests/test_security_http_base_description.py

@ -23,8 +23,9 @@ def test_security_http_base():
def test_security_http_base_no_credentials(): def test_security_http_base_no_credentials():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.status_code == 401, response.text
assert response.headers["WWW-Authenticate"] == "Other"
def test_openapi_schema(): def test_openapi_schema():

6
tests/test_security_http_bearer.py

@ -23,14 +23,16 @@ def test_security_http_bearer():
def test_security_http_bearer_no_credentials(): def test_security_http_bearer_no_credentials():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_security_http_bearer_incorrect_scheme_credentials(): def test_security_http_bearer_incorrect_scheme_credentials():
response = client.get("/users/me", headers={"Authorization": "Basic notreally"}) response = client.get("/users/me", headers={"Authorization": "Basic notreally"})
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Invalid authentication credentials"} assert response.json() == {"detail": "Invalid authentication credentials"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_openapi_schema(): def test_openapi_schema():

6
tests/test_security_http_bearer_description.py

@ -23,14 +23,16 @@ def test_security_http_bearer():
def test_security_http_bearer_no_credentials(): def test_security_http_bearer_no_credentials():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_security_http_bearer_incorrect_scheme_credentials(): def test_security_http_bearer_incorrect_scheme_credentials():
response = client.get("/users/me", headers={"Authorization": "Basic notreally"}) response = client.get("/users/me", headers={"Authorization": "Basic notreally"})
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Invalid authentication credentials"} assert response.json() == {"detail": "Invalid authentication credentials"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_openapi_schema(): def test_openapi_schema():

6
tests/test_security_http_digest.py

@ -23,16 +23,18 @@ def test_security_http_digest():
def test_security_http_digest_no_credentials(): def test_security_http_digest_no_credentials():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Digest"
def test_security_http_digest_incorrect_scheme_credentials(): def test_security_http_digest_incorrect_scheme_credentials():
response = client.get( response = client.get(
"/users/me", headers={"Authorization": "Other invalidauthorization"} "/users/me", headers={"Authorization": "Other invalidauthorization"}
) )
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Invalid authentication credentials"} assert response.json() == {"detail": "Invalid authentication credentials"}
assert response.headers["WWW-Authenticate"] == "Digest"
def test_openapi_schema(): def test_openapi_schema():

6
tests/test_security_http_digest_description.py

@ -23,15 +23,17 @@ def test_security_http_digest():
def test_security_http_digest_no_credentials(): def test_security_http_digest_no_credentials():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Digest"
def test_security_http_digest_incorrect_scheme_credentials(): def test_security_http_digest_incorrect_scheme_credentials():
response = client.get( response = client.get(
"/users/me", headers={"Authorization": "Other invalidauthorization"} "/users/me", headers={"Authorization": "Other invalidauthorization"}
) )
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.headers["WWW-Authenticate"] == "Digest"
assert response.json() == {"detail": "Invalid authentication credentials"} assert response.json() == {"detail": "Invalid authentication credentials"}

3
tests/test_security_oauth2.py

@ -56,8 +56,9 @@ def test_security_oauth2_password_other_header():
def test_security_oauth2_password_bearer_no_header(): def test_security_oauth2_password_bearer_no_header():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_strict_login_no_data(): def test_strict_login_no_data():

2
tests/test_security_oauth2_authorization_code_bearer.py

@ -23,12 +23,14 @@ def test_no_token():
response = client.get("/items") response = client.get("/items")
assert response.status_code == 401, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_incorrect_token(): def test_incorrect_token():
response = client.get("/items", headers={"Authorization": "Non-existent testtoken"}) response = client.get("/items", headers={"Authorization": "Non-existent testtoken"})
assert response.status_code == 401, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_token(): def test_token():

2
tests/test_security_oauth2_authorization_code_bearer_description.py

@ -26,12 +26,14 @@ def test_no_token():
response = client.get("/items") response = client.get("/items")
assert response.status_code == 401, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_incorrect_token(): def test_incorrect_token():
response = client.get("/items", headers={"Authorization": "Non-existent testtoken"}) response = client.get("/items", headers={"Authorization": "Non-existent testtoken"})
assert response.status_code == 401, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_token(): def test_token():

3
tests/test_security_openid_connect.py

@ -39,8 +39,9 @@ def test_security_oauth2_password_other_header():
def test_security_oauth2_password_bearer_no_header(): def test_security_oauth2_password_bearer_no_header():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_openapi_schema(): def test_openapi_schema():

3
tests/test_security_openid_connect_description.py

@ -41,8 +41,9 @@ def test_security_oauth2_password_other_header():
def test_security_oauth2_password_bearer_no_header(): def test_security_oauth2_password_bearer_no_header():
response = client.get("/users/me") response = client.get("/users/me")
assert response.status_code == 403, response.text assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
assert response.headers["WWW-Authenticate"] == "Bearer"
def test_openapi_schema(): def test_openapi_schema():

Loading…
Cancel
Save