Browse Source

Update parameter names and order

fix mypy types, refactor, lint
pull/1/head
Sebastián Ramírez 7 years ago
parent
commit
0e19c24014
  1. 379
      fastapi/applications.py
  2. 8
      fastapi/dependencies/models.py
  3. 65
      fastapi/dependencies/utils.py
  4. 8
      fastapi/encoders.py
  5. 11
      fastapi/openapi/docs.py
  6. 48
      fastapi/openapi/models.py
  7. 98
      fastapi/openapi/utils.py
  8. 46
      fastapi/params.py
  9. 437
      fastapi/routing.py
  10. 13
      fastapi/security/api_key.py
  11. 6
      fastapi/security/base.py
  12. 15
      fastapi/security/http.py
  13. 8
      fastapi/security/oauth2.py
  14. 4
      fastapi/security/open_id_connect_url.py
  15. 24
      fastapi/utils.py

379
fastapi/applications.py

@ -1,19 +1,19 @@
from typing import Any, Callable, Dict, List, Type from typing import Any, Callable, Dict, List, Optional, Type
from pydantic import BaseModel
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.exceptions import ExceptionMiddleware, HTTPException from starlette.exceptions import ExceptionMiddleware, HTTPException
from starlette.middleware.errors import ServerErrorMiddleware from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.lifespan import LifespanMiddleware from starlette.middleware.lifespan import LifespanMiddleware
from starlette.responses import JSONResponse from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from fastapi import routing from fastapi import routing
from fastapi.openapi.utils import get_openapi
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
async def http_exception(request, exc: HTTPException): async def http_exception(request: Request, exc: HTTPException) -> JSONResponse:
print(exc)
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
@ -31,7 +31,7 @@ class FastAPI(Starlette):
**extra: Dict[str, Any], **extra: Dict[str, Any],
) -> None: ) -> None:
self._debug = debug self._debug = debug
self.router = routing.APIRouter() self.router: routing.APIRouter = routing.APIRouter()
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
self.error_middleware = ServerErrorMiddleware( self.error_middleware = ServerErrorMiddleware(
self.exception_middleware, debug=debug self.exception_middleware, debug=debug
@ -56,33 +56,41 @@ class FastAPI(Starlette):
if self.swagger_ui_url or self.redoc_url: if self.swagger_ui_url or self.redoc_url:
assert self.openapi_url, "The openapi_url is required for the docs" assert self.openapi_url, "The openapi_url is required for the docs"
self.openapi_schema: Optional[Dict[str, Any]] = None
self.setup() self.setup()
def setup(self): def openapi(self) -> Dict:
if not self.openapi_schema:
self.openapi_schema = get_openapi(
title=self.title,
version=self.version,
openapi_version=self.openapi_version,
description=self.description,
routes=self.routes,
)
return self.openapi_schema
def setup(self) -> None:
if self.openapi_url: if self.openapi_url:
self.add_route( self.add_route(
self.openapi_url, self.openapi_url,
lambda req: JSONResponse( lambda req: JSONResponse(self.openapi()),
get_openapi(
title=self.title,
version=self.version,
openapi_version=self.openapi_version,
description=self.description,
routes=self.routes,
)
),
include_in_schema=False, include_in_schema=False,
) )
if self.swagger_ui_url: if self.swagger_ui_url:
self.add_route( self.add_route(
self.swagger_ui_url, self.swagger_ui_url,
lambda r: get_swagger_ui_html(openapi_url=self.openapi_url, title=self.title + " - Swagger UI"), lambda r: get_swagger_ui_html(
openapi_url=self.openapi_url, title=self.title + " - Swagger UI"
),
include_in_schema=False, include_in_schema=False,
) )
if self.redoc_url: if self.redoc_url:
self.add_route( self.add_route(
self.redoc_url, self.redoc_url,
lambda r: get_redoc_html(openapi_url=self.openapi_url, title=self.title + " - ReDoc"), lambda r: get_redoc_html(
openapi_url=self.openapi_url, title=self.title + " - ReDoc"
),
include_in_schema=False, include_in_schema=False,
) )
self.add_exception_handler(HTTPException, http_exception) self.add_exception_handler(HTTPException, http_exception)
@ -91,311 +99,322 @@ class FastAPI(Starlette):
self, self,
path: str, path: str,
endpoint: Callable, endpoint: Callable,
methods: List[str] = None, *,
name: str = None, response_model: Type[BaseModel] = None,
include_in_schema: bool = True, status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
methods: List[str] = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> None: ) -> None:
self.router.add_api_route( self.router.add_api_route(
path, path,
endpoint=endpoint, endpoint=endpoint,
methods=methods, response_model=response_model,
name=name, status_code=status_code,
include_in_schema=include_in_schema,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
methods=methods,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
def api_route( def api_route(
self, self,
path: str, path: str,
methods: List[str] = None, *,
name: str = None, response_model: Type[BaseModel] = None,
include_in_schema: bool = True, status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
methods: List[str] = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable: ) -> Callable:
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
self.router.add_api_route( self.router.add_api_route(
path, path,
func, func,
methods=methods, response_model=response_model,
name=name, status_code=status_code,
include_in_schema=include_in_schema,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
methods=methods,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
return func return func
return decorator return decorator
def include_router(self, router: "APIRouter", *, prefix=""): def include_router(self, router: routing.APIRouter, *, prefix: str = "") -> None:
self.router.include_router(router, prefix=prefix) self.router.include_router(router, prefix=prefix)
def get( def get(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.get( return self.router.get(
path=path, path,
name=name, response_model=response_model,
include_in_schema=include_in_schema, status_code=status_code,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
def put( def put(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.put( return self.router.put(
path=path, path,
name=name, response_model=response_model,
include_in_schema=include_in_schema, status_code=status_code,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
def post( def post(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.post( return self.router.post(
path=path, path,
name=name, response_model=response_model,
include_in_schema=include_in_schema, status_code=status_code,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
def delete( def delete(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.delete( return self.router.delete(
path=path, path,
name=name, response_model=response_model,
include_in_schema=include_in_schema, status_code=status_code,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
def options( def options(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.options( return self.router.options(
path=path, path,
name=name, response_model=response_model,
include_in_schema=include_in_schema, status_code=status_code,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
def head( def head(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.head( return self.router.head(
path=path, path,
name=name, response_model=response_model,
include_in_schema=include_in_schema, status_code=status_code,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
def patch( def patch(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.patch( return self.router.patch(
path=path, path,
name=name, response_model=response_model,
include_in_schema=include_in_schema, status_code=status_code,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
def trace( def trace(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.trace( return self.router.trace(
path=path, path,
name=name, response_model=response_model,
include_in_schema=include_in_schema, status_code=status_code,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )

8
fastapi/dependencies/models.py

@ -1,14 +1,14 @@
from typing import Any, Callable, Dict, List, Sequence, Tuple from typing import Any, Callable, Dict, List, Sequence, Tuple
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from fastapi.security.base import SecurityBase
from pydantic import BaseConfig, Schema from pydantic import BaseConfig, Schema
from pydantic.error_wrappers import ErrorWrapper from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError from pydantic.errors import MissingError
from pydantic.fields import Field, Required from pydantic.fields import Field, Required
from pydantic.schema import get_annotation_from_schema from pydantic.schema import get_annotation_from_schema
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from fastapi.security.base import SecurityBase
param_supported_types = (str, int, float, bool) param_supported_types = (str, int, float, bool)

65
fastapi/dependencies/utils.py

@ -1,8 +1,14 @@
import asyncio import asyncio
import inspect import inspect
from copy import deepcopy from copy import deepcopy
from typing import Any, Callable, Dict, List, Tuple from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Type
from pydantic import BaseConfig, Schema, create_model
from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import Field, Required
from pydantic.schema import get_annotation_from_schema
from pydantic.utils import lenient_issubclass
from starlette.concurrency import run_in_threadpool from starlette.concurrency import run_in_threadpool
from starlette.requests import Request from starlette.requests import Request
@ -10,17 +16,11 @@ from fastapi import params
from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.utils import get_path_param_names from fastapi.utils import get_path_param_names
from pydantic import BaseConfig, Schema, create_model
from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import Field, Required
from pydantic.schema import get_annotation_from_schema
from pydantic.utils import lenient_issubclass
param_supported_types = (str, int, float, bool) param_supported_types = (str, int, float, bool)
def get_sub_dependant(*, param: inspect.Parameter, path: str): def get_sub_dependant(*, param: inspect.Parameter, path: str) -> Dependant:
depends: params.Depends = param.default depends: params.Depends = param.default
if depends.dependency: if depends.dependency:
dependency = depends.dependency dependency = depends.dependency
@ -36,7 +36,7 @@ def get_sub_dependant(*, param: inspect.Parameter, path: str):
return sub_dependant return sub_dependant
def get_flat_dependant(dependant: Dependant): def get_flat_dependant(dependant: Dependant) -> Dependant:
flat_dependant = Dependant( flat_dependant = Dependant(
path_params=dependant.path_params.copy(), path_params=dependant.path_params.copy(),
query_params=dependant.query_params.copy(), query_params=dependant.query_params.copy(),
@ -58,7 +58,7 @@ def get_flat_dependant(dependant: Dependant):
return flat_dependant return flat_dependant
def get_dependant(*, path: str, call: Callable, name: str = None): def get_dependant(*, path: str, call: Callable, name: str = None) -> Dependant:
path_param_names = get_path_param_names(path) path_param_names = get_path_param_names(path)
endpoint_signature = inspect.signature(call) endpoint_signature = inspect.signature(call)
signature_params = endpoint_signature.parameters signature_params = endpoint_signature.parameters
@ -73,9 +73,10 @@ def get_dependant(*, path: str, call: Callable, name: str = None):
if ( if (
(param.default == param.empty) or isinstance(param.default, params.Path) (param.default == param.empty) or isinstance(param.default, params.Path)
) and (param_name in path_param_names): ) and (param_name in path_param_names):
assert lenient_issubclass( assert (
param.annotation, param_supported_types lenient_issubclass(param.annotation, param_supported_types)
) or param.annotation == param.empty, f"Path params must be of type str, int, float or boot: {param}" or param.annotation == param.empty
), f"Path params must be of type str, int, float or boot: {param}"
param = signature_params[param_name] param = signature_params[param_name]
add_param_to_fields( add_param_to_fields(
param=param, param=param,
@ -109,9 +110,9 @@ def add_param_to_fields(
*, *,
param: inspect.Parameter, param: inspect.Parameter,
dependant: Dependant, dependant: Dependant,
default_schema=params.Param, default_schema: Type[Schema] = params.Param,
force_type: params.ParamTypes = None, force_type: params.ParamTypes = None,
): ) -> None:
default_value = Required default_value = Required
if not param.default == param.empty: if not param.default == param.empty:
default_value = param.default default_value = param.default
@ -125,15 +126,19 @@ def add_param_to_fields(
else: else:
schema = default_schema(default_value) schema = default_schema(default_value)
required = default_value == Required required = default_value == Required
annotation = Any annotation: Type = Type[Any]
if not param.annotation == param.empty: if not param.annotation == param.empty:
annotation = param.annotation annotation = param.annotation
annotation = get_annotation_from_schema(annotation, schema) annotation = get_annotation_from_schema(annotation, schema)
if not schema.alias and getattr(schema, "alias_underscore_to_hyphen", None):
alias = param.name.replace("_", "-")
else:
alias = schema.alias or param.name
field = Field( field = Field(
name=param.name, name=param.name,
type_=annotation, type_=annotation,
default=None if required else default_value, default=None if required else default_value,
alias=schema.alias or param.name, alias=alias,
required=required, required=required,
model_config=BaseConfig(), model_config=BaseConfig(),
class_validators=[], class_validators=[],
@ -152,7 +157,7 @@ def add_param_to_fields(
dependant.cookie_params.append(field) dependant.cookie_params.append(field)
def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant): def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant) -> None:
default_value = Required default_value = Required
if not param.default == param.empty: if not param.default == param.empty:
default_value = param.default default_value = param.default
@ -176,7 +181,7 @@ def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
dependant.body_params.append(field) dependant.body_params.append(field)
def is_coroutine_callable(call: Callable = None): def is_coroutine_callable(call: Callable = None) -> bool:
if not call: if not call:
return False return False
if inspect.isfunction(call): if inspect.isfunction(call):
@ -191,7 +196,7 @@ def is_coroutine_callable(call: Callable = None):
async def solve_dependencies( async def solve_dependencies(
*, request: Request, dependant: Dependant, body: Dict[str, Any] = None *, request: Request, dependant: Dependant, body: Dict[str, Any] = None
): ) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = [] errors: List[ErrorWrapper] = []
for sub_dependant in dependant.dependencies: for sub_dependant in dependant.dependencies:
@ -200,13 +205,13 @@ async def solve_dependencies(
) )
if sub_errors: if sub_errors:
return {}, errors return {}, errors
if sub_dependant.call and is_coroutine_callable(sub_dependant.call): assert sub_dependant.call is not None, "sub_dependant.call must be a function"
if is_coroutine_callable(sub_dependant.call):
solved = await sub_dependant.call(**sub_values) solved = await sub_dependant.call(**sub_values)
else: else:
solved = await run_in_threadpool(sub_dependant.call, **sub_values) solved = await run_in_threadpool(sub_dependant.call, **sub_values)
values[ assert sub_dependant.name is not None, "Subdependants always have a name"
sub_dependant.name values[sub_dependant.name] = solved
] = solved # type: ignore # Sub-dependants always have a name
path_values, path_errors = request_params_to_args( path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params dependant.path_params, request.path_params
) )
@ -236,7 +241,7 @@ async def solve_dependencies(
def request_params_to_args( def request_params_to_args(
required_params: List[Field], received_params: Dict[str, Any] required_params: Sequence[Field], received_params: Mapping[str, Any]
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: ) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
values = {} values = {}
errors = [] errors = []
@ -250,9 +255,9 @@ def request_params_to_args(
else: else:
values[field.name] = deepcopy(field.default) values[field.name] = deepcopy(field.default)
continue continue
v_, errors_ = field.validate( schema: params.Param = field.schema
value, values, loc=(field.schema.in_.value, field.alias) assert isinstance(schema, params.Param), "Params must be subclasses of Param"
) v_, errors_ = field.validate(value, values, loc=(schema.in_.value, field.alias))
if isinstance(errors_, ErrorWrapper): if isinstance(errors_, ErrorWrapper):
errors.append(errors_) errors.append(errors_)
elif isinstance(errors_, list): elif isinstance(errors_, list):
@ -294,7 +299,7 @@ async def request_body_to_args(
return values, errors return values, errors
def get_body_field(*, dependant: Dependant, name: str): def get_body_field(*, dependant: Dependant, name: str) -> Field:
flat_dependant = get_flat_dependant(dependant) flat_dependant = get_flat_dependant(dependant)
if not flat_dependant.body_params: if not flat_dependant.body_params:
return None return None
@ -308,7 +313,7 @@ def get_body_field(*, dependant: Dependant, name: str):
BodyModel.__fields__[f.name] = f BodyModel.__fields__[f.name] = f
required = any(True for f in flat_dependant.body_params if f.required) required = any(True for f in flat_dependant.body_params if f.required)
if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params): if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params):
BodySchema = params.File BodySchema: Type[params.Body] = params.File
elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params): elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params):
BodySchema = params.Form BodySchema = params.Form
else: else:

8
fastapi/encoders.py

@ -1,18 +1,18 @@
from enum import Enum from enum import Enum
from types import GeneratorType from types import GeneratorType
from typing import Set from typing import Any, Set
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.json import pydantic_encoder from pydantic.json import pydantic_encoder
def jsonable_encoder( def jsonable_encoder(
obj, obj: Any,
include: Set[str] = None, include: Set[str] = None,
exclude: Set[str] = set(), exclude: Set[str] = set(),
by_alias: bool = False, by_alias: bool = False,
include_none=True, include_none: bool = True,
): ) -> Any:
if isinstance(obj, BaseModel): if isinstance(obj, BaseModel):
return jsonable_encoder( return jsonable_encoder(
obj.dict(include=include, exclude=exclude, by_alias=by_alias), obj.dict(include=include, exclude=exclude, by_alias=by_alias),

11
fastapi/openapi/docs.py

@ -1,6 +1,7 @@
from starlette.responses import HTMLResponse from starlette.responses import HTMLResponse
def get_swagger_ui_html(*, openapi_url: str, title: str):
def get_swagger_ui_html(*, openapi_url: str, title: str) -> HTMLResponse:
return HTMLResponse( return HTMLResponse(
""" """
<! doctype html> <! doctype html>
@ -35,12 +36,11 @@ def get_swagger_ui_html(*, openapi_url: str, title: str):
</script> </script>
</body> </body>
</html> </html>
""", """
media_type="text/html",
) )
def get_redoc_html(*, openapi_url: str, title: str): def get_redoc_html(*, openapi_url: str, title: str) -> HTMLResponse:
return HTMLResponse( return HTMLResponse(
""" """
<!DOCTYPE html> <!DOCTYPE html>
@ -73,6 +73,5 @@ def get_redoc_html(*, openapi_url: str, title: str):
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script> <script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
</body> </body>
</html> </html>
""", """
media_type="text/html",
) )

48
fastapi/openapi/models.py

@ -7,13 +7,13 @@ from pydantic.types import UrlStr
try: try:
import pydantic.types.EmailStr import pydantic.types.EmailStr
from pydantic.types import EmailStr from pydantic.types import EmailStr # type: ignore
except ImportError: except ImportError:
logging.warning( logging.warning(
"email-validator not installed, email fields will be treated as str" "email-validator not installed, email fields will be treated as str"
) )
class EmailStr(str): class EmailStr(str): # type: ignore
pass pass
@ -50,7 +50,7 @@ class Server(BaseModel):
class Reference(BaseModel): class Reference(BaseModel):
ref: str = PSchema(..., alias="$ref") ref: str = PSchema(..., alias="$ref") # type: ignore
class Discriminator(BaseModel): class Discriminator(BaseModel):
@ -72,28 +72,28 @@ class ExternalDocumentation(BaseModel):
class SchemaBase(BaseModel): class SchemaBase(BaseModel):
ref: Optional[str] = PSchema(None, alias="$ref") ref: Optional[str] = PSchema(None, alias="$ref") # type: ignore
title: Optional[str] = None title: Optional[str] = None
multipleOf: Optional[float] = None multipleOf: Optional[float] = None
maximum: Optional[float] = None maximum: Optional[float] = None
exclusiveMaximum: Optional[float] = None exclusiveMaximum: Optional[float] = None
minimum: Optional[float] = None minimum: Optional[float] = None
exclusiveMinimum: Optional[float] = None exclusiveMinimum: Optional[float] = None
maxLength: Optional[int] = PSchema(None, gte=0) maxLength: Optional[int] = PSchema(None, gte=0) # type: ignore
minLength: Optional[int] = PSchema(None, gte=0) minLength: Optional[int] = PSchema(None, gte=0) # type: ignore
pattern: Optional[str] = None pattern: Optional[str] = None
maxItems: Optional[int] = PSchema(None, gte=0) maxItems: Optional[int] = PSchema(None, gte=0) # type: ignore
minItems: Optional[int] = PSchema(None, gte=0) minItems: Optional[int] = PSchema(None, gte=0) # type: ignore
uniqueItems: Optional[bool] = None uniqueItems: Optional[bool] = None
maxProperties: Optional[int] = PSchema(None, gte=0) maxProperties: Optional[int] = PSchema(None, gte=0) # type: ignore
minProperties: Optional[int] = PSchema(None, gte=0) minProperties: Optional[int] = PSchema(None, gte=0) # type: ignore
required: Optional[List[str]] = None required: Optional[List[str]] = None
enum: Optional[List[str]] = None enum: Optional[List[str]] = None
type: Optional[str] = None type: Optional[str] = None
allOf: Optional[List[Any]] = None allOf: Optional[List[Any]] = None
oneOf: Optional[List[Any]] = None oneOf: Optional[List[Any]] = None
anyOf: Optional[List[Any]] = None anyOf: Optional[List[Any]] = None
not_: Optional[List[Any]] = PSchema(None, alias="not") not_: Optional[List[Any]] = PSchema(None, alias="not") # type: ignore
items: Optional[Any] = None items: Optional[Any] = None
properties: Optional[Dict[str, Any]] = None properties: Optional[Dict[str, Any]] = None
additionalProperties: Optional[Union[bool, Any]] = None additionalProperties: Optional[Union[bool, Any]] = None
@ -114,7 +114,7 @@ class Schema(SchemaBase):
allOf: Optional[List[SchemaBase]] = None allOf: Optional[List[SchemaBase]] = None
oneOf: Optional[List[SchemaBase]] = None oneOf: Optional[List[SchemaBase]] = None
anyOf: Optional[List[SchemaBase]] = None anyOf: Optional[List[SchemaBase]] = None
not_: Optional[List[SchemaBase]] = PSchema(None, alias="not") not_: Optional[List[SchemaBase]] = PSchema(None, alias="not") # type: ignore
items: Optional[SchemaBase] = None items: Optional[SchemaBase] = None
properties: Optional[Dict[str, SchemaBase]] = None properties: Optional[Dict[str, SchemaBase]] = None
additionalProperties: Optional[Union[bool, SchemaBase]] = None additionalProperties: Optional[Union[bool, SchemaBase]] = None
@ -144,7 +144,9 @@ class Encoding(BaseModel):
class MediaType(BaseModel): class MediaType(BaseModel):
schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema") schema_: Optional[Union[Schema, Reference]] = PSchema(
None, alias="schema"
) # type: ignore
example: Optional[Any] = None example: Optional[Any] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None examples: Optional[Dict[str, Union[Example, Reference]]] = None
encoding: Optional[Dict[str, Encoding]] = None encoding: Optional[Dict[str, Encoding]] = None
@ -158,7 +160,9 @@ class ParameterBase(BaseModel):
style: Optional[str] = None style: Optional[str] = None
explode: Optional[bool] = None explode: Optional[bool] = None
allowReserved: Optional[bool] = None allowReserved: Optional[bool] = None
schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema") schema_: Optional[Union[Schema, Reference]] = PSchema(
None, alias="schema"
) # type: ignore
example: Optional[Any] = None example: Optional[Any] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None examples: Optional[Dict[str, Union[Example, Reference]]] = None
# Serialization rules for more complex scenarios # Serialization rules for more complex scenarios
@ -167,7 +171,7 @@ class ParameterBase(BaseModel):
class Parameter(ParameterBase): class Parameter(ParameterBase):
name: str name: str
in_: ParameterInType = PSchema(..., alias="in") in_: ParameterInType = PSchema(..., alias="in") # type: ignore
class Header(ParameterBase): class Header(ParameterBase):
@ -222,7 +226,7 @@ class Operation(BaseModel):
class PathItem(BaseModel): class PathItem(BaseModel):
ref: Optional[str] = PSchema(None, alias="$ref") ref: Optional[str] = PSchema(None, alias="$ref") # type: ignore
summary: Optional[str] = None summary: Optional[str] = None
description: Optional[str] = None description: Optional[str] = None
get: Optional[Operation] = None get: Optional[Operation] = None
@ -250,7 +254,7 @@ class SecuritySchemeType(Enum):
class SecurityBase(BaseModel): class SecurityBase(BaseModel):
type_: SecuritySchemeType = PSchema(..., alias="type") type_: SecuritySchemeType = PSchema(..., alias="type") # type: ignore
description: Optional[str] = None description: Optional[str] = None
@ -261,13 +265,13 @@ class APIKeyIn(Enum):
class APIKey(SecurityBase): class APIKey(SecurityBase):
type_ = PSchema(SecuritySchemeType.apiKey, alias="type") type_ = PSchema(SecuritySchemeType.apiKey, alias="type") # type: ignore
in_: APIKeyIn = PSchema(..., alias="in") in_: APIKeyIn = PSchema(..., alias="in") # type: ignore
name: str name: str
class HTTPBase(SecurityBase): class HTTPBase(SecurityBase):
type_ = PSchema(SecuritySchemeType.http, alias="type") type_ = PSchema(SecuritySchemeType.http, alias="type") # type: ignore
scheme: str scheme: str
@ -306,12 +310,12 @@ class OAuthFlows(BaseModel):
class OAuth2(SecurityBase): class OAuth2(SecurityBase):
type_ = PSchema(SecuritySchemeType.oauth2, alias="type") type_ = PSchema(SecuritySchemeType.oauth2, alias="type") # type: ignore
flows: OAuthFlows flows: OAuthFlows
class OpenIdConnect(SecurityBase): class OpenIdConnect(SecurityBase):
type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type") type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type") # type: ignore
openIdConnectUrl: str openIdConnectUrl: str

98
fastapi/openapi/utils.py

@ -1,23 +1,21 @@
from typing import Any, Dict, Sequence, Type, List from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
from pydantic.fields import Field from pydantic.fields import Field
from pydantic.schema import field_schema, get_model_name_map from pydantic.schema import Schema, field_schema, get_model_name_map
from pydantic.utils import lenient_issubclass from pydantic.utils import lenient_issubclass
from starlette.responses import HTMLResponse, JSONResponse from starlette.responses import HTMLResponse, JSONResponse
from starlette.routing import BaseRoute from starlette.routing import BaseRoute, Route
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from fastapi import routing from fastapi import routing
from fastapi.dependencies.models import Dependant from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_flat_dependant from fastapi.dependencies.utils import get_flat_dependant
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX
from fastapi.openapi.models import OpenAPI from fastapi.openapi.models import OpenAPI
from fastapi.params import Body from fastapi.params import Body, Param
from fastapi.utils import get_flat_models_from_routes, get_model_definitions from fastapi.utils import get_flat_models_from_routes, get_model_definitions
validation_error_definition = { validation_error_definition = {
"title": "ValidationError", "title": "ValidationError",
"type": "object", "type": "object",
@ -42,7 +40,7 @@ validation_error_response_definition = {
} }
def get_openapi_params(dependant: Dependant): def get_openapi_params(dependant: Dependant) -> List[Field]:
flat_dependant = get_flat_dependant(dependant) flat_dependant = get_flat_dependant(dependant)
return ( return (
flat_dependant.path_params flat_dependant.path_params
@ -52,7 +50,7 @@ def get_openapi_params(dependant: Dependant):
) )
def get_openapi_security_definitions(flat_dependant: Dependant): def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
security_definitions = {} security_definitions = {}
operation_security = [] operation_security = []
for security_requirement in flat_dependant.security_requirements: for security_requirement in flat_dependant.security_requirements:
@ -61,59 +59,60 @@ def get_openapi_security_definitions(flat_dependant: Dependant):
by_alias=True, by_alias=True,
include_none=False, include_none=False,
) )
security_name = ( security_name = security_requirement.security_scheme.scheme_name
security_requirement.security_scheme.scheme_name
)
security_definitions[security_name] = security_definition security_definitions[security_name] = security_definition
operation_security.append({security_name: security_requirement.scopes}) operation_security.append({security_name: security_requirement.scopes})
return security_definitions, operation_security return security_definitions, operation_security
def get_openapi_operation_parameters(all_route_params: List[Field]): def get_openapi_operation_parameters(
all_route_params: Sequence[Field]
) -> Tuple[Dict[str, Dict], List[Dict[str, Any]]]:
definitions: Dict[str, Dict] = {} definitions: Dict[str, Dict] = {}
parameters = [] parameters = []
for param in all_route_params: for param in all_route_params:
schema: Param = param.schema
if "ValidationError" not in definitions: if "ValidationError" not in definitions:
definitions["ValidationError"] = validation_error_definition definitions["ValidationError"] = validation_error_definition
definitions["HTTPValidationError"] = validation_error_response_definition definitions["HTTPValidationError"] = validation_error_response_definition
parameter = { parameter = {
"name": param.alias, "name": param.alias,
"in": param.schema.in_.value, "in": schema.in_.value,
"required": param.required, "required": param.required,
"schema": field_schema(param, model_name_map={})[0], "schema": field_schema(param, model_name_map={})[0],
} }
if param.schema.description: if schema.description:
parameter["description"] = param.schema.description parameter["description"] = schema.description
if param.schema.deprecated: if schema.deprecated:
parameter["deprecated"] = param.schema.deprecated parameter["deprecated"] = schema.deprecated
parameters.append(parameter) parameters.append(parameter)
return definitions, parameters return definitions, parameters
def get_openapi_operation_request_body( def get_openapi_operation_request_body(
*, body_field: Field, model_name_map: Dict[Type, str] *, body_field: Field, model_name_map: Dict[Type, str]
): ) -> Optional[Dict]:
if not body_field: if not body_field:
return None return None
assert isinstance(body_field, Field) assert isinstance(body_field, Field)
body_schema, _ = field_schema( body_schema, _ = field_schema(
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
) )
if isinstance(body_field.schema, Body): schema: Schema = body_field.schema
request_media_type = body_field.schema.media_type if isinstance(schema, Body):
request_media_type = schema.media_type
else: else:
# Includes not declared media types (Schema) # Includes not declared media types (Schema)
request_media_type = "application/json" request_media_type = "application/json"
required = body_field.required required = body_field.required
request_body_oai = {} request_body_oai: Dict[str, Any] = {}
if required: if required:
request_body_oai["required"] = required request_body_oai["required"] = required
request_body_oai["content"] = {request_media_type: {"schema": body_schema}} request_body_oai["content"] = {request_media_type: {"schema": body_schema}}
return request_body_oai return request_body_oai
def generate_operation_id(*, route: routing.APIRoute, method: str): def generate_operation_id(*, route: routing.APIRoute, method: str) -> str:
if route.operation_id: if route.operation_id:
return route.operation_id return route.operation_id
path: str = route.path path: str = route.path
@ -123,12 +122,13 @@ def generate_operation_id(*, route: routing.APIRoute, method: str):
return operation_id return operation_id
def generate_operation_summary(*, route: routing.APIRoute, method: str): def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
if route.summary: if route.summary:
return route.summary return route.summary
return method.title() + " " + route.name.replace("_", " ").title() return method.title() + " " + route.name.replace("_", " ").title()
def get_openapi_operation_metadata(*, route: BaseRoute, method: str):
def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> Dict:
operation: Dict[str, Any] = {} operation: Dict[str, Any] = {}
if route.tags: if route.tags:
operation["tags"] = route.tags operation["tags"] = route.tags
@ -141,12 +141,13 @@ def get_openapi_operation_metadata(*, route: BaseRoute, method: str):
return operation return operation
def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]): def get_openapi_path(
if not (route.include_in_schema and isinstance(route, routing.APIRoute)): *, route: routing.APIRoute, model_name_map: Dict[Type, str]
return None ) -> Tuple[Dict, Dict, Dict]:
path = {} path = {}
security_schemes: Dict[str, Any] = {} security_schemes: Dict[str, Any] = {}
definitions: Dict[str, Any] = {} definitions: Dict[str, Any] = {}
assert route.methods is not None, "Methods must be a list"
for method in route.methods: for method in route.methods:
operation = get_openapi_operation_metadata(route=route, method=method) operation = get_openapi_operation_metadata(route=route, method=method)
parameters: List[Dict] = [] parameters: List[Dict] = []
@ -172,10 +173,9 @@ def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
) )
if request_body_oai: if request_body_oai:
operation["requestBody"] = request_body_oai operation["requestBody"] = request_body_oai
response_code = str(route.response_code) status_code = str(route.status_code)
response_schema = {"type": "string"} response_schema = {"type": "string"}
if lenient_issubclass(route.response_wrapper, JSONResponse): if lenient_issubclass(route.content_type, JSONResponse):
response_media_type = "application/json"
if route.response_field: if route.response_field:
response_schema, _ = field_schema( response_schema, _ = field_schema(
route.response_field, route.response_field,
@ -184,16 +184,9 @@ def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
) )
else: else:
response_schema = {} response_schema = {}
elif lenient_issubclass(route.response_wrapper, HTMLResponse): content = {route.content_type.media_type: {"schema": response_schema}}
response_media_type = "text/html"
else:
response_media_type = "text/plain"
content = {response_media_type: {"schema": response_schema}}
operation["responses"] = { operation["responses"] = {
response_code: { status_code: {"description": route.response_description, "content": content}
"description": route.response_description,
"content": content,
}
} }
if all_route_params or route.body_field: if all_route_params or route.body_field:
operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = { operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
@ -215,7 +208,7 @@ def get_openapi(
openapi_version: str = "3.0.2", openapi_version: str = "3.0.2",
description: str = None, description: str = None,
routes: Sequence[BaseRoute] routes: Sequence[BaseRoute]
): ) -> Dict:
info = {"title": title, "version": version} info = {"title": title, "version": version}
if description: if description:
info["description"] = description info["description"] = description
@ -228,15 +221,18 @@ def get_openapi(
flat_models=flat_models, model_name_map=model_name_map flat_models=flat_models, model_name_map=model_name_map
) )
for route in routes: for route in routes:
result = get_openapi_path(route=route, model_name_map=model_name_map) if isinstance(route, routing.APIRoute):
if result: result = get_openapi_path(route=route, model_name_map=model_name_map)
path, security_schemes, path_definitions = result if result:
if path: path, security_schemes, path_definitions = result
paths.setdefault(route.path, {}).update(path) if path:
if security_schemes: paths.setdefault(route.path, {}).update(path)
components.setdefault("securitySchemes", {}).update(security_schemes) if security_schemes:
if path_definitions: components.setdefault("securitySchemes", {}).update(
definitions.update(path_definitions) security_schemes
)
if path_definitions:
definitions.update(path_definitions)
if definitions: if definitions:
components.setdefault("schemas", {}).update(definitions) components.setdefault("schemas", {}).update(definitions)
if components: if components:

46
fastapi/params.py

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Sequence, Any, Dict from typing import Any, Callable, Sequence
from pydantic import Schema from pydantic import Schema
@ -16,7 +16,7 @@ class Param(Schema):
def __init__( def __init__(
self, self,
default, default: Any,
*, *,
deprecated: bool = None, deprecated: bool = None,
alias: str = None, alias: str = None,
@ -29,7 +29,7 @@ class Param(Schema):
min_length: int = None, min_length: int = None,
max_length: int = None, max_length: int = None,
regex: str = None, regex: str = None,
**extra: Dict[str, Any], **extra: Any,
): ):
self.deprecated = deprecated self.deprecated = deprecated
super().__init__( super().__init__(
@ -53,7 +53,7 @@ class Path(Param):
def __init__( def __init__(
self, self,
default, default: Any,
*, *,
deprecated: bool = None, deprecated: bool = None,
alias: str = None, alias: str = None,
@ -66,7 +66,7 @@ class Path(Param):
min_length: int = None, min_length: int = None,
max_length: int = None, max_length: int = None,
regex: str = None, regex: str = None,
**extra: Dict[str, Any], **extra: Any,
): ):
self.description = description self.description = description
self.deprecated = deprecated self.deprecated = deprecated
@ -92,7 +92,7 @@ class Query(Param):
def __init__( def __init__(
self, self,
default, default: Any,
*, *,
deprecated: bool = None, deprecated: bool = None,
alias: str = None, alias: str = None,
@ -105,7 +105,7 @@ class Query(Param):
min_length: int = None, min_length: int = None,
max_length: int = None, max_length: int = None,
regex: str = None, regex: str = None,
**extra: Dict[str, Any], **extra: Any,
): ):
self.description = description self.description = description
self.deprecated = deprecated self.deprecated = deprecated
@ -130,10 +130,11 @@ class Header(Param):
def __init__( def __init__(
self, self,
default, default: Any,
*, *,
deprecated: bool = None, deprecated: bool = None,
alias: str = None, alias: str = None,
alias_underscore_to_hyphen: bool = True,
title: str = None, title: str = None,
description: str = None, description: str = None,
gt: float = None, gt: float = None,
@ -143,10 +144,11 @@ class Header(Param):
min_length: int = None, min_length: int = None,
max_length: int = None, max_length: int = None,
regex: str = None, regex: str = None,
**extra: Dict[str, Any], **extra: Any,
): ):
self.description = description self.description = description
self.deprecated = deprecated self.deprecated = deprecated
self.alias_underscore_to_hyphen = alias_underscore_to_hyphen
super().__init__( super().__init__(
default, default,
alias=alias, alias=alias,
@ -168,7 +170,7 @@ class Cookie(Param):
def __init__( def __init__(
self, self,
default, default: Any,
*, *,
deprecated: bool = None, deprecated: bool = None,
alias: str = None, alias: str = None,
@ -181,7 +183,7 @@ class Cookie(Param):
min_length: int = None, min_length: int = None,
max_length: int = None, max_length: int = None,
regex: str = None, regex: str = None,
**extra: Dict[str, Any], **extra: Any,
): ):
self.description = description self.description = description
self.deprecated = deprecated self.deprecated = deprecated
@ -204,9 +206,9 @@ class Cookie(Param):
class Body(Schema): class Body(Schema):
def __init__( def __init__(
self, self,
default, default: Any,
*, *,
embed=False, embed: bool = False,
media_type: str = "application/json", media_type: str = "application/json",
alias: str = None, alias: str = None,
title: str = None, title: str = None,
@ -218,7 +220,7 @@ class Body(Schema):
min_length: int = None, min_length: int = None,
max_length: int = None, max_length: int = None,
regex: str = None, regex: str = None,
**extra: Dict[str, Any], **extra: Any,
): ):
self.embed = embed self.embed = embed
self.media_type = media_type self.media_type = media_type
@ -241,9 +243,9 @@ class Body(Schema):
class Form(Body): class Form(Body):
def __init__( def __init__(
self, self,
default, default: Any,
*, *,
sub_key=False, sub_key: bool = False,
media_type: str = "application/x-www-form-urlencoded", media_type: str = "application/x-www-form-urlencoded",
alias: str = None, alias: str = None,
title: str = None, title: str = None,
@ -255,7 +257,7 @@ class Form(Body):
min_length: int = None, min_length: int = None,
max_length: int = None, max_length: int = None,
regex: str = None, regex: str = None,
**extra: Dict[str, Any], **extra: Any,
): ):
super().__init__( super().__init__(
default, default,
@ -278,9 +280,9 @@ class Form(Body):
class File(Form): class File(Form):
def __init__( def __init__(
self, self,
default, default: Any,
*, *,
sub_key=False, sub_key: bool = False,
media_type: str = "multipart/form-data", media_type: str = "multipart/form-data",
alias: str = None, alias: str = None,
title: str = None, title: str = None,
@ -292,7 +294,7 @@ class File(Form):
min_length: int = None, min_length: int = None,
max_length: int = None, max_length: int = None,
regex: str = None, regex: str = None,
**extra: Dict[str, Any], **extra: Any,
): ):
super().__init__( super().__init__(
default, default,
@ -313,11 +315,11 @@ class File(Form):
class Depends: class Depends:
def __init__(self, dependency=None): def __init__(self, dependency: Callable = None):
self.dependency = dependency self.dependency = dependency
class Security(Depends): class Security(Depends):
def __init__(self, dependency=None, scopes: Sequence[str] = None): def __init__(self, dependency: Callable = None, scopes: Sequence[str] = None):
self.scopes = scopes or [] self.scopes = scopes or []
super().__init__(dependency=dependency) super().__init__(dependency=dependency)

437
fastapi/routing.py

@ -1,12 +1,11 @@
import asyncio import asyncio
import inspect import inspect
from typing import Callable, List, Type from typing import Any, Callable, List, Optional, Type
from pydantic import BaseConfig, BaseModel, Schema from pydantic import BaseConfig, BaseModel, Schema
from pydantic.error_wrappers import ErrorWrapper, ValidationError from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.fields import Field from pydantic.fields import Field
from pydantic.utils import lenient_issubclass from pydantic.utils import lenient_issubclass
from starlette import routing from starlette import routing
from starlette.concurrency import run_in_threadpool from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
@ -22,7 +21,7 @@ from fastapi.dependencies.utils import get_body_field, get_dependant, solve_depe
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
def serialize_response(*, field: Field = None, response): def serialize_response(*, field: Field = None, response: Response) -> Any:
if field: if field:
errors = [] errors = []
value, errors_ = field.validate(response, {}, loc=("response",)) value, errors_ = field.validate(response, {}, loc=("response",))
@ -40,11 +39,12 @@ def serialize_response(*, field: Field = None, response):
def get_app( def get_app(
dependant: Dependant, dependant: Dependant,
body_field: Field = None, body_field: Field = None,
response_code: str = 200, status_code: int = 200,
response_wrapper: Type[Response] = JSONResponse, content_type: Type[Response] = JSONResponse,
response_field: Type[Field] = None, response_field: Field = None,
): ) -> Callable:
is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call) assert dependant.call is not None, "dependant.call must me a function"
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(body_field.schema, params.Form) is_body_form = body_field and isinstance(body_field.schema, params.Form)
async def app(request: Request) -> Response: async def app(request: Request) -> Response:
@ -69,6 +69,7 @@ def get_app(
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors() status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
) )
else: else:
assert dependant.call is not None, "dependant.call must me a function"
if is_coroutine: if is_coroutine:
raw_response = await dependant.call(**values) raw_response = await dependant.call(**values)
else: else:
@ -76,32 +77,32 @@ def get_app(
if isinstance(raw_response, Response): if isinstance(raw_response, Response):
return raw_response return raw_response
if isinstance(raw_response, BaseModel): if isinstance(raw_response, BaseModel):
return response_wrapper( return content_type(
content=jsonable_encoder(raw_response), status_code=response_code content=jsonable_encoder(raw_response), status_code=status_code
) )
errors = [] errors = []
try: try:
return response_wrapper( return content_type(
content=serialize_response( content=serialize_response(
field=response_field, response=raw_response field=response_field, response=raw_response
), ),
status_code=response_code, status_code=status_code,
) )
except Exception as e: except Exception as e:
errors.append(e) errors.append(e)
try: try:
response = dict(raw_response) response = dict(raw_response)
return response_wrapper( return content_type(
content=serialize_response(field=response_field, response=response), content=serialize_response(field=response_field, response=response),
status_code=response_code, status_code=status_code,
) )
except Exception as e: except Exception as e:
errors.append(e) errors.append(e)
try: try:
response = vars(raw_response) response = vars(raw_response)
return response_wrapper( return content_type(
content=serialize_response(field=response_field, response=response), content=serialize_response(field=response_field, response=response),
status_code=response_code, status_code=status_code,
) )
except Exception as e: except Exception as e:
errors.append(e) errors.append(e)
@ -116,43 +117,32 @@ class APIRoute(routing.Route):
path: str, path: str,
endpoint: Callable, endpoint: Callable,
*, *,
methods: List[str] = None, response_model: Type[BaseModel] = None,
name: str = None, status_code: int = 200,
include_in_schema: bool = True,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
methods: List[str] = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> None: ) -> None:
assert path.startswith("/"), "Routed paths must always start with '/'" assert path.startswith("/"), "Routed paths must always start with '/'"
self.path = path self.path = path
self.endpoint = endpoint self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name self.name = get_name(endpoint) if name is None else name
self.include_in_schema = include_in_schema self.response_model = response_model
self.tags = tags or [] if self.response_model:
self.summary = summary
self.description = description or self.endpoint.__doc__
self.operation_id = operation_id
self.deprecated = deprecated
self.body_field: Field = None
self.response_description = response_description
self.response_code = response_code
self.response_wrapper = response_wrapper
self.response_field = None
if response_type:
assert lenient_issubclass( assert lenient_issubclass(
response_wrapper, JSONResponse content_type, JSONResponse
), "To declare a type the response must be a JSON response" ), "To declare a type the response must be a JSON response"
self.response_type = response_type
response_name = "Response_" + self.name response_name = "Response_" + self.name
self.response_field = Field( self.response_field: Optional[Field] = Field(
name=response_name, name=response_name,
type_=self.response_type, type_=self.response_model,
class_validators=[], class_validators=[],
default=None, default=None,
required=False, required=False,
@ -160,25 +150,34 @@ class APIRoute(routing.Route):
schema=Schema(None), schema=Schema(None),
) )
else: else:
self.response_type = None self.response_field = None
self.status_code = status_code
self.tags = tags or []
self.summary = summary
self.description = description or self.endpoint.__doc__
self.response_description = response_description
self.deprecated = deprecated
if methods is None: if methods is None:
methods = ["GET"] methods = ["GET"]
self.methods = methods self.methods = methods
self.operation_id = operation_id
self.include_in_schema = include_in_schema
self.content_type = content_type
self.path_regex, self.path_format, self.param_convertors = self.compile_path( self.path_regex, self.path_format, self.param_convertors = self.compile_path(
path path
) )
assert inspect.isfunction(endpoint) or inspect.ismethod( assert inspect.isfunction(endpoint) or inspect.ismethod(
endpoint endpoint
), f"An endpoint must be a function or method" ), f"An endpoint must be a function or method"
self.dependant = get_dependant(path=path, call=self.endpoint) self.dependant = get_dependant(path=path, call=self.endpoint)
self.body_field = get_body_field(dependant=self.dependant, name=self.name) self.body_field = get_body_field(dependant=self.dependant, name=self.name)
self.app = request_response( self.app = request_response(
get_app( get_app(
dependant=self.dependant, dependant=self.dependant,
body_field=self.body_field, body_field=self.body_field,
response_code=self.response_code, status_code=self.status_code,
response_wrapper=self.response_wrapper, content_type=self.content_type,
response_field=self.response_field, response_field=self.response_field,
) )
) )
@ -189,75 +188,77 @@ class APIRouter(routing.Router):
self, self,
path: str, path: str,
endpoint: Callable, endpoint: Callable,
methods: List[str] = None, *,
name: str = None, response_model: Type[BaseModel] = None,
include_in_schema: bool = True, status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
methods: List[str] = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> None: ) -> None:
route = APIRoute( route = APIRoute(
path, path,
endpoint=endpoint, endpoint=endpoint,
methods=methods, response_model=response_model,
name=name, status_code=status_code,
include_in_schema=include_in_schema,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
methods=methods,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
self.routes.append(route) self.routes.append(route)
def api_route( def api_route(
self, self,
path: str, path: str,
methods: List[str] = None, *,
name: str = None, response_model: Type[BaseModel] = None,
include_in_schema: bool = True, status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
methods: List[str] = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable: ) -> Callable:
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
self.add_api_route( self.add_api_route(
path, path,
func, func,
methods=methods, response_model=response_model,
name=name, status_code=status_code,
include_in_schema=include_in_schema,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
methods=methods,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
return func return func
return decorator return decorator
def include_router(self, router: "APIRouter", *, prefix=""): def include_router(self, router: "APIRouter", *, prefix: str = "") -> None:
if prefix: if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'" assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith( assert not prefix.endswith(
@ -268,18 +269,18 @@ class APIRouter(routing.Router):
self.add_api_route( self.add_api_route(
prefix + route.path, prefix + route.path,
route.endpoint, route.endpoint,
methods=route.methods, response_model=route.response_model,
name=route.name, status_code=route.status_code,
include_in_schema=route.include_in_schema, tags=route.tags or [],
tags=route.tags,
summary=route.summary, summary=route.summary,
description=route.description, description=route.description,
operation_id=route.operation_id,
deprecated=route.deprecated,
response_type=route.response_type,
response_description=route.response_description, response_description=route.response_description,
response_code=route.response_code, deprecated=route.deprecated,
response_wrapper=route.response_wrapper, name=route.name,
methods=route.methods,
operation_id=route.operation_id,
include_in_schema=route.include_in_schema,
content_type=route.content_type,
) )
elif isinstance(route, routing.Route): elif isinstance(route, routing.Route):
self.add_route( self.add_route(
@ -293,247 +294,255 @@ class APIRouter(routing.Router):
def get( def get(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route( return self.api_route(
path=path, path=path,
methods=["GET"], response_model=response_model,
name=name, status_code=status_code,
include_in_schema=include_in_schema,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
methods=["GET"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
def put( def put(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route( return self.api_route(
path=path, path=path,
methods=["PUT"], response_model=response_model,
name=name, status_code=status_code,
include_in_schema=include_in_schema,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
methods=["PUT"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
def post( def post(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route( return self.api_route(
path=path, path=path,
methods=["POST"], response_model=response_model,
name=name, status_code=status_code,
include_in_schema=include_in_schema,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
methods=["POST"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
def delete( def delete(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route( return self.api_route(
path=path, path=path,
methods=["DELETE"], response_model=response_model,
name=name, status_code=status_code,
include_in_schema=include_in_schema,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
methods=["DELETE"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
def options( def options(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route( return self.api_route(
path=path, path=path,
methods=["OPTIONS"], response_model=response_model,
name=name, status_code=status_code,
include_in_schema=include_in_schema,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
methods=["OPTIONS"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
def head( def head(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route( return self.api_route(
path=path, path=path,
methods=["HEAD"], response_model=response_model,
name=name, status_code=status_code,
include_in_schema=include_in_schema,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
methods=["HEAD"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
def patch( def patch(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route( return self.api_route(
path=path, path=path,
methods=["PATCH"], response_model=response_model,
name=name, status_code=status_code,
include_in_schema=include_in_schema,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
methods=["PATCH"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )
def trace( def trace(
self, self,
path: str, path: str,
name: str = None, *,
include_in_schema: bool = True, response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None, tags: List[str] = None,
summary: str = None, summary: str = None,
description: str = None, description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response", response_description: str = "Successful Response",
response_code=200, deprecated: bool = None,
response_wrapper=JSONResponse, name: str = None,
): operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route( return self.api_route(
path=path, path=path,
methods=["TRACE"], response_model=response_model,
name=name, status_code=status_code,
include_in_schema=include_in_schema,
tags=tags or [], tags=tags or [],
summary=summary, summary=summary,
description=description, description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description, response_description=response_description,
response_code=response_code, deprecated=deprecated,
response_wrapper=response_wrapper, name=name,
methods=["TRACE"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
) )

13
fastapi/security/api_key.py

@ -1,18 +1,19 @@
from starlette.requests import Request from starlette.requests import Request
from .base import SecurityBase from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.openapi.models import APIKeyIn, APIKey from fastapi.security.base import SecurityBase
class APIKeyBase(SecurityBase): class APIKeyBase(SecurityBase):
pass pass
class APIKeyQuery(APIKeyBase):
class APIKeyQuery(APIKeyBase):
def __init__(self, *, name: str, scheme_name: str = None): def __init__(self, *, name: str, scheme_name: str = None):
self.model = APIKey(in_=APIKeyIn.query, name=name) self.model = APIKey(in_=APIKeyIn.query, name=name)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request): async def __call__(self, requests: Request) -> str:
return requests.query_params.get(self.model.name) return requests.query_params.get(self.model.name)
@ -21,7 +22,7 @@ class APIKeyHeader(APIKeyBase):
self.model = APIKey(in_=APIKeyIn.header, name=name) self.model = APIKey(in_=APIKeyIn.header, name=name)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request): async def __call__(self, requests: Request) -> str:
return requests.headers.get(self.model.name) return requests.headers.get(self.model.name)
@ -30,5 +31,5 @@ class APIKeyCookie(APIKeyBase):
self.model = APIKey(in_=APIKeyIn.cookie, name=name) self.model = APIKey(in_=APIKeyIn.cookie, name=name)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request): async def __call__(self, requests: Request) -> str:
return requests.cookies.get(self.model.name) return requests.cookies.get(self.model.name)

6
fastapi/security/base.py

@ -1,6 +1,6 @@
from pydantic import BaseModel
from fastapi.openapi.models import SecurityBase as SecurityBaseModel from fastapi.openapi.models import SecurityBase as SecurityBaseModel
class SecurityBase: class SecurityBase:
pass model: SecurityBaseModel
scheme_name: str

15
fastapi/security/http.py

@ -1,7 +1,10 @@
from starlette.requests import Request from starlette.requests import Request
from .base import SecurityBase from fastapi.openapi.models import (
from fastapi.openapi.models import HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel HTTPBase as HTTPBaseModel,
HTTPBearer as HTTPBearerModel,
)
from fastapi.security.base import SecurityBase
class HTTPBase(SecurityBase): class HTTPBase(SecurityBase):
@ -9,7 +12,7 @@ class HTTPBase(SecurityBase):
self.model = HTTPBaseModel(scheme=scheme) self.model = HTTPBaseModel(scheme=scheme)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request): async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization") return request.headers.get("Authorization")
@ -18,7 +21,7 @@ class HTTPBasic(HTTPBase):
self.model = HTTPBaseModel(scheme="basic") self.model = HTTPBaseModel(scheme="basic")
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request): async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization") return request.headers.get("Authorization")
@ -27,7 +30,7 @@ class HTTPBearer(HTTPBase):
self.model = HTTPBearerModel(bearerFormat=bearerFormat) self.model = HTTPBearerModel(bearerFormat=bearerFormat)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request): async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization") return request.headers.get("Authorization")
@ -36,5 +39,5 @@ class HTTPDigest(HTTPBase):
self.model = HTTPBaseModel(scheme="digest") self.model = HTTPBaseModel(scheme="digest")
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request): async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization") return request.headers.get("Authorization")

8
fastapi/security/oauth2.py

@ -1,13 +1,15 @@
from starlette.requests import Request from starlette.requests import Request
from .base import SecurityBase
from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel
from fastapi.security.base import SecurityBase
class OAuth2(SecurityBase): class OAuth2(SecurityBase):
def __init__(self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None): def __init__(
self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None
):
self.model = OAuth2Model(flows=flows) self.model = OAuth2Model(flows=flows)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request): async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization") return request.headers.get("Authorization")

4
fastapi/security/open_id_connect_url.py

@ -1,7 +1,7 @@
from starlette.requests import Request from starlette.requests import Request
from .base import SecurityBase
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase
class OpenIdConnect(SecurityBase): class OpenIdConnect(SecurityBase):
@ -9,5 +9,5 @@ class OpenIdConnect(SecurityBase):
self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl) self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl)
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request): async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization") return request.headers.get("Authorization")

24
fastapi/utils.py

@ -1,20 +1,24 @@
import re import re
from typing import Dict, Sequence, Set, Type from typing import Any, Dict, List, Sequence, Set, Type
from pydantic import BaseModel
from pydantic.fields import Field
from pydantic.schema import get_flat_models_from_fields, model_process_schema
from starlette.routing import BaseRoute from starlette.routing import BaseRoute
from fastapi import routing from fastapi import routing
from fastapi.openapi.constants import REF_PREFIX from fastapi.openapi.constants import REF_PREFIX
from pydantic import BaseModel
from pydantic.fields import Field
from pydantic.schema import get_flat_models_from_fields, model_process_schema
def get_flat_models_from_routes(routes: Sequence[BaseRoute]): def get_flat_models_from_routes(
body_fields_from_routes = [] routes: Sequence[Type[BaseRoute]]
responses_from_routes = [] ) -> Set[Type[BaseModel]]:
body_fields_from_routes: List[Field] = []
responses_from_routes: List[Field] = []
for route in routes: for route in routes:
if route.include_in_schema and isinstance(route, routing.APIRoute): if getattr(route, "include_in_schema", None) and isinstance(
route, routing.APIRoute
):
if route.body_field: if route.body_field:
assert isinstance( assert isinstance(
route.body_field, Field route.body_field, Field
@ -30,7 +34,7 @@ def get_flat_models_from_routes(routes: Sequence[BaseRoute]):
def get_model_definitions( def get_model_definitions(
*, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str] *, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
): ) -> Dict[str, Any]:
definitions: Dict[str, Dict] = {} definitions: Dict[str, Dict] = {}
for model in flat_models: for model in flat_models:
m_schema, m_definitions = model_process_schema( m_schema, m_definitions = model_process_schema(
@ -42,5 +46,5 @@ def get_model_definitions(
return definitions return definitions
def get_path_param_names(path: str): def get_path_param_names(path: str) -> Set[str]:
return {item.strip("{}") for item in re.findall("{[^}]*}", path)} return {item.strip("{}") for item in re.findall("{[^}]*}", path)}

Loading…
Cancel
Save