Browse Source

Improve security utilities and add tests

pull/11/head
Sebastián Ramírez 6 years ago
parent
commit
0393a093d3
  1. 9
      fastapi/dependencies/utils.py
  2. 35
      fastapi/security/api_key.py
  3. 64
      fastapi/security/http.py
  4. 14
      fastapi/security/oauth2.py
  5. 9
      fastapi/security/open_id_connect_url.py
  6. 5
      fastapi/security/utils.py
  7. 29
      tests/main.py
  8. 22
      tests/test_application.py
  9. 2
      tests/test_extra_routes.py
  10. 23
      tests/test_include_route.py
  11. 10
      tests/test_query.py
  12. 25
      tests/test_security.py
  13. 68
      tests/test_security_api_key_cookie.py
  14. 68
      tests/test_security_api_key_header.py
  15. 68
      tests/test_security_api_key_query.py
  16. 247
      tests/test_security_oauth2.py
  17. 74
      tests/test_security_openid_connect.py

9
fastapi/dependencies/utils.py

@ -97,7 +97,7 @@ def get_dependant(*, path: str, call: Callable, name: str = None) -> Dependant:
elif ( elif (
param.default == param.empty param.default == param.empty
or param.default is None or param.default is None
or type(param.default) in param_supported_types or isinstance(param.default, param_supported_types)
) and ( ) and (
param.annotation == param.empty param.annotation == param.empty
or lenient_issubclass(param.annotation, param_supported_types) or lenient_issubclass(param.annotation, param_supported_types)
@ -214,7 +214,8 @@ async def solve_dependencies(
request=request, dependant=sub_dependant, body=body request=request, dependant=sub_dependant, body=body
) )
if sub_errors: if sub_errors:
return {}, errors errors.extend(sub_errors)
continue
assert sub_dependant.call is not None, "sub_dependant.call must be a function" assert sub_dependant.call is not None, "sub_dependant.call must be a function"
if is_coroutine_callable(sub_dependant.call): if is_coroutine_callable(sub_dependant.call):
solved = await sub_dependant.call(**sub_values) solved = await sub_dependant.call(**sub_values)
@ -238,7 +239,7 @@ async def solve_dependencies(
values.update(query_values) values.update(query_values)
values.update(header_values) values.update(header_values)
values.update(cookie_values) values.update(cookie_values)
errors = path_errors + query_errors + header_errors + cookie_errors errors += path_errors + query_errors + header_errors + cookie_errors
if dependant.body_params: if dependant.body_params:
body_values, body_errors = await request_body_to_args( # type: ignore # body_params checked above body_values, body_errors = await request_body_to_args( # type: ignore # body_params checked above
dependant.body_params, body dependant.body_params, body
@ -295,7 +296,7 @@ async def request_body_to_args(
received_body = {} received_body = {}
for field in required_params: for field in required_params:
value = received_body.get(field.alias) value = received_body.get(field.alias)
if value is None: if value is None or (isinstance(field.schema, params.Form) and value == ""):
if field.required: if field.required:
errors.append( errors.append(
ErrorWrapper( ErrorWrapper(

35
fastapi/security/api_key.py

@ -1,6 +1,8 @@
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN
class APIKeyBase(SecurityBase): class APIKeyBase(SecurityBase):
@ -9,26 +11,41 @@ class APIKeyBase(SecurityBase):
class APIKeyQuery(APIKeyBase): class APIKeyQuery(APIKeyBase):
def __init__(self, *, name: str, scheme_name: str = None): def __init__(self, *, name: str, scheme_name: str = None):
self.model = APIKey(in_=APIKeyIn.query, name=name) self.model = APIKey(**{"in": APIKeyIn.query}, name=name)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request) -> str: async def __call__(self, request: Request) -> str:
return requests.query_params.get(self.model.name) api_key: str = request.query_params.get(self.model.name)
if not api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
return api_key
class APIKeyHeader(APIKeyBase): class APIKeyHeader(APIKeyBase):
def __init__(self, *, name: str, scheme_name: str = None): def __init__(self, *, name: str, scheme_name: str = None):
self.model = APIKey(in_=APIKeyIn.header, name=name) self.model = APIKey(**{"in": APIKeyIn.header}, name=name)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request) -> str: async def __call__(self, request: Request) -> str:
return requests.headers.get(self.model.name) api_key: str = request.headers.get(self.model.name)
if not api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
return api_key
class APIKeyCookie(APIKeyBase): class APIKeyCookie(APIKeyBase):
def __init__(self, *, name: str, scheme_name: str = None): def __init__(self, *, name: str, scheme_name: str = None):
self.model = APIKey(in_=APIKeyIn.cookie, name=name) self.model = APIKey(**{"in": APIKeyIn.cookie}, name=name)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request) -> str: async def __call__(self, request: Request) -> str:
return requests.cookies.get(self.model.name) api_key: str = request.cookies.get(self.model.name)
if not api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
return api_key

64
fastapi/security/http.py

@ -1,9 +1,26 @@
import binascii
from base64 import b64decode
from fastapi.openapi.models import ( from fastapi.openapi.models import (
HTTPBase as HTTPBaseModel, HTTPBase as HTTPBaseModel,
HTTPBearer as HTTPBearerModel, HTTPBearer as HTTPBearerModel,
) )
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel
from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN
class HTTPBasicCredentials(BaseModel):
username: str
password: str
class HTTPAuthorizationCredentials(BaseModel):
scheme: str
credentials: str
class HTTPBase(SecurityBase): class HTTPBase(SecurityBase):
@ -12,16 +29,41 @@ class HTTPBase(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization") authorization: str = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
class HTTPBasic(HTTPBase): class HTTPBasic(HTTPBase):
def __init__(self, *, scheme_name: str = None): def __init__(self, *, scheme_name: str = None, realm: str = None):
self.model = HTTPBaseModel(scheme="basic") self.model = HTTPBaseModel(scheme="basic")
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.realm = realm
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization") authorization: str = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
# before implementing headers with 401 errors, wait for: https://github.com/encode/starlette/issues/295
# unauthorized_headers = {"WWW-Authenticate": "Basic"}
invalid_user_credentials_exc = HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid authentication credentials"
)
if not authorization or scheme.lower() != "basic":
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
try:
data = b64decode(param).decode("ascii")
except (ValueError, UnicodeDecodeError, binascii.Error):
raise invalid_user_credentials_exc
username, separator, password = data.partition(":")
if not (separator):
raise invalid_user_credentials_exc
return HTTPBasicCredentials(username=username, password=password)
class HTTPBearer(HTTPBase): class HTTPBearer(HTTPBase):
@ -30,7 +72,13 @@ class HTTPBearer(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization") authorization: str = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
class HTTPDigest(HTTPBase): class HTTPDigest(HTTPBase):
@ -39,4 +87,10 @@ class HTTPDigest(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization") authorization: str = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)

14
fastapi/security/oauth2.py

@ -3,6 +3,7 @@ from typing import Optional
from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel
from fastapi.params import Form from fastapi.params import Form
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
@ -118,7 +119,12 @@ class OAuth2(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization") authorization: str = request.headers.get("Authorization")
if not authorization:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
return authorization
class OAuth2PasswordBearer(OAuth2): class OAuth2PasswordBearer(OAuth2):
@ -130,9 +136,9 @@ class OAuth2PasswordBearer(OAuth2):
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> str:
authorization: str = request.headers.get("Authorization") authorization: str = request.headers.get("Authorization")
if not authorization or "Bearer " not in authorization: scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
) )
token = authorization.replace("Bearer ", "") return param
return token

9
fastapi/security/open_id_connect_url.py

@ -1,6 +1,8 @@
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN
class OpenIdConnect(SecurityBase): class OpenIdConnect(SecurityBase):
@ -9,4 +11,9 @@ class OpenIdConnect(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization") authorization: str = request.headers.get("Authorization")
if not authorization:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
return authorization

5
fastapi/security/utils.py

@ -0,0 +1,5 @@
def get_authorization_scheme_param(authorization_header_value: str):
if not authorization_header_value:
return "", ""
scheme, _, param = authorization_header_value.partition(" ")
return scheme, param

29
tests/main.py

@ -1,6 +1,4 @@
from fastapi import Depends, FastAPI, Path, Query, Security from fastapi import FastAPI, Path, Query
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
app = FastAPI() app = FastAPI()
@ -144,8 +142,6 @@ def get_path_param_le_ge_int(item_id: int = Path(..., le=3, ge=1)):
@app.get("/query") @app.get("/query")
def get_query(query): def get_query(query):
if query is None:
return "foo bar"
return f"foo bar {query}" return f"foo bar {query}"
@ -158,8 +154,6 @@ def get_query_optional(query=None):
@app.get("/query/int") @app.get("/query/int")
def get_query_type(query: int): def get_query_type(query: int):
if query is None:
return "foo bar"
return f"foo bar {query}" return f"foo bar {query}"
@ -184,30 +178,9 @@ def get_query_param(query=Query(None)):
@app.get("/query/param-required") @app.get("/query/param-required")
def get_query_param_required(query=Query(...)): def get_query_param_required(query=Query(...)):
if query is None:
return "foo bar"
return f"foo bar {query}" return f"foo bar {query}"
@app.get("/query/param-required/int") @app.get("/query/param-required/int")
def get_query_param_required_type(query: int = Query(...)): def get_query_param_required_type(query: int = Query(...)):
if query is None:
return "foo bar"
return f"foo bar {query}" return f"foo bar {query}"
reusable_oauth2b = OAuth2PasswordBearer(tokenUrl="/token")
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(reusable_oauth2b)):
user = User(username=oauth_header)
return user
@app.get("/security/oauth2b")
def read_current_user(current_user: User = Depends(get_current_user)):
return current_user

22
tests/test_application.py

@ -1078,19 +1078,6 @@ openapi_schema = {
], ],
} }
}, },
"/security/oauth2b": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User Get",
"operationId": "read_current_user_security_oauth2b_get",
"security": [{"OAuth2PasswordBearer": []}],
}
},
}, },
"components": { "components": {
"schemas": { "schemas": {
@ -1119,13 +1106,7 @@ openapi_schema = {
} }
}, },
}, },
}, }
"securitySchemes": {
"OAuth2PasswordBearer": {
"type": "oauth2",
"flows": {"password": {"scopes": {}, "tokenUrl": "/token"}},
}
},
}, },
} }
@ -1134,6 +1115,7 @@ openapi_schema = {
"path,expected_status,expected_response", "path,expected_status,expected_response",
[ [
("/api_route", 200, {"message": "Hello World"}), ("/api_route", 200, {"message": "Hello World"}),
("/non_decorated_route", 200, {"message": "Hello World"}),
("/nonexistent", 404, {"detail": "Not Found"}), ("/nonexistent", 404, {"detail": "Not Found"}),
("/openapi.json", 200, openapi_schema), ("/openapi.json", 200, openapi_schema),
], ],

2
tests/test_extra_routes.py

@ -343,7 +343,7 @@ def test_head():
def test_options(): def test_options():
response = client.head("/items/foo") response = client.options("/items/foo")
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["x-fastapi-item-id"] == "foo" assert response.headers["x-fastapi-item-id"] == "foo"

23
tests/test_include_route.py

@ -0,0 +1,23 @@
from fastapi import APIRouter, FastAPI
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.testclient import TestClient
app = FastAPI()
router = APIRouter()
@router.route("/items/")
def read_items(request: Request):
return JSONResponse({"hello": "world"})
app.include_router(router)
client = TestClient(app)
def test_sub_router():
response = client.get("/items/")
assert response.status_code == 200
assert response.json() == {"hello": "world"}

10
tests/test_query.py

@ -40,9 +40,19 @@ response_not_valid_int = {
("/query/int?query=42.5", 422, response_not_valid_int), ("/query/int?query=42.5", 422, response_not_valid_int),
("/query/int?query=baz", 422, response_not_valid_int), ("/query/int?query=baz", 422, response_not_valid_int),
("/query/int?not_declared=baz", 422, response_missing), ("/query/int?not_declared=baz", 422, response_missing),
("/query/int/optional", 200, "foo bar"),
("/query/int/optional?query=50", 200, "foo bar 50"),
("/query/int/optional?query=foo", 422, response_not_valid_int),
("/query/int/default", 200, "foo bar 10"), ("/query/int/default", 200, "foo bar 10"),
("/query/int/default?query=50", 200, "foo bar 50"), ("/query/int/default?query=50", 200, "foo bar 50"),
("/query/int/default?query=foo", 422, response_not_valid_int), ("/query/int/default?query=foo", 422, response_not_valid_int),
("/query/param", 200, "foo bar"),
("/query/param?query=50", 200, "foo bar 50"),
("/query/param-required", 422, response_missing),
("/query/param-required?query=50", 200, "foo bar 50"),
("/query/param-required/int", 422, response_missing),
("/query/param-required/int?query=50", 200, "foo bar 50"),
("/query/param-required/int?query=foo", 422, response_not_valid_int),
], ],
) )
def test_get_path(path, expected_status, expected_response): def test_get_path(path, expected_status, expected_response):

25
tests/test_security.py

@ -1,25 +0,0 @@
from starlette.testclient import TestClient
from .main import app
client = TestClient(app)
def test_security_oauth2_password_bearer():
response = client.get(
"/security/oauth2b", headers={"Authorization": "Bearer footokenbar"}
)
assert response.status_code == 200
assert response.json() == {"username": "footokenbar"}
def test_security_oauth2_password_bearer_wrong_header():
response = client.get("/security/oauth2b", headers={"Authorization": "footokenbar"})
assert response.status_code == 403
assert response.json() == {"detail": "Not authenticated"}
def test_security_oauth2_password_bearer_no_header():
response = client.get("/security/oauth2b")
assert response.status_code == 403
assert response.json() == {"detail": "Not authenticated"}

68
tests/test_security_api_key_cookie.py

@ -0,0 +1,68 @@
from fastapi import Depends, FastAPI, Security
from fastapi.security import APIKeyCookie
from pydantic import BaseModel
from starlette.testclient import TestClient
app = FastAPI()
api_key = APIKeyCookie(name="key")
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(api_key)):
user = User(username=oauth_header)
return user
@app.get("/users/me")
def read_current_user(current_user: User = Depends(get_current_user)):
return current_user
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/users/me": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User Get",
"operationId": "read_current_user_users_me_get",
"security": [{"APIKeyCookie": []}],
}
}
},
"components": {
"securitySchemes": {
"APIKeyCookie": {"type": "apiKey", "name": "key", "in": "cookie"}
}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_security_api_key():
response = client.get("/users/me", cookies={"key": "secret"})
assert response.status_code == 200
assert response.json() == {"username": "secret"}
def test_security_api_key_no_key():
response = client.get("/users/me")
assert response.status_code == 403
assert response.json() == {"detail": "Not authenticated"}

68
tests/test_security_api_key_header.py

@ -0,0 +1,68 @@
from fastapi import Depends, FastAPI, Security
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from starlette.testclient import TestClient
app = FastAPI()
api_key = APIKeyHeader(name="key")
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(api_key)):
user = User(username=oauth_header)
return user
@app.get("/users/me")
def read_current_user(current_user: User = Depends(get_current_user)):
return current_user
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/users/me": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User Get",
"operationId": "read_current_user_users_me_get",
"security": [{"APIKeyHeader": []}],
}
}
},
"components": {
"securitySchemes": {
"APIKeyHeader": {"type": "apiKey", "name": "key", "in": "header"}
}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_security_api_key():
response = client.get("/users/me", headers={"key": "secret"})
assert response.status_code == 200
assert response.json() == {"username": "secret"}
def test_security_api_key_no_key():
response = client.get("/users/me")
assert response.status_code == 403
assert response.json() == {"detail": "Not authenticated"}

68
tests/test_security_api_key_query.py

@ -0,0 +1,68 @@
from fastapi import Depends, FastAPI, Security
from fastapi.security import APIKeyQuery
from pydantic import BaseModel
from starlette.testclient import TestClient
app = FastAPI()
api_key = APIKeyQuery(name="key")
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(api_key)):
user = User(username=oauth_header)
return user
@app.get("/users/me")
def read_current_user(current_user: User = Depends(get_current_user)):
return current_user
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/users/me": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User Get",
"operationId": "read_current_user_users_me_get",
"security": [{"APIKeyQuery": []}],
}
}
},
"components": {
"securitySchemes": {
"APIKeyQuery": {"type": "apiKey", "name": "key", "in": "query"}
}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_security_api_key():
response = client.get("/users/me?key=secret")
assert response.status_code == 200
assert response.json() == {"username": "secret"}
def test_security_api_key_no_key():
response = client.get("/users/me")
assert response.status_code == 403
assert response.json() == {"detail": "Not authenticated"}

247
tests/test_security_oauth2.py

@ -0,0 +1,247 @@
import pytest
from fastapi import Depends, FastAPI, Security
from fastapi.security import OAuth2
from fastapi.security.oauth2 import OAuth2PasswordRequestFormStrict
from pydantic import BaseModel
from starlette.testclient import TestClient
app = FastAPI()
reusable_oauth2 = OAuth2(
flows={
"password": {
"tokenUrl": "/token",
"scopes": {"read:users": "Read the users", "write:users": "Create users"},
}
}
)
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(reusable_oauth2)):
user = User(username=oauth_header)
return user
@app.post("/login")
def read_current_user(form_data: OAuth2PasswordRequestFormStrict = Depends()):
return form_data
@app.get("/users/me")
def read_current_user(current_user: User = Depends(get_current_user)):
return current_user
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/login": {
"post": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
"summary": "Read Current User Post",
"operationId": "read_current_user_login_post",
"requestBody": {
"content": {
"application/x-www-form-urlencoded": {
"schema": {
"$ref": "#/components/schemas/Body_read_current_user"
}
}
},
"required": True,
},
}
},
"/users/me": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User Get",
"operationId": "read_current_user_users_me_get",
"security": [{"OAuth2": []}],
}
},
},
"components": {
"schemas": {
"Body_read_current_user": {
"title": "Body_read_current_user",
"required": ["grant_type", "username", "password"],
"type": "object",
"properties": {
"grant_type": {
"title": "Grant_Type",
"pattern": "password",
"type": "string",
},
"username": {"title": "Username", "type": "string"},
"password": {"title": "Password", "type": "string"},
"scope": {"title": "Scope", "type": "string", "default": ""},
"client_id": {"title": "Client_Id", "type": "string"},
"client_secret": {"title": "Client_Secret", "type": "string"},
},
},
"ValidationError": {
"title": "ValidationError",
"required": ["loc", "msg", "type"],
"type": "object",
"properties": {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
},
"HTTPValidationError": {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": "#/components/schemas/ValidationError"},
}
},
},
},
"securitySchemes": {
"OAuth2": {
"type": "oauth2",
"flows": {
"password": {
"scopes": {
"read:users": "Read the users",
"write:users": "Create users",
},
"tokenUrl": "/token",
}
},
}
},
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_security_oauth2():
response = client.get("/users/me", headers={"Authorization": "Bearer footokenbar"})
assert response.status_code == 200
assert response.json() == {"username": "Bearer footokenbar"}
def test_security_oauth2_password_other_header():
response = client.get("/users/me", headers={"Authorization": "Other footokenbar"})
assert response.status_code == 200
assert response.json() == {"username": "Other footokenbar"}
def test_security_oauth2_password_bearer_no_header():
response = client.get("/users/me")
assert response.status_code == 403
assert response.json() == {"detail": "Not authenticated"}
required_params = {
"detail": [
{
"loc": ["body", "grant_type"],
"msg": "field required",
"type": "value_error.missing",
},
{
"loc": ["body", "username"],
"msg": "field required",
"type": "value_error.missing",
},
{
"loc": ["body", "password"],
"msg": "field required",
"type": "value_error.missing",
},
]
}
grant_type_required = {
"detail": [
{
"loc": ["body", "grant_type"],
"msg": "field required",
"type": "value_error.missing",
}
]
}
grant_type_incorrect = {
"detail": [
{
"loc": ["body", "grant_type"],
"msg": 'string does not match regex "password"',
"type": "value_error.str.regex",
"ctx": {"pattern": "password"},
}
]
}
@pytest.mark.parametrize(
"data,expected_status,expected_response",
[
(None, 422, required_params),
({"username": "johndoe", "password": "secret"}, 422, grant_type_required),
(
{"username": "johndoe", "password": "secret", "grant_type": "incorrect"},
422,
grant_type_incorrect,
),
(
{"username": "johndoe", "password": "secret", "grant_type": "password"},
200,
{
"grant_type": "password",
"username": "johndoe",
"password": "secret",
"scopes": [],
"client_id": None,
"client_secret": None,
},
),
],
)
def test_strict_login(data, expected_status, expected_response):
response = client.post("/login", data=data)
assert response.status_code == expected_status
assert response.json() == expected_response

74
tests/test_security_openid_connect.py

@ -0,0 +1,74 @@
from fastapi import Depends, FastAPI, Security
from fastapi.security.open_id_connect_url import OpenIdConnect
from pydantic import BaseModel
from starlette.testclient import TestClient
app = FastAPI()
oid = OpenIdConnect(openIdConnectUrl="/openid")
class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(oid)):
user = User(username=oauth_header)
return user
@app.get("/users/me")
def read_current_user(current_user: User = Depends(get_current_user)):
return current_user
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/users/me": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User Get",
"operationId": "read_current_user_users_me_get",
"security": [{"OpenIdConnect": []}],
}
}
},
"components": {
"securitySchemes": {
"OpenIdConnect": {"type": "openIdConnect", "openIdConnectUrl": "/openid"}
}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_security_oauth2():
response = client.get("/users/me", headers={"Authorization": "Bearer footokenbar"})
assert response.status_code == 200
assert response.json() == {"username": "Bearer footokenbar"}
def test_security_oauth2_password_other_header():
response = client.get("/users/me", headers={"Authorization": "Other footokenbar"})
assert response.status_code == 200
assert response.json() == {"username": "Other footokenbar"}
def test_security_oauth2_password_bearer_no_header():
response = client.get("/users/me")
assert response.status_code == 403
assert response.json() == {"detail": "Not authenticated"}
Loading…
Cancel
Save