Browse Source

Improve type annotations, add support for mypy --strict, internally and for external packages (#2547)

pull/2548/head
Sebastián Ramírez 4 years ago
committed by GitHub
parent
commit
fdb6c9ccc5
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      docs_src/openapi_callbacks/tutorial001.py
  2. 39
      fastapi/__init__.py
  3. 89
      fastapi/applications.py
  4. 2
      fastapi/background.py
  5. 22
      fastapi/concurrency.py
  6. 3
      fastapi/datastructures.py
  7. 4
      fastapi/dependencies/models.py
  8. 38
      fastapi/dependencies/utils.py
  9. 14
      fastapi/encoders.py
  10. 2
      fastapi/middleware/__init__.py
  11. 2
      fastapi/middleware/cors.py
  12. 2
      fastapi/middleware/gzip.py
  13. 4
      fastapi/middleware/httpsredirect.py
  14. 4
      fastapi/middleware/trustedhost.py
  15. 2
      fastapi/middleware/wsgi.py
  16. 4
      fastapi/openapi/docs.py
  17. 4
      fastapi/openapi/models.py
  18. 52
      fastapi/openapi/utils.py
  19. 4
      fastapi/param_functions.py
  20. 4
      fastapi/params.py
  21. 16
      fastapi/responses.py
  22. 106
      fastapi/routing.py
  23. 32
      fastapi/security/__init__.py
  24. 8
      fastapi/security/oauth2.py
  25. 2
      fastapi/staticfiles.py
  26. 2
      fastapi/templating.py
  27. 2
      fastapi/testclient.py
  28. 3
      fastapi/types.py
  29. 9
      fastapi/utils.py
  30. 4
      fastapi/websockets.py
  31. 22
      mypy.ini
  32. 4
      pyproject.toml
  33. 9
      tests/test_custom_route_class.py
  34. 2
      tests/test_get_request_body.py
  35. 10
      tests/test_include_router_defaults_overrides.py
  36. 2
      tests/test_inherited_custom_class.py
  37. 4
      tests/test_jsonable_encoder.py
  38. 10
      tests/test_local_docs.py
  39. 2
      tests/test_multi_body_errors.py
  40. 2
      tests/test_param_class.py
  41. 4
      tests/test_params_repr.py
  42. 4
      tests/test_starlette_urlconvertors.py
  43. 2
      tests/test_sub_callbacks.py

2
docs_src/openapi_callbacks/tutorial001.py

@ -26,7 +26,7 @@ invoices_callback_router = APIRouter()
@invoices_callback_router.post(
"{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived,
"{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived
)
def invoice_notification(body: InvoiceEvent):
pass

39
fastapi/__init__.py

@ -2,24 +2,23 @@
__version__ = "0.62.0"
from starlette import status
from starlette import status as status
from .applications import FastAPI
from .background import BackgroundTasks
from .datastructures import UploadFile
from .exceptions import HTTPException
from .param_functions import (
Body,
Cookie,
Depends,
File,
Form,
Header,
Path,
Query,
Security,
)
from .requests import Request
from .responses import Response
from .routing import APIRouter
from .websockets import WebSocket, WebSocketDisconnect
from .applications import FastAPI as FastAPI
from .background import BackgroundTasks as BackgroundTasks
from .datastructures import UploadFile as UploadFile
from .exceptions import HTTPException as HTTPException
from .param_functions import Body as Body
from .param_functions import Cookie as Cookie
from .param_functions import Depends as Depends
from .param_functions import File as File
from .param_functions import Form as Form
from .param_functions import Header as Header
from .param_functions import Path as Path
from .param_functions import Query as Query
from .param_functions import Security as Security
from .requests import Request as Request
from .responses import Response as Response
from .routing import APIRouter as APIRouter
from .websockets import WebSocket as WebSocket
from .websockets import WebSocketDisconnect as WebSocketDisconnect

89
fastapi/applications.py

@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union
from fastapi import routing
from fastapi.concurrency import AsyncExitStack
@ -17,6 +17,7 @@ from fastapi.openapi.docs import (
)
from fastapi.openapi.utils import get_openapi
from fastapi.params import Depends
from fastapi.types import DecoratedCallable
from starlette.applications import Starlette
from starlette.datastructures import State
from starlette.exceptions import HTTPException
@ -24,7 +25,7 @@ from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse, Response
from starlette.routing import BaseRoute
from starlette.types import Receive, Scope, Send
from starlette.types import ASGIApp, Receive, Scope, Send
class FastAPI(Starlette):
@ -44,24 +45,27 @@ class FastAPI(Starlette):
docs_url: Optional[str] = "/docs",
redoc_url: Optional[str] = "/redoc",
swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect",
swagger_ui_init_oauth: Optional[dict] = None,
swagger_ui_init_oauth: Optional[Dict[str, Any]] = None,
middleware: Optional[Sequence[Middleware]] = None,
exception_handlers: Optional[
Dict[Union[int, Type[Exception]], Callable]
Dict[
Union[int, Type[Exception]],
Callable[[Request, Any], Coroutine[Any, Any, Response]],
]
] = None,
on_startup: Optional[Sequence[Callable]] = None,
on_shutdown: Optional[Sequence[Callable]] = None,
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
openapi_prefix: str = "",
root_path: str = "",
root_path_in_servers: bool = True,
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
deprecated: bool = None,
callbacks: Optional[List[BaseRoute]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
**extra: Any,
) -> None:
self._debug = debug
self.state = State()
self._debug: bool = debug
self.state: State = State()
self.router: routing.APIRouter = routing.APIRouter(
routes=routes,
dependency_overrides_provider=self,
@ -74,7 +78,10 @@ class FastAPI(Starlette):
include_in_schema=include_in_schema,
responses=responses,
)
self.exception_handlers = (
self.exception_handlers: Dict[
Union[int, Type[Exception]],
Callable[[Request, Any], Coroutine[Any, Any, Response]],
] = (
{} if exception_handlers is None else dict(exception_handlers)
)
self.exception_handlers.setdefault(HTTPException, http_exception_handler)
@ -82,8 +89,10 @@ class FastAPI(Starlette):
RequestValidationError, request_validation_exception_handler
)
self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack = self.build_middleware_stack()
self.user_middleware: List[Middleware] = (
[] if middleware is None else list(middleware)
)
self.middleware_stack: ASGIApp = self.build_middleware_stack()
self.title = title
self.description = description
@ -106,7 +115,7 @@ class FastAPI(Starlette):
self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url
self.swagger_ui_init_oauth = swagger_ui_init_oauth
self.extra = extra
self.dependency_overrides: Dict[Callable, Callable] = {}
self.dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] = {}
self.openapi_version = "3.0.2"
@ -116,7 +125,7 @@ class FastAPI(Starlette):
self.openapi_schema: Optional[Dict[str, Any]] = None
self.setup()
def openapi(self) -> Dict:
def openapi(self) -> Dict[str, Any]:
if not self.openapi_schema:
self.openapi_schema = get_openapi(
title=self.title,
@ -194,7 +203,7 @@ class FastAPI(Starlette):
def add_api_route(
self,
path: str,
endpoint: Callable,
endpoint: Callable[..., Coroutine[Any, Any, Response]],
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
@ -268,8 +277,8 @@ class FastAPI(Starlette):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
) -> Callable:
def decorator(func: Callable) -> Callable:
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.router.add_api_route(
path,
func,
@ -299,12 +308,14 @@ class FastAPI(Starlette):
return decorator
def add_api_websocket_route(
self, path: str, endpoint: Callable, name: Optional[str] = None
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
) -> None:
self.router.add_api_websocket_route(path, endpoint, name=name)
def websocket(self, path: str, name: Optional[str] = None) -> Callable:
def decorator(func: Callable) -> Callable:
def websocket(
self, path: str, name: Optional[str] = None
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route(path, func, name=name)
return func
@ -318,10 +329,10 @@ class FastAPI(Starlette):
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: bool = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
default_response_class: Type[Response] = Default(JSONResponse),
callbacks: Optional[List[routing.APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
) -> None:
self.router.include_router(
router,
@ -358,8 +369,8 @@ class FastAPI(Starlette):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.get(
path,
response_model=response_model,
@ -407,8 +418,8 @@ class FastAPI(Starlette):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.put(
path,
response_model=response_model,
@ -456,8 +467,8 @@ class FastAPI(Starlette):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.post(
path,
response_model=response_model,
@ -505,8 +516,8 @@ class FastAPI(Starlette):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.delete(
path,
response_model=response_model,
@ -554,8 +565,8 @@ class FastAPI(Starlette):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.options(
path,
response_model=response_model,
@ -603,8 +614,8 @@ class FastAPI(Starlette):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.head(
path,
response_model=response_model,
@ -652,8 +663,8 @@ class FastAPI(Starlette):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.patch(
path,
response_model=response_model,
@ -701,8 +712,8 @@ class FastAPI(Starlette):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.trace(
path,
response_model=response_model,

2
fastapi/background.py

@ -1 +1 @@
from starlette.background import BackgroundTasks # noqa
from starlette.background import BackgroundTasks as BackgroundTasks # noqa

22
fastapi/concurrency.py

@ -1,8 +1,10 @@
from typing import Any, Callable
from starlette.concurrency import iterate_in_threadpool # noqa
from starlette.concurrency import run_in_threadpool # noqa
from starlette.concurrency import run_until_first_complete # noqa
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
from starlette.concurrency import ( # noqa
run_until_first_complete as run_until_first_complete,
)
asynccontextmanager_error_message = """
FastAPI's contextmanager_in_threadpool require Python 3.7 or above,
@ -11,7 +13,7 @@ or the backport for Python 3.6, installed with:
"""
def _fake_asynccontextmanager(func: Callable) -> Callable:
def _fake_asynccontextmanager(func: Callable[..., Any]) -> Callable[..., Any]:
def raiser(*args: Any, **kwargs: Any) -> Any:
raise RuntimeError(asynccontextmanager_error_message)
@ -19,23 +21,25 @@ def _fake_asynccontextmanager(func: Callable) -> Callable:
try:
from contextlib import asynccontextmanager # type: ignore
from contextlib import asynccontextmanager as asynccontextmanager # type: ignore
except ImportError:
try:
from async_generator import asynccontextmanager # type: ignore
from async_generator import ( # type: ignore # isort: skip
asynccontextmanager as asynccontextmanager,
)
except ImportError: # pragma: no cover
asynccontextmanager = _fake_asynccontextmanager
try:
from contextlib import AsyncExitStack # type: ignore
from contextlib import AsyncExitStack as AsyncExitStack # type: ignore
except ImportError:
try:
from async_exit_stack import AsyncExitStack # type: ignore
from async_exit_stack import AsyncExitStack as AsyncExitStack # type: ignore
except ImportError: # pragma: no cover
AsyncExitStack = None # type: ignore
@asynccontextmanager
@asynccontextmanager # type: ignore
async def contextmanager_in_threadpool(cm: Any) -> Any:
try:
yield await run_in_threadpool(cm.__enter__)

3
fastapi/datastructures.py

@ -1,11 +1,12 @@
from typing import Any, Callable, Iterable, Type, TypeVar
from starlette.datastructures import State as State # noqa: F401
from starlette.datastructures import UploadFile as StarletteUploadFile
class UploadFile(StarletteUploadFile):
@classmethod
def __get_validators__(cls: Type["UploadFile"]) -> Iterable[Callable]:
def __get_validators__(cls: Type["UploadFile"]) -> Iterable[Callable[..., Any]]:
yield cls.validate
@classmethod

4
fastapi/dependencies/models.py

@ -1,4 +1,4 @@
from typing import Callable, List, Optional, Sequence
from typing import Any, Callable, List, Optional, Sequence
from fastapi.security.base import SecurityBase
from pydantic.fields import ModelField
@ -24,7 +24,7 @@ class Dependant:
dependencies: Optional[List["Dependant"]] = None,
security_schemes: Optional[List[SecurityRequirement]] = None,
name: Optional[str] = None,
call: Optional[Callable] = None,
call: Optional[Callable[..., Any]] = None,
request_param_name: Optional[str] = None,
websocket_param_name: Optional[str] = None,
http_connection_param_name: Optional[str] = None,

38
fastapi/dependencies/utils.py

@ -90,12 +90,12 @@ def check_file_field(field: ModelField) -> None:
if isinstance(field_info, params.Form):
try:
# __version__ is available in both multiparts, and can be mocked
from multipart import __version__
from multipart import __version__ # type: ignore
assert __version__
try:
# parse_options_header is only available in the right multipart
from multipart.multipart import parse_options_header
from multipart.multipart import parse_options_header # type: ignore
assert parse_options_header
except ImportError:
@ -133,7 +133,7 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De
def get_sub_dependant(
*,
depends: params.Depends,
dependency: Callable,
dependency: Callable[..., Any],
path: str,
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
@ -163,7 +163,7 @@ def get_sub_dependant(
return sub_dependant
CacheKey = Tuple[Optional[Callable], Tuple[str, ...]]
CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
def get_flat_dependant(
@ -240,7 +240,7 @@ def is_scalar_sequence_field(field: ModelField) -> bool:
return False
def get_typed_signature(call: Callable) -> inspect.Signature:
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
@ -259,9 +259,7 @@ def get_typed_signature(call: Callable) -> inspect.Signature:
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
annotation = param.annotation
if isinstance(annotation, str):
# Temporary ignore type
# Ref: https://github.com/samuelcolvin/pydantic/issues/1738
annotation = ForwardRef(annotation) # type: ignore
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
return annotation
@ -281,7 +279,7 @@ def check_dependency_contextmanagers() -> None:
def get_dependant(
*,
path: str,
call: Callable,
call: Callable[..., Any],
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
@ -423,7 +421,7 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
dependant.cookie_params.append(field)
def is_coroutine_callable(call: Callable) -> bool:
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
@ -432,14 +430,14 @@ def is_coroutine_callable(call: Callable) -> bool:
return inspect.iscoroutinefunction(call)
def is_async_gen_callable(call: Callable) -> bool:
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isasyncgenfunction(call):
return True
call = getattr(call, "__call__", None)
return inspect.isasyncgenfunction(call)
def is_gen_callable(call: Callable) -> bool:
def is_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isgeneratorfunction(call):
return True
call = getattr(call, "__call__", None)
@ -447,7 +445,7 @@ def is_gen_callable(call: Callable) -> bool:
async def solve_generator(
*, call: Callable, stack: AsyncExitStack, sub_values: Dict[str, Any]
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
) -> Any:
if is_gen_callable(call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
@ -472,29 +470,29 @@ async def solve_dependencies(
background_tasks: Optional[BackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable, Tuple[str]], Any]] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
) -> Tuple[
Dict[str, Any],
List[ErrorWrapper],
Optional[BackgroundTasks],
Response,
Dict[Tuple[Callable, Tuple[str]], Any],
Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
]:
values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = []
response = response or Response(
content=None,
status_code=None, # type: ignore
headers=None,
media_type=None,
background=None,
headers=None, # type: ignore # in Starlette
media_type=None, # type: ignore # in Starlette
background=None, # type: ignore # in Starlette
)
dependency_cache = dependency_cache or {}
sub_dependant: Dependant
for sub_dependant in dependant.dependencies:
sub_dependant.call = cast(Callable, sub_dependant.call)
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast(
Tuple[Callable, Tuple[str]], sub_dependant.cache_key
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
)
call = sub_dependant.call
use_sub_dependant = sub_dependant

14
fastapi/encoders.py

@ -12,9 +12,11 @@ DictIntStrAny = Dict[Union[int, str], Any]
def generate_encoders_by_class_tuples(
type_encoder_map: Dict[Any, Callable]
) -> Dict[Callable, Tuple]:
encoders_by_class_tuples: Dict[Callable, Tuple] = defaultdict(tuple)
type_encoder_map: Dict[Any, Callable[[Any], Any]]
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
tuple
)
for type_, encoder in type_encoder_map.items():
encoders_by_class_tuples[encoder] += (type_,)
return encoders_by_class_tuples
@ -31,7 +33,7 @@ def jsonable_encoder(
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
custom_encoder: dict = {},
custom_encoder: Dict[Any, Callable[[Any], Any]] = {},
sqlalchemy_safe: bool = True,
) -> Any:
if include is not None and not isinstance(include, set):
@ -43,8 +45,8 @@ def jsonable_encoder(
if custom_encoder:
encoder.update(custom_encoder)
obj_dict = obj.dict(
include=include,
exclude=exclude,
include=include, # type: ignore # in Pydantic
exclude=exclude, # type: ignore # in Pydantic
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,

2
fastapi/middleware/__init__.py

@ -1 +1 @@
from starlette.middleware import Middleware
from starlette.middleware import Middleware as Middleware

2
fastapi/middleware/cors.py

@ -1 +1 @@
from starlette.middleware.cors import CORSMiddleware # noqa
from starlette.middleware.cors import CORSMiddleware as CORSMiddleware # noqa

2
fastapi/middleware/gzip.py

@ -1 +1 @@
from starlette.middleware.gzip import GZipMiddleware # noqa
from starlette.middleware.gzip import GZipMiddleware as GZipMiddleware # noqa

4
fastapi/middleware/httpsredirect.py

@ -1 +1,3 @@
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware # noqa
from starlette.middleware.httpsredirect import ( # noqa
HTTPSRedirectMiddleware as HTTPSRedirectMiddleware,
)

4
fastapi/middleware/trustedhost.py

@ -1 +1,3 @@
from starlette.middleware.trustedhost import TrustedHostMiddleware # noqa
from starlette.middleware.trustedhost import ( # noqa
TrustedHostMiddleware as TrustedHostMiddleware,
)

2
fastapi/middleware/wsgi.py

@ -1 +1 @@
from starlette.middleware.wsgi import WSGIMiddleware # noqa
from starlette.middleware.wsgi import WSGIMiddleware as WSGIMiddleware # noqa

4
fastapi/openapi/docs.py

@ -1,5 +1,5 @@
import json
from typing import Optional
from typing import Any, Dict, Optional
from fastapi.encoders import jsonable_encoder
from starlette.responses import HTMLResponse
@ -13,7 +13,7 @@ def get_swagger_ui_html(
swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css",
swagger_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png",
oauth2_redirect_url: Optional[str] = None,
init_oauth: Optional[dict] = None,
init_oauth: Optional[Dict[str, Any]] = None,
) -> HTMLResponse:
html = f"""

4
fastapi/openapi/models.py

@ -5,7 +5,7 @@ from fastapi.logger import logger
from pydantic import AnyUrl, BaseModel, Field
try:
import email_validator
import email_validator # type: ignore
assert email_validator # make autoflake ignore the unused import
from pydantic import EmailStr
@ -13,7 +13,7 @@ except ImportError: # pragma: no cover
class EmailStr(str): # type: ignore
@classmethod
def __get_validators__(cls) -> Iterable[Callable]:
def __get_validators__(cls) -> Iterable[Callable[..., Any]]:
yield cls.validate
@classmethod

52
fastapi/openapi/utils.py

@ -14,6 +14,7 @@ from fastapi.openapi.constants import (
)
from fastapi.openapi.models import OpenAPI
from fastapi.params import Body, Param
from fastapi.responses import Response
from fastapi.utils import (
deep_dict_update,
generate_operation_id_for_path,
@ -64,7 +65,9 @@ status_code_ranges: Dict[str, str] = {
}
def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
def get_openapi_security_definitions(
flat_dependant: Dependant,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
security_definitions = {}
operation_security = []
for security_requirement in flat_dependant.security_requirements:
@ -88,13 +91,12 @@ def get_openapi_operation_parameters(
for param in all_route_params:
field_info = param.field_info
field_info = cast(Param, field_info)
# ignore mypy error until enum schemas are released
parameter = {
"name": param.alias,
"in": field_info.in_.value,
"required": param.required,
"schema": field_schema(
param, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
param, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)[0],
}
if field_info.description:
@ -109,13 +111,12 @@ def get_openapi_operation_request_body(
*,
body_field: Optional[ModelField],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Optional[Dict]:
) -> Optional[Dict[str, Any]]:
if not body_field:
return None
assert isinstance(body_field, ModelField)
# ignore mypy error until enum schemas are released
body_schema, _, _ = field_schema(
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
field_info = cast(Body, body_field.field_info)
request_media_type = field_info.media_type
@ -140,7 +141,9 @@ def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
return route.name.replace("_", " ").title()
def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> Dict:
def get_openapi_operation_metadata(
*, route: routing.APIRoute, method: str
) -> Dict[str, Any]:
operation: Dict[str, Any] = {}
if route.tags:
operation["tags"] = route.tags
@ -154,14 +157,14 @@ def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> D
def get_openapi_path(
*, route: routing.APIRoute, model_name_map: Dict[Type, str]
) -> Tuple[Dict, Dict, Dict]:
*, route: routing.APIRoute, model_name_map: Dict[type, str]
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
path = {}
security_schemes: Dict[str, Any] = {}
definitions: Dict[str, Any] = {}
assert route.methods is not None, "Methods must be a list"
if isinstance(route.response_class, DefaultPlaceholder):
current_response_class: Type[routing.Response] = route.response_class.value
current_response_class: Type[Response] = route.response_class.value
else:
current_response_class = route.response_class
assert current_response_class, "A response class is needed to generate OpenAPI"
@ -169,7 +172,7 @@ def get_openapi_path(
if route.include_in_schema:
for method in route.methods:
operation = get_openapi_operation_metadata(route=route, method=method)
parameters: List[Dict] = []
parameters: List[Dict[str, Any]] = []
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
security_definitions, operation_security = get_openapi_security_definitions(
flat_dependant=flat_dependant
@ -196,10 +199,15 @@ def get_openapi_path(
if route.callbacks:
callbacks = {}
for callback in route.callbacks:
cb_path, cb_security_schemes, cb_definitions, = get_openapi_path(
route=callback, model_name_map=model_name_map
)
callbacks[callback.name] = {callback.path: cb_path}
if isinstance(callback, routing.APIRoute):
(
cb_path,
cb_security_schemes,
cb_definitions,
) = get_openapi_path(
route=callback, model_name_map=model_name_map
)
callbacks[callback.name] = {callback.path: cb_path}
operation["callbacks"] = callbacks
status_code = str(route.status_code)
operation.setdefault("responses", {}).setdefault(status_code, {})[
@ -332,21 +340,19 @@ def get_openapi(
routes: Sequence[BaseRoute],
tags: Optional[List[Dict[str, Any]]] = None,
servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
) -> Dict:
) -> Dict[str, Any]:
info = {"title": title, "version": version}
if description:
info["description"] = description
output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
if servers:
output["servers"] = servers
components: Dict[str, Dict] = {}
paths: Dict[str, Dict] = {}
components: Dict[str, Dict[str, Any]] = {}
paths: Dict[str, Dict[str, Any]] = {}
flat_models = get_flat_models_from_routes(routes)
# ignore mypy error until enum schemas are released
model_name_map = get_model_name_map(flat_models) # type: ignore
# ignore mypy error until enum schemas are released
model_name_map = get_model_name_map(flat_models)
definitions = get_model_definitions(
flat_models=flat_models, model_name_map=model_name_map # type: ignore
flat_models=flat_models, model_name_map=model_name_map
)
for route in routes:
if isinstance(route, routing.APIRoute):
@ -368,4 +374,4 @@ def get_openapi(
output["paths"] = paths
if tags:
output["tags"] = tags
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True)
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore

4
fastapi/param_functions.py

@ -239,13 +239,13 @@ def File( # noqa: N802
def Depends( # noqa: N802
dependency: Optional[Callable] = None, *, use_cache: bool = True
dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
) -> Any:
return params.Depends(dependency=dependency, use_cache=use_cache)
def Security( # noqa: N802
dependency: Optional[Callable] = None,
dependency: Optional[Callable[..., Any]] = None,
*,
scopes: Optional[Sequence[str]] = None,
use_cache: bool = True,

4
fastapi/params.py

@ -315,7 +315,7 @@ class File(Form):
class Depends:
def __init__(
self, dependency: Optional[Callable] = None, *, use_cache: bool = True
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
):
self.dependency = dependency
self.use_cache = use_cache
@ -329,7 +329,7 @@ class Depends:
class Security(Depends):
def __init__(
self,
dependency: Optional[Callable] = None,
dependency: Optional[Callable[..., Any]] = None,
*,
scopes: Optional[Sequence[str]] = None,
use_cache: bool = True,

16
fastapi/responses.py

@ -1,13 +1,13 @@
from typing import Any
from starlette.responses import FileResponse # noqa
from starlette.responses import HTMLResponse # noqa
from starlette.responses import JSONResponse # noqa
from starlette.responses import PlainTextResponse # noqa
from starlette.responses import RedirectResponse # noqa
from starlette.responses import Response # noqa
from starlette.responses import StreamingResponse # noqa
from starlette.responses import UJSONResponse # noqa
from starlette.responses import FileResponse as FileResponse # noqa
from starlette.responses import HTMLResponse as HTMLResponse # noqa
from starlette.responses import JSONResponse as JSONResponse # noqa
from starlette.responses import PlainTextResponse as PlainTextResponse # noqa
from starlette.responses import RedirectResponse as RedirectResponse # noqa
from starlette.responses import Response as Response # noqa
from starlette.responses import StreamingResponse as StreamingResponse # noqa
from starlette.responses import UJSONResponse as UJSONResponse # noqa
try:
import orjson

106
fastapi/routing.py

@ -2,7 +2,18 @@ import asyncio
import enum
import inspect
import json
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Type, Union
from typing import (
Any,
Callable,
Coroutine,
Dict,
List,
Optional,
Sequence,
Set,
Type,
Union,
)
from fastapi import params
from fastapi.datastructures import Default, DefaultPlaceholder
@ -16,6 +27,7 @@ from fastapi.dependencies.utils import (
from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY
from fastapi.types import DecoratedCallable
from fastapi.utils import (
create_cloned_field,
create_response_field,
@ -30,7 +42,8 @@ from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Mount # noqa
from starlette.routing import BaseRoute
from starlette.routing import Mount as Mount # noqa
from starlette.routing import (
compile_path,
get_name,
@ -150,7 +163,7 @@ def get_request_handler(
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
dependency_overrides_provider: Optional[Any] = None,
) -> Callable:
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(body_field.field_info, params.Form)
@ -207,7 +220,7 @@ def get_request_handler(
response = actual_response_class(
content=response_data,
status_code=status_code,
background=background_tasks,
background=background_tasks, # type: ignore # in Starlette
)
response.headers.raw.extend(sub_response.headers.raw)
if sub_response.status_code:
@ -219,7 +232,7 @@ def get_request_handler(
def get_websocket_app(
dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
) -> Callable:
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
async def app(websocket: WebSocket) -> None:
solved_result = await solve_dependencies(
request=websocket,
@ -240,7 +253,7 @@ class APIWebSocketRoute(routing.WebSocketRoute):
def __init__(
self,
path: str,
endpoint: Callable,
endpoint: Callable[..., Any],
*,
name: Optional[str] = None,
dependency_overrides_provider: Optional[Any] = None,
@ -262,7 +275,7 @@ class APIRoute(routing.Route):
def __init__(
self,
path: str,
endpoint: Callable,
endpoint: Callable[..., Any],
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
@ -287,7 +300,7 @@ class APIRoute(routing.Route):
JSONResponse
),
dependency_overrides_provider: Optional[Any] = None,
callbacks: Optional[List["APIRoute"]] = None,
callbacks: Optional[List[BaseRoute]] = None,
) -> None:
# normalise enums e.g. http.HTTPStatus
if isinstance(status_code, enum.IntEnum):
@ -298,7 +311,7 @@ class APIRoute(routing.Route):
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
if methods is None:
methods = ["GET"]
self.methods = set([method.upper() for method in methods])
self.methods: Set[str] = set([method.upper() for method in methods])
self.unique_id = generate_operation_id_for_path(
name=self.name, path=self.path_format, method=list(methods)[0]
)
@ -375,7 +388,7 @@ class APIRoute(routing.Route):
self.callbacks = callbacks
self.app = request_response(self.get_route_handler())
def get_route_handler(self) -> Callable:
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
return get_request_handler(
dependant=self.dependant,
body_field=self.body_field,
@ -401,23 +414,23 @@ class APIRouter(routing.Router):
dependencies: Optional[Sequence[params.Depends]] = None,
default_response_class: Type[Response] = Default(JSONResponse),
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
callbacks: Optional[List[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
routes: Optional[List[routing.BaseRoute]] = None,
redirect_slashes: bool = True,
default: Optional[ASGIApp] = None,
dependency_overrides_provider: Optional[Any] = None,
route_class: Type[APIRoute] = APIRoute,
on_startup: Optional[Sequence[Callable]] = None,
on_shutdown: Optional[Sequence[Callable]] = None,
deprecated: bool = None,
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
) -> None:
super().__init__(
routes=routes,
routes=routes, # type: ignore # in Starlette
redirect_slashes=redirect_slashes,
default=default,
on_startup=on_startup,
on_shutdown=on_shutdown,
default=default, # type: ignore # in Starlette
on_startup=on_startup, # type: ignore # in Starlette
on_shutdown=on_shutdown, # type: ignore # in Starlette
)
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
@ -438,7 +451,7 @@ class APIRouter(routing.Router):
def add_api_route(
self,
path: str,
endpoint: Callable,
endpoint: Callable[..., Any],
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
@ -463,7 +476,7 @@ class APIRouter(routing.Router):
),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
) -> None:
route_class = route_class_override or self.route_class
responses = responses or {}
@ -532,9 +545,9 @@ class APIRouter(routing.Router):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
def decorator(func: Callable) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_route(
path,
func,
@ -565,7 +578,7 @@ class APIRouter(routing.Router):
return decorator
def add_api_websocket_route(
self, path: str, endpoint: Callable, name: Optional[str] = None
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
) -> None:
route = APIWebSocketRoute(
path,
@ -575,8 +588,10 @@ class APIRouter(routing.Router):
)
self.routes.append(route)
def websocket(self, path: str, name: Optional[str] = None) -> Callable:
def decorator(func: Callable) -> Callable:
def websocket(
self, path: str, name: Optional[str] = None
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route(path, func, name=name)
return func
@ -591,8 +606,8 @@ class APIRouter(routing.Router):
dependencies: Optional[Sequence[params.Depends]] = None,
default_response_class: Type[Response] = Default(JSONResponse),
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
callbacks: Optional[List[APIRoute]] = None,
deprecated: bool = None,
callbacks: Optional[List[BaseRoute]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
) -> None:
if prefix:
@ -663,10 +678,11 @@ class APIRouter(routing.Router):
callbacks=current_callbacks,
)
elif isinstance(route, routing.Route):
methods = list(route.methods or []) # type: ignore # in Starlette
self.add_route(
prefix + route.path,
route.endpoint,
methods=list(route.methods or []),
methods=methods,
include_in_schema=route.include_in_schema,
name=route.name,
)
@ -706,8 +722,8 @@ class APIRouter(routing.Router):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,
response_model=response_model,
@ -756,8 +772,8 @@ class APIRouter(routing.Router):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,
response_model=response_model,
@ -806,8 +822,8 @@ class APIRouter(routing.Router):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,
response_model=response_model,
@ -856,8 +872,8 @@ class APIRouter(routing.Router):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,
response_model=response_model,
@ -906,8 +922,8 @@ class APIRouter(routing.Router):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,
response_model=response_model,
@ -956,8 +972,8 @@ class APIRouter(routing.Router):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,
response_model=response_model,
@ -1006,8 +1022,8 @@ class APIRouter(routing.Router):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,
response_model=response_model,
@ -1056,8 +1072,8 @@ class APIRouter(routing.Router):
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,

32
fastapi/security/__init__.py

@ -1,17 +1,15 @@
from .api_key import APIKeyCookie, APIKeyHeader, APIKeyQuery
from .http import (
HTTPAuthorizationCredentials,
HTTPBasic,
HTTPBasicCredentials,
HTTPBearer,
HTTPDigest,
)
from .oauth2 import (
OAuth2,
OAuth2AuthorizationCodeBearer,
OAuth2PasswordBearer,
OAuth2PasswordRequestForm,
OAuth2PasswordRequestFormStrict,
SecurityScopes,
)
from .open_id_connect_url import OpenIdConnect
from .api_key import APIKeyCookie as APIKeyCookie
from .api_key import APIKeyHeader as APIKeyHeader
from .api_key import APIKeyQuery as APIKeyQuery
from .http import HTTPAuthorizationCredentials as HTTPAuthorizationCredentials
from .http import HTTPBasic as HTTPBasic
from .http import HTTPBasicCredentials as HTTPBasicCredentials
from .http import HTTPBearer as HTTPBearer
from .http import HTTPDigest as HTTPDigest
from .oauth2 import OAuth2 as OAuth2
from .oauth2 import OAuth2AuthorizationCodeBearer as OAuth2AuthorizationCodeBearer
from .oauth2 import OAuth2PasswordBearer as OAuth2PasswordBearer
from .oauth2 import OAuth2PasswordRequestForm as OAuth2PasswordRequestForm
from .oauth2 import OAuth2PasswordRequestFormStrict as OAuth2PasswordRequestFormStrict
from .oauth2 import SecurityScopes as SecurityScopes
from .open_id_connect_url import OpenIdConnect as OpenIdConnect

8
fastapi/security/oauth2.py

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Any, Dict, List, Optional, Union
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import OAuth2 as OAuth2Model
@ -116,7 +116,7 @@ class OAuth2(SecurityBase):
def __init__(
self,
*,
flows: OAuthFlowsModel = OAuthFlowsModel(),
flows: Union[OAuthFlowsModel, Dict[str, Dict[str, Any]]] = OAuthFlowsModel(),
scheme_name: Optional[str] = None,
auto_error: Optional[bool] = True
):
@ -141,7 +141,7 @@ class OAuth2PasswordBearer(OAuth2):
self,
tokenUrl: str,
scheme_name: Optional[str] = None,
scopes: Optional[dict] = None,
scopes: Optional[Dict[str, str]] = None,
auto_error: bool = True,
):
if not scopes:
@ -171,7 +171,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
tokenUrl: str,
refreshUrl: Optional[str] = None,
scheme_name: Optional[str] = None,
scopes: Optional[dict] = None,
scopes: Optional[Dict[str, str]] = None,
auto_error: bool = True,
):
if not scopes:

2
fastapi/staticfiles.py

@ -1 +1 @@
from starlette.staticfiles import StaticFiles # noqa
from starlette.staticfiles import StaticFiles as StaticFiles # noqa

2
fastapi/templating.py

@ -1 +1 @@
from starlette.templating import Jinja2Templates # noqa
from starlette.templating import Jinja2Templates as Jinja2Templates # noqa

2
fastapi/testclient.py

@ -1 +1 @@
from starlette.testclient import TestClient # noqa
from starlette.testclient import TestClient as TestClient # noqa

3
fastapi/types.py

@ -0,0 +1,3 @@
from typing import Any, Callable, TypeVar
DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])

9
fastapi/utils.py

@ -19,11 +19,10 @@ def get_model_definitions(
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Dict[str, Any]:
definitions: Dict[str, Dict] = {}
definitions: Dict[str, Dict[str, Any]] = {}
for model in flat_models:
# ignore mypy error until enum schemas are released
m_schema, m_definitions, m_nested_models = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
definitions.update(m_definitions)
model_name = model_name_map[model]
@ -80,7 +79,7 @@ def create_cloned_field(
cloned_types = dict()
original_type = field.type_
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
original_type = original_type.__pydantic_model__ # type: ignore
original_type = original_type.__pydantic_model__
use_type = original_type
if lenient_issubclass(original_type, BaseModel):
original_type = cast(Type[BaseModel], original_type)
@ -127,7 +126,7 @@ def generate_operation_id_for_path(*, name: str, path: str, method: str) -> str:
return operation_id
def deep_dict_update(main_dict: dict, update_dict: dict) -> None:
def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None:
for key in update_dict:
if (
key in main_dict

4
fastapi/websockets.py

@ -1,2 +1,2 @@
from starlette.websockets import WebSocket # noqa
from starlette.websockets import WebSocketDisconnect # noqa
from starlette.websockets import WebSocket as WebSocket # noqa
from starlette.websockets import WebSocketDisconnect as WebSocketDisconnect # noqa

22
mypy.ini

@ -1,3 +1,25 @@
[mypy]
# --strict
disallow_any_generics = True
disallow_subclassing_any = True
disallow_untyped_calls = True
disallow_untyped_defs = True
disallow_incomplete_defs = True
check_untyped_defs = True
disallow_untyped_decorators = True
no_implicit_optional = True
warn_redundant_casts = True
warn_unused_ignores = True
warn_return_any = True
implicit_reexport = False
strict_equality = True
# --strict end
[mypy-fastapi.concurrency]
warn_unused_ignores = False
ignore_missing_imports = True
[mypy-fastapi.tests.*]
ignore_missing_imports = True
check_untyped_defs = True

4
pyproject.toml

@ -46,9 +46,9 @@ test = [
"pytest ==5.4.3",
"pytest-cov ==2.10.0",
"pytest-asyncio >=0.14.0,<0.15.0",
"mypy ==0.782",
"mypy ==0.790",
"flake8 >=3.8.3,<4.0.0",
"black ==19.10b0",
"black ==20.8b1",
"isort >=5.0.6,<6.0.0",
"requests >=2.24.0,<3.0.0",
"httpx >=0.14.0,<0.15.0",

9
tests/test_custom_route_class.py

@ -2,6 +2,7 @@ import pytest
from fastapi import APIRouter, FastAPI
from fastapi.routing import APIRoute
from fastapi.testclient import TestClient
from starlette.routing import Route
app = FastAPI()
@ -106,9 +107,9 @@ def test_get_path(path, expected_status, expected_response):
def test_route_classes():
routes = {}
r: APIRoute
for r in app.router.routes:
assert isinstance(r, Route)
routes[r.path] = r
assert routes["/a/"].x_type == "A"
assert routes["/a/b/"].x_type == "B"
assert routes["/a/b/c/"].x_type == "C"
assert getattr(routes["/a/"], "x_type") == "A"
assert getattr(routes["/a/b/"], "x_type") == "B"
assert getattr(routes["/a/b/c/"], "x_type") == "C"

2
tests/test_get_request_body.py

@ -7,7 +7,7 @@ app = FastAPI()
class Product(BaseModel):
name: str
description: str = None
description: str = None # type: ignore
price: float

10
tests/test_include_router_defaults_overrides.py

@ -175,7 +175,7 @@ async def path3_override_router2_override(level3: str):
return level3
@router2_override.get("/default3",)
@router2_override.get("/default3")
async def path3_default_router2_override(level3: str):
return level3
@ -217,7 +217,9 @@ async def path5_override_router4_override(level5: str):
return level5
@router4_override.get("/default5",)
@router4_override.get(
"/default5",
)
async def path5_default_router4_override(level5: str):
return level5
@ -238,7 +240,9 @@ async def path5_override_router4_default(level5: str):
return level5
@router4_default.get("/default5",)
@router4_default.get(
"/default5",
)
async def path5_default_router4_default(level5: str):
return level5

2
tests/test_inherited_custom_class.py

@ -15,7 +15,7 @@ class MyUuid:
def __str__(self):
return self.uuid
@property
@property # type: ignore
def __class__(self):
return uuid.UUID

4
tests/test_jsonable_encoder.py

@ -71,7 +71,7 @@ class ModelWithAlias(BaseModel):
class ModelWithDefault(BaseModel):
foo: str = ...
foo: str = ... # type: ignore
bar: str = "bar"
bla: str = "bla"
@ -88,7 +88,7 @@ def fixture_model_with_path(request):
arbitrary_types_allowed = True
ModelWithPath = create_model(
"ModelWithPath", path=(request.param, ...), __config__=Config
"ModelWithPath", path=(request.param, ...), __config__=Config # type: ignore
)
return ModelWithPath(path=request.param("/foo", "bar"))

10
tests/test_local_docs.py

@ -5,9 +5,9 @@ from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
def test_strings_in_generated_swagger():
sig = inspect.signature(get_swagger_ui_html)
swagger_js_url = sig.parameters.get("swagger_js_url").default
swagger_css_url = sig.parameters.get("swagger_css_url").default
swagger_favicon_url = sig.parameters.get("swagger_favicon_url").default
swagger_js_url = sig.parameters.get("swagger_js_url").default # type: ignore
swagger_css_url = sig.parameters.get("swagger_css_url").default # type: ignore
swagger_favicon_url = sig.parameters.get("swagger_favicon_url").default # type: ignore
html = get_swagger_ui_html(openapi_url="/docs", title="title")
body_content = html.body.decode()
assert swagger_js_url in body_content
@ -34,8 +34,8 @@ def test_strings_in_custom_swagger():
def test_strings_in_generated_redoc():
sig = inspect.signature(get_redoc_html)
redoc_js_url = sig.parameters.get("redoc_js_url").default
redoc_favicon_url = sig.parameters.get("redoc_favicon_url").default
redoc_js_url = sig.parameters.get("redoc_js_url").default # type: ignore
redoc_favicon_url = sig.parameters.get("redoc_favicon_url").default # type: ignore
html = get_redoc_html(openapi_url="/docs", title="title")
body_content = html.body.decode()
assert redoc_js_url in body_content

2
tests/test_multi_body_errors.py

@ -10,7 +10,7 @@ app = FastAPI()
class Item(BaseModel):
name: str
age: condecimal(gt=Decimal(0.0))
age: condecimal(gt=Decimal(0.0)) # type: ignore
@app.post("/items/")

2
tests/test_param_class.py

@ -8,7 +8,7 @@ app = FastAPI()
@app.get("/items/")
def read_items(q: Optional[str] = Param(None)):
def read_items(q: Optional[str] = Param(None)): # type: ignore
return {"q": q}

4
tests/test_params_repr.py

@ -1,7 +1,9 @@
from typing import Any, List
import pytest
from fastapi.params import Body, Cookie, Depends, Header, Param, Path, Query
test_data = ["teststr", None, ..., 1, []]
test_data: List[Any] = ["teststr", None, ..., 1, []]
def get_user():

4
tests/test_starlette_urlconvertors.py

@ -27,7 +27,7 @@ def test_route_converters_int():
response = client.get("/int/5")
assert response.status_code == 200, response.text
assert response.json() == {"int": 5}
assert app.url_path_for("int_convertor", param=5) == "/int/5"
assert app.url_path_for("int_convertor", param=5) == "/int/5" # type: ignore
def test_route_converters_float():
@ -35,7 +35,7 @@ def test_route_converters_float():
response = client.get("/float/25.5")
assert response.status_code == 200, response.text
assert response.json() == {"float": 25.5}
assert app.url_path_for("float_convertor", param=25.5) == "/float/25.5"
assert app.url_path_for("float_convertor", param=25.5) == "/float/25.5" # type: ignore
def test_route_converters_path():

2
tests/test_sub_callbacks.py

@ -27,7 +27,7 @@ invoices_callback_router = APIRouter()
@invoices_callback_router.post(
"{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived,
"{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived
)
def invoice_notification(body: InvoiceEvent):
pass # pragma: nocover

Loading…
Cancel
Save