From fdb6c9ccc504f90afd0fbcec53f3ea0bfebc261a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 20 Dec 2020 19:50:00 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Improve=20type=20annotations,=20add?= =?UTF-8?q?=20support=20for=20mypy=20--strict,=20internally=20and=20for=20?= =?UTF-8?q?external=20packages=20(#2547)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs_src/openapi_callbacks/tutorial001.py | 2 +- fastapi/__init__.py | 39 ++++--- fastapi/applications.py | 89 ++++++++------- fastapi/background.py | 2 +- fastapi/concurrency.py | 22 ++-- fastapi/datastructures.py | 3 +- fastapi/dependencies/models.py | 4 +- fastapi/dependencies/utils.py | 38 +++---- fastapi/encoders.py | 14 ++- fastapi/middleware/__init__.py | 2 +- fastapi/middleware/cors.py | 2 +- fastapi/middleware/gzip.py | 2 +- fastapi/middleware/httpsredirect.py | 4 +- fastapi/middleware/trustedhost.py | 4 +- fastapi/middleware/wsgi.py | 2 +- fastapi/openapi/docs.py | 4 +- fastapi/openapi/models.py | 4 +- fastapi/openapi/utils.py | 52 +++++---- fastapi/param_functions.py | 4 +- fastapi/params.py | 4 +- fastapi/responses.py | 16 +-- fastapi/routing.py | 106 ++++++++++-------- fastapi/security/__init__.py | 32 +++--- fastapi/security/oauth2.py | 8 +- fastapi/staticfiles.py | 2 +- fastapi/templating.py | 2 +- fastapi/testclient.py | 2 +- fastapi/types.py | 3 + fastapi/utils.py | 9 +- fastapi/websockets.py | 4 +- mypy.ini | 22 ++++ pyproject.toml | 4 +- tests/test_custom_route_class.py | 9 +- tests/test_get_request_body.py | 2 +- .../test_include_router_defaults_overrides.py | 10 +- tests/test_inherited_custom_class.py | 2 +- tests/test_jsonable_encoder.py | 4 +- tests/test_local_docs.py | 10 +- tests/test_multi_body_errors.py | 2 +- tests/test_param_class.py | 2 +- tests/test_params_repr.py | 4 +- tests/test_starlette_urlconvertors.py | 4 +- tests/test_sub_callbacks.py | 2 +- 43 files changed, 314 insertions(+), 244 deletions(-) create mode 100644 fastapi/types.py diff --git a/docs_src/openapi_callbacks/tutorial001.py b/docs_src/openapi_callbacks/tutorial001.py index f04fec4d7..2fb836751 100644 --- a/docs_src/openapi_callbacks/tutorial001.py +++ b/docs_src/openapi_callbacks/tutorial001.py @@ -26,7 +26,7 @@ invoices_callback_router = APIRouter() @invoices_callback_router.post( - "{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived, + "{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived ) def invoice_notification(body: InvoiceEvent): pass diff --git a/fastapi/__init__.py b/fastapi/__init__.py index 3d1a699d9..858da48c5 100644 --- a/fastapi/__init__.py +++ b/fastapi/__init__.py @@ -2,24 +2,23 @@ __version__ = "0.62.0" -from starlette import status +from starlette import status as status -from .applications import FastAPI -from .background import BackgroundTasks -from .datastructures import UploadFile -from .exceptions import HTTPException -from .param_functions import ( - Body, - Cookie, - Depends, - File, - Form, - Header, - Path, - Query, - Security, -) -from .requests import Request -from .responses import Response -from .routing import APIRouter -from .websockets import WebSocket, WebSocketDisconnect +from .applications import FastAPI as FastAPI +from .background import BackgroundTasks as BackgroundTasks +from .datastructures import UploadFile as UploadFile +from .exceptions import HTTPException as HTTPException +from .param_functions import Body as Body +from .param_functions import Cookie as Cookie +from .param_functions import Depends as Depends +from .param_functions import File as File +from .param_functions import Form as Form +from .param_functions import Header as Header +from .param_functions import Path as Path +from .param_functions import Query as Query +from .param_functions import Security as Security +from .requests import Request as Request +from .responses import Response as Response +from .routing import APIRouter as APIRouter +from .websockets import WebSocket as WebSocket +from .websockets import WebSocketDisconnect as WebSocketDisconnect diff --git a/fastapi/applications.py b/fastapi/applications.py index 519dc74ae..92d041c5c 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union +from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union from fastapi import routing from fastapi.concurrency import AsyncExitStack @@ -17,6 +17,7 @@ from fastapi.openapi.docs import ( ) from fastapi.openapi.utils import get_openapi from fastapi.params import Depends +from fastapi.types import DecoratedCallable from starlette.applications import Starlette from starlette.datastructures import State from starlette.exceptions import HTTPException @@ -24,7 +25,7 @@ from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import HTMLResponse, JSONResponse, Response from starlette.routing import BaseRoute -from starlette.types import Receive, Scope, Send +from starlette.types import ASGIApp, Receive, Scope, Send class FastAPI(Starlette): @@ -44,24 +45,27 @@ class FastAPI(Starlette): docs_url: Optional[str] = "/docs", redoc_url: Optional[str] = "/redoc", swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect", - swagger_ui_init_oauth: Optional[dict] = None, + swagger_ui_init_oauth: Optional[Dict[str, Any]] = None, middleware: Optional[Sequence[Middleware]] = None, exception_handlers: Optional[ - Dict[Union[int, Type[Exception]], Callable] + Dict[ + Union[int, Type[Exception]], + Callable[[Request, Any], Coroutine[Any, Any, Response]], + ] ] = None, - on_startup: Optional[Sequence[Callable]] = None, - on_shutdown: Optional[Sequence[Callable]] = None, + on_startup: Optional[Sequence[Callable[[], Any]]] = None, + on_shutdown: Optional[Sequence[Callable[[], Any]]] = None, openapi_prefix: str = "", root_path: str = "", root_path_in_servers: bool = True, responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, - callbacks: Optional[List[routing.APIRoute]] = None, - deprecated: bool = None, + callbacks: Optional[List[BaseRoute]] = None, + deprecated: Optional[bool] = None, include_in_schema: bool = True, **extra: Any, ) -> None: - self._debug = debug - self.state = State() + self._debug: bool = debug + self.state: State = State() self.router: routing.APIRouter = routing.APIRouter( routes=routes, dependency_overrides_provider=self, @@ -74,7 +78,10 @@ class FastAPI(Starlette): include_in_schema=include_in_schema, responses=responses, ) - self.exception_handlers = ( + self.exception_handlers: Dict[ + Union[int, Type[Exception]], + Callable[[Request, Any], Coroutine[Any, Any, Response]], + ] = ( {} if exception_handlers is None else dict(exception_handlers) ) self.exception_handlers.setdefault(HTTPException, http_exception_handler) @@ -82,8 +89,10 @@ class FastAPI(Starlette): RequestValidationError, request_validation_exception_handler ) - self.user_middleware = [] if middleware is None else list(middleware) - self.middleware_stack = self.build_middleware_stack() + self.user_middleware: List[Middleware] = ( + [] if middleware is None else list(middleware) + ) + self.middleware_stack: ASGIApp = self.build_middleware_stack() self.title = title self.description = description @@ -106,7 +115,7 @@ class FastAPI(Starlette): self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url self.swagger_ui_init_oauth = swagger_ui_init_oauth self.extra = extra - self.dependency_overrides: Dict[Callable, Callable] = {} + self.dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] = {} self.openapi_version = "3.0.2" @@ -116,7 +125,7 @@ class FastAPI(Starlette): self.openapi_schema: Optional[Dict[str, Any]] = None self.setup() - def openapi(self) -> Dict: + def openapi(self) -> Dict[str, Any]: if not self.openapi_schema: self.openapi_schema = get_openapi( title=self.title, @@ -194,7 +203,7 @@ class FastAPI(Starlette): def add_api_route( self, path: str, - endpoint: Callable, + endpoint: Callable[..., Coroutine[Any, Any, Response]], *, response_model: Optional[Type[Any]] = None, status_code: int = 200, @@ -268,8 +277,8 @@ class FastAPI(Starlette): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - ) -> Callable: - def decorator(func: Callable) -> Callable: + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + def decorator(func: DecoratedCallable) -> DecoratedCallable: self.router.add_api_route( path, func, @@ -299,12 +308,14 @@ class FastAPI(Starlette): return decorator def add_api_websocket_route( - self, path: str, endpoint: Callable, name: Optional[str] = None + self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None ) -> None: self.router.add_api_websocket_route(path, endpoint, name=name) - def websocket(self, path: str, name: Optional[str] = None) -> Callable: - def decorator(func: Callable) -> Callable: + def websocket( + self, path: str, name: Optional[str] = None + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_api_websocket_route(path, func, name=name) return func @@ -318,10 +329,10 @@ class FastAPI(Starlette): tags: Optional[List[str]] = None, dependencies: Optional[Sequence[Depends]] = None, responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, - deprecated: bool = None, + deprecated: Optional[bool] = None, include_in_schema: bool = True, default_response_class: Type[Response] = Default(JSONResponse), - callbacks: Optional[List[routing.APIRoute]] = None, + callbacks: Optional[List[BaseRoute]] = None, ) -> None: self.router.include_router( router, @@ -358,8 +369,8 @@ class FastAPI(Starlette): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[routing.APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.router.get( path, response_model=response_model, @@ -407,8 +418,8 @@ class FastAPI(Starlette): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[routing.APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.router.put( path, response_model=response_model, @@ -456,8 +467,8 @@ class FastAPI(Starlette): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[routing.APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.router.post( path, response_model=response_model, @@ -505,8 +516,8 @@ class FastAPI(Starlette): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[routing.APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.router.delete( path, response_model=response_model, @@ -554,8 +565,8 @@ class FastAPI(Starlette): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[routing.APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.router.options( path, response_model=response_model, @@ -603,8 +614,8 @@ class FastAPI(Starlette): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[routing.APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.router.head( path, response_model=response_model, @@ -652,8 +663,8 @@ class FastAPI(Starlette): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[routing.APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.router.patch( path, response_model=response_model, @@ -701,8 +712,8 @@ class FastAPI(Starlette): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[routing.APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.router.trace( path, response_model=response_model, diff --git a/fastapi/background.py b/fastapi/background.py index 2d0d3d35e..dd3bbe249 100644 --- a/fastapi/background.py +++ b/fastapi/background.py @@ -1 +1 @@ -from starlette.background import BackgroundTasks # noqa +from starlette.background import BackgroundTasks as BackgroundTasks # noqa diff --git a/fastapi/concurrency.py b/fastapi/concurrency.py index 451923c55..d1fdfe5f6 100644 --- a/fastapi/concurrency.py +++ b/fastapi/concurrency.py @@ -1,8 +1,10 @@ from typing import Any, Callable -from starlette.concurrency import iterate_in_threadpool # noqa -from starlette.concurrency import run_in_threadpool # noqa -from starlette.concurrency import run_until_first_complete # noqa +from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa +from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa +from starlette.concurrency import ( # noqa + run_until_first_complete as run_until_first_complete, +) asynccontextmanager_error_message = """ FastAPI's contextmanager_in_threadpool require Python 3.7 or above, @@ -11,7 +13,7 @@ or the backport for Python 3.6, installed with: """ -def _fake_asynccontextmanager(func: Callable) -> Callable: +def _fake_asynccontextmanager(func: Callable[..., Any]) -> Callable[..., Any]: def raiser(*args: Any, **kwargs: Any) -> Any: raise RuntimeError(asynccontextmanager_error_message) @@ -19,23 +21,25 @@ def _fake_asynccontextmanager(func: Callable) -> Callable: try: - from contextlib import asynccontextmanager # type: ignore + from contextlib import asynccontextmanager as asynccontextmanager # type: ignore except ImportError: try: - from async_generator import asynccontextmanager # type: ignore + from async_generator import ( # type: ignore # isort: skip + asynccontextmanager as asynccontextmanager, + ) except ImportError: # pragma: no cover asynccontextmanager = _fake_asynccontextmanager try: - from contextlib import AsyncExitStack # type: ignore + from contextlib import AsyncExitStack as AsyncExitStack # type: ignore except ImportError: try: - from async_exit_stack import AsyncExitStack # type: ignore + from async_exit_stack import AsyncExitStack as AsyncExitStack # type: ignore except ImportError: # pragma: no cover AsyncExitStack = None # type: ignore -@asynccontextmanager +@asynccontextmanager # type: ignore async def contextmanager_in_threadpool(cm: Any) -> Any: try: yield await run_in_threadpool(cm.__enter__) diff --git a/fastapi/datastructures.py b/fastapi/datastructures.py index 1fe8ebdad..f22409c51 100644 --- a/fastapi/datastructures.py +++ b/fastapi/datastructures.py @@ -1,11 +1,12 @@ from typing import Any, Callable, Iterable, Type, TypeVar +from starlette.datastructures import State as State # noqa: F401 from starlette.datastructures import UploadFile as StarletteUploadFile class UploadFile(StarletteUploadFile): @classmethod - def __get_validators__(cls: Type["UploadFile"]) -> Iterable[Callable]: + def __get_validators__(cls: Type["UploadFile"]) -> Iterable[Callable[..., Any]]: yield cls.validate @classmethod diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 4e2294bd7..443590b9c 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional, Sequence +from typing import Any, Callable, List, Optional, Sequence from fastapi.security.base import SecurityBase from pydantic.fields import ModelField @@ -24,7 +24,7 @@ class Dependant: dependencies: Optional[List["Dependant"]] = None, security_schemes: Optional[List[SecurityRequirement]] = None, name: Optional[str] = None, - call: Optional[Callable] = None, + call: Optional[Callable[..., Any]] = None, request_param_name: Optional[str] = None, websocket_param_name: Optional[str] = None, http_connection_param_name: Optional[str] = None, diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 35329a46a..fcfaa2cb1 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -90,12 +90,12 @@ def check_file_field(field: ModelField) -> None: if isinstance(field_info, params.Form): try: # __version__ is available in both multiparts, and can be mocked - from multipart import __version__ + from multipart import __version__ # type: ignore assert __version__ try: # parse_options_header is only available in the right multipart - from multipart.multipart import parse_options_header + from multipart.multipart import parse_options_header # type: ignore assert parse_options_header except ImportError: @@ -133,7 +133,7 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De def get_sub_dependant( *, depends: params.Depends, - dependency: Callable, + dependency: Callable[..., Any], path: str, name: Optional[str] = None, security_scopes: Optional[List[str]] = None, @@ -163,7 +163,7 @@ def get_sub_dependant( return sub_dependant -CacheKey = Tuple[Optional[Callable], Tuple[str, ...]] +CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] def get_flat_dependant( @@ -240,7 +240,7 @@ def is_scalar_sequence_field(field: ModelField) -> bool: return False -def get_typed_signature(call: Callable) -> inspect.Signature: +def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) typed_params = [ @@ -259,9 +259,7 @@ def get_typed_signature(call: Callable) -> inspect.Signature: def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any: annotation = param.annotation if isinstance(annotation, str): - # Temporary ignore type - # Ref: https://github.com/samuelcolvin/pydantic/issues/1738 - annotation = ForwardRef(annotation) # type: ignore + annotation = ForwardRef(annotation) annotation = evaluate_forwardref(annotation, globalns, globalns) return annotation @@ -281,7 +279,7 @@ def check_dependency_contextmanagers() -> None: def get_dependant( *, path: str, - call: Callable, + call: Callable[..., Any], name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, @@ -423,7 +421,7 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: dependant.cookie_params.append(field) -def is_coroutine_callable(call: Callable) -> bool: +def is_coroutine_callable(call: Callable[..., Any]) -> bool: if inspect.isroutine(call): return inspect.iscoroutinefunction(call) if inspect.isclass(call): @@ -432,14 +430,14 @@ def is_coroutine_callable(call: Callable) -> bool: return inspect.iscoroutinefunction(call) -def is_async_gen_callable(call: Callable) -> bool: +def is_async_gen_callable(call: Callable[..., Any]) -> bool: if inspect.isasyncgenfunction(call): return True call = getattr(call, "__call__", None) return inspect.isasyncgenfunction(call) -def is_gen_callable(call: Callable) -> bool: +def is_gen_callable(call: Callable[..., Any]) -> bool: if inspect.isgeneratorfunction(call): return True call = getattr(call, "__call__", None) @@ -447,7 +445,7 @@ def is_gen_callable(call: Callable) -> bool: async def solve_generator( - *, call: Callable, stack: AsyncExitStack, sub_values: Dict[str, Any] + *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] ) -> Any: if is_gen_callable(call): cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) @@ -472,29 +470,29 @@ async def solve_dependencies( background_tasks: Optional[BackgroundTasks] = None, response: Optional[Response] = None, dependency_overrides_provider: Optional[Any] = None, - dependency_cache: Optional[Dict[Tuple[Callable, Tuple[str]], Any]] = None, + dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, ) -> Tuple[ Dict[str, Any], List[ErrorWrapper], Optional[BackgroundTasks], Response, - Dict[Tuple[Callable, Tuple[str]], Any], + Dict[Tuple[Callable[..., Any], Tuple[str]], Any], ]: values: Dict[str, Any] = {} errors: List[ErrorWrapper] = [] response = response or Response( content=None, status_code=None, # type: ignore - headers=None, - media_type=None, - background=None, + headers=None, # type: ignore # in Starlette + media_type=None, # type: ignore # in Starlette + background=None, # type: ignore # in Starlette ) dependency_cache = dependency_cache or {} sub_dependant: Dependant for sub_dependant in dependant.dependencies: - sub_dependant.call = cast(Callable, sub_dependant.call) + sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) sub_dependant.cache_key = cast( - Tuple[Callable, Tuple[str]], sub_dependant.cache_key + Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key ) call = sub_dependant.call use_sub_dependant = sub_dependant diff --git a/fastapi/encoders.py b/fastapi/encoders.py index 1255b7497..6a2a75dda 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -12,9 +12,11 @@ DictIntStrAny = Dict[Union[int, str], Any] def generate_encoders_by_class_tuples( - type_encoder_map: Dict[Any, Callable] -) -> Dict[Callable, Tuple]: - encoders_by_class_tuples: Dict[Callable, Tuple] = defaultdict(tuple) + type_encoder_map: Dict[Any, Callable[[Any], Any]] +) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]: + encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict( + tuple + ) for type_, encoder in type_encoder_map.items(): encoders_by_class_tuples[encoder] += (type_,) return encoders_by_class_tuples @@ -31,7 +33,7 @@ def jsonable_encoder( exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - custom_encoder: dict = {}, + custom_encoder: Dict[Any, Callable[[Any], Any]] = {}, sqlalchemy_safe: bool = True, ) -> Any: if include is not None and not isinstance(include, set): @@ -43,8 +45,8 @@ def jsonable_encoder( if custom_encoder: encoder.update(custom_encoder) obj_dict = obj.dict( - include=include, - exclude=exclude, + include=include, # type: ignore # in Pydantic + exclude=exclude, # type: ignore # in Pydantic by_alias=by_alias, exclude_unset=exclude_unset, exclude_none=exclude_none, diff --git a/fastapi/middleware/__init__.py b/fastapi/middleware/__init__.py index 6601b1783..620296d5a 100644 --- a/fastapi/middleware/__init__.py +++ b/fastapi/middleware/__init__.py @@ -1 +1 @@ -from starlette.middleware import Middleware +from starlette.middleware import Middleware as Middleware diff --git a/fastapi/middleware/cors.py b/fastapi/middleware/cors.py index 4c08a161a..8dfaad0db 100644 --- a/fastapi/middleware/cors.py +++ b/fastapi/middleware/cors.py @@ -1 +1 @@ -from starlette.middleware.cors import CORSMiddleware # noqa +from starlette.middleware.cors import CORSMiddleware as CORSMiddleware # noqa diff --git a/fastapi/middleware/gzip.py b/fastapi/middleware/gzip.py index 08460d07e..bbeb2cc78 100644 --- a/fastapi/middleware/gzip.py +++ b/fastapi/middleware/gzip.py @@ -1 +1 @@ -from starlette.middleware.gzip import GZipMiddleware # noqa +from starlette.middleware.gzip import GZipMiddleware as GZipMiddleware # noqa diff --git a/fastapi/middleware/httpsredirect.py b/fastapi/middleware/httpsredirect.py index 674263af3..b7a3d8e07 100644 --- a/fastapi/middleware/httpsredirect.py +++ b/fastapi/middleware/httpsredirect.py @@ -1 +1,3 @@ -from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware # noqa +from starlette.middleware.httpsredirect import ( # noqa + HTTPSRedirectMiddleware as HTTPSRedirectMiddleware, +) diff --git a/fastapi/middleware/trustedhost.py b/fastapi/middleware/trustedhost.py index b16aee872..08d7e0353 100644 --- a/fastapi/middleware/trustedhost.py +++ b/fastapi/middleware/trustedhost.py @@ -1 +1,3 @@ -from starlette.middleware.trustedhost import TrustedHostMiddleware # noqa +from starlette.middleware.trustedhost import ( # noqa + TrustedHostMiddleware as TrustedHostMiddleware, +) diff --git a/fastapi/middleware/wsgi.py b/fastapi/middleware/wsgi.py index bf8d3e66e..c4c6a797d 100644 --- a/fastapi/middleware/wsgi.py +++ b/fastapi/middleware/wsgi.py @@ -1 +1 @@ -from starlette.middleware.wsgi import WSGIMiddleware # noqa +from starlette.middleware.wsgi import WSGIMiddleware as WSGIMiddleware # noqa diff --git a/fastapi/openapi/docs.py b/fastapi/openapi/docs.py index 44c4e69a3..fd22e4e8c 100644 --- a/fastapi/openapi/docs.py +++ b/fastapi/openapi/docs.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Dict, Optional from fastapi.encoders import jsonable_encoder from starlette.responses import HTMLResponse @@ -13,7 +13,7 @@ def get_swagger_ui_html( swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css", swagger_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png", oauth2_redirect_url: Optional[str] = None, - init_oauth: Optional[dict] = None, + init_oauth: Optional[Dict[str, Any]] = None, ) -> HTMLResponse: html = f""" diff --git a/fastapi/openapi/models.py b/fastapi/openapi/models.py index 3b716766d..fd480946d 100644 --- a/fastapi/openapi/models.py +++ b/fastapi/openapi/models.py @@ -5,7 +5,7 @@ from fastapi.logger import logger from pydantic import AnyUrl, BaseModel, Field try: - import email_validator + import email_validator # type: ignore assert email_validator # make autoflake ignore the unused import from pydantic import EmailStr @@ -13,7 +13,7 @@ except ImportError: # pragma: no cover class EmailStr(str): # type: ignore @classmethod - def __get_validators__(cls) -> Iterable[Callable]: + def __get_validators__(cls) -> Iterable[Callable[..., Any]]: yield cls.validate @classmethod diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 5547cce4f..410ba9389 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -14,6 +14,7 @@ from fastapi.openapi.constants import ( ) from fastapi.openapi.models import OpenAPI from fastapi.params import Body, Param +from fastapi.responses import Response from fastapi.utils import ( deep_dict_update, generate_operation_id_for_path, @@ -64,7 +65,9 @@ status_code_ranges: Dict[str, str] = { } -def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]: +def get_openapi_security_definitions( + flat_dependant: Dependant, +) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: security_definitions = {} operation_security = [] for security_requirement in flat_dependant.security_requirements: @@ -88,13 +91,12 @@ def get_openapi_operation_parameters( for param in all_route_params: field_info = param.field_info field_info = cast(Param, field_info) - # ignore mypy error until enum schemas are released parameter = { "name": param.alias, "in": field_info.in_.value, "required": param.required, "schema": field_schema( - param, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore + param, model_name_map=model_name_map, ref_prefix=REF_PREFIX )[0], } if field_info.description: @@ -109,13 +111,12 @@ def get_openapi_operation_request_body( *, body_field: Optional[ModelField], model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], -) -> Optional[Dict]: +) -> Optional[Dict[str, Any]]: if not body_field: return None assert isinstance(body_field, ModelField) - # ignore mypy error until enum schemas are released body_schema, _, _ = field_schema( - body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore + body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX ) field_info = cast(Body, body_field.field_info) request_media_type = field_info.media_type @@ -140,7 +141,9 @@ def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str: return route.name.replace("_", " ").title() -def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> Dict: +def get_openapi_operation_metadata( + *, route: routing.APIRoute, method: str +) -> Dict[str, Any]: operation: Dict[str, Any] = {} if route.tags: operation["tags"] = route.tags @@ -154,14 +157,14 @@ def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> D def get_openapi_path( - *, route: routing.APIRoute, model_name_map: Dict[Type, str] -) -> Tuple[Dict, Dict, Dict]: + *, route: routing.APIRoute, model_name_map: Dict[type, str] +) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: path = {} security_schemes: Dict[str, Any] = {} definitions: Dict[str, Any] = {} assert route.methods is not None, "Methods must be a list" if isinstance(route.response_class, DefaultPlaceholder): - current_response_class: Type[routing.Response] = route.response_class.value + current_response_class: Type[Response] = route.response_class.value else: current_response_class = route.response_class assert current_response_class, "A response class is needed to generate OpenAPI" @@ -169,7 +172,7 @@ def get_openapi_path( if route.include_in_schema: for method in route.methods: operation = get_openapi_operation_metadata(route=route, method=method) - parameters: List[Dict] = [] + parameters: List[Dict[str, Any]] = [] flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True) security_definitions, operation_security = get_openapi_security_definitions( flat_dependant=flat_dependant @@ -196,10 +199,15 @@ def get_openapi_path( if route.callbacks: callbacks = {} for callback in route.callbacks: - cb_path, cb_security_schemes, cb_definitions, = get_openapi_path( - route=callback, model_name_map=model_name_map - ) - callbacks[callback.name] = {callback.path: cb_path} + if isinstance(callback, routing.APIRoute): + ( + cb_path, + cb_security_schemes, + cb_definitions, + ) = get_openapi_path( + route=callback, model_name_map=model_name_map + ) + callbacks[callback.name] = {callback.path: cb_path} operation["callbacks"] = callbacks status_code = str(route.status_code) operation.setdefault("responses", {}).setdefault(status_code, {})[ @@ -332,21 +340,19 @@ def get_openapi( routes: Sequence[BaseRoute], tags: Optional[List[Dict[str, Any]]] = None, servers: Optional[List[Dict[str, Union[str, Any]]]] = None, -) -> Dict: +) -> Dict[str, Any]: info = {"title": title, "version": version} if description: info["description"] = description output: Dict[str, Any] = {"openapi": openapi_version, "info": info} if servers: output["servers"] = servers - components: Dict[str, Dict] = {} - paths: Dict[str, Dict] = {} + components: Dict[str, Dict[str, Any]] = {} + paths: Dict[str, Dict[str, Any]] = {} flat_models = get_flat_models_from_routes(routes) - # ignore mypy error until enum schemas are released - model_name_map = get_model_name_map(flat_models) # type: ignore - # ignore mypy error until enum schemas are released + model_name_map = get_model_name_map(flat_models) definitions = get_model_definitions( - flat_models=flat_models, model_name_map=model_name_map # type: ignore + flat_models=flat_models, model_name_map=model_name_map ) for route in routes: if isinstance(route, routing.APIRoute): @@ -368,4 +374,4 @@ def get_openapi( output["paths"] = paths if tags: output["tags"] = tags - return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) + return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index 91620c7c0..9ebb59100 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -239,13 +239,13 @@ def File( # noqa: N802 def Depends( # noqa: N802 - dependency: Optional[Callable] = None, *, use_cache: bool = True + dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True ) -> Any: return params.Depends(dependency=dependency, use_cache=use_cache) def Security( # noqa: N802 - dependency: Optional[Callable] = None, + dependency: Optional[Callable[..., Any]] = None, *, scopes: Optional[Sequence[str]] = None, use_cache: bool = True, diff --git a/fastapi/params.py b/fastapi/params.py index f53e2dba9..aa3269a80 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -315,7 +315,7 @@ class File(Form): class Depends: def __init__( - self, dependency: Optional[Callable] = None, *, use_cache: bool = True + self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True ): self.dependency = dependency self.use_cache = use_cache @@ -329,7 +329,7 @@ class Depends: class Security(Depends): def __init__( self, - dependency: Optional[Callable] = None, + dependency: Optional[Callable[..., Any]] = None, *, scopes: Optional[Sequence[str]] = None, use_cache: bool = True, diff --git a/fastapi/responses.py b/fastapi/responses.py index 0aeff61d0..8d9d62dfb 100644 --- a/fastapi/responses.py +++ b/fastapi/responses.py @@ -1,13 +1,13 @@ from typing import Any -from starlette.responses import FileResponse # noqa -from starlette.responses import HTMLResponse # noqa -from starlette.responses import JSONResponse # noqa -from starlette.responses import PlainTextResponse # noqa -from starlette.responses import RedirectResponse # noqa -from starlette.responses import Response # noqa -from starlette.responses import StreamingResponse # noqa -from starlette.responses import UJSONResponse # noqa +from starlette.responses import FileResponse as FileResponse # noqa +from starlette.responses import HTMLResponse as HTMLResponse # noqa +from starlette.responses import JSONResponse as JSONResponse # noqa +from starlette.responses import PlainTextResponse as PlainTextResponse # noqa +from starlette.responses import RedirectResponse as RedirectResponse # noqa +from starlette.responses import Response as Response # noqa +from starlette.responses import StreamingResponse as StreamingResponse # noqa +from starlette.responses import UJSONResponse as UJSONResponse # noqa try: import orjson diff --git a/fastapi/routing.py b/fastapi/routing.py index 53f35a4a5..ac5e19d99 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -2,7 +2,18 @@ import asyncio import enum import inspect import json -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Type, Union +from typing import ( + Any, + Callable, + Coroutine, + Dict, + List, + Optional, + Sequence, + Set, + Type, + Union, +) from fastapi import params from fastapi.datastructures import Default, DefaultPlaceholder @@ -16,6 +27,7 @@ from fastapi.dependencies.utils import ( from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY +from fastapi.types import DecoratedCallable from fastapi.utils import ( create_cloned_field, create_response_field, @@ -30,7 +42,8 @@ from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import JSONResponse, Response -from starlette.routing import Mount # noqa +from starlette.routing import BaseRoute +from starlette.routing import Mount as Mount # noqa from starlette.routing import ( compile_path, get_name, @@ -150,7 +163,7 @@ def get_request_handler( response_model_exclude_defaults: bool = False, response_model_exclude_none: bool = False, dependency_overrides_provider: Optional[Any] = None, -) -> Callable: +) -> Callable[[Request], Coroutine[Any, Any, Response]]: assert dependant.call is not None, "dependant.call must be a function" is_coroutine = asyncio.iscoroutinefunction(dependant.call) is_body_form = body_field and isinstance(body_field.field_info, params.Form) @@ -207,7 +220,7 @@ def get_request_handler( response = actual_response_class( content=response_data, status_code=status_code, - background=background_tasks, + background=background_tasks, # type: ignore # in Starlette ) response.headers.raw.extend(sub_response.headers.raw) if sub_response.status_code: @@ -219,7 +232,7 @@ def get_request_handler( def get_websocket_app( dependant: Dependant, dependency_overrides_provider: Optional[Any] = None -) -> Callable: +) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: async def app(websocket: WebSocket) -> None: solved_result = await solve_dependencies( request=websocket, @@ -240,7 +253,7 @@ class APIWebSocketRoute(routing.WebSocketRoute): def __init__( self, path: str, - endpoint: Callable, + endpoint: Callable[..., Any], *, name: Optional[str] = None, dependency_overrides_provider: Optional[Any] = None, @@ -262,7 +275,7 @@ class APIRoute(routing.Route): def __init__( self, path: str, - endpoint: Callable, + endpoint: Callable[..., Any], *, response_model: Optional[Type[Any]] = None, status_code: int = 200, @@ -287,7 +300,7 @@ class APIRoute(routing.Route): JSONResponse ), dependency_overrides_provider: Optional[Any] = None, - callbacks: Optional[List["APIRoute"]] = None, + callbacks: Optional[List[BaseRoute]] = None, ) -> None: # normalise enums e.g. http.HTTPStatus if isinstance(status_code, enum.IntEnum): @@ -298,7 +311,7 @@ class APIRoute(routing.Route): self.path_regex, self.path_format, self.param_convertors = compile_path(path) if methods is None: methods = ["GET"] - self.methods = set([method.upper() for method in methods]) + self.methods: Set[str] = set([method.upper() for method in methods]) self.unique_id = generate_operation_id_for_path( name=self.name, path=self.path_format, method=list(methods)[0] ) @@ -375,7 +388,7 @@ class APIRoute(routing.Route): self.callbacks = callbacks self.app = request_response(self.get_route_handler()) - def get_route_handler(self) -> Callable: + def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: return get_request_handler( dependant=self.dependant, body_field=self.body_field, @@ -401,23 +414,23 @@ class APIRouter(routing.Router): dependencies: Optional[Sequence[params.Depends]] = None, default_response_class: Type[Response] = Default(JSONResponse), responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, - callbacks: Optional[List[APIRoute]] = None, + callbacks: Optional[List[BaseRoute]] = None, routes: Optional[List[routing.BaseRoute]] = None, redirect_slashes: bool = True, default: Optional[ASGIApp] = None, dependency_overrides_provider: Optional[Any] = None, route_class: Type[APIRoute] = APIRoute, - on_startup: Optional[Sequence[Callable]] = None, - on_shutdown: Optional[Sequence[Callable]] = None, - deprecated: bool = None, + on_startup: Optional[Sequence[Callable[[], Any]]] = None, + on_shutdown: Optional[Sequence[Callable[[], Any]]] = None, + deprecated: Optional[bool] = None, include_in_schema: bool = True, ) -> None: super().__init__( - routes=routes, + routes=routes, # type: ignore # in Starlette redirect_slashes=redirect_slashes, - default=default, - on_startup=on_startup, - on_shutdown=on_shutdown, + default=default, # type: ignore # in Starlette + on_startup=on_startup, # type: ignore # in Starlette + on_shutdown=on_shutdown, # type: ignore # in Starlette ) if prefix: assert prefix.startswith("/"), "A path prefix must start with '/'" @@ -438,7 +451,7 @@ class APIRouter(routing.Router): def add_api_route( self, path: str, - endpoint: Callable, + endpoint: Callable[..., Any], *, response_model: Optional[Type[Any]] = None, status_code: int = 200, @@ -463,7 +476,7 @@ class APIRouter(routing.Router): ), name: Optional[str] = None, route_class_override: Optional[Type[APIRoute]] = None, - callbacks: Optional[List[APIRoute]] = None, + callbacks: Optional[List[BaseRoute]] = None, ) -> None: route_class = route_class_override or self.route_class responses = responses or {} @@ -532,9 +545,9 @@ class APIRouter(routing.Router): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[APIRoute]] = None, - ) -> Callable: - def decorator(func: Callable) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_api_route( path, func, @@ -565,7 +578,7 @@ class APIRouter(routing.Router): return decorator def add_api_websocket_route( - self, path: str, endpoint: Callable, name: Optional[str] = None + self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None ) -> None: route = APIWebSocketRoute( path, @@ -575,8 +588,10 @@ class APIRouter(routing.Router): ) self.routes.append(route) - def websocket(self, path: str, name: Optional[str] = None) -> Callable: - def decorator(func: Callable) -> Callable: + def websocket( + self, path: str, name: Optional[str] = None + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_api_websocket_route(path, func, name=name) return func @@ -591,8 +606,8 @@ class APIRouter(routing.Router): dependencies: Optional[Sequence[params.Depends]] = None, default_response_class: Type[Response] = Default(JSONResponse), responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, - callbacks: Optional[List[APIRoute]] = None, - deprecated: bool = None, + callbacks: Optional[List[BaseRoute]] = None, + deprecated: Optional[bool] = None, include_in_schema: bool = True, ) -> None: if prefix: @@ -663,10 +678,11 @@ class APIRouter(routing.Router): callbacks=current_callbacks, ) elif isinstance(route, routing.Route): + methods = list(route.methods or []) # type: ignore # in Starlette self.add_route( prefix + route.path, route.endpoint, - methods=list(route.methods or []), + methods=methods, include_in_schema=route.include_in_schema, name=route.name, ) @@ -706,8 +722,8 @@ class APIRouter(routing.Router): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.api_route( path=path, response_model=response_model, @@ -756,8 +772,8 @@ class APIRouter(routing.Router): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.api_route( path=path, response_model=response_model, @@ -806,8 +822,8 @@ class APIRouter(routing.Router): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.api_route( path=path, response_model=response_model, @@ -856,8 +872,8 @@ class APIRouter(routing.Router): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.api_route( path=path, response_model=response_model, @@ -906,8 +922,8 @@ class APIRouter(routing.Router): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.api_route( path=path, response_model=response_model, @@ -956,8 +972,8 @@ class APIRouter(routing.Router): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.api_route( path=path, response_model=response_model, @@ -1006,8 +1022,8 @@ class APIRouter(routing.Router): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.api_route( path=path, response_model=response_model, @@ -1056,8 +1072,8 @@ class APIRouter(routing.Router): include_in_schema: bool = True, response_class: Type[Response] = Default(JSONResponse), name: Optional[str] = None, - callbacks: Optional[List[APIRoute]] = None, - ) -> Callable: + callbacks: Optional[List[BaseRoute]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: return self.api_route( path=path, diff --git a/fastapi/security/__init__.py b/fastapi/security/__init__.py index ad727742c..3aa6bf21e 100644 --- a/fastapi/security/__init__.py +++ b/fastapi/security/__init__.py @@ -1,17 +1,15 @@ -from .api_key import APIKeyCookie, APIKeyHeader, APIKeyQuery -from .http import ( - HTTPAuthorizationCredentials, - HTTPBasic, - HTTPBasicCredentials, - HTTPBearer, - HTTPDigest, -) -from .oauth2 import ( - OAuth2, - OAuth2AuthorizationCodeBearer, - OAuth2PasswordBearer, - OAuth2PasswordRequestForm, - OAuth2PasswordRequestFormStrict, - SecurityScopes, -) -from .open_id_connect_url import OpenIdConnect +from .api_key import APIKeyCookie as APIKeyCookie +from .api_key import APIKeyHeader as APIKeyHeader +from .api_key import APIKeyQuery as APIKeyQuery +from .http import HTTPAuthorizationCredentials as HTTPAuthorizationCredentials +from .http import HTTPBasic as HTTPBasic +from .http import HTTPBasicCredentials as HTTPBasicCredentials +from .http import HTTPBearer as HTTPBearer +from .http import HTTPDigest as HTTPDigest +from .oauth2 import OAuth2 as OAuth2 +from .oauth2 import OAuth2AuthorizationCodeBearer as OAuth2AuthorizationCodeBearer +from .oauth2 import OAuth2PasswordBearer as OAuth2PasswordBearer +from .oauth2 import OAuth2PasswordRequestForm as OAuth2PasswordRequestForm +from .oauth2 import OAuth2PasswordRequestFormStrict as OAuth2PasswordRequestFormStrict +from .oauth2 import SecurityScopes as SecurityScopes +from .open_id_connect_url import OpenIdConnect as OpenIdConnect diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py index 0d1a5f12f..46571ad53 100644 --- a/fastapi/security/oauth2.py +++ b/fastapi/security/oauth2.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, Dict, List, Optional, Union from fastapi.exceptions import HTTPException from fastapi.openapi.models import OAuth2 as OAuth2Model @@ -116,7 +116,7 @@ class OAuth2(SecurityBase): def __init__( self, *, - flows: OAuthFlowsModel = OAuthFlowsModel(), + flows: Union[OAuthFlowsModel, Dict[str, Dict[str, Any]]] = OAuthFlowsModel(), scheme_name: Optional[str] = None, auto_error: Optional[bool] = True ): @@ -141,7 +141,7 @@ class OAuth2PasswordBearer(OAuth2): self, tokenUrl: str, scheme_name: Optional[str] = None, - scopes: Optional[dict] = None, + scopes: Optional[Dict[str, str]] = None, auto_error: bool = True, ): if not scopes: @@ -171,7 +171,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2): tokenUrl: str, refreshUrl: Optional[str] = None, scheme_name: Optional[str] = None, - scopes: Optional[dict] = None, + scopes: Optional[Dict[str, str]] = None, auto_error: bool = True, ): if not scopes: diff --git a/fastapi/staticfiles.py b/fastapi/staticfiles.py index 78359dd1e..299015d4f 100644 --- a/fastapi/staticfiles.py +++ b/fastapi/staticfiles.py @@ -1 +1 @@ -from starlette.staticfiles import StaticFiles # noqa +from starlette.staticfiles import StaticFiles as StaticFiles # noqa diff --git a/fastapi/templating.py b/fastapi/templating.py index d4c035cf8..0cb868486 100644 --- a/fastapi/templating.py +++ b/fastapi/templating.py @@ -1 +1 @@ -from starlette.templating import Jinja2Templates # noqa +from starlette.templating import Jinja2Templates as Jinja2Templates # noqa diff --git a/fastapi/testclient.py b/fastapi/testclient.py index 0288f694c..4012406aa 100644 --- a/fastapi/testclient.py +++ b/fastapi/testclient.py @@ -1 +1 @@ -from starlette.testclient import TestClient # noqa +from starlette.testclient import TestClient as TestClient # noqa diff --git a/fastapi/types.py b/fastapi/types.py new file mode 100644 index 000000000..e0bca4632 --- /dev/null +++ b/fastapi/types.py @@ -0,0 +1,3 @@ +from typing import Any, Callable, TypeVar + +DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) diff --git a/fastapi/utils.py b/fastapi/utils.py index 058956e32..8913d85b2 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -19,11 +19,10 @@ def get_model_definitions( flat_models: Set[Union[Type[BaseModel], Type[Enum]]], model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], ) -> Dict[str, Any]: - definitions: Dict[str, Dict] = {} + definitions: Dict[str, Dict[str, Any]] = {} for model in flat_models: - # ignore mypy error until enum schemas are released m_schema, m_definitions, m_nested_models = model_process_schema( - model, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore + model, model_name_map=model_name_map, ref_prefix=REF_PREFIX ) definitions.update(m_definitions) model_name = model_name_map[model] @@ -80,7 +79,7 @@ def create_cloned_field( cloned_types = dict() original_type = field.type_ if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"): - original_type = original_type.__pydantic_model__ # type: ignore + original_type = original_type.__pydantic_model__ use_type = original_type if lenient_issubclass(original_type, BaseModel): original_type = cast(Type[BaseModel], original_type) @@ -127,7 +126,7 @@ def generate_operation_id_for_path(*, name: str, path: str, method: str) -> str: return operation_id -def deep_dict_update(main_dict: dict, update_dict: dict) -> None: +def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None: for key in update_dict: if ( key in main_dict diff --git a/fastapi/websockets.py b/fastapi/websockets.py index 2edf97328..bed672acf 100644 --- a/fastapi/websockets.py +++ b/fastapi/websockets.py @@ -1,2 +1,2 @@ -from starlette.websockets import WebSocket # noqa -from starlette.websockets import WebSocketDisconnect # noqa +from starlette.websockets import WebSocket as WebSocket # noqa +from starlette.websockets import WebSocketDisconnect as WebSocketDisconnect # noqa diff --git a/mypy.ini b/mypy.ini index 4ff4483ab..e6a33cffb 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,3 +1,25 @@ [mypy] + +# --strict +disallow_any_generics = True +disallow_subclassing_any = True +disallow_untyped_calls = True disallow_untyped_defs = True +disallow_incomplete_defs = True +check_untyped_defs = True +disallow_untyped_decorators = True +no_implicit_optional = True +warn_redundant_casts = True +warn_unused_ignores = True +warn_return_any = True +implicit_reexport = False +strict_equality = True +# --strict end + +[mypy-fastapi.concurrency] +warn_unused_ignores = False ignore_missing_imports = True + +[mypy-fastapi.tests.*] +ignore_missing_imports = True +check_untyped_defs = True diff --git a/pyproject.toml b/pyproject.toml index c17f63e84..3dc6b6f83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,9 +46,9 @@ test = [ "pytest ==5.4.3", "pytest-cov ==2.10.0", "pytest-asyncio >=0.14.0,<0.15.0", - "mypy ==0.782", + "mypy ==0.790", "flake8 >=3.8.3,<4.0.0", - "black ==19.10b0", + "black ==20.8b1", "isort >=5.0.6,<6.0.0", "requests >=2.24.0,<3.0.0", "httpx >=0.14.0,<0.15.0", diff --git a/tests/test_custom_route_class.py b/tests/test_custom_route_class.py index afca4732f..1a9ea7199 100644 --- a/tests/test_custom_route_class.py +++ b/tests/test_custom_route_class.py @@ -2,6 +2,7 @@ import pytest from fastapi import APIRouter, FastAPI from fastapi.routing import APIRoute from fastapi.testclient import TestClient +from starlette.routing import Route app = FastAPI() @@ -106,9 +107,9 @@ def test_get_path(path, expected_status, expected_response): def test_route_classes(): routes = {} - r: APIRoute for r in app.router.routes: + assert isinstance(r, Route) routes[r.path] = r - assert routes["/a/"].x_type == "A" - assert routes["/a/b/"].x_type == "B" - assert routes["/a/b/c/"].x_type == "C" + assert getattr(routes["/a/"], "x_type") == "A" + assert getattr(routes["/a/b/"], "x_type") == "B" + assert getattr(routes["/a/b/c/"], "x_type") == "C" diff --git a/tests/test_get_request_body.py b/tests/test_get_request_body.py index 348aee5f9..b12f499eb 100644 --- a/tests/test_get_request_body.py +++ b/tests/test_get_request_body.py @@ -7,7 +7,7 @@ app = FastAPI() class Product(BaseModel): name: str - description: str = None + description: str = None # type: ignore price: float diff --git a/tests/test_include_router_defaults_overrides.py b/tests/test_include_router_defaults_overrides.py index ecfa0b2fa..c46cb6701 100644 --- a/tests/test_include_router_defaults_overrides.py +++ b/tests/test_include_router_defaults_overrides.py @@ -175,7 +175,7 @@ async def path3_override_router2_override(level3: str): return level3 -@router2_override.get("/default3",) +@router2_override.get("/default3") async def path3_default_router2_override(level3: str): return level3 @@ -217,7 +217,9 @@ async def path5_override_router4_override(level5: str): return level5 -@router4_override.get("/default5",) +@router4_override.get( + "/default5", +) async def path5_default_router4_override(level5: str): return level5 @@ -238,7 +240,9 @@ async def path5_override_router4_default(level5: str): return level5 -@router4_default.get("/default5",) +@router4_default.get( + "/default5", +) async def path5_default_router4_default(level5: str): return level5 diff --git a/tests/test_inherited_custom_class.py b/tests/test_inherited_custom_class.py index 1ed5bf1b9..bac7eec1b 100644 --- a/tests/test_inherited_custom_class.py +++ b/tests/test_inherited_custom_class.py @@ -15,7 +15,7 @@ class MyUuid: def __str__(self): return self.uuid - @property + @property # type: ignore def __class__(self): return uuid.UUID diff --git a/tests/test_jsonable_encoder.py b/tests/test_jsonable_encoder.py index 87b2466e8..e2aa8adf8 100644 --- a/tests/test_jsonable_encoder.py +++ b/tests/test_jsonable_encoder.py @@ -71,7 +71,7 @@ class ModelWithAlias(BaseModel): class ModelWithDefault(BaseModel): - foo: str = ... + foo: str = ... # type: ignore bar: str = "bar" bla: str = "bla" @@ -88,7 +88,7 @@ def fixture_model_with_path(request): arbitrary_types_allowed = True ModelWithPath = create_model( - "ModelWithPath", path=(request.param, ...), __config__=Config + "ModelWithPath", path=(request.param, ...), __config__=Config # type: ignore ) return ModelWithPath(path=request.param("/foo", "bar")) diff --git a/tests/test_local_docs.py b/tests/test_local_docs.py index 0ef777030..5f102edf1 100644 --- a/tests/test_local_docs.py +++ b/tests/test_local_docs.py @@ -5,9 +5,9 @@ from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html def test_strings_in_generated_swagger(): sig = inspect.signature(get_swagger_ui_html) - swagger_js_url = sig.parameters.get("swagger_js_url").default - swagger_css_url = sig.parameters.get("swagger_css_url").default - swagger_favicon_url = sig.parameters.get("swagger_favicon_url").default + swagger_js_url = sig.parameters.get("swagger_js_url").default # type: ignore + swagger_css_url = sig.parameters.get("swagger_css_url").default # type: ignore + swagger_favicon_url = sig.parameters.get("swagger_favicon_url").default # type: ignore html = get_swagger_ui_html(openapi_url="/docs", title="title") body_content = html.body.decode() assert swagger_js_url in body_content @@ -34,8 +34,8 @@ def test_strings_in_custom_swagger(): def test_strings_in_generated_redoc(): sig = inspect.signature(get_redoc_html) - redoc_js_url = sig.parameters.get("redoc_js_url").default - redoc_favicon_url = sig.parameters.get("redoc_favicon_url").default + redoc_js_url = sig.parameters.get("redoc_js_url").default # type: ignore + redoc_favicon_url = sig.parameters.get("redoc_favicon_url").default # type: ignore html = get_redoc_html(openapi_url="/docs", title="title") body_content = html.body.decode() assert redoc_js_url in body_content diff --git a/tests/test_multi_body_errors.py b/tests/test_multi_body_errors.py index 4719f0b27..c1be82806 100644 --- a/tests/test_multi_body_errors.py +++ b/tests/test_multi_body_errors.py @@ -10,7 +10,7 @@ app = FastAPI() class Item(BaseModel): name: str - age: condecimal(gt=Decimal(0.0)) + age: condecimal(gt=Decimal(0.0)) # type: ignore @app.post("/items/") diff --git a/tests/test_param_class.py b/tests/test_param_class.py index c2a9096d4..f5767ec96 100644 --- a/tests/test_param_class.py +++ b/tests/test_param_class.py @@ -8,7 +8,7 @@ app = FastAPI() @app.get("/items/") -def read_items(q: Optional[str] = Param(None)): +def read_items(q: Optional[str] = Param(None)): # type: ignore return {"q": q} diff --git a/tests/test_params_repr.py b/tests/test_params_repr.py index e21772aca..d721257d7 100644 --- a/tests/test_params_repr.py +++ b/tests/test_params_repr.py @@ -1,7 +1,9 @@ +from typing import Any, List + import pytest from fastapi.params import Body, Cookie, Depends, Header, Param, Path, Query -test_data = ["teststr", None, ..., 1, []] +test_data: List[Any] = ["teststr", None, ..., 1, []] def get_user(): diff --git a/tests/test_starlette_urlconvertors.py b/tests/test_starlette_urlconvertors.py index 1ea22116c..2320c7005 100644 --- a/tests/test_starlette_urlconvertors.py +++ b/tests/test_starlette_urlconvertors.py @@ -27,7 +27,7 @@ def test_route_converters_int(): response = client.get("/int/5") assert response.status_code == 200, response.text assert response.json() == {"int": 5} - assert app.url_path_for("int_convertor", param=5) == "/int/5" + assert app.url_path_for("int_convertor", param=5) == "/int/5" # type: ignore def test_route_converters_float(): @@ -35,7 +35,7 @@ def test_route_converters_float(): response = client.get("/float/25.5") assert response.status_code == 200, response.text assert response.json() == {"float": 25.5} - assert app.url_path_for("float_convertor", param=25.5) == "/float/25.5" + assert app.url_path_for("float_convertor", param=25.5) == "/float/25.5" # type: ignore def test_route_converters_path(): diff --git a/tests/test_sub_callbacks.py b/tests/test_sub_callbacks.py index 40ca1475d..16644b556 100644 --- a/tests/test_sub_callbacks.py +++ b/tests/test_sub_callbacks.py @@ -27,7 +27,7 @@ invoices_callback_router = APIRouter() @invoices_callback_router.post( - "{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived, + "{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived ) def invoice_notification(body: InvoiceEvent): pass # pragma: nocover