diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 9cb9f3f14..e72ac8f01 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -97,7 +97,7 @@ def get_dependant(*, path: str, call: Callable, name: str = None) -> Dependant: elif ( param.default == param.empty or param.default is None - or type(param.default) in param_supported_types + or isinstance(param.default, param_supported_types) ) and ( param.annotation == param.empty or lenient_issubclass(param.annotation, param_supported_types) @@ -214,7 +214,8 @@ async def solve_dependencies( request=request, dependant=sub_dependant, body=body ) if sub_errors: - return {}, errors + errors.extend(sub_errors) + continue assert sub_dependant.call is not None, "sub_dependant.call must be a function" if is_coroutine_callable(sub_dependant.call): solved = await sub_dependant.call(**sub_values) @@ -238,7 +239,7 @@ async def solve_dependencies( values.update(query_values) values.update(header_values) values.update(cookie_values) - errors = path_errors + query_errors + header_errors + cookie_errors + errors += path_errors + query_errors + header_errors + cookie_errors if dependant.body_params: body_values, body_errors = await request_body_to_args( # type: ignore # body_params checked above dependant.body_params, body @@ -295,7 +296,7 @@ async def request_body_to_args( received_body = {} for field in required_params: value = received_body.get(field.alias) - if value is None: + if value is None or (isinstance(field.schema, params.Form) and value == ""): if field.required: errors.append( ErrorWrapper( diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py index 12eba37ee..018e4f99e 100644 --- a/fastapi/security/api_key.py +++ b/fastapi/security/api_key.py @@ -1,6 +1,8 @@ from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.security.base import SecurityBase +from starlette.exceptions import HTTPException from starlette.requests import Request +from starlette.status import HTTP_403_FORBIDDEN class APIKeyBase(SecurityBase): @@ -9,26 +11,41 @@ class APIKeyBase(SecurityBase): class APIKeyQuery(APIKeyBase): def __init__(self, *, name: str, scheme_name: str = None): - self.model = APIKey(in_=APIKeyIn.query, name=name) + self.model = APIKey(**{"in": APIKeyIn.query}, name=name) self.scheme_name = scheme_name or self.__class__.__name__ - async def __call__(self, requests: Request) -> str: - return requests.query_params.get(self.model.name) + async def __call__(self, request: Request) -> str: + api_key: str = request.query_params.get(self.model.name) + if not api_key: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + return api_key class APIKeyHeader(APIKeyBase): def __init__(self, *, name: str, scheme_name: str = None): - self.model = APIKey(in_=APIKeyIn.header, name=name) + self.model = APIKey(**{"in": APIKeyIn.header}, name=name) self.scheme_name = scheme_name or self.__class__.__name__ - async def __call__(self, requests: Request) -> str: - return requests.headers.get(self.model.name) + async def __call__(self, request: Request) -> str: + api_key: str = request.headers.get(self.model.name) + if not api_key: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + return api_key class APIKeyCookie(APIKeyBase): def __init__(self, *, name: str, scheme_name: str = None): - self.model = APIKey(in_=APIKeyIn.cookie, name=name) + self.model = APIKey(**{"in": APIKeyIn.cookie}, name=name) self.scheme_name = scheme_name or self.__class__.__name__ - async def __call__(self, requests: Request) -> str: - return requests.cookies.get(self.model.name) + async def __call__(self, request: Request) -> str: + api_key: str = request.cookies.get(self.model.name) + if not api_key: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + return api_key diff --git a/fastapi/security/http.py b/fastapi/security/http.py index b1cba1921..287beee58 100644 --- a/fastapi/security/http.py +++ b/fastapi/security/http.py @@ -1,9 +1,26 @@ +import binascii +from base64 import b64decode + from fastapi.openapi.models import ( HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel, ) from fastapi.security.base import SecurityBase +from fastapi.security.utils import get_authorization_scheme_param +from pydantic import BaseModel +from starlette.exceptions import HTTPException from starlette.requests import Request +from starlette.status import HTTP_403_FORBIDDEN + + +class HTTPBasicCredentials(BaseModel): + username: str + password: str + + +class HTTPAuthorizationCredentials(BaseModel): + scheme: str + credentials: str class HTTPBase(SecurityBase): @@ -12,16 +29,41 @@ class HTTPBase(SecurityBase): self.scheme_name = scheme_name or self.__class__.__name__ async def __call__(self, request: Request) -> str: - return request.headers.get("Authorization") + authorization: str = request.headers.get("Authorization") + scheme, credentials = get_authorization_scheme_param(authorization) + if not (authorization and scheme and credentials): + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) class HTTPBasic(HTTPBase): - def __init__(self, *, scheme_name: str = None): + def __init__(self, *, scheme_name: str = None, realm: str = None): self.model = HTTPBaseModel(scheme="basic") self.scheme_name = scheme_name or self.__class__.__name__ + self.realm = realm async def __call__(self, request: Request) -> str: - return request.headers.get("Authorization") + authorization: str = request.headers.get("Authorization") + scheme, param = get_authorization_scheme_param(authorization) + # before implementing headers with 401 errors, wait for: https://github.com/encode/starlette/issues/295 + # unauthorized_headers = {"WWW-Authenticate": "Basic"} + invalid_user_credentials_exc = HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Invalid authentication credentials" + ) + if not authorization or scheme.lower() != "basic": + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + try: + data = b64decode(param).decode("ascii") + except (ValueError, UnicodeDecodeError, binascii.Error): + raise invalid_user_credentials_exc + username, separator, password = data.partition(":") + if not (separator): + raise invalid_user_credentials_exc + return HTTPBasicCredentials(username=username, password=password) class HTTPBearer(HTTPBase): @@ -30,7 +72,13 @@ class HTTPBearer(HTTPBase): self.scheme_name = scheme_name or self.__class__.__name__ async def __call__(self, request: Request) -> str: - return request.headers.get("Authorization") + authorization: str = request.headers.get("Authorization") + scheme, credentials = get_authorization_scheme_param(authorization) + if not (authorization and scheme and credentials): + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) class HTTPDigest(HTTPBase): @@ -39,4 +87,10 @@ class HTTPDigest(HTTPBase): self.scheme_name = scheme_name or self.__class__.__name__ async def __call__(self, request: Request) -> str: - return request.headers.get("Authorization") + authorization: str = request.headers.get("Authorization") + scheme, credentials = get_authorization_scheme_param(authorization) + if not (authorization and scheme and credentials): + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py index 4fd767ec6..b1132fef1 100644 --- a/fastapi/security/oauth2.py +++ b/fastapi/security/oauth2.py @@ -3,6 +3,7 @@ from typing import Optional from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel from fastapi.params import Form from fastapi.security.base import SecurityBase +from fastapi.security.utils import get_authorization_scheme_param from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.status import HTTP_403_FORBIDDEN @@ -118,7 +119,12 @@ class OAuth2(SecurityBase): self.scheme_name = scheme_name or self.__class__.__name__ async def __call__(self, request: Request) -> str: - return request.headers.get("Authorization") + authorization: str = request.headers.get("Authorization") + if not authorization: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + return authorization class OAuth2PasswordBearer(OAuth2): @@ -130,9 +136,9 @@ class OAuth2PasswordBearer(OAuth2): async def __call__(self, request: Request) -> str: authorization: str = request.headers.get("Authorization") - if not authorization or "Bearer " not in authorization: + scheme, param = get_authorization_scheme_param(authorization) + if not authorization or scheme.lower() != "bearer": raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" ) - token = authorization.replace("Bearer ", "") - return token + return param diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py index 7d73ed81f..e10f4a510 100644 --- a/fastapi/security/open_id_connect_url.py +++ b/fastapi/security/open_id_connect_url.py @@ -1,6 +1,8 @@ from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.security.base import SecurityBase +from starlette.exceptions import HTTPException from starlette.requests import Request +from starlette.status import HTTP_403_FORBIDDEN class OpenIdConnect(SecurityBase): @@ -9,4 +11,9 @@ class OpenIdConnect(SecurityBase): self.scheme_name = scheme_name or self.__class__.__name__ async def __call__(self, request: Request) -> str: - return request.headers.get("Authorization") + authorization: str = request.headers.get("Authorization") + if not authorization: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + return authorization diff --git a/fastapi/security/utils.py b/fastapi/security/utils.py new file mode 100644 index 000000000..3ddd83a43 --- /dev/null +++ b/fastapi/security/utils.py @@ -0,0 +1,5 @@ +def get_authorization_scheme_param(authorization_header_value: str): + if not authorization_header_value: + return "", "" + scheme, _, param = authorization_header_value.partition(" ") + return scheme, param diff --git a/tests/main.py b/tests/main.py index c384bbb75..ab0b18607 100644 --- a/tests/main.py +++ b/tests/main.py @@ -1,6 +1,4 @@ -from fastapi import Depends, FastAPI, Path, Query, Security -from fastapi.security import OAuth2PasswordBearer -from pydantic import BaseModel +from fastapi import FastAPI, Path, Query app = FastAPI() @@ -144,8 +142,6 @@ def get_path_param_le_ge_int(item_id: int = Path(..., le=3, ge=1)): @app.get("/query") def get_query(query): - if query is None: - return "foo bar" return f"foo bar {query}" @@ -158,8 +154,6 @@ def get_query_optional(query=None): @app.get("/query/int") def get_query_type(query: int): - if query is None: - return "foo bar" return f"foo bar {query}" @@ -184,30 +178,9 @@ def get_query_param(query=Query(None)): @app.get("/query/param-required") def get_query_param_required(query=Query(...)): - if query is None: - return "foo bar" return f"foo bar {query}" @app.get("/query/param-required/int") def get_query_param_required_type(query: int = Query(...)): - if query is None: - return "foo bar" return f"foo bar {query}" - - -reusable_oauth2b = OAuth2PasswordBearer(tokenUrl="/token") - - -class User(BaseModel): - username: str - - -def get_current_user(oauth_header: str = Security(reusable_oauth2b)): - user = User(username=oauth_header) - return user - - -@app.get("/security/oauth2b") -def read_current_user(current_user: User = Depends(get_current_user)): - return current_user diff --git a/tests/test_application.py b/tests/test_application.py index 55394c19d..fb0f53975 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1078,19 +1078,6 @@ openapi_schema = { ], } }, - "/security/oauth2b": { - "get": { - "responses": { - "200": { - "description": "Successful Response", - "content": {"application/json": {"schema": {}}}, - } - }, - "summary": "Read Current User Get", - "operationId": "read_current_user_security_oauth2b_get", - "security": [{"OAuth2PasswordBearer": []}], - } - }, }, "components": { "schemas": { @@ -1119,13 +1106,7 @@ openapi_schema = { } }, }, - }, - "securitySchemes": { - "OAuth2PasswordBearer": { - "type": "oauth2", - "flows": {"password": {"scopes": {}, "tokenUrl": "/token"}}, - } - }, + } }, } @@ -1134,6 +1115,7 @@ openapi_schema = { "path,expected_status,expected_response", [ ("/api_route", 200, {"message": "Hello World"}), + ("/non_decorated_route", 200, {"message": "Hello World"}), ("/nonexistent", 404, {"detail": "Not Found"}), ("/openapi.json", 200, openapi_schema), ], diff --git a/tests/test_extra_routes.py b/tests/test_extra_routes.py index d07b90d3f..6147c3414 100644 --- a/tests/test_extra_routes.py +++ b/tests/test_extra_routes.py @@ -343,7 +343,7 @@ def test_head(): def test_options(): - response = client.head("/items/foo") + response = client.options("/items/foo") assert response.status_code == 200 assert response.headers["x-fastapi-item-id"] == "foo" diff --git a/tests/test_include_route.py b/tests/test_include_route.py new file mode 100644 index 000000000..c194d2060 --- /dev/null +++ b/tests/test_include_route.py @@ -0,0 +1,23 @@ +from fastapi import APIRouter, FastAPI +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.testclient import TestClient + +app = FastAPI() +router = APIRouter() + + +@router.route("/items/") +def read_items(request: Request): + return JSONResponse({"hello": "world"}) + + +app.include_router(router) + +client = TestClient(app) + + +def test_sub_router(): + response = client.get("/items/") + assert response.status_code == 200 + assert response.json() == {"hello": "world"} diff --git a/tests/test_query.py b/tests/test_query.py index 17d120287..92cff2bb5 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -40,9 +40,19 @@ response_not_valid_int = { ("/query/int?query=42.5", 422, response_not_valid_int), ("/query/int?query=baz", 422, response_not_valid_int), ("/query/int?not_declared=baz", 422, response_missing), + ("/query/int/optional", 200, "foo bar"), + ("/query/int/optional?query=50", 200, "foo bar 50"), + ("/query/int/optional?query=foo", 422, response_not_valid_int), ("/query/int/default", 200, "foo bar 10"), ("/query/int/default?query=50", 200, "foo bar 50"), ("/query/int/default?query=foo", 422, response_not_valid_int), + ("/query/param", 200, "foo bar"), + ("/query/param?query=50", 200, "foo bar 50"), + ("/query/param-required", 422, response_missing), + ("/query/param-required?query=50", 200, "foo bar 50"), + ("/query/param-required/int", 422, response_missing), + ("/query/param-required/int?query=50", 200, "foo bar 50"), + ("/query/param-required/int?query=foo", 422, response_not_valid_int), ], ) def test_get_path(path, expected_status, expected_response): diff --git a/tests/test_security.py b/tests/test_security.py deleted file mode 100644 index 672a8460f..000000000 --- a/tests/test_security.py +++ /dev/null @@ -1,25 +0,0 @@ -from starlette.testclient import TestClient - -from .main import app - -client = TestClient(app) - - -def test_security_oauth2_password_bearer(): - response = client.get( - "/security/oauth2b", headers={"Authorization": "Bearer footokenbar"} - ) - assert response.status_code == 200 - assert response.json() == {"username": "footokenbar"} - - -def test_security_oauth2_password_bearer_wrong_header(): - response = client.get("/security/oauth2b", headers={"Authorization": "footokenbar"}) - assert response.status_code == 403 - assert response.json() == {"detail": "Not authenticated"} - - -def test_security_oauth2_password_bearer_no_header(): - response = client.get("/security/oauth2b") - assert response.status_code == 403 - assert response.json() == {"detail": "Not authenticated"} diff --git a/tests/test_security_api_key_cookie.py b/tests/test_security_api_key_cookie.py new file mode 100644 index 000000000..88b3eef01 --- /dev/null +++ b/tests/test_security_api_key_cookie.py @@ -0,0 +1,68 @@ +from fastapi import Depends, FastAPI, Security +from fastapi.security import APIKeyCookie +from pydantic import BaseModel +from starlette.testclient import TestClient + +app = FastAPI() + +api_key = APIKeyCookie(name="key") + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: str = Security(api_key)): + user = User(username=oauth_header) + return user + + +@app.get("/users/me") +def read_current_user(current_user: User = Depends(get_current_user)): + return current_user + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/users/me": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Current User Get", + "operationId": "read_current_user_users_me_get", + "security": [{"APIKeyCookie": []}], + } + } + }, + "components": { + "securitySchemes": { + "APIKeyCookie": {"type": "apiKey", "name": "key", "in": "cookie"} + } + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_security_api_key(): + response = client.get("/users/me", cookies={"key": "secret"}) + assert response.status_code == 200 + assert response.json() == {"username": "secret"} + + +def test_security_api_key_no_key(): + response = client.get("/users/me") + assert response.status_code == 403 + assert response.json() == {"detail": "Not authenticated"} diff --git a/tests/test_security_api_key_header.py b/tests/test_security_api_key_header.py new file mode 100644 index 000000000..2d6114d3e --- /dev/null +++ b/tests/test_security_api_key_header.py @@ -0,0 +1,68 @@ +from fastapi import Depends, FastAPI, Security +from fastapi.security import APIKeyHeader +from pydantic import BaseModel +from starlette.testclient import TestClient + +app = FastAPI() + +api_key = APIKeyHeader(name="key") + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: str = Security(api_key)): + user = User(username=oauth_header) + return user + + +@app.get("/users/me") +def read_current_user(current_user: User = Depends(get_current_user)): + return current_user + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/users/me": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Current User Get", + "operationId": "read_current_user_users_me_get", + "security": [{"APIKeyHeader": []}], + } + } + }, + "components": { + "securitySchemes": { + "APIKeyHeader": {"type": "apiKey", "name": "key", "in": "header"} + } + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_security_api_key(): + response = client.get("/users/me", headers={"key": "secret"}) + assert response.status_code == 200 + assert response.json() == {"username": "secret"} + + +def test_security_api_key_no_key(): + response = client.get("/users/me") + assert response.status_code == 403 + assert response.json() == {"detail": "Not authenticated"} diff --git a/tests/test_security_api_key_query.py b/tests/test_security_api_key_query.py new file mode 100644 index 000000000..599b2540c --- /dev/null +++ b/tests/test_security_api_key_query.py @@ -0,0 +1,68 @@ +from fastapi import Depends, FastAPI, Security +from fastapi.security import APIKeyQuery +from pydantic import BaseModel +from starlette.testclient import TestClient + +app = FastAPI() + +api_key = APIKeyQuery(name="key") + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: str = Security(api_key)): + user = User(username=oauth_header) + return user + + +@app.get("/users/me") +def read_current_user(current_user: User = Depends(get_current_user)): + return current_user + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/users/me": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Current User Get", + "operationId": "read_current_user_users_me_get", + "security": [{"APIKeyQuery": []}], + } + } + }, + "components": { + "securitySchemes": { + "APIKeyQuery": {"type": "apiKey", "name": "key", "in": "query"} + } + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_security_api_key(): + response = client.get("/users/me?key=secret") + assert response.status_code == 200 + assert response.json() == {"username": "secret"} + + +def test_security_api_key_no_key(): + response = client.get("/users/me") + assert response.status_code == 403 + assert response.json() == {"detail": "Not authenticated"} diff --git a/tests/test_security_oauth2.py b/tests/test_security_oauth2.py new file mode 100644 index 000000000..050f17da3 --- /dev/null +++ b/tests/test_security_oauth2.py @@ -0,0 +1,247 @@ +import pytest +from fastapi import Depends, FastAPI, Security +from fastapi.security import OAuth2 +from fastapi.security.oauth2 import OAuth2PasswordRequestFormStrict +from pydantic import BaseModel +from starlette.testclient import TestClient + +app = FastAPI() + +reusable_oauth2 = OAuth2( + flows={ + "password": { + "tokenUrl": "/token", + "scopes": {"read:users": "Read the users", "write:users": "Create users"}, + } + } +) + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: str = Security(reusable_oauth2)): + user = User(username=oauth_header) + return user + + +@app.post("/login") +def read_current_user(form_data: OAuth2PasswordRequestFormStrict = Depends()): + return form_data + + +@app.get("/users/me") +def read_current_user(current_user: User = Depends(get_current_user)): + return current_user + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/login": { + "post": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Read Current User Post", + "operationId": "read_current_user_login_post", + "requestBody": { + "content": { + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Body_read_current_user" + } + } + }, + "required": True, + }, + } + }, + "/users/me": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Current User Get", + "operationId": "read_current_user_users_me_get", + "security": [{"OAuth2": []}], + } + }, + }, + "components": { + "schemas": { + "Body_read_current_user": { + "title": "Body_read_current_user", + "required": ["grant_type", "username", "password"], + "type": "object", + "properties": { + "grant_type": { + "title": "Grant_Type", + "pattern": "password", + "type": "string", + }, + "username": {"title": "Username", "type": "string"}, + "password": {"title": "Password", "type": "string"}, + "scope": {"title": "Scope", "type": "string", "default": ""}, + "client_id": {"title": "Client_Id", "type": "string"}, + "client_secret": {"title": "Client_Secret", "type": "string"}, + }, + }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + }, + "securitySchemes": { + "OAuth2": { + "type": "oauth2", + "flows": { + "password": { + "scopes": { + "read:users": "Read the users", + "write:users": "Create users", + }, + "tokenUrl": "/token", + } + }, + } + }, + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_security_oauth2(): + response = client.get("/users/me", headers={"Authorization": "Bearer footokenbar"}) + assert response.status_code == 200 + assert response.json() == {"username": "Bearer footokenbar"} + + +def test_security_oauth2_password_other_header(): + response = client.get("/users/me", headers={"Authorization": "Other footokenbar"}) + assert response.status_code == 200 + assert response.json() == {"username": "Other footokenbar"} + + +def test_security_oauth2_password_bearer_no_header(): + response = client.get("/users/me") + assert response.status_code == 403 + assert response.json() == {"detail": "Not authenticated"} + + +required_params = { + "detail": [ + { + "loc": ["body", "grant_type"], + "msg": "field required", + "type": "value_error.missing", + }, + { + "loc": ["body", "username"], + "msg": "field required", + "type": "value_error.missing", + }, + { + "loc": ["body", "password"], + "msg": "field required", + "type": "value_error.missing", + }, + ] +} + +grant_type_required = { + "detail": [ + { + "loc": ["body", "grant_type"], + "msg": "field required", + "type": "value_error.missing", + } + ] +} + +grant_type_incorrect = { + "detail": [ + { + "loc": ["body", "grant_type"], + "msg": 'string does not match regex "password"', + "type": "value_error.str.regex", + "ctx": {"pattern": "password"}, + } + ] +} + + +@pytest.mark.parametrize( + "data,expected_status,expected_response", + [ + (None, 422, required_params), + ({"username": "johndoe", "password": "secret"}, 422, grant_type_required), + ( + {"username": "johndoe", "password": "secret", "grant_type": "incorrect"}, + 422, + grant_type_incorrect, + ), + ( + {"username": "johndoe", "password": "secret", "grant_type": "password"}, + 200, + { + "grant_type": "password", + "username": "johndoe", + "password": "secret", + "scopes": [], + "client_id": None, + "client_secret": None, + }, + ), + ], +) +def test_strict_login(data, expected_status, expected_response): + response = client.post("/login", data=data) + assert response.status_code == expected_status + assert response.json() == expected_response diff --git a/tests/test_security_openid_connect.py b/tests/test_security_openid_connect.py new file mode 100644 index 000000000..ce19dd92e --- /dev/null +++ b/tests/test_security_openid_connect.py @@ -0,0 +1,74 @@ +from fastapi import Depends, FastAPI, Security +from fastapi.security.open_id_connect_url import OpenIdConnect +from pydantic import BaseModel +from starlette.testclient import TestClient + +app = FastAPI() + +oid = OpenIdConnect(openIdConnectUrl="/openid") + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: str = Security(oid)): + user = User(username=oauth_header) + return user + + +@app.get("/users/me") +def read_current_user(current_user: User = Depends(get_current_user)): + return current_user + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/users/me": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Current User Get", + "operationId": "read_current_user_users_me_get", + "security": [{"OpenIdConnect": []}], + } + } + }, + "components": { + "securitySchemes": { + "OpenIdConnect": {"type": "openIdConnect", "openIdConnectUrl": "/openid"} + } + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_security_oauth2(): + response = client.get("/users/me", headers={"Authorization": "Bearer footokenbar"}) + assert response.status_code == 200 + assert response.json() == {"username": "Bearer footokenbar"} + + +def test_security_oauth2_password_other_header(): + response = client.get("/users/me", headers={"Authorization": "Other footokenbar"}) + assert response.status_code == 200 + assert response.json() == {"username": "Other footokenbar"} + + +def test_security_oauth2_password_bearer_no_header(): + response = client.get("/users/me") + assert response.status_code == 403 + assert response.json() == {"detail": "Not authenticated"}