From 4ac55af283457d7279711224c5f9a3810d4d6534 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 12 Jun 2023 00:16:01 +0200 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Update=20internal=20type?= =?UTF-8?q?=20annotations=20and=20upgrade=20mypy=20(#9658)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/openapi/models.py | 13 ++++++++----- fastapi/security/api_key.py | 12 +++++++++--- fastapi/security/oauth2.py | 25 ++++++++++++++++--------- requirements-tests.txt | 2 +- 4 files changed, 34 insertions(+), 18 deletions(-) diff --git a/fastapi/openapi/models.py b/fastapi/openapi/models.py index 11edfe38a..81a24f389 100644 --- a/fastapi/openapi/models.py +++ b/fastapi/openapi/models.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Union from fastapi.logger import logger from pydantic import AnyUrl, BaseModel, Field +from typing_extensions import Literal try: import email_validator # type: ignore @@ -298,18 +299,18 @@ class APIKeyIn(Enum): class APIKey(SecurityBase): - type_ = Field(SecuritySchemeType.apiKey, alias="type") + type_: SecuritySchemeType = Field(default=SecuritySchemeType.apiKey, alias="type") in_: APIKeyIn = Field(alias="in") name: str class HTTPBase(SecurityBase): - type_ = Field(SecuritySchemeType.http, alias="type") + type_: SecuritySchemeType = Field(default=SecuritySchemeType.http, alias="type") scheme: str class HTTPBearer(HTTPBase): - scheme = "bearer" + scheme: Literal["bearer"] = "bearer" bearerFormat: Optional[str] = None @@ -349,12 +350,14 @@ class OAuthFlows(BaseModel): class OAuth2(SecurityBase): - type_ = Field(SecuritySchemeType.oauth2, alias="type") + type_: SecuritySchemeType = Field(default=SecuritySchemeType.oauth2, alias="type") flows: OAuthFlows class OpenIdConnect(SecurityBase): - type_ = Field(SecuritySchemeType.openIdConnect, alias="type") + type_: SecuritySchemeType = Field( + default=SecuritySchemeType.openIdConnect, alias="type" + ) openIdConnectUrl: str diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py index 61730187a..8b2c5c080 100644 --- a/fastapi/security/api_key.py +++ b/fastapi/security/api_key.py @@ -21,7 +21,9 @@ class APIKeyQuery(APIKeyBase): auto_error: bool = True, ): self.model: APIKey = APIKey( - **{"in": APIKeyIn.query}, name=name, description=description + **{"in": APIKeyIn.query}, # type: ignore[arg-type] + name=name, + description=description, ) self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error @@ -48,7 +50,9 @@ class APIKeyHeader(APIKeyBase): auto_error: bool = True, ): self.model: APIKey = APIKey( - **{"in": APIKeyIn.header}, name=name, description=description + **{"in": APIKeyIn.header}, # type: ignore[arg-type] + name=name, + description=description, ) self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error @@ -75,7 +79,9 @@ class APIKeyCookie(APIKeyBase): auto_error: bool = True, ): self.model: APIKey = APIKey( - **{"in": APIKeyIn.cookie}, name=name, description=description + **{"in": APIKeyIn.cookie}, # type: ignore[arg-type] + name=name, + description=description, ) self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py index dc75dc9fe..938dec37c 100644 --- a/fastapi/security/oauth2.py +++ b/fastapi/security/oauth2.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, cast from fastapi.exceptions import HTTPException from fastapi.openapi.models import OAuth2 as OAuth2Model @@ -121,7 +121,9 @@ class OAuth2(SecurityBase): description: Optional[str] = None, auto_error: bool = True, ): - self.model = OAuth2Model(flows=flows, description=description) + self.model = OAuth2Model( + flows=cast(OAuthFlowsModel, flows), description=description + ) self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error @@ -148,7 +150,9 @@ class OAuth2PasswordBearer(OAuth2): ): if not scopes: scopes = {} - flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes}) + flows = OAuthFlowsModel( + password=cast(Any, {"tokenUrl": tokenUrl, "scopes": scopes}) + ) super().__init__( flows=flows, scheme_name=scheme_name, @@ -185,12 +189,15 @@ class OAuth2AuthorizationCodeBearer(OAuth2): if not scopes: scopes = {} flows = OAuthFlowsModel( - authorizationCode={ - "authorizationUrl": authorizationUrl, - "tokenUrl": tokenUrl, - "refreshUrl": refreshUrl, - "scopes": scopes, - } + authorizationCode=cast( + Any, + { + "authorizationUrl": authorizationUrl, + "tokenUrl": tokenUrl, + "refreshUrl": refreshUrl, + "scopes": scopes, + }, + ) ) super().__init__( flows=flows, diff --git a/requirements-tests.txt b/requirements-tests.txt index 52a44cec5..5105071be 100644 --- a/requirements-tests.txt +++ b/requirements-tests.txt @@ -1,7 +1,7 @@ -e . pytest >=7.1.3,<8.0.0 coverage[toml] >= 6.5.0,< 8.0 -mypy ==0.982 +mypy ==1.3.0 ruff ==0.0.138 black == 23.1.0 isort >=5.0.6,<6.0.0