From 19daeca2e22c9c3c2d299a74dbc2e0428c9b0f4c Mon Sep 17 00:00:00 2001 From: Yurii Motov Date: Mon, 9 Jun 2025 12:29:37 +0200 Subject: [PATCH] Move common logic of `APIKey**` classes to `APIKeyBase` class --- fastapi/security/api_key.py | 55 ++++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py index 70c2dca8a..d0616c73f 100644 --- a/fastapi/security/api_key.py +++ b/fastapi/security/api_key.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.security.base import SecurityBase @@ -9,10 +9,27 @@ from typing_extensions import Annotated, Doc class APIKeyBase(SecurityBase): - @staticmethod - def check_api_key(api_key: Optional[str], auto_error: bool) -> Optional[str]: + def __init__( + self, + location: APIKeyIn, + name: str, + description: Union[str, None], + scheme_name: Union[str, None], + auto_error: bool, + ): + self.parameter_location = location.value + self.parameter_name = name + self.auto_error = auto_error + self.model: APIKey = APIKey( + **{"in": location}, # type: ignore[arg-type] + name=name, + description=description, + ) + self.scheme_name = scheme_name or self.__class__.__name__ + + def check_api_key(self, api_key: Optional[str]) -> Optional[str]: if not api_key: - if auto_error: + if self.auto_error: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" ) @@ -99,17 +116,17 @@ class APIKeyQuery(APIKeyBase): ), ] = True, ): - self.model: APIKey = APIKey( - **{"in": APIKeyIn.query}, # type: ignore[arg-type] + super().__init__( + location=APIKeyIn.query, name=name, + scheme_name=scheme_name, description=description, + auto_error=auto_error, ) - self.scheme_name = scheme_name or self.__class__.__name__ - self.auto_error = auto_error async def __call__(self, request: Request) -> Optional[str]: api_key = request.query_params.get(self.model.name) - return self.check_api_key(api_key, self.auto_error) + return self.check_api_key(api_key) class APIKeyHeader(APIKeyBase): @@ -187,17 +204,17 @@ class APIKeyHeader(APIKeyBase): ), ] = True, ): - self.model: APIKey = APIKey( - **{"in": APIKeyIn.header}, # type: ignore[arg-type] + super().__init__( + location=APIKeyIn.header, name=name, + scheme_name=scheme_name, description=description, + auto_error=auto_error, ) - self.scheme_name = scheme_name or self.__class__.__name__ - self.auto_error = auto_error async def __call__(self, request: Request) -> Optional[str]: api_key = request.headers.get(self.model.name) - return self.check_api_key(api_key, self.auto_error) + return self.check_api_key(api_key) class APIKeyCookie(APIKeyBase): @@ -275,14 +292,14 @@ class APIKeyCookie(APIKeyBase): ), ] = True, ): - self.model: APIKey = APIKey( - **{"in": APIKeyIn.cookie}, # type: ignore[arg-type] + super().__init__( + location=APIKeyIn.cookie, name=name, + scheme_name=scheme_name, description=description, + auto_error=auto_error, ) - self.scheme_name = scheme_name or self.__class__.__name__ - self.auto_error = auto_error async def __call__(self, request: Request) -> Optional[str]: api_key = request.cookies.get(self.model.name) - return self.check_api_key(api_key, self.auto_error) + return self.check_api_key(api_key)