Browse Source

Add auto_error to security utils (#134)

to allow them to be optional, also allowing the declaration of multiple security schemes
pull/138/head
Sebastián Ramírez 6 years ago
committed by GitHub
parent
commit
fad3a9e1dc
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 50
      fastapi/security/api_key.py
  2. 73
      fastapi/security/http.py
  3. 39
      fastapi/security/oauth2.py
  4. 18
      fastapi/security/open_id_connect_url.py
  5. 75
      tests/test_security_api_key_cookie_optional.py
  6. 74
      tests/test_security_api_key_header_optional.py
  7. 74
      tests/test_security_api_key_query_optional.py
  8. 62
      tests/test_security_http_base_optional.py
  9. 79
      tests/test_security_http_basic_optional.py
  10. 68
      tests/test_security_http_bearer_optional.py
  11. 70
      tests/test_security_http_digest_optional.py
  12. 254
      tests/test_security_oauth2_optional.py
  13. 71
      tests/test_security_oauth2_password_bearer_optional.py
  14. 80
      tests/test_security_openid_connect_optional.py

50
fastapi/security/api_key.py

@ -1,3 +1,5 @@
from typing import Optional
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
@ -10,42 +12,54 @@ class APIKeyBase(SecurityBase):
class APIKeyQuery(APIKeyBase): class APIKeyQuery(APIKeyBase):
def __init__(self, *, name: str, scheme_name: str = None): def __init__(self, *, name: str, scheme_name: str = None, auto_error: bool = True):
self.model = APIKey(**{"in": APIKeyIn.query}, name=name) self.model: APIKey = APIKey(**{"in": APIKeyIn.query}, name=name)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> Optional[str]:
api_key: str = request.query_params.get(self.model.name) api_key: str = request.query_params.get(self.model.name)
if not api_key: if not api_key:
raise HTTPException( if self.auto_error:
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" raise HTTPException(
) status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return api_key return api_key
class APIKeyHeader(APIKeyBase): class APIKeyHeader(APIKeyBase):
def __init__(self, *, name: str, scheme_name: str = None): def __init__(self, *, name: str, scheme_name: str = None, auto_error: bool = True):
self.model = APIKey(**{"in": APIKeyIn.header}, name=name) self.model: APIKey = APIKey(**{"in": APIKeyIn.header}, name=name)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> Optional[str]:
api_key: str = request.headers.get(self.model.name) api_key: str = request.headers.get(self.model.name)
if not api_key: if not api_key:
raise HTTPException( if self.auto_error:
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" raise HTTPException(
) status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return api_key return api_key
class APIKeyCookie(APIKeyBase): class APIKeyCookie(APIKeyBase):
def __init__(self, *, name: str, scheme_name: str = None): def __init__(self, *, name: str, scheme_name: str = None, auto_error: bool = True):
self.model = APIKey(**{"in": APIKeyIn.cookie}, name=name) self.model: APIKey = APIKey(**{"in": APIKeyIn.cookie}, name=name)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> Optional[str]:
api_key: str = request.cookies.get(self.model.name) api_key: str = request.cookies.get(self.model.name)
if not api_key: if not api_key:
raise HTTPException( if self.auto_error:
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" raise HTTPException(
) status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return api_key return api_key

73
fastapi/security/http.py

@ -1,5 +1,6 @@
import binascii import binascii
from base64 import b64decode from base64 import b64decode
from typing import Optional
from fastapi.openapi.models import ( from fastapi.openapi.models import (
HTTPBase as HTTPBaseModel, HTTPBase as HTTPBaseModel,
@ -24,27 +25,38 @@ class HTTPAuthorizationCredentials(BaseModel):
class HTTPBase(SecurityBase): class HTTPBase(SecurityBase):
def __init__(self, *, scheme: str, scheme_name: str = None): def __init__(
self, *, scheme: str, scheme_name: str = None, auto_error: bool = True
):
self.model = HTTPBaseModel(scheme=scheme) self.model = HTTPBaseModel(scheme=scheme)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> str: async def __call__(
self, request: Request
) -> Optional[HTTPAuthorizationCredentials]:
authorization: str = request.headers.get("Authorization") authorization: str = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
raise HTTPException( if self.auto_error:
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" raise HTTPException(
) status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
class HTTPBasic(HTTPBase): class HTTPBasic(HTTPBase):
def __init__(self, *, scheme_name: str = None, realm: str = None): def __init__(
self, *, scheme_name: str = None, realm: str = None, auto_error: bool = True
):
self.model = HTTPBaseModel(scheme="basic") self.model = HTTPBaseModel(scheme="basic")
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.realm = realm self.realm = realm
self.auto_error = auto_error
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> Optional[HTTPBasicCredentials]:
authorization: str = request.headers.get("Authorization") authorization: str = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
# before implementing headers with 401 errors, wait for: https://github.com/encode/starlette/issues/295 # before implementing headers with 401 errors, wait for: https://github.com/encode/starlette/issues/295
@ -53,9 +65,12 @@ class HTTPBasic(HTTPBase):
status_code=HTTP_403_FORBIDDEN, detail="Invalid authentication credentials" status_code=HTTP_403_FORBIDDEN, detail="Invalid authentication credentials"
) )
if not authorization or scheme.lower() != "basic": if not authorization or scheme.lower() != "basic":
raise HTTPException( if self.auto_error:
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" raise HTTPException(
) status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
try: try:
data = b64decode(param).decode("ascii") data = b64decode(param).decode("ascii")
except (ValueError, UnicodeDecodeError, binascii.Error): except (ValueError, UnicodeDecodeError, binascii.Error):
@ -67,17 +82,29 @@ class HTTPBasic(HTTPBase):
class HTTPBearer(HTTPBase): class HTTPBearer(HTTPBase):
def __init__(self, *, bearerFormat: str = None, scheme_name: str = None): def __init__(
self,
*,
bearerFormat: str = None,
scheme_name: str = None,
auto_error: bool = True
):
self.model = HTTPBearerModel(bearerFormat=bearerFormat) self.model = HTTPBearerModel(bearerFormat=bearerFormat)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> str: async def __call__(
self, request: Request
) -> Optional[HTTPAuthorizationCredentials]:
authorization: str = request.headers.get("Authorization") authorization: str = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
raise HTTPException( if self.auto_error:
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" raise HTTPException(
) status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
if scheme.lower() != "bearer": if scheme.lower() != "bearer":
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, status_code=HTTP_403_FORBIDDEN,
@ -87,17 +114,23 @@ class HTTPBearer(HTTPBase):
class HTTPDigest(HTTPBase): class HTTPDigest(HTTPBase):
def __init__(self, *, scheme_name: str = None): def __init__(self, *, scheme_name: str = None, auto_error: bool = True):
self.model = HTTPBaseModel(scheme="digest") self.model = HTTPBaseModel(scheme="digest")
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> str: async def __call__(
self, request: Request
) -> Optional[HTTPAuthorizationCredentials]:
authorization: str = request.headers.get("Authorization") authorization: str = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
raise HTTPException( if self.auto_error:
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" raise HTTPException(
) status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
if scheme.lower() != "digest": if scheme.lower() != "digest":
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, status_code=HTTP_403_FORBIDDEN,

39
fastapi/security/oauth2.py

@ -113,32 +113,49 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
class OAuth2(SecurityBase): class OAuth2(SecurityBase):
def __init__( def __init__(
self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None self,
*,
flows: OAuthFlowsModel = OAuthFlowsModel(),
scheme_name: str = None,
auto_error: bool = True
): ):
self.model = OAuth2Model(flows=flows) self.model = OAuth2Model(flows=flows)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization") authorization: str = request.headers.get("Authorization")
if not authorization: if not authorization:
raise HTTPException( if self.auto_error:
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" raise HTTPException(
) status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return authorization return authorization
class OAuth2PasswordBearer(OAuth2): class OAuth2PasswordBearer(OAuth2):
def __init__(self, tokenUrl: str, scheme_name: str = None, scopes: dict = None): def __init__(
self,
tokenUrl: str,
scheme_name: str = None,
scopes: dict = None,
auto_error: bool = True,
):
if not scopes: if not scopes:
scopes = {} scopes = {}
flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes}) flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes})
super().__init__(flows=flows, scheme_name=scheme_name) super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization") authorization: str = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":
raise HTTPException( if self.auto_error:
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" raise HTTPException(
) status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return param return param

18
fastapi/security/open_id_connect_url.py

@ -1,3 +1,5 @@
from typing import Optional
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
@ -6,14 +8,20 @@ from starlette.status import HTTP_403_FORBIDDEN
class OpenIdConnect(SecurityBase): class OpenIdConnect(SecurityBase):
def __init__(self, *, openIdConnectUrl: str, scheme_name: str = None): def __init__(
self, *, openIdConnectUrl: str, scheme_name: str = None, auto_error: bool = True
):
self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl) self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization") authorization: str = request.headers.get("Authorization")
if not authorization: if not authorization:
raise HTTPException( if self.auto_error:
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" raise HTTPException(
) status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return authorization return authorization

75
tests/test_security_api_key_cookie_optional.py

@ -0,0 +1,75 @@
from typing import Optional
from fastapi import Depends, FastAPI, Security
from fastapi.security import APIKeyCookie
from pydantic import BaseModel
from starlette.testclient import TestClient
app = FastAPI()
api_key = APIKeyCookie(name="key", auto_error=False)
class User(BaseModel):
username: str
def get_current_user(oauth_header: Optional[str] = Security(api_key)):
if oauth_header is None:
return None
user = User(username=oauth_header)
return user
@app.get("/users/me")
def read_current_user(current_user: User = Depends(get_current_user)):
if current_user is None:
return {"msg": "Create an account first"}
else:
return current_user
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/users/me": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User Get",
"operationId": "read_current_user_users_me_get",
"security": [{"APIKeyCookie": []}],
}
}
},
"components": {
"securitySchemes": {
"APIKeyCookie": {"type": "apiKey", "name": "key", "in": "cookie"}
}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_security_api_key():
response = client.get("/users/me", cookies={"key": "secret"})
assert response.status_code == 200
assert response.json() == {"username": "secret"}
def test_security_api_key_no_key():
response = client.get("/users/me")
assert response.status_code == 200
assert response.json() == {"msg": "Create an account first"}

74
tests/test_security_api_key_header_optional.py

@ -0,0 +1,74 @@
from typing import Optional
from fastapi import Depends, FastAPI, Security
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from starlette.testclient import TestClient
app = FastAPI()
api_key = APIKeyHeader(name="key", auto_error=False)
class User(BaseModel):
username: str
def get_current_user(oauth_header: Optional[str] = Security(api_key)):
if oauth_header is None:
return None
user = User(username=oauth_header)
return user
@app.get("/users/me")
def read_current_user(current_user: Optional[User] = Depends(get_current_user)):
if current_user is None:
return {"msg": "Create an account first"}
return current_user
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/users/me": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User Get",
"operationId": "read_current_user_users_me_get",
"security": [{"APIKeyHeader": []}],
}
}
},
"components": {
"securitySchemes": {
"APIKeyHeader": {"type": "apiKey", "name": "key", "in": "header"}
}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_security_api_key():
response = client.get("/users/me", headers={"key": "secret"})
assert response.status_code == 200
assert response.json() == {"username": "secret"}
def test_security_api_key_no_key():
response = client.get("/users/me")
assert response.status_code == 200
assert response.json() == {"msg": "Create an account first"}

74
tests/test_security_api_key_query_optional.py

@ -0,0 +1,74 @@
from typing import Optional
from fastapi import Depends, FastAPI, Security
from fastapi.security import APIKeyQuery
from pydantic import BaseModel
from starlette.testclient import TestClient
app = FastAPI()
api_key = APIKeyQuery(name="key", auto_error=False)
class User(BaseModel):
username: str
def get_current_user(oauth_header: Optional[str] = Security(api_key)):
if oauth_header is None:
return None
user = User(username=oauth_header)
return user
@app.get("/users/me")
def read_current_user(current_user: Optional[User] = Depends(get_current_user)):
if current_user is None:
return {"msg": "Create an account first"}
return current_user
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/users/me": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User Get",
"operationId": "read_current_user_users_me_get",
"security": [{"APIKeyQuery": []}],
}
}
},
"components": {
"securitySchemes": {
"APIKeyQuery": {"type": "apiKey", "name": "key", "in": "query"}
}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_security_api_key():
response = client.get("/users/me?key=secret")
assert response.status_code == 200
assert response.json() == {"username": "secret"}
def test_security_api_key_no_key():
response = client.get("/users/me")
assert response.status_code == 200
assert response.json() == {"msg": "Create an account first"}

62
tests/test_security_http_base_optional.py

@ -0,0 +1,62 @@
from typing import Optional
from fastapi import FastAPI, Security
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase
from starlette.testclient import TestClient
app = FastAPI()
security = HTTPBase(scheme="Other", auto_error=False)
@app.get("/users/me")
def read_current_user(
credentials: Optional[HTTPAuthorizationCredentials] = Security(security)
):
if credentials is None:
return {"msg": "Create an account first"}
return {"scheme": credentials.scheme, "credentials": credentials.credentials}
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/users/me": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User Get",
"operationId": "read_current_user_users_me_get",
"security": [{"HTTPBase": []}],
}
}
},
"components": {
"securitySchemes": {"HTTPBase": {"type": "http", "scheme": "Other"}}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_security_http_base():
response = client.get("/users/me", headers={"Authorization": "Other foobar"})
assert response.status_code == 200
assert response.json() == {"scheme": "Other", "credentials": "foobar"}
def test_security_http_base_no_credentials():
response = client.get("/users/me")
assert response.status_code == 200
assert response.json() == {"msg": "Create an account first"}

79
tests/test_security_http_basic_optional.py

@ -0,0 +1,79 @@
from base64 import b64encode
from typing import Optional
from fastapi import FastAPI, Security
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from requests.auth import HTTPBasicAuth
from starlette.testclient import TestClient
app = FastAPI()
security = HTTPBasic(auto_error=False)
@app.get("/users/me")
def read_current_user(credentials: Optional[HTTPBasicCredentials] = Security(security)):
if credentials is None:
return {"msg": "Create an account first"}
return {"username": credentials.username, "password": credentials.password}
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/users/me": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User Get",
"operationId": "read_current_user_users_me_get",
"security": [{"HTTPBasic": []}],
}
}
},
"components": {
"securitySchemes": {"HTTPBasic": {"type": "http", "scheme": "basic"}}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_security_http_basic():
auth = HTTPBasicAuth(username="john", password="secret")
response = client.get("/users/me", auth=auth)
assert response.status_code == 200
assert response.json() == {"username": "john", "password": "secret"}
def test_security_http_basic_no_credentials():
response = client.get("/users/me")
assert response.status_code == 200
assert response.json() == {"msg": "Create an account first"}
def test_security_http_basic_invalid_credentials():
response = client.get(
"/users/me", headers={"Authorization": "Basic notabase64token"}
)
assert response.status_code == 403
assert response.json() == {"detail": "Invalid authentication credentials"}
def test_security_http_basic_non_basic_credentials():
payload = b64encode(b"johnsecret").decode("ascii")
auth_header = f"Basic {payload}"
response = client.get("/users/me", headers={"Authorization": auth_header})
assert response.status_code == 403
assert response.json() == {"detail": "Invalid authentication credentials"}

68
tests/test_security_http_bearer_optional.py

@ -0,0 +1,68 @@
from typing import Optional
from fastapi import FastAPI, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette.testclient import TestClient
app = FastAPI()
security = HTTPBearer(auto_error=False)
@app.get("/users/me")
def read_current_user(
credentials: Optional[HTTPAuthorizationCredentials] = Security(security)
):
if credentials is None:
return {"msg": "Create an account first"}
return {"scheme": credentials.scheme, "credentials": credentials.credentials}
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/users/me": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User Get",
"operationId": "read_current_user_users_me_get",
"security": [{"HTTPBearer": []}],
}
}
},
"components": {
"securitySchemes": {"HTTPBearer": {"type": "http", "scheme": "bearer"}}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_security_http_bearer():
response = client.get("/users/me", headers={"Authorization": "Bearer foobar"})
assert response.status_code == 200
assert response.json() == {"scheme": "Bearer", "credentials": "foobar"}
def test_security_http_bearer_no_credentials():
response = client.get("/users/me")
assert response.status_code == 200
assert response.json() == {"msg": "Create an account first"}
def test_security_http_bearer_incorrect_scheme_credentials():
response = client.get("/users/me", headers={"Authorization": "Basic notreally"})
assert response.status_code == 403
assert response.json() == {"detail": "Invalid authentication credentials"}

70
tests/test_security_http_digest_optional.py

@ -0,0 +1,70 @@
from typing import Optional
from fastapi import FastAPI, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPDigest
from starlette.testclient import TestClient
app = FastAPI()
security = HTTPDigest(auto_error=False)
@app.get("/users/me")
def read_current_user(
credentials: Optional[HTTPAuthorizationCredentials] = Security(security)
):
if credentials is None:
return {"msg": "Create an account first"}
return {"scheme": credentials.scheme, "credentials": credentials.credentials}
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/users/me": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User Get",
"operationId": "read_current_user_users_me_get",
"security": [{"HTTPDigest": []}],
}
}
},
"components": {
"securitySchemes": {"HTTPDigest": {"type": "http", "scheme": "digest"}}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_security_http_digest():
response = client.get("/users/me", headers={"Authorization": "Digest foobar"})
assert response.status_code == 200
assert response.json() == {"scheme": "Digest", "credentials": "foobar"}
def test_security_http_digest_no_credentials():
response = client.get("/users/me")
assert response.status_code == 200
assert response.json() == {"msg": "Create an account first"}
def test_security_http_digest_incorrect_scheme_credentials():
response = client.get(
"/users/me", headers={"Authorization": "Other invalidauthorization"}
)
assert response.status_code == 403
assert response.json() == {"detail": "Invalid authentication credentials"}

254
tests/test_security_oauth2_optional.py

@ -0,0 +1,254 @@
from typing import Optional
import pytest
from fastapi import Depends, FastAPI, Security
from fastapi.security import OAuth2
from fastapi.security.oauth2 import OAuth2PasswordRequestFormStrict
from pydantic import BaseModel
from starlette.testclient import TestClient
app = FastAPI()
reusable_oauth2 = OAuth2(
flows={
"password": {
"tokenUrl": "/token",
"scopes": {"read:users": "Read the users", "write:users": "Create users"},
}
},
auto_error=False,
)
class User(BaseModel):
username: str
def get_current_user(oauth_header: Optional[str] = Security(reusable_oauth2)):
if oauth_header is None:
return None
user = User(username=oauth_header)
return user
@app.post("/login")
def read_current_user(form_data: OAuth2PasswordRequestFormStrict = Depends()):
return form_data
@app.get("/users/me")
def read_current_user(current_user: Optional[User] = Depends(get_current_user)):
if current_user is None:
return {"msg": "Create an account first"}
return current_user
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/login": {
"post": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
"summary": "Read Current User Post",
"operationId": "read_current_user_login_post",
"requestBody": {
"content": {
"application/x-www-form-urlencoded": {
"schema": {
"$ref": "#/components/schemas/Body_read_current_user"
}
}
},
"required": True,
},
}
},
"/users/me": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User Get",
"operationId": "read_current_user_users_me_get",
"security": [{"OAuth2": []}],
}
},
},
"components": {
"schemas": {
"Body_read_current_user": {
"title": "Body_read_current_user",
"required": ["grant_type", "username", "password"],
"type": "object",
"properties": {
"grant_type": {
"title": "Grant_Type",
"pattern": "password",
"type": "string",
},
"username": {"title": "Username", "type": "string"},
"password": {"title": "Password", "type": "string"},
"scope": {"title": "Scope", "type": "string", "default": ""},
"client_id": {"title": "Client_Id", "type": "string"},
"client_secret": {"title": "Client_Secret", "type": "string"},
},
},
"ValidationError": {
"title": "ValidationError",
"required": ["loc", "msg", "type"],
"type": "object",
"properties": {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
},
"HTTPValidationError": {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": "#/components/schemas/ValidationError"},
}
},
},
},
"securitySchemes": {
"OAuth2": {
"type": "oauth2",
"flows": {
"password": {
"scopes": {
"read:users": "Read the users",
"write:users": "Create users",
},
"tokenUrl": "/token",
}
},
}
},
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_security_oauth2():
response = client.get("/users/me", headers={"Authorization": "Bearer footokenbar"})
assert response.status_code == 200
assert response.json() == {"username": "Bearer footokenbar"}
def test_security_oauth2_password_other_header():
response = client.get("/users/me", headers={"Authorization": "Other footokenbar"})
assert response.status_code == 200
assert response.json() == {"username": "Other footokenbar"}
def test_security_oauth2_password_bearer_no_header():
response = client.get("/users/me")
assert response.status_code == 200
assert response.json() == {"msg": "Create an account first"}
required_params = {
"detail": [
{
"loc": ["body", "grant_type"],
"msg": "field required",
"type": "value_error.missing",
},
{
"loc": ["body", "username"],
"msg": "field required",
"type": "value_error.missing",
},
{
"loc": ["body", "password"],
"msg": "field required",
"type": "value_error.missing",
},
]
}
grant_type_required = {
"detail": [
{
"loc": ["body", "grant_type"],
"msg": "field required",
"type": "value_error.missing",
}
]
}
grant_type_incorrect = {
"detail": [
{
"loc": ["body", "grant_type"],
"msg": 'string does not match regex "password"',
"type": "value_error.str.regex",
"ctx": {"pattern": "password"},
}
]
}
@pytest.mark.parametrize(
"data,expected_status,expected_response",
[
(None, 422, required_params),
({"username": "johndoe", "password": "secret"}, 422, grant_type_required),
(
{"username": "johndoe", "password": "secret", "grant_type": "incorrect"},
422,
grant_type_incorrect,
),
(
{"username": "johndoe", "password": "secret", "grant_type": "password"},
200,
{
"grant_type": "password",
"username": "johndoe",
"password": "secret",
"scopes": [],
"client_id": None,
"client_secret": None,
},
),
],
)
def test_strict_login(data, expected_status, expected_response):
response = client.post("/login", data=data)
assert response.status_code == expected_status
assert response.json() == expected_response

71
tests/test_security_oauth2_password_bearer_optional.py

@ -0,0 +1,71 @@
from typing import Optional
from fastapi import FastAPI, Security
from fastapi.security import OAuth2PasswordBearer
from starlette.testclient import TestClient
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token", auto_error=False)
@app.get("/items/")
async def read_items(token: Optional[str] = Security(oauth2_scheme)):
if token is None:
return {"msg": "Create an account first"}
return {"token": token}
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/items/": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Items Get",
"operationId": "read_items_items__get",
"security": [{"OAuth2PasswordBearer": []}],
}
}
},
"components": {
"securitySchemes": {
"OAuth2PasswordBearer": {
"type": "oauth2",
"flows": {"password": {"scopes": {}, "tokenUrl": "/token"}},
}
}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_no_token():
response = client.get("/items")
assert response.status_code == 200
assert response.json() == {"msg": "Create an account first"}
def test_token():
response = client.get("/items", headers={"Authorization": "Bearer testtoken"})
assert response.status_code == 200
assert response.json() == {"token": "testtoken"}
def test_incorrect_token():
response = client.get("/items", headers={"Authorization": "Notexistent testtoken"})
assert response.status_code == 200
assert response.json() == {"msg": "Create an account first"}

80
tests/test_security_openid_connect_optional.py

@ -0,0 +1,80 @@
from typing import Optional
from fastapi import Depends, FastAPI, Security
from fastapi.security.open_id_connect_url import OpenIdConnect
from pydantic import BaseModel
from starlette.testclient import TestClient
app = FastAPI()
oid = OpenIdConnect(openIdConnectUrl="/openid", auto_error=False)
class User(BaseModel):
username: str
def get_current_user(oauth_header: Optional[str] = Security(oid)):
if oauth_header is None:
return None
user = User(username=oauth_header)
return user
@app.get("/users/me")
def read_current_user(current_user: Optional[User] = Depends(get_current_user)):
if current_user is None:
return {"msg": "Create an account first"}
return current_user
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/users/me": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User Get",
"operationId": "read_current_user_users_me_get",
"security": [{"OpenIdConnect": []}],
}
}
},
"components": {
"securitySchemes": {
"OpenIdConnect": {"type": "openIdConnect", "openIdConnectUrl": "/openid"}
}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_security_oauth2():
response = client.get("/users/me", headers={"Authorization": "Bearer footokenbar"})
assert response.status_code == 200
assert response.json() == {"username": "Bearer footokenbar"}
def test_security_oauth2_password_other_header():
response = client.get("/users/me", headers={"Authorization": "Other footokenbar"})
assert response.status_code == 200
assert response.json() == {"username": "Other footokenbar"}
def test_security_oauth2_password_bearer_no_header():
response = client.get("/users/me")
assert response.status_code == 200
assert response.json() == {"msg": "Create an account first"}
Loading…
Cancel
Save