From f216d340ec852e72c002d920e68552bf0da7e364 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 21 Apr 2019 21:44:25 +0400 Subject: [PATCH] :sparkles: Add automatic header handling for HTTP Basic Auth (#175) * :sparkles: Add automatic header handling for HTTP Basic Auth * :art: Remove obsolete comment --- fastapi/security/http.py | 20 ++++-- tests/test_security_http_basic.py | 9 ++- tests/test_security_http_basic_optional.py | 6 +- tests/test_security_http_basic_realm.py | 79 ++++++++++++++++++++++ 4 files changed, 102 insertions(+), 12 deletions(-) create mode 100644 tests/test_security_http_basic_realm.py diff --git a/fastapi/security/http.py b/fastapi/security/http.py index b2da3fcb5..f41d8d944 100644 --- a/fastapi/security/http.py +++ b/fastapi/security/http.py @@ -2,6 +2,7 @@ import binascii from base64 import b64decode from typing import Optional +from fastapi.exceptions import HTTPException from fastapi.openapi.models import ( HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel, @@ -9,9 +10,8 @@ from fastapi.openapi.models import ( from fastapi.security.base import SecurityBase from fastapi.security.utils import get_authorization_scheme_param from pydantic import BaseModel -from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.status import HTTP_403_FORBIDDEN +from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN class HTTPBasicCredentials(BaseModel): @@ -59,15 +59,21 @@ class HTTPBasic(HTTPBase): async def __call__(self, request: Request) -> Optional[HTTPBasicCredentials]: authorization: str = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) - # before implementing headers with 401 errors, wait for: https://github.com/encode/starlette/issues/295 - # unauthorized_headers = {"WWW-Authenticate": "Basic"} + if self.realm: + unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'} + else: + unauthorized_headers = {"WWW-Authenticate": "Basic"} invalid_user_credentials_exc = HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Invalid authentication credentials" + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers=unauthorized_headers, ) if not authorization or scheme.lower() != "basic": if self.auto_error: raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers=unauthorized_headers, ) else: return None @@ -87,7 +93,7 @@ class HTTPBearer(HTTPBase): *, bearerFormat: str = None, scheme_name: str = None, - auto_error: bool = True + auto_error: bool = True, ): self.model = HTTPBearerModel(bearerFormat=bearerFormat) self.scheme_name = scheme_name or self.__class__.__name__ diff --git a/tests/test_security_http_basic.py b/tests/test_security_http_basic.py index dd289301d..7d380fef0 100644 --- a/tests/test_security_http_basic.py +++ b/tests/test_security_http_basic.py @@ -56,15 +56,17 @@ def test_security_http_basic(): def test_security_http_basic_no_credentials(): response = client.get("/users/me") - assert response.status_code == 403 assert response.json() == {"detail": "Not authenticated"} + assert response.status_code == 401 + assert response.headers["WWW-Authenticate"] == "Basic" def test_security_http_basic_invalid_credentials(): response = client.get( "/users/me", headers={"Authorization": "Basic notabase64token"} ) - assert response.status_code == 403 + assert response.status_code == 401 + assert response.headers["WWW-Authenticate"] == "Basic" assert response.json() == {"detail": "Invalid authentication credentials"} @@ -72,5 +74,6 @@ def test_security_http_basic_non_basic_credentials(): payload = b64encode(b"johnsecret").decode("ascii") auth_header = f"Basic {payload}" response = client.get("/users/me", headers={"Authorization": auth_header}) - assert response.status_code == 403 + assert response.status_code == 401 + assert response.headers["WWW-Authenticate"] == "Basic" assert response.json() == {"detail": "Invalid authentication credentials"} diff --git a/tests/test_security_http_basic_optional.py b/tests/test_security_http_basic_optional.py index 40d64d412..2a4686bb3 100644 --- a/tests/test_security_http_basic_optional.py +++ b/tests/test_security_http_basic_optional.py @@ -67,7 +67,8 @@ def test_security_http_basic_invalid_credentials(): response = client.get( "/users/me", headers={"Authorization": "Basic notabase64token"} ) - assert response.status_code == 403 + assert response.status_code == 401 + assert response.headers["WWW-Authenticate"] == "Basic" assert response.json() == {"detail": "Invalid authentication credentials"} @@ -75,5 +76,6 @@ def test_security_http_basic_non_basic_credentials(): payload = b64encode(b"johnsecret").decode("ascii") auth_header = f"Basic {payload}" response = client.get("/users/me", headers={"Authorization": auth_header}) - assert response.status_code == 403 + assert response.status_code == 401 + assert response.headers["WWW-Authenticate"] == "Basic" assert response.json() == {"detail": "Invalid authentication credentials"} diff --git a/tests/test_security_http_basic_realm.py b/tests/test_security_http_basic_realm.py new file mode 100644 index 000000000..6b5b4aeee --- /dev/null +++ b/tests/test_security_http_basic_realm.py @@ -0,0 +1,79 @@ +from base64 import b64encode + +from fastapi import FastAPI, Security +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from requests.auth import HTTPBasicAuth +from starlette.testclient import TestClient + +app = FastAPI() + +security = HTTPBasic(realm="simple") + + +@app.get("/users/me") +def read_current_user(credentials: HTTPBasicCredentials = Security(security)): + return {"username": credentials.username, "password": credentials.password} + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/users/me": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Read Current User", + "operationId": "read_current_user_users_me_get", + "security": [{"HTTPBasic": []}], + } + } + }, + "components": { + "securitySchemes": {"HTTPBasic": {"type": "http", "scheme": "basic"}} + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_security_http_basic(): + auth = HTTPBasicAuth(username="john", password="secret") + response = client.get("/users/me", auth=auth) + assert response.status_code == 200 + assert response.json() == {"username": "john", "password": "secret"} + + +def test_security_http_basic_no_credentials(): + response = client.get("/users/me") + assert response.json() == {"detail": "Not authenticated"} + assert response.status_code == 401 + assert response.headers["WWW-Authenticate"] == 'Basic realm="simple"' + + +def test_security_http_basic_invalid_credentials(): + response = client.get( + "/users/me", headers={"Authorization": "Basic notabase64token"} + ) + assert response.status_code == 401 + assert response.headers["WWW-Authenticate"] == 'Basic realm="simple"' + assert response.json() == {"detail": "Invalid authentication credentials"} + + +def test_security_http_basic_non_basic_credentials(): + payload = b64encode(b"johnsecret").decode("ascii") + auth_header = f"Basic {payload}" + response = client.get("/users/me", headers={"Authorization": auth_header}) + assert response.status_code == 401 + assert response.headers["WWW-Authenticate"] == 'Basic realm="simple"' + assert response.json() == {"detail": "Invalid authentication credentials"}