Browse Source

♻️ Update internal type annotations and upgrade mypy (#9658)

pull/9661/head
Sebastián Ramírez 2 years ago
committed by GitHub
parent
commit
4ac55af283
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 13
      fastapi/openapi/models.py
  2. 12
      fastapi/security/api_key.py
  3. 25
      fastapi/security/oauth2.py
  4. 2
      requirements-tests.txt

13
fastapi/openapi/models.py

@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from fastapi.logger import logger from fastapi.logger import logger
from pydantic import AnyUrl, BaseModel, Field from pydantic import AnyUrl, BaseModel, Field
from typing_extensions import Literal
try: try:
import email_validator # type: ignore import email_validator # type: ignore
@ -298,18 +299,18 @@ class APIKeyIn(Enum):
class APIKey(SecurityBase): class APIKey(SecurityBase):
type_ = Field(SecuritySchemeType.apiKey, alias="type") type_: SecuritySchemeType = Field(default=SecuritySchemeType.apiKey, alias="type")
in_: APIKeyIn = Field(alias="in") in_: APIKeyIn = Field(alias="in")
name: str name: str
class HTTPBase(SecurityBase): class HTTPBase(SecurityBase):
type_ = Field(SecuritySchemeType.http, alias="type") type_: SecuritySchemeType = Field(default=SecuritySchemeType.http, alias="type")
scheme: str scheme: str
class HTTPBearer(HTTPBase): class HTTPBearer(HTTPBase):
scheme = "bearer" scheme: Literal["bearer"] = "bearer"
bearerFormat: Optional[str] = None bearerFormat: Optional[str] = None
@ -349,12 +350,14 @@ class OAuthFlows(BaseModel):
class OAuth2(SecurityBase): class OAuth2(SecurityBase):
type_ = Field(SecuritySchemeType.oauth2, alias="type") type_: SecuritySchemeType = Field(default=SecuritySchemeType.oauth2, alias="type")
flows: OAuthFlows flows: OAuthFlows
class OpenIdConnect(SecurityBase): class OpenIdConnect(SecurityBase):
type_ = Field(SecuritySchemeType.openIdConnect, alias="type") type_: SecuritySchemeType = Field(
default=SecuritySchemeType.openIdConnect, alias="type"
)
openIdConnectUrl: str openIdConnectUrl: str

12
fastapi/security/api_key.py

@ -21,7 +21,9 @@ class APIKeyQuery(APIKeyBase):
auto_error: bool = True, auto_error: bool = True,
): ):
self.model: APIKey = APIKey( self.model: APIKey = APIKey(
**{"in": APIKeyIn.query}, name=name, description=description **{"in": APIKeyIn.query}, # type: ignore[arg-type]
name=name,
description=description,
) )
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@ -48,7 +50,9 @@ class APIKeyHeader(APIKeyBase):
auto_error: bool = True, auto_error: bool = True,
): ):
self.model: APIKey = APIKey( self.model: APIKey = APIKey(
**{"in": APIKeyIn.header}, name=name, description=description **{"in": APIKeyIn.header}, # type: ignore[arg-type]
name=name,
description=description,
) )
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@ -75,7 +79,9 @@ class APIKeyCookie(APIKeyBase):
auto_error: bool = True, auto_error: bool = True,
): ):
self.model: APIKey = APIKey( self.model: APIKey = APIKey(
**{"in": APIKeyIn.cookie}, name=name, description=description **{"in": APIKeyIn.cookie}, # type: ignore[arg-type]
name=name,
description=description,
) )
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error

25
fastapi/security/oauth2.py

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union, cast
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from fastapi.openapi.models import OAuth2 as OAuth2Model from fastapi.openapi.models import OAuth2 as OAuth2Model
@ -121,7 +121,9 @@ class OAuth2(SecurityBase):
description: Optional[str] = None, description: Optional[str] = None,
auto_error: bool = True, auto_error: bool = True,
): ):
self.model = OAuth2Model(flows=flows, description=description) self.model = OAuth2Model(
flows=cast(OAuthFlowsModel, flows), description=description
)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@ -148,7 +150,9 @@ class OAuth2PasswordBearer(OAuth2):
): ):
if not scopes: if not scopes:
scopes = {} scopes = {}
flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes}) flows = OAuthFlowsModel(
password=cast(Any, {"tokenUrl": tokenUrl, "scopes": scopes})
)
super().__init__( super().__init__(
flows=flows, flows=flows,
scheme_name=scheme_name, scheme_name=scheme_name,
@ -185,12 +189,15 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
if not scopes: if not scopes:
scopes = {} scopes = {}
flows = OAuthFlowsModel( flows = OAuthFlowsModel(
authorizationCode={ authorizationCode=cast(
"authorizationUrl": authorizationUrl, Any,
"tokenUrl": tokenUrl, {
"refreshUrl": refreshUrl, "authorizationUrl": authorizationUrl,
"scopes": scopes, "tokenUrl": tokenUrl,
} "refreshUrl": refreshUrl,
"scopes": scopes,
},
)
) )
super().__init__( super().__init__(
flows=flows, flows=flows,

2
requirements-tests.txt

@ -1,7 +1,7 @@
-e . -e .
pytest >=7.1.3,<8.0.0 pytest >=7.1.3,<8.0.0
coverage[toml] >= 6.5.0,< 8.0 coverage[toml] >= 6.5.0,< 8.0
mypy ==0.982 mypy ==1.3.0
ruff ==0.0.138 ruff ==0.0.138
black == 23.1.0 black == 23.1.0
isort >=5.0.6,<6.0.0 isort >=5.0.6,<6.0.0

Loading…
Cancel
Save