diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py index 018e4f99e..356681eee 100644 --- a/fastapi/security/api_key.py +++ b/fastapi/security/api_key.py @@ -1,3 +1,5 @@ +from typing import Optional + from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.security.base import SecurityBase from starlette.exceptions import HTTPException @@ -10,42 +12,54 @@ class APIKeyBase(SecurityBase): class APIKeyQuery(APIKeyBase): - def __init__(self, *, name: str, scheme_name: str = None): - self.model = APIKey(**{"in": APIKeyIn.query}, name=name) + def __init__(self, *, name: str, scheme_name: str = None, auto_error: bool = True): + self.model: APIKey = APIKey(**{"in": APIKeyIn.query}, name=name) self.scheme_name = scheme_name or self.__class__.__name__ + self.auto_error = auto_error - async def __call__(self, request: Request) -> str: + async def __call__(self, request: Request) -> Optional[str]: api_key: str = request.query_params.get(self.model.name) if not api_key: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.auto_error: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + return None return api_key class APIKeyHeader(APIKeyBase): - def __init__(self, *, name: str, scheme_name: str = None): - self.model = APIKey(**{"in": APIKeyIn.header}, name=name) + def __init__(self, *, name: str, scheme_name: str = None, auto_error: bool = True): + self.model: APIKey = APIKey(**{"in": APIKeyIn.header}, name=name) self.scheme_name = scheme_name or self.__class__.__name__ + self.auto_error = auto_error - async def __call__(self, request: Request) -> str: + async def __call__(self, request: Request) -> Optional[str]: api_key: str = request.headers.get(self.model.name) if not api_key: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.auto_error: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + return None return api_key class APIKeyCookie(APIKeyBase): - def __init__(self, *, name: str, scheme_name: str = None): - self.model = APIKey(**{"in": APIKeyIn.cookie}, name=name) + def __init__(self, *, name: str, scheme_name: str = None, auto_error: bool = True): + self.model: APIKey = APIKey(**{"in": APIKeyIn.cookie}, name=name) self.scheme_name = scheme_name or self.__class__.__name__ + self.auto_error = auto_error - async def __call__(self, request: Request) -> str: + async def __call__(self, request: Request) -> Optional[str]: api_key: str = request.cookies.get(self.model.name) if not api_key: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.auto_error: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + return None return api_key diff --git a/fastapi/security/http.py b/fastapi/security/http.py index 3e4aeb67c..b2da3fcb5 100644 --- a/fastapi/security/http.py +++ b/fastapi/security/http.py @@ -1,5 +1,6 @@ import binascii from base64 import b64decode +from typing import Optional from fastapi.openapi.models import ( HTTPBase as HTTPBaseModel, @@ -24,27 +25,38 @@ class HTTPAuthorizationCredentials(BaseModel): class HTTPBase(SecurityBase): - def __init__(self, *, scheme: str, scheme_name: str = None): + def __init__( + self, *, scheme: str, scheme_name: str = None, auto_error: bool = True + ): self.model = HTTPBaseModel(scheme=scheme) self.scheme_name = scheme_name or self.__class__.__name__ + self.auto_error = auto_error - async def __call__(self, request: Request) -> str: + async def __call__( + self, request: Request + ) -> Optional[HTTPAuthorizationCredentials]: authorization: str = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) if not (authorization and scheme and credentials): - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.auto_error: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + return None return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) class HTTPBasic(HTTPBase): - def __init__(self, *, scheme_name: str = None, realm: str = None): + def __init__( + self, *, scheme_name: str = None, realm: str = None, auto_error: bool = True + ): self.model = HTTPBaseModel(scheme="basic") self.scheme_name = scheme_name or self.__class__.__name__ self.realm = realm + self.auto_error = auto_error - async def __call__(self, request: Request) -> str: + async def __call__(self, request: Request) -> Optional[HTTPBasicCredentials]: authorization: str = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) # before implementing headers with 401 errors, wait for: https://github.com/encode/starlette/issues/295 @@ -53,9 +65,12 @@ class HTTPBasic(HTTPBase): status_code=HTTP_403_FORBIDDEN, detail="Invalid authentication credentials" ) if not authorization or scheme.lower() != "basic": - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.auto_error: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + return None try: data = b64decode(param).decode("ascii") except (ValueError, UnicodeDecodeError, binascii.Error): @@ -67,17 +82,29 @@ class HTTPBasic(HTTPBase): class HTTPBearer(HTTPBase): - def __init__(self, *, bearerFormat: str = None, scheme_name: str = None): + def __init__( + self, + *, + bearerFormat: str = None, + scheme_name: str = None, + auto_error: bool = True + ): self.model = HTTPBearerModel(bearerFormat=bearerFormat) self.scheme_name = scheme_name or self.__class__.__name__ + self.auto_error = auto_error - async def __call__(self, request: Request) -> str: + async def __call__( + self, request: Request + ) -> Optional[HTTPAuthorizationCredentials]: authorization: str = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) if not (authorization and scheme and credentials): - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.auto_error: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + return None if scheme.lower() != "bearer": raise HTTPException( status_code=HTTP_403_FORBIDDEN, @@ -87,17 +114,23 @@ class HTTPBearer(HTTPBase): class HTTPDigest(HTTPBase): - def __init__(self, *, scheme_name: str = None): + def __init__(self, *, scheme_name: str = None, auto_error: bool = True): self.model = HTTPBaseModel(scheme="digest") self.scheme_name = scheme_name or self.__class__.__name__ + self.auto_error = auto_error - async def __call__(self, request: Request) -> str: + async def __call__( + self, request: Request + ) -> Optional[HTTPAuthorizationCredentials]: authorization: str = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) if not (authorization and scheme and credentials): - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.auto_error: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + return None if scheme.lower() != "digest": raise HTTPException( status_code=HTTP_403_FORBIDDEN, diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py index b1132fef1..d779bcae1 100644 --- a/fastapi/security/oauth2.py +++ b/fastapi/security/oauth2.py @@ -113,32 +113,49 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm): class OAuth2(SecurityBase): def __init__( - self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None + self, + *, + flows: OAuthFlowsModel = OAuthFlowsModel(), + scheme_name: str = None, + auto_error: bool = True ): self.model = OAuth2Model(flows=flows) self.scheme_name = scheme_name or self.__class__.__name__ + self.auto_error = auto_error - async def __call__(self, request: Request) -> str: + async def __call__(self, request: Request) -> Optional[str]: authorization: str = request.headers.get("Authorization") if not authorization: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.auto_error: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + return None return authorization class OAuth2PasswordBearer(OAuth2): - def __init__(self, tokenUrl: str, scheme_name: str = None, scopes: dict = None): + def __init__( + self, + tokenUrl: str, + scheme_name: str = None, + scopes: dict = None, + auto_error: bool = True, + ): if not scopes: scopes = {} flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes}) - super().__init__(flows=flows, scheme_name=scheme_name) + super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error) - async def __call__(self, request: Request) -> str: + async def __call__(self, request: Request) -> Optional[str]: authorization: str = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "bearer": - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.auto_error: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + return None return param diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py index e10f4a510..f4d5ab3f4 100644 --- a/fastapi/security/open_id_connect_url.py +++ b/fastapi/security/open_id_connect_url.py @@ -1,3 +1,5 @@ +from typing import Optional + from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.security.base import SecurityBase from starlette.exceptions import HTTPException @@ -6,14 +8,20 @@ from starlette.status import HTTP_403_FORBIDDEN class OpenIdConnect(SecurityBase): - def __init__(self, *, openIdConnectUrl: str, scheme_name: str = None): + def __init__( + self, *, openIdConnectUrl: str, scheme_name: str = None, auto_error: bool = True + ): self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl) self.scheme_name = scheme_name or self.__class__.__name__ + self.auto_error = auto_error - async def __call__(self, request: Request) -> str: + async def __call__(self, request: Request) -> Optional[str]: authorization: str = request.headers.get("Authorization") if not authorization: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) + if self.auto_error: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + return None return authorization diff --git a/tests/test_security_api_key_cookie_optional.py b/tests/test_security_api_key_cookie_optional.py new file mode 100644 index 000000000..3a962d8fe --- /dev/null +++ b/tests/test_security_api_key_cookie_optional.py @@ -0,0 +1,75 @@ +from typing import Optional + +from fastapi import Depends, FastAPI, Security +from fastapi.security import APIKeyCookie +from pydantic import BaseModel +from starlette.testclient import TestClient + +app = FastAPI() + +api_key = APIKeyCookie(name="key", auto_error=False) + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: Optional[str] = Security(api_key)): + if oauth_header is None: + return None + user = User(username=oauth_header) + return user + + +@app.get("/users/me") +def read_current_user(current_user: User = Depends(get_current_user)): + if current_user is None: + return {"msg": "Create an account first"} + else: + return current_user + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/users/me": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Current User Get", + "operationId": "read_current_user_users_me_get", + "security": [{"APIKeyCookie": []}], + } + } + }, + "components": { + "securitySchemes": { + "APIKeyCookie": {"type": "apiKey", "name": "key", "in": "cookie"} + } + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_security_api_key(): + response = client.get("/users/me", cookies={"key": "secret"}) + assert response.status_code == 200 + assert response.json() == {"username": "secret"} + + +def test_security_api_key_no_key(): + response = client.get("/users/me") + assert response.status_code == 200 + assert response.json() == {"msg": "Create an account first"} diff --git a/tests/test_security_api_key_header_optional.py b/tests/test_security_api_key_header_optional.py new file mode 100644 index 000000000..6dcb7b288 --- /dev/null +++ b/tests/test_security_api_key_header_optional.py @@ -0,0 +1,74 @@ +from typing import Optional + +from fastapi import Depends, FastAPI, Security +from fastapi.security import APIKeyHeader +from pydantic import BaseModel +from starlette.testclient import TestClient + +app = FastAPI() + +api_key = APIKeyHeader(name="key", auto_error=False) + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: Optional[str] = Security(api_key)): + if oauth_header is None: + return None + user = User(username=oauth_header) + return user + + +@app.get("/users/me") +def read_current_user(current_user: Optional[User] = Depends(get_current_user)): + if current_user is None: + return {"msg": "Create an account first"} + return current_user + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/users/me": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Current User Get", + "operationId": "read_current_user_users_me_get", + "security": [{"APIKeyHeader": []}], + } + } + }, + "components": { + "securitySchemes": { + "APIKeyHeader": {"type": "apiKey", "name": "key", "in": "header"} + } + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_security_api_key(): + response = client.get("/users/me", headers={"key": "secret"}) + assert response.status_code == 200 + assert response.json() == {"username": "secret"} + + +def test_security_api_key_no_key(): + response = client.get("/users/me") + assert response.status_code == 200 + assert response.json() == {"msg": "Create an account first"} diff --git a/tests/test_security_api_key_query_optional.py b/tests/test_security_api_key_query_optional.py new file mode 100644 index 000000000..0edc502e0 --- /dev/null +++ b/tests/test_security_api_key_query_optional.py @@ -0,0 +1,74 @@ +from typing import Optional + +from fastapi import Depends, FastAPI, Security +from fastapi.security import APIKeyQuery +from pydantic import BaseModel +from starlette.testclient import TestClient + +app = FastAPI() + +api_key = APIKeyQuery(name="key", auto_error=False) + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: Optional[str] = Security(api_key)): + if oauth_header is None: + return None + user = User(username=oauth_header) + return user + + +@app.get("/users/me") +def read_current_user(current_user: Optional[User] = Depends(get_current_user)): + if current_user is None: + return {"msg": "Create an account first"} + return current_user + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/users/me": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Current User Get", + "operationId": "read_current_user_users_me_get", + "security": [{"APIKeyQuery": []}], + } + } + }, + "components": { + "securitySchemes": { + "APIKeyQuery": {"type": "apiKey", "name": "key", "in": "query"} + } + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_security_api_key(): + response = client.get("/users/me?key=secret") + assert response.status_code == 200 + assert response.json() == {"username": "secret"} + + +def test_security_api_key_no_key(): + response = client.get("/users/me") + assert response.status_code == 200 + assert response.json() == {"msg": "Create an account first"} diff --git a/tests/test_security_http_base_optional.py b/tests/test_security_http_base_optional.py new file mode 100644 index 000000000..b42b63b2a --- /dev/null +++ b/tests/test_security_http_base_optional.py @@ -0,0 +1,62 @@ +from typing import Optional + +from fastapi import FastAPI, Security +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase +from starlette.testclient import TestClient + +app = FastAPI() + +security = HTTPBase(scheme="Other", auto_error=False) + + +@app.get("/users/me") +def read_current_user( + credentials: Optional[HTTPAuthorizationCredentials] = Security(security) +): + if credentials is None: + return {"msg": "Create an account first"} + return {"scheme": credentials.scheme, "credentials": credentials.credentials} + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/users/me": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Current User Get", + "operationId": "read_current_user_users_me_get", + "security": [{"HTTPBase": []}], + } + } + }, + "components": { + "securitySchemes": {"HTTPBase": {"type": "http", "scheme": "Other"}} + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_security_http_base(): + response = client.get("/users/me", headers={"Authorization": "Other foobar"}) + assert response.status_code == 200 + assert response.json() == {"scheme": "Other", "credentials": "foobar"} + + +def test_security_http_base_no_credentials(): + response = client.get("/users/me") + assert response.status_code == 200 + assert response.json() == {"msg": "Create an account first"} diff --git a/tests/test_security_http_basic_optional.py b/tests/test_security_http_basic_optional.py new file mode 100644 index 000000000..c5ee51a4d --- /dev/null +++ b/tests/test_security_http_basic_optional.py @@ -0,0 +1,79 @@ +from base64 import b64encode +from typing import Optional + +from fastapi import FastAPI, Security +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from requests.auth import HTTPBasicAuth +from starlette.testclient import TestClient + +app = FastAPI() + +security = HTTPBasic(auto_error=False) + + +@app.get("/users/me") +def read_current_user(credentials: Optional[HTTPBasicCredentials] = Security(security)): + if credentials is None: + return {"msg": "Create an account first"} + return {"username": credentials.username, "password": credentials.password} + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/users/me": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Current User Get", + "operationId": "read_current_user_users_me_get", + "security": [{"HTTPBasic": []}], + } + } + }, + "components": { + "securitySchemes": {"HTTPBasic": {"type": "http", "scheme": "basic"}} + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_security_http_basic(): + auth = HTTPBasicAuth(username="john", password="secret") + response = client.get("/users/me", auth=auth) + assert response.status_code == 200 + assert response.json() == {"username": "john", "password": "secret"} + + +def test_security_http_basic_no_credentials(): + response = client.get("/users/me") + assert response.status_code == 200 + assert response.json() == {"msg": "Create an account first"} + + +def test_security_http_basic_invalid_credentials(): + response = client.get( + "/users/me", headers={"Authorization": "Basic notabase64token"} + ) + assert response.status_code == 403 + assert response.json() == {"detail": "Invalid authentication credentials"} + + +def test_security_http_basic_non_basic_credentials(): + payload = b64encode(b"johnsecret").decode("ascii") + auth_header = f"Basic {payload}" + response = client.get("/users/me", headers={"Authorization": auth_header}) + assert response.status_code == 403 + assert response.json() == {"detail": "Invalid authentication credentials"} diff --git a/tests/test_security_http_bearer_optional.py b/tests/test_security_http_bearer_optional.py new file mode 100644 index 000000000..426793216 --- /dev/null +++ b/tests/test_security_http_bearer_optional.py @@ -0,0 +1,68 @@ +from typing import Optional + +from fastapi import FastAPI, Security +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from starlette.testclient import TestClient + +app = FastAPI() + +security = HTTPBearer(auto_error=False) + + +@app.get("/users/me") +def read_current_user( + credentials: Optional[HTTPAuthorizationCredentials] = Security(security) +): + if credentials is None: + return {"msg": "Create an account first"} + return {"scheme": credentials.scheme, "credentials": credentials.credentials} + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/users/me": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Current User Get", + "operationId": "read_current_user_users_me_get", + "security": [{"HTTPBearer": []}], + } + } + }, + "components": { + "securitySchemes": {"HTTPBearer": {"type": "http", "scheme": "bearer"}} + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_security_http_bearer(): + response = client.get("/users/me", headers={"Authorization": "Bearer foobar"}) + assert response.status_code == 200 + assert response.json() == {"scheme": "Bearer", "credentials": "foobar"} + + +def test_security_http_bearer_no_credentials(): + response = client.get("/users/me") + assert response.status_code == 200 + assert response.json() == {"msg": "Create an account first"} + + +def test_security_http_bearer_incorrect_scheme_credentials(): + response = client.get("/users/me", headers={"Authorization": "Basic notreally"}) + assert response.status_code == 403 + assert response.json() == {"detail": "Invalid authentication credentials"} diff --git a/tests/test_security_http_digest_optional.py b/tests/test_security_http_digest_optional.py new file mode 100644 index 000000000..072a564f5 --- /dev/null +++ b/tests/test_security_http_digest_optional.py @@ -0,0 +1,70 @@ +from typing import Optional + +from fastapi import FastAPI, Security +from fastapi.security import HTTPAuthorizationCredentials, HTTPDigest +from starlette.testclient import TestClient + +app = FastAPI() + +security = HTTPDigest(auto_error=False) + + +@app.get("/users/me") +def read_current_user( + credentials: Optional[HTTPAuthorizationCredentials] = Security(security) +): + if credentials is None: + return {"msg": "Create an account first"} + return {"scheme": credentials.scheme, "credentials": credentials.credentials} + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/users/me": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Current User Get", + "operationId": "read_current_user_users_me_get", + "security": [{"HTTPDigest": []}], + } + } + }, + "components": { + "securitySchemes": {"HTTPDigest": {"type": "http", "scheme": "digest"}} + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_security_http_digest(): + response = client.get("/users/me", headers={"Authorization": "Digest foobar"}) + assert response.status_code == 200 + assert response.json() == {"scheme": "Digest", "credentials": "foobar"} + + +def test_security_http_digest_no_credentials(): + response = client.get("/users/me") + assert response.status_code == 200 + assert response.json() == {"msg": "Create an account first"} + + +def test_security_http_digest_incorrect_scheme_credentials(): + response = client.get( + "/users/me", headers={"Authorization": "Other invalidauthorization"} + ) + assert response.status_code == 403 + assert response.json() == {"detail": "Invalid authentication credentials"} diff --git a/tests/test_security_oauth2_optional.py b/tests/test_security_oauth2_optional.py new file mode 100644 index 000000000..7c245a0b4 --- /dev/null +++ b/tests/test_security_oauth2_optional.py @@ -0,0 +1,254 @@ +from typing import Optional + +import pytest +from fastapi import Depends, FastAPI, Security +from fastapi.security import OAuth2 +from fastapi.security.oauth2 import OAuth2PasswordRequestFormStrict +from pydantic import BaseModel +from starlette.testclient import TestClient + +app = FastAPI() + +reusable_oauth2 = OAuth2( + flows={ + "password": { + "tokenUrl": "/token", + "scopes": {"read:users": "Read the users", "write:users": "Create users"}, + } + }, + auto_error=False, +) + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: Optional[str] = Security(reusable_oauth2)): + if oauth_header is None: + return None + user = User(username=oauth_header) + return user + + +@app.post("/login") +def read_current_user(form_data: OAuth2PasswordRequestFormStrict = Depends()): + return form_data + + +@app.get("/users/me") +def read_current_user(current_user: Optional[User] = Depends(get_current_user)): + if current_user is None: + return {"msg": "Create an account first"} + return current_user + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/login": { + "post": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Read Current User Post", + "operationId": "read_current_user_login_post", + "requestBody": { + "content": { + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Body_read_current_user" + } + } + }, + "required": True, + }, + } + }, + "/users/me": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Current User Get", + "operationId": "read_current_user_users_me_get", + "security": [{"OAuth2": []}], + } + }, + }, + "components": { + "schemas": { + "Body_read_current_user": { + "title": "Body_read_current_user", + "required": ["grant_type", "username", "password"], + "type": "object", + "properties": { + "grant_type": { + "title": "Grant_Type", + "pattern": "password", + "type": "string", + }, + "username": {"title": "Username", "type": "string"}, + "password": {"title": "Password", "type": "string"}, + "scope": {"title": "Scope", "type": "string", "default": ""}, + "client_id": {"title": "Client_Id", "type": "string"}, + "client_secret": {"title": "Client_Secret", "type": "string"}, + }, + }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + }, + "securitySchemes": { + "OAuth2": { + "type": "oauth2", + "flows": { + "password": { + "scopes": { + "read:users": "Read the users", + "write:users": "Create users", + }, + "tokenUrl": "/token", + } + }, + } + }, + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_security_oauth2(): + response = client.get("/users/me", headers={"Authorization": "Bearer footokenbar"}) + assert response.status_code == 200 + assert response.json() == {"username": "Bearer footokenbar"} + + +def test_security_oauth2_password_other_header(): + response = client.get("/users/me", headers={"Authorization": "Other footokenbar"}) + assert response.status_code == 200 + assert response.json() == {"username": "Other footokenbar"} + + +def test_security_oauth2_password_bearer_no_header(): + response = client.get("/users/me") + assert response.status_code == 200 + assert response.json() == {"msg": "Create an account first"} + + +required_params = { + "detail": [ + { + "loc": ["body", "grant_type"], + "msg": "field required", + "type": "value_error.missing", + }, + { + "loc": ["body", "username"], + "msg": "field required", + "type": "value_error.missing", + }, + { + "loc": ["body", "password"], + "msg": "field required", + "type": "value_error.missing", + }, + ] +} + +grant_type_required = { + "detail": [ + { + "loc": ["body", "grant_type"], + "msg": "field required", + "type": "value_error.missing", + } + ] +} + +grant_type_incorrect = { + "detail": [ + { + "loc": ["body", "grant_type"], + "msg": 'string does not match regex "password"', + "type": "value_error.str.regex", + "ctx": {"pattern": "password"}, + } + ] +} + + +@pytest.mark.parametrize( + "data,expected_status,expected_response", + [ + (None, 422, required_params), + ({"username": "johndoe", "password": "secret"}, 422, grant_type_required), + ( + {"username": "johndoe", "password": "secret", "grant_type": "incorrect"}, + 422, + grant_type_incorrect, + ), + ( + {"username": "johndoe", "password": "secret", "grant_type": "password"}, + 200, + { + "grant_type": "password", + "username": "johndoe", + "password": "secret", + "scopes": [], + "client_id": None, + "client_secret": None, + }, + ), + ], +) +def test_strict_login(data, expected_status, expected_response): + response = client.post("/login", data=data) + assert response.status_code == expected_status + assert response.json() == expected_response diff --git a/tests/test_security_oauth2_password_bearer_optional.py b/tests/test_security_oauth2_password_bearer_optional.py new file mode 100644 index 000000000..b50ff943b --- /dev/null +++ b/tests/test_security_oauth2_password_bearer_optional.py @@ -0,0 +1,71 @@ +from typing import Optional + +from fastapi import FastAPI, Security +from fastapi.security import OAuth2PasswordBearer +from starlette.testclient import TestClient + +app = FastAPI() + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token", auto_error=False) + + +@app.get("/items/") +async def read_items(token: Optional[str] = Security(oauth2_scheme)): + if token is None: + return {"msg": "Create an account first"} + return {"token": token} + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/items/": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Items Get", + "operationId": "read_items_items__get", + "security": [{"OAuth2PasswordBearer": []}], + } + } + }, + "components": { + "securitySchemes": { + "OAuth2PasswordBearer": { + "type": "oauth2", + "flows": {"password": {"scopes": {}, "tokenUrl": "/token"}}, + } + } + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_no_token(): + response = client.get("/items") + assert response.status_code == 200 + assert response.json() == {"msg": "Create an account first"} + + +def test_token(): + response = client.get("/items", headers={"Authorization": "Bearer testtoken"}) + assert response.status_code == 200 + assert response.json() == {"token": "testtoken"} + + +def test_incorrect_token(): + response = client.get("/items", headers={"Authorization": "Notexistent testtoken"}) + assert response.status_code == 200 + assert response.json() == {"msg": "Create an account first"} diff --git a/tests/test_security_openid_connect_optional.py b/tests/test_security_openid_connect_optional.py new file mode 100644 index 000000000..a37382603 --- /dev/null +++ b/tests/test_security_openid_connect_optional.py @@ -0,0 +1,80 @@ +from typing import Optional + +from fastapi import Depends, FastAPI, Security +from fastapi.security.open_id_connect_url import OpenIdConnect +from pydantic import BaseModel +from starlette.testclient import TestClient + +app = FastAPI() + +oid = OpenIdConnect(openIdConnectUrl="/openid", auto_error=False) + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: Optional[str] = Security(oid)): + if oauth_header is None: + return None + user = User(username=oauth_header) + return user + + +@app.get("/users/me") +def read_current_user(current_user: Optional[User] = Depends(get_current_user)): + if current_user is None: + return {"msg": "Create an account first"} + return current_user + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/users/me": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Current User Get", + "operationId": "read_current_user_users_me_get", + "security": [{"OpenIdConnect": []}], + } + } + }, + "components": { + "securitySchemes": { + "OpenIdConnect": {"type": "openIdConnect", "openIdConnectUrl": "/openid"} + } + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_security_oauth2(): + response = client.get("/users/me", headers={"Authorization": "Bearer footokenbar"}) + assert response.status_code == 200 + assert response.json() == {"username": "Bearer footokenbar"} + + +def test_security_oauth2_password_other_header(): + response = client.get("/users/me", headers={"Authorization": "Other footokenbar"}) + assert response.status_code == 200 + assert response.json() == {"username": "Other footokenbar"} + + +def test_security_oauth2_password_bearer_no_header(): + response = client.get("/users/me") + assert response.status_code == 200 + assert response.json() == {"msg": "Create an account first"}