From 0e19c24014c96e241bd73bede2805e21fc20c9d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Fri, 7 Dec 2018 19:12:16 +0400 Subject: [PATCH] :sparkles: Update parameter names and order fix mypy types, refactor, lint --- fastapi/applications.py | 381 +++++++++++---------- fastapi/dependencies/models.py | 8 +- fastapi/dependencies/utils.py | 65 ++-- fastapi/encoders.py | 8 +- fastapi/openapi/docs.py | 11 +- fastapi/openapi/models.py | 48 +-- fastapi/openapi/utils.py | 98 +++--- fastapi/params.py | 46 +-- fastapi/routing.py | 437 ++++++++++++------------ fastapi/security/api_key.py | 13 +- fastapi/security/base.py | 6 +- fastapi/security/http.py | 21 +- fastapi/security/oauth2.py | 10 +- fastapi/security/open_id_connect_url.py | 6 +- fastapi/utils.py | 24 +- 15 files changed, 613 insertions(+), 569 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index bb21076df..f1d405221 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -1,19 +1,19 @@ -from typing import Any, Callable, Dict, List, Type +from typing import Any, Callable, Dict, List, Optional, Type +from pydantic import BaseModel from starlette.applications import Starlette from starlette.exceptions import ExceptionMiddleware, HTTPException from starlette.middleware.errors import ServerErrorMiddleware from starlette.middleware.lifespan import LifespanMiddleware -from starlette.responses import JSONResponse - +from starlette.requests import Request +from starlette.responses import JSONResponse, Response from fastapi import routing -from fastapi.openapi.utils import get_openapi from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html +from fastapi.openapi.utils import get_openapi -async def http_exception(request, exc: HTTPException): - print(exc) +async def http_exception(request: Request, exc: HTTPException) -> JSONResponse: return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) @@ -31,7 +31,7 @@ class FastAPI(Starlette): **extra: Dict[str, Any], ) -> None: self._debug = debug - self.router = routing.APIRouter() + self.router: routing.APIRouter = routing.APIRouter() self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) self.error_middleware = ServerErrorMiddleware( self.exception_middleware, debug=debug @@ -56,33 +56,41 @@ class FastAPI(Starlette): if self.swagger_ui_url or self.redoc_url: assert self.openapi_url, "The openapi_url is required for the docs" + self.openapi_schema: Optional[Dict[str, Any]] = None self.setup() - def setup(self): + def openapi(self) -> Dict: + if not self.openapi_schema: + self.openapi_schema = get_openapi( + title=self.title, + version=self.version, + openapi_version=self.openapi_version, + description=self.description, + routes=self.routes, + ) + return self.openapi_schema + + def setup(self) -> None: if self.openapi_url: self.add_route( self.openapi_url, - lambda req: JSONResponse( - get_openapi( - title=self.title, - version=self.version, - openapi_version=self.openapi_version, - description=self.description, - routes=self.routes, - ) - ), + lambda req: JSONResponse(self.openapi()), include_in_schema=False, ) if self.swagger_ui_url: self.add_route( self.swagger_ui_url, - lambda r: get_swagger_ui_html(openapi_url=self.openapi_url, title=self.title + " - Swagger UI"), + lambda r: get_swagger_ui_html( + openapi_url=self.openapi_url, title=self.title + " - Swagger UI" + ), include_in_schema=False, ) if self.redoc_url: self.add_route( self.redoc_url, - lambda r: get_redoc_html(openapi_url=self.openapi_url, title=self.title + " - ReDoc"), + lambda r: get_redoc_html( + openapi_url=self.openapi_url, title=self.title + " - ReDoc" + ), include_in_schema=False, ) self.add_exception_handler(HTTPException, http_exception) @@ -91,311 +99,322 @@ class FastAPI(Starlette): self, path: str, endpoint: Callable, - methods: List[str] = None, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, + deprecated: bool = None, + name: str = None, + methods: List[str] = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, ) -> None: self.router.add_api_route( path, endpoint=endpoint, - methods=methods, - name=name, - include_in_schema=include_in_schema, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + methods=methods, + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) def api_route( self, path: str, - methods: List[str] = None, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, + deprecated: bool = None, + name: str = None, + methods: List[str] = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, ) -> Callable: def decorator(func: Callable) -> Callable: self.router.add_api_route( path, func, - methods=methods, - name=name, - include_in_schema=include_in_schema, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + methods=methods, + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) return func + return decorator - - def include_router(self, router: "APIRouter", *, prefix=""): + + def include_router(self, router: routing.APIRouter, *, prefix: str = "") -> None: self.router.include_router(router, prefix=prefix) def get( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.router.get( - path=path, - name=name, - include_in_schema=include_in_schema, + path, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) def put( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.router.put( - path=path, - name=name, - include_in_schema=include_in_schema, + path, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) def post( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.router.post( - path=path, - name=name, - include_in_schema=include_in_schema, + path, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) def delete( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.router.delete( - path=path, - name=name, - include_in_schema=include_in_schema, + path, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) def options( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.router.options( - path=path, - name=name, - include_in_schema=include_in_schema, + path, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) def head( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.router.head( - path=path, - name=name, - include_in_schema=include_in_schema, + path, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) def patch( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.router.patch( - path=path, - name=name, - include_in_schema=include_in_schema, + path, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) def trace( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.router.trace( - path=path, - name=name, - include_in_schema=include_in_schema, + path, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index ad9419db5..5857f9202 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -1,14 +1,14 @@ from typing import Any, Callable, Dict, List, Sequence, Tuple -from starlette.concurrency import run_in_threadpool -from starlette.requests import Request - -from fastapi.security.base import SecurityBase from pydantic import BaseConfig, Schema from pydantic.error_wrappers import ErrorWrapper from pydantic.errors import MissingError from pydantic.fields import Field, Required from pydantic.schema import get_annotation_from_schema +from starlette.concurrency import run_in_threadpool +from starlette.requests import Request + +from fastapi.security.base import SecurityBase param_supported_types = (str, int, float, bool) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 6e86de5a5..834774e1b 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,8 +1,14 @@ import asyncio import inspect from copy import deepcopy -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Type +from pydantic import BaseConfig, Schema, create_model +from pydantic.error_wrappers import ErrorWrapper +from pydantic.errors import MissingError +from pydantic.fields import Field, Required +from pydantic.schema import get_annotation_from_schema +from pydantic.utils import lenient_issubclass from starlette.concurrency import run_in_threadpool from starlette.requests import Request @@ -10,17 +16,11 @@ from fastapi import params from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.security.base import SecurityBase from fastapi.utils import get_path_param_names -from pydantic import BaseConfig, Schema, create_model -from pydantic.error_wrappers import ErrorWrapper -from pydantic.errors import MissingError -from pydantic.fields import Field, Required -from pydantic.schema import get_annotation_from_schema -from pydantic.utils import lenient_issubclass param_supported_types = (str, int, float, bool) -def get_sub_dependant(*, param: inspect.Parameter, path: str): +def get_sub_dependant(*, param: inspect.Parameter, path: str) -> Dependant: depends: params.Depends = param.default if depends.dependency: dependency = depends.dependency @@ -36,7 +36,7 @@ def get_sub_dependant(*, param: inspect.Parameter, path: str): return sub_dependant -def get_flat_dependant(dependant: Dependant): +def get_flat_dependant(dependant: Dependant) -> Dependant: flat_dependant = Dependant( path_params=dependant.path_params.copy(), query_params=dependant.query_params.copy(), @@ -58,7 +58,7 @@ def get_flat_dependant(dependant: Dependant): return flat_dependant -def get_dependant(*, path: str, call: Callable, name: str = None): +def get_dependant(*, path: str, call: Callable, name: str = None) -> Dependant: path_param_names = get_path_param_names(path) endpoint_signature = inspect.signature(call) signature_params = endpoint_signature.parameters @@ -73,9 +73,10 @@ def get_dependant(*, path: str, call: Callable, name: str = None): if ( (param.default == param.empty) or isinstance(param.default, params.Path) ) and (param_name in path_param_names): - assert lenient_issubclass( - param.annotation, param_supported_types - ) or param.annotation == param.empty, f"Path params must be of type str, int, float or boot: {param}" + assert ( + lenient_issubclass(param.annotation, param_supported_types) + or param.annotation == param.empty + ), f"Path params must be of type str, int, float or boot: {param}" param = signature_params[param_name] add_param_to_fields( param=param, @@ -109,9 +110,9 @@ def add_param_to_fields( *, param: inspect.Parameter, dependant: Dependant, - default_schema=params.Param, + default_schema: Type[Schema] = params.Param, force_type: params.ParamTypes = None, -): +) -> None: default_value = Required if not param.default == param.empty: default_value = param.default @@ -125,15 +126,19 @@ def add_param_to_fields( else: schema = default_schema(default_value) required = default_value == Required - annotation = Any + annotation: Type = Type[Any] if not param.annotation == param.empty: annotation = param.annotation annotation = get_annotation_from_schema(annotation, schema) + if not schema.alias and getattr(schema, "alias_underscore_to_hyphen", None): + alias = param.name.replace("_", "-") + else: + alias = schema.alias or param.name field = Field( name=param.name, type_=annotation, default=None if required else default_value, - alias=schema.alias or param.name, + alias=alias, required=required, model_config=BaseConfig(), class_validators=[], @@ -152,7 +157,7 @@ def add_param_to_fields( dependant.cookie_params.append(field) -def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant): +def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant) -> None: default_value = Required if not param.default == param.empty: default_value = param.default @@ -176,7 +181,7 @@ def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant): dependant.body_params.append(field) -def is_coroutine_callable(call: Callable = None): +def is_coroutine_callable(call: Callable = None) -> bool: if not call: return False if inspect.isfunction(call): @@ -191,7 +196,7 @@ def is_coroutine_callable(call: Callable = None): async def solve_dependencies( *, request: Request, dependant: Dependant, body: Dict[str, Any] = None -): +) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: values: Dict[str, Any] = {} errors: List[ErrorWrapper] = [] for sub_dependant in dependant.dependencies: @@ -200,13 +205,13 @@ async def solve_dependencies( ) if sub_errors: return {}, errors - if sub_dependant.call and is_coroutine_callable(sub_dependant.call): + assert sub_dependant.call is not None, "sub_dependant.call must be a function" + if is_coroutine_callable(sub_dependant.call): solved = await sub_dependant.call(**sub_values) else: solved = await run_in_threadpool(sub_dependant.call, **sub_values) - values[ - sub_dependant.name - ] = solved # type: ignore # Sub-dependants always have a name + assert sub_dependant.name is not None, "Subdependants always have a name" + values[sub_dependant.name] = solved path_values, path_errors = request_params_to_args( dependant.path_params, request.path_params ) @@ -236,7 +241,7 @@ async def solve_dependencies( def request_params_to_args( - required_params: List[Field], received_params: Dict[str, Any] + required_params: Sequence[Field], received_params: Mapping[str, Any] ) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: values = {} errors = [] @@ -250,9 +255,9 @@ def request_params_to_args( else: values[field.name] = deepcopy(field.default) continue - v_, errors_ = field.validate( - value, values, loc=(field.schema.in_.value, field.alias) - ) + schema: params.Param = field.schema + assert isinstance(schema, params.Param), "Params must be subclasses of Param" + v_, errors_ = field.validate(value, values, loc=(schema.in_.value, field.alias)) if isinstance(errors_, ErrorWrapper): errors.append(errors_) elif isinstance(errors_, list): @@ -294,7 +299,7 @@ async def request_body_to_args( return values, errors -def get_body_field(*, dependant: Dependant, name: str): +def get_body_field(*, dependant: Dependant, name: str) -> Field: flat_dependant = get_flat_dependant(dependant) if not flat_dependant.body_params: return None @@ -308,7 +313,7 @@ def get_body_field(*, dependant: Dependant, name: str): BodyModel.__fields__[f.name] = f required = any(True for f in flat_dependant.body_params if f.required) if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params): - BodySchema = params.File + BodySchema: Type[params.Body] = params.File elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params): BodySchema = params.Form else: diff --git a/fastapi/encoders.py b/fastapi/encoders.py index 95ce4479e..3234f8927 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -1,18 +1,18 @@ from enum import Enum from types import GeneratorType -from typing import Set +from typing import Any, Set from pydantic import BaseModel from pydantic.json import pydantic_encoder def jsonable_encoder( - obj, + obj: Any, include: Set[str] = None, exclude: Set[str] = set(), by_alias: bool = False, - include_none=True, -): + include_none: bool = True, +) -> Any: if isinstance(obj, BaseModel): return jsonable_encoder( obj.dict(include=include, exclude=exclude, by_alias=by_alias), diff --git a/fastapi/openapi/docs.py b/fastapi/openapi/docs.py index c8a1d6178..955a99f00 100644 --- a/fastapi/openapi/docs.py +++ b/fastapi/openapi/docs.py @@ -1,6 +1,7 @@ from starlette.responses import HTMLResponse -def get_swagger_ui_html(*, openapi_url: str, title: str): + +def get_swagger_ui_html(*, openapi_url: str, title: str) -> HTMLResponse: return HTMLResponse( """ @@ -35,12 +36,11 @@ def get_swagger_ui_html(*, openapi_url: str, title: str): - """, - media_type="text/html", + """ ) -def get_redoc_html(*, openapi_url: str, title: str): +def get_redoc_html(*, openapi_url: str, title: str) -> HTMLResponse: return HTMLResponse( """ @@ -73,6 +73,5 @@ def get_redoc_html(*, openapi_url: str, title: str): - """, - media_type="text/html", + """ ) diff --git a/fastapi/openapi/models.py b/fastapi/openapi/models.py index e3d96bd7f..87eed07be 100644 --- a/fastapi/openapi/models.py +++ b/fastapi/openapi/models.py @@ -7,13 +7,13 @@ from pydantic.types import UrlStr try: import pydantic.types.EmailStr - from pydantic.types import EmailStr + from pydantic.types import EmailStr # type: ignore except ImportError: logging.warning( "email-validator not installed, email fields will be treated as str" ) - class EmailStr(str): + class EmailStr(str): # type: ignore pass @@ -50,7 +50,7 @@ class Server(BaseModel): class Reference(BaseModel): - ref: str = PSchema(..., alias="$ref") + ref: str = PSchema(..., alias="$ref") # type: ignore class Discriminator(BaseModel): @@ -72,28 +72,28 @@ class ExternalDocumentation(BaseModel): class SchemaBase(BaseModel): - ref: Optional[str] = PSchema(None, alias="$ref") + ref: Optional[str] = PSchema(None, alias="$ref") # type: ignore title: Optional[str] = None multipleOf: Optional[float] = None maximum: Optional[float] = None exclusiveMaximum: Optional[float] = None minimum: Optional[float] = None exclusiveMinimum: Optional[float] = None - maxLength: Optional[int] = PSchema(None, gte=0) - minLength: Optional[int] = PSchema(None, gte=0) + maxLength: Optional[int] = PSchema(None, gte=0) # type: ignore + minLength: Optional[int] = PSchema(None, gte=0) # type: ignore pattern: Optional[str] = None - maxItems: Optional[int] = PSchema(None, gte=0) - minItems: Optional[int] = PSchema(None, gte=0) + maxItems: Optional[int] = PSchema(None, gte=0) # type: ignore + minItems: Optional[int] = PSchema(None, gte=0) # type: ignore uniqueItems: Optional[bool] = None - maxProperties: Optional[int] = PSchema(None, gte=0) - minProperties: Optional[int] = PSchema(None, gte=0) + maxProperties: Optional[int] = PSchema(None, gte=0) # type: ignore + minProperties: Optional[int] = PSchema(None, gte=0) # type: ignore required: Optional[List[str]] = None enum: Optional[List[str]] = None type: Optional[str] = None allOf: Optional[List[Any]] = None oneOf: Optional[List[Any]] = None anyOf: Optional[List[Any]] = None - not_: Optional[List[Any]] = PSchema(None, alias="not") + not_: Optional[List[Any]] = PSchema(None, alias="not") # type: ignore items: Optional[Any] = None properties: Optional[Dict[str, Any]] = None additionalProperties: Optional[Union[bool, Any]] = None @@ -114,7 +114,7 @@ class Schema(SchemaBase): allOf: Optional[List[SchemaBase]] = None oneOf: Optional[List[SchemaBase]] = None anyOf: Optional[List[SchemaBase]] = None - not_: Optional[List[SchemaBase]] = PSchema(None, alias="not") + not_: Optional[List[SchemaBase]] = PSchema(None, alias="not") # type: ignore items: Optional[SchemaBase] = None properties: Optional[Dict[str, SchemaBase]] = None additionalProperties: Optional[Union[bool, SchemaBase]] = None @@ -144,7 +144,9 @@ class Encoding(BaseModel): class MediaType(BaseModel): - schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema") + schema_: Optional[Union[Schema, Reference]] = PSchema( + None, alias="schema" + ) # type: ignore example: Optional[Any] = None examples: Optional[Dict[str, Union[Example, Reference]]] = None encoding: Optional[Dict[str, Encoding]] = None @@ -158,7 +160,9 @@ class ParameterBase(BaseModel): style: Optional[str] = None explode: Optional[bool] = None allowReserved: Optional[bool] = None - schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema") + schema_: Optional[Union[Schema, Reference]] = PSchema( + None, alias="schema" + ) # type: ignore example: Optional[Any] = None examples: Optional[Dict[str, Union[Example, Reference]]] = None # Serialization rules for more complex scenarios @@ -167,7 +171,7 @@ class ParameterBase(BaseModel): class Parameter(ParameterBase): name: str - in_: ParameterInType = PSchema(..., alias="in") + in_: ParameterInType = PSchema(..., alias="in") # type: ignore class Header(ParameterBase): @@ -222,7 +226,7 @@ class Operation(BaseModel): class PathItem(BaseModel): - ref: Optional[str] = PSchema(None, alias="$ref") + ref: Optional[str] = PSchema(None, alias="$ref") # type: ignore summary: Optional[str] = None description: Optional[str] = None get: Optional[Operation] = None @@ -250,7 +254,7 @@ class SecuritySchemeType(Enum): class SecurityBase(BaseModel): - type_: SecuritySchemeType = PSchema(..., alias="type") + type_: SecuritySchemeType = PSchema(..., alias="type") # type: ignore description: Optional[str] = None @@ -261,13 +265,13 @@ class APIKeyIn(Enum): class APIKey(SecurityBase): - type_ = PSchema(SecuritySchemeType.apiKey, alias="type") - in_: APIKeyIn = PSchema(..., alias="in") + type_ = PSchema(SecuritySchemeType.apiKey, alias="type") # type: ignore + in_: APIKeyIn = PSchema(..., alias="in") # type: ignore name: str class HTTPBase(SecurityBase): - type_ = PSchema(SecuritySchemeType.http, alias="type") + type_ = PSchema(SecuritySchemeType.http, alias="type") # type: ignore scheme: str @@ -306,12 +310,12 @@ class OAuthFlows(BaseModel): class OAuth2(SecurityBase): - type_ = PSchema(SecuritySchemeType.oauth2, alias="type") + type_ = PSchema(SecuritySchemeType.oauth2, alias="type") # type: ignore flows: OAuthFlows class OpenIdConnect(SecurityBase): - type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type") + type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type") # type: ignore openIdConnectUrl: str diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 7dbeece73..1036d2012 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -1,23 +1,21 @@ -from typing import Any, Dict, Sequence, Type, List +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type from pydantic.fields import Field -from pydantic.schema import field_schema, get_model_name_map +from pydantic.schema import Schema, field_schema, get_model_name_map from pydantic.utils import lenient_issubclass - from starlette.responses import HTMLResponse, JSONResponse -from starlette.routing import BaseRoute +from starlette.routing import BaseRoute, Route from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY from fastapi import routing from fastapi.dependencies.models import Dependant from fastapi.dependencies.utils import get_flat_dependant from fastapi.encoders import jsonable_encoder -from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY +from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX from fastapi.openapi.models import OpenAPI -from fastapi.params import Body +from fastapi.params import Body, Param from fastapi.utils import get_flat_models_from_routes, get_model_definitions - validation_error_definition = { "title": "ValidationError", "type": "object", @@ -42,7 +40,7 @@ validation_error_response_definition = { } -def get_openapi_params(dependant: Dependant): +def get_openapi_params(dependant: Dependant) -> List[Field]: flat_dependant = get_flat_dependant(dependant) return ( flat_dependant.path_params @@ -52,7 +50,7 @@ def get_openapi_params(dependant: Dependant): ) -def get_openapi_security_definitions(flat_dependant: Dependant): +def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]: security_definitions = {} operation_security = [] for security_requirement in flat_dependant.security_requirements: @@ -61,59 +59,60 @@ def get_openapi_security_definitions(flat_dependant: Dependant): by_alias=True, include_none=False, ) - security_name = ( - security_requirement.security_scheme.scheme_name - - ) + security_name = security_requirement.security_scheme.scheme_name security_definitions[security_name] = security_definition operation_security.append({security_name: security_requirement.scopes}) return security_definitions, operation_security -def get_openapi_operation_parameters(all_route_params: List[Field]): +def get_openapi_operation_parameters( + all_route_params: Sequence[Field] +) -> Tuple[Dict[str, Dict], List[Dict[str, Any]]]: definitions: Dict[str, Dict] = {} parameters = [] for param in all_route_params: + schema: Param = param.schema if "ValidationError" not in definitions: definitions["ValidationError"] = validation_error_definition definitions["HTTPValidationError"] = validation_error_response_definition parameter = { "name": param.alias, - "in": param.schema.in_.value, + "in": schema.in_.value, "required": param.required, "schema": field_schema(param, model_name_map={})[0], } - if param.schema.description: - parameter["description"] = param.schema.description - if param.schema.deprecated: - parameter["deprecated"] = param.schema.deprecated + if schema.description: + parameter["description"] = schema.description + if schema.deprecated: + parameter["deprecated"] = schema.deprecated parameters.append(parameter) return definitions, parameters def get_openapi_operation_request_body( *, body_field: Field, model_name_map: Dict[Type, str] -): +) -> Optional[Dict]: if not body_field: return None assert isinstance(body_field, Field) body_schema, _ = field_schema( body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX ) - if isinstance(body_field.schema, Body): - request_media_type = body_field.schema.media_type + schema: Schema = body_field.schema + if isinstance(schema, Body): + request_media_type = schema.media_type else: # Includes not declared media types (Schema) request_media_type = "application/json" required = body_field.required - request_body_oai = {} + request_body_oai: Dict[str, Any] = {} if required: request_body_oai["required"] = required request_body_oai["content"] = {request_media_type: {"schema": body_schema}} return request_body_oai -def generate_operation_id(*, route: routing.APIRoute, method: str): +def generate_operation_id(*, route: routing.APIRoute, method: str) -> str: if route.operation_id: return route.operation_id path: str = route.path @@ -123,12 +122,13 @@ def generate_operation_id(*, route: routing.APIRoute, method: str): return operation_id -def generate_operation_summary(*, route: routing.APIRoute, method: str): +def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str: if route.summary: return route.summary return method.title() + " " + route.name.replace("_", " ").title() -def get_openapi_operation_metadata(*, route: BaseRoute, method: str): + +def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> Dict: operation: Dict[str, Any] = {} if route.tags: operation["tags"] = route.tags @@ -141,12 +141,13 @@ def get_openapi_operation_metadata(*, route: BaseRoute, method: str): return operation -def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]): - if not (route.include_in_schema and isinstance(route, routing.APIRoute)): - return None +def get_openapi_path( + *, route: routing.APIRoute, model_name_map: Dict[Type, str] +) -> Tuple[Dict, Dict, Dict]: path = {} security_schemes: Dict[str, Any] = {} definitions: Dict[str, Any] = {} + assert route.methods is not None, "Methods must be a list" for method in route.methods: operation = get_openapi_operation_metadata(route=route, method=method) parameters: List[Dict] = [] @@ -172,10 +173,9 @@ def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]): ) if request_body_oai: operation["requestBody"] = request_body_oai - response_code = str(route.response_code) + status_code = str(route.status_code) response_schema = {"type": "string"} - if lenient_issubclass(route.response_wrapper, JSONResponse): - response_media_type = "application/json" + if lenient_issubclass(route.content_type, JSONResponse): if route.response_field: response_schema, _ = field_schema( route.response_field, @@ -184,16 +184,9 @@ def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]): ) else: response_schema = {} - elif lenient_issubclass(route.response_wrapper, HTMLResponse): - response_media_type = "text/html" - else: - response_media_type = "text/plain" - content = {response_media_type: {"schema": response_schema}} + content = {route.content_type.media_type: {"schema": response_schema}} operation["responses"] = { - response_code: { - "description": route.response_description, - "content": content, - } + status_code: {"description": route.response_description, "content": content} } if all_route_params or route.body_field: operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = { @@ -215,7 +208,7 @@ def get_openapi( openapi_version: str = "3.0.2", description: str = None, routes: Sequence[BaseRoute] -): +) -> Dict: info = {"title": title, "version": version} if description: info["description"] = description @@ -228,15 +221,18 @@ def get_openapi( flat_models=flat_models, model_name_map=model_name_map ) for route in routes: - result = get_openapi_path(route=route, model_name_map=model_name_map) - if result: - path, security_schemes, path_definitions = result - if path: - paths.setdefault(route.path, {}).update(path) - if security_schemes: - components.setdefault("securitySchemes", {}).update(security_schemes) - if path_definitions: - definitions.update(path_definitions) + if isinstance(route, routing.APIRoute): + result = get_openapi_path(route=route, model_name_map=model_name_map) + if result: + path, security_schemes, path_definitions = result + if path: + paths.setdefault(route.path, {}).update(path) + if security_schemes: + components.setdefault("securitySchemes", {}).update( + security_schemes + ) + if path_definitions: + definitions.update(path_definitions) if definitions: components.setdefault("schemas", {}).update(definitions) if components: diff --git a/fastapi/params.py b/fastapi/params.py index abbce8aeb..8df0112c8 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Sequence, Any, Dict +from typing import Any, Callable, Sequence from pydantic import Schema @@ -16,7 +16,7 @@ class Param(Schema): def __init__( self, - default, + default: Any, *, deprecated: bool = None, alias: str = None, @@ -29,7 +29,7 @@ class Param(Schema): min_length: int = None, max_length: int = None, regex: str = None, - **extra: Dict[str, Any], + **extra: Any, ): self.deprecated = deprecated super().__init__( @@ -53,7 +53,7 @@ class Path(Param): def __init__( self, - default, + default: Any, *, deprecated: bool = None, alias: str = None, @@ -66,7 +66,7 @@ class Path(Param): min_length: int = None, max_length: int = None, regex: str = None, - **extra: Dict[str, Any], + **extra: Any, ): self.description = description self.deprecated = deprecated @@ -92,7 +92,7 @@ class Query(Param): def __init__( self, - default, + default: Any, *, deprecated: bool = None, alias: str = None, @@ -105,7 +105,7 @@ class Query(Param): min_length: int = None, max_length: int = None, regex: str = None, - **extra: Dict[str, Any], + **extra: Any, ): self.description = description self.deprecated = deprecated @@ -130,10 +130,11 @@ class Header(Param): def __init__( self, - default, + default: Any, *, deprecated: bool = None, alias: str = None, + alias_underscore_to_hyphen: bool = True, title: str = None, description: str = None, gt: float = None, @@ -143,10 +144,11 @@ class Header(Param): min_length: int = None, max_length: int = None, regex: str = None, - **extra: Dict[str, Any], + **extra: Any, ): self.description = description self.deprecated = deprecated + self.alias_underscore_to_hyphen = alias_underscore_to_hyphen super().__init__( default, alias=alias, @@ -168,7 +170,7 @@ class Cookie(Param): def __init__( self, - default, + default: Any, *, deprecated: bool = None, alias: str = None, @@ -181,7 +183,7 @@ class Cookie(Param): min_length: int = None, max_length: int = None, regex: str = None, - **extra: Dict[str, Any], + **extra: Any, ): self.description = description self.deprecated = deprecated @@ -204,9 +206,9 @@ class Cookie(Param): class Body(Schema): def __init__( self, - default, + default: Any, *, - embed=False, + embed: bool = False, media_type: str = "application/json", alias: str = None, title: str = None, @@ -218,7 +220,7 @@ class Body(Schema): min_length: int = None, max_length: int = None, regex: str = None, - **extra: Dict[str, Any], + **extra: Any, ): self.embed = embed self.media_type = media_type @@ -241,9 +243,9 @@ class Body(Schema): class Form(Body): def __init__( self, - default, + default: Any, *, - sub_key=False, + sub_key: bool = False, media_type: str = "application/x-www-form-urlencoded", alias: str = None, title: str = None, @@ -255,7 +257,7 @@ class Form(Body): min_length: int = None, max_length: int = None, regex: str = None, - **extra: Dict[str, Any], + **extra: Any, ): super().__init__( default, @@ -278,9 +280,9 @@ class Form(Body): class File(Form): def __init__( self, - default, + default: Any, *, - sub_key=False, + sub_key: bool = False, media_type: str = "multipart/form-data", alias: str = None, title: str = None, @@ -292,7 +294,7 @@ class File(Form): min_length: int = None, max_length: int = None, regex: str = None, - **extra: Dict[str, Any], + **extra: Any, ): super().__init__( default, @@ -313,11 +315,11 @@ class File(Form): class Depends: - def __init__(self, dependency=None): + def __init__(self, dependency: Callable = None): self.dependency = dependency class Security(Depends): - def __init__(self, dependency=None, scopes: Sequence[str] = None): + def __init__(self, dependency: Callable = None, scopes: Sequence[str] = None): self.scopes = scopes or [] super().__init__(dependency=dependency) diff --git a/fastapi/routing.py b/fastapi/routing.py index 22a62a53a..8620db5db 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1,12 +1,11 @@ import asyncio import inspect -from typing import Callable, List, Type +from typing import Any, Callable, List, Optional, Type from pydantic import BaseConfig, BaseModel, Schema from pydantic.error_wrappers import ErrorWrapper, ValidationError from pydantic.fields import Field from pydantic.utils import lenient_issubclass - from starlette import routing from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException @@ -22,7 +21,7 @@ from fastapi.dependencies.utils import get_body_field, get_dependant, solve_depe from fastapi.encoders import jsonable_encoder -def serialize_response(*, field: Field = None, response): +def serialize_response(*, field: Field = None, response: Response) -> Any: if field: errors = [] value, errors_ = field.validate(response, {}, loc=("response",)) @@ -40,11 +39,12 @@ def serialize_response(*, field: Field = None, response): def get_app( dependant: Dependant, body_field: Field = None, - response_code: str = 200, - response_wrapper: Type[Response] = JSONResponse, - response_field: Type[Field] = None, -): - is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call) + status_code: int = 200, + content_type: Type[Response] = JSONResponse, + response_field: Field = None, +) -> Callable: + assert dependant.call is not None, "dependant.call must me a function" + is_coroutine = asyncio.iscoroutinefunction(dependant.call) is_body_form = body_field and isinstance(body_field.schema, params.Form) async def app(request: Request) -> Response: @@ -69,6 +69,7 @@ def get_app( status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors() ) else: + assert dependant.call is not None, "dependant.call must me a function" if is_coroutine: raw_response = await dependant.call(**values) else: @@ -76,32 +77,32 @@ def get_app( if isinstance(raw_response, Response): return raw_response if isinstance(raw_response, BaseModel): - return response_wrapper( - content=jsonable_encoder(raw_response), status_code=response_code + return content_type( + content=jsonable_encoder(raw_response), status_code=status_code ) errors = [] try: - return response_wrapper( + return content_type( content=serialize_response( field=response_field, response=raw_response ), - status_code=response_code, + status_code=status_code, ) except Exception as e: errors.append(e) try: response = dict(raw_response) - return response_wrapper( + return content_type( content=serialize_response(field=response_field, response=response), - status_code=response_code, + status_code=status_code, ) except Exception as e: errors.append(e) try: response = vars(raw_response) - return response_wrapper( + return content_type( content=serialize_response(field=response_field, response=response), - status_code=response_code, + status_code=status_code, ) except Exception as e: errors.append(e) @@ -116,43 +117,32 @@ class APIRoute(routing.Route): path: str, endpoint: Callable, *, - methods: List[str] = None, - name: str = None, - include_in_schema: bool = True, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, + deprecated: bool = None, + name: str = None, + methods: List[str] = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, ) -> None: assert path.startswith("/"), "Routed paths must always start with '/'" self.path = path self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name - self.include_in_schema = include_in_schema - self.tags = tags or [] - self.summary = summary - self.description = description or self.endpoint.__doc__ - self.operation_id = operation_id - self.deprecated = deprecated - self.body_field: Field = None - self.response_description = response_description - self.response_code = response_code - self.response_wrapper = response_wrapper - self.response_field = None - if response_type: + self.response_model = response_model + if self.response_model: assert lenient_issubclass( - response_wrapper, JSONResponse + content_type, JSONResponse ), "To declare a type the response must be a JSON response" - self.response_type = response_type response_name = "Response_" + self.name - self.response_field = Field( + self.response_field: Optional[Field] = Field( name=response_name, - type_=self.response_type, + type_=self.response_model, class_validators=[], default=None, required=False, @@ -160,25 +150,34 @@ class APIRoute(routing.Route): schema=Schema(None), ) else: - self.response_type = None + self.response_field = None + self.status_code = status_code + self.tags = tags or [] + self.summary = summary + self.description = description or self.endpoint.__doc__ + self.response_description = response_description + self.deprecated = deprecated if methods is None: methods = ["GET"] self.methods = methods + self.operation_id = operation_id + self.include_in_schema = include_in_schema + self.content_type = content_type + self.path_regex, self.path_format, self.param_convertors = self.compile_path( path ) assert inspect.isfunction(endpoint) or inspect.ismethod( endpoint ), f"An endpoint must be a function or method" - self.dependant = get_dependant(path=path, call=self.endpoint) self.body_field = get_body_field(dependant=self.dependant, name=self.name) self.app = request_response( get_app( dependant=self.dependant, body_field=self.body_field, - response_code=self.response_code, - response_wrapper=self.response_wrapper, + status_code=self.status_code, + content_type=self.content_type, response_field=self.response_field, ) ) @@ -189,75 +188,77 @@ class APIRouter(routing.Router): self, path: str, endpoint: Callable, - methods: List[str] = None, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, + deprecated: bool = None, + name: str = None, + methods: List[str] = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, ) -> None: route = APIRoute( path, endpoint=endpoint, - methods=methods, - name=name, - include_in_schema=include_in_schema, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + methods=methods, + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) self.routes.append(route) def api_route( self, path: str, - methods: List[str] = None, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, + deprecated: bool = None, + name: str = None, + methods: List[str] = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, ) -> Callable: def decorator(func: Callable) -> Callable: self.add_api_route( path, func, - methods=methods, - name=name, - include_in_schema=include_in_schema, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + methods=methods, + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) return func return decorator - def include_router(self, router: "APIRouter", *, prefix=""): + def include_router(self, router: "APIRouter", *, prefix: str = "") -> None: if prefix: assert prefix.startswith("/"), "A path prefix must start with '/'" assert not prefix.endswith( @@ -268,18 +269,18 @@ class APIRouter(routing.Router): self.add_api_route( prefix + route.path, route.endpoint, - methods=route.methods, - name=route.name, - include_in_schema=route.include_in_schema, - tags=route.tags, + response_model=route.response_model, + status_code=route.status_code, + tags=route.tags or [], summary=route.summary, description=route.description, - operation_id=route.operation_id, - deprecated=route.deprecated, - response_type=route.response_type, response_description=route.response_description, - response_code=route.response_code, - response_wrapper=route.response_wrapper, + deprecated=route.deprecated, + name=route.name, + methods=route.methods, + operation_id=route.operation_id, + include_in_schema=route.include_in_schema, + content_type=route.content_type, ) elif isinstance(route, routing.Route): self.add_route( @@ -293,247 +294,255 @@ class APIRouter(routing.Router): def get( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.api_route( path=path, - methods=["GET"], - name=name, - include_in_schema=include_in_schema, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + methods=["GET"], + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) def put( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.api_route( path=path, - methods=["PUT"], - name=name, - include_in_schema=include_in_schema, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + methods=["PUT"], + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) def post( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.api_route( path=path, - methods=["POST"], - name=name, - include_in_schema=include_in_schema, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + methods=["POST"], + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) def delete( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.api_route( path=path, - methods=["DELETE"], - name=name, - include_in_schema=include_in_schema, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + methods=["DELETE"], + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) def options( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.api_route( path=path, - methods=["OPTIONS"], - name=name, - include_in_schema=include_in_schema, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + methods=["OPTIONS"], + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) def head( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.api_route( path=path, - methods=["HEAD"], - name=name, - include_in_schema=include_in_schema, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + methods=["HEAD"], + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) def patch( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.api_route( path=path, - methods=["PATCH"], - name=name, - include_in_schema=include_in_schema, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + methods=["PATCH"], + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) def trace( self, path: str, - name: str = None, - include_in_schema: bool = True, + *, + response_model: Type[BaseModel] = None, + status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, - operation_id: str = None, - deprecated: bool = None, - response_type: Type = None, response_description: str = "Successful Response", - response_code=200, - response_wrapper=JSONResponse, - ): + deprecated: bool = None, + name: str = None, + operation_id: str = None, + include_in_schema: bool = True, + content_type: Type[Response] = JSONResponse, + ) -> Callable: return self.api_route( path=path, - methods=["TRACE"], - name=name, - include_in_schema=include_in_schema, + response_model=response_model, + status_code=status_code, tags=tags or [], summary=summary, description=description, - operation_id=operation_id, - deprecated=deprecated, - response_type=response_type, response_description=response_description, - response_code=response_code, - response_wrapper=response_wrapper, + deprecated=deprecated, + name=name, + methods=["TRACE"], + operation_id=operation_id, + include_in_schema=include_in_schema, + content_type=content_type, ) diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py index 047898dfe..c4b045b71 100644 --- a/fastapi/security/api_key.py +++ b/fastapi/security/api_key.py @@ -1,18 +1,19 @@ from starlette.requests import Request -from .base import SecurityBase -from fastapi.openapi.models import APIKeyIn, APIKey +from fastapi.openapi.models import APIKey, APIKeyIn +from fastapi.security.base import SecurityBase + class APIKeyBase(SecurityBase): pass -class APIKeyQuery(APIKeyBase): +class APIKeyQuery(APIKeyBase): def __init__(self, *, name: str, scheme_name: str = None): self.model = APIKey(in_=APIKeyIn.query, name=name) self.scheme_name = scheme_name or self.__class__.__name__ - async def __call__(self, requests: Request): + async def __call__(self, requests: Request) -> str: return requests.query_params.get(self.model.name) @@ -21,7 +22,7 @@ class APIKeyHeader(APIKeyBase): self.model = APIKey(in_=APIKeyIn.header, name=name) self.scheme_name = scheme_name or self.__class__.__name__ - async def __call__(self, requests: Request): + async def __call__(self, requests: Request) -> str: return requests.headers.get(self.model.name) @@ -30,5 +31,5 @@ class APIKeyCookie(APIKeyBase): self.model = APIKey(in_=APIKeyIn.cookie, name=name) self.scheme_name = scheme_name or self.__class__.__name__ - async def __call__(self, requests: Request): + async def __call__(self, requests: Request) -> str: return requests.cookies.get(self.model.name) diff --git a/fastapi/security/base.py b/fastapi/security/base.py index 8589da0be..c43555deb 100644 --- a/fastapi/security/base.py +++ b/fastapi/security/base.py @@ -1,6 +1,6 @@ -from pydantic import BaseModel - from fastapi.openapi.models import SecurityBase as SecurityBaseModel + class SecurityBase: - pass + model: SecurityBaseModel + scheme_name: str diff --git a/fastapi/security/http.py b/fastapi/security/http.py index cee42b868..480a1ae54 100644 --- a/fastapi/security/http.py +++ b/fastapi/security/http.py @@ -1,7 +1,10 @@ from starlette.requests import Request -from .base import SecurityBase -from fastapi.openapi.models import HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel +from fastapi.openapi.models import ( + HTTPBase as HTTPBaseModel, + HTTPBearer as HTTPBearerModel, +) +from fastapi.security.base import SecurityBase class HTTPBase(SecurityBase): @@ -9,7 +12,7 @@ class HTTPBase(SecurityBase): self.model = HTTPBaseModel(scheme=scheme) self.scheme_name = scheme_name or self.__class__.__name__ - async def __call__(self, request: Request): + async def __call__(self, request: Request) -> str: return request.headers.get("Authorization") @@ -17,8 +20,8 @@ class HTTPBasic(HTTPBase): def __init__(self, *, scheme_name: str = None): self.model = HTTPBaseModel(scheme="basic") self.scheme_name = scheme_name or self.__class__.__name__ - - async def __call__(self, request: Request): + + async def __call__(self, request: Request) -> str: return request.headers.get("Authorization") @@ -26,8 +29,8 @@ class HTTPBearer(HTTPBase): def __init__(self, *, bearerFormat: str = None, scheme_name: str = None): self.model = HTTPBearerModel(bearerFormat=bearerFormat) self.scheme_name = scheme_name or self.__class__.__name__ - - async def __call__(self, request: Request): + + async def __call__(self, request: Request) -> str: return request.headers.get("Authorization") @@ -35,6 +38,6 @@ class HTTPDigest(HTTPBase): def __init__(self, *, scheme_name: str = None): self.model = HTTPBaseModel(scheme="digest") self.scheme_name = scheme_name or self.__class__.__name__ - - async def __call__(self, request: Request): + + async def __call__(self, request: Request) -> str: return request.headers.get("Authorization") diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py index 65517e962..90838fdad 100644 --- a/fastapi/security/oauth2.py +++ b/fastapi/security/oauth2.py @@ -1,13 +1,15 @@ from starlette.requests import Request -from .base import SecurityBase from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel +from fastapi.security.base import SecurityBase class OAuth2(SecurityBase): - def __init__(self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None): + def __init__( + self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None + ): self.model = OAuth2Model(flows=flows) self.scheme_name = scheme_name or self.__class__.__name__ - - async def __call__(self, request: Request): + + async def __call__(self, request: Request) -> str: return request.headers.get("Authorization") diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py index 49c5aae2d..b6c0a32dc 100644 --- a/fastapi/security/open_id_connect_url.py +++ b/fastapi/security/open_id_connect_url.py @@ -1,13 +1,13 @@ from starlette.requests import Request -from .base import SecurityBase from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel +from fastapi.security.base import SecurityBase class OpenIdConnect(SecurityBase): def __init__(self, *, openIdConnectUrl: str, scheme_name: str = None): self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl) self.scheme_name = scheme_name or self.__class__.__name__ - - async def __call__(self, request: Request): + + async def __call__(self, request: Request) -> str: return request.headers.get("Authorization") diff --git a/fastapi/utils.py b/fastapi/utils.py index 091f868fe..81ca910cf 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -1,20 +1,24 @@ import re -from typing import Dict, Sequence, Set, Type +from typing import Any, Dict, List, Sequence, Set, Type +from pydantic import BaseModel +from pydantic.fields import Field +from pydantic.schema import get_flat_models_from_fields, model_process_schema from starlette.routing import BaseRoute from fastapi import routing from fastapi.openapi.constants import REF_PREFIX -from pydantic import BaseModel -from pydantic.fields import Field -from pydantic.schema import get_flat_models_from_fields, model_process_schema -def get_flat_models_from_routes(routes: Sequence[BaseRoute]): - body_fields_from_routes = [] - responses_from_routes = [] +def get_flat_models_from_routes( + routes: Sequence[Type[BaseRoute]] +) -> Set[Type[BaseModel]]: + body_fields_from_routes: List[Field] = [] + responses_from_routes: List[Field] = [] for route in routes: - if route.include_in_schema and isinstance(route, routing.APIRoute): + if getattr(route, "include_in_schema", None) and isinstance( + route, routing.APIRoute + ): if route.body_field: assert isinstance( route.body_field, Field @@ -30,7 +34,7 @@ def get_flat_models_from_routes(routes: Sequence[BaseRoute]): def get_model_definitions( *, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str] -): +) -> Dict[str, Any]: definitions: Dict[str, Dict] = {} for model in flat_models: m_schema, m_definitions = model_process_schema( @@ -42,5 +46,5 @@ def get_model_definitions( return definitions -def get_path_param_names(path: str): +def get_path_param_names(path: str) -> Set[str]: return {item.strip("{}") for item in re.findall("{[^}]*}", path)}