Browse Source

Update parameter names and order

fix mypy types, refactor, lint
pull/1/head
Sebastián Ramírez 7 years ago
parent
commit
0e19c24014
  1. 381
      fastapi/applications.py
  2. 8
      fastapi/dependencies/models.py
  3. 65
      fastapi/dependencies/utils.py
  4. 8
      fastapi/encoders.py
  5. 11
      fastapi/openapi/docs.py
  6. 48
      fastapi/openapi/models.py
  7. 98
      fastapi/openapi/utils.py
  8. 46
      fastapi/params.py
  9. 437
      fastapi/routing.py
  10. 13
      fastapi/security/api_key.py
  11. 6
      fastapi/security/base.py
  12. 21
      fastapi/security/http.py
  13. 10
      fastapi/security/oauth2.py
  14. 6
      fastapi/security/open_id_connect_url.py
  15. 24
      fastapi/utils.py

381
fastapi/applications.py

@ -1,19 +1,19 @@
from typing import Any, Callable, Dict, List, Type
from typing import Any, Callable, Dict, List, Optional, Type
from pydantic import BaseModel
from starlette.applications import Starlette
from starlette.exceptions import ExceptionMiddleware, HTTPException
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.lifespan import LifespanMiddleware
from starlette.responses import JSONResponse
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from fastapi import routing
from fastapi.openapi.utils import get_openapi
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
async def http_exception(request, exc: HTTPException):
print(exc)
async def http_exception(request: Request, exc: HTTPException) -> JSONResponse:
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
@ -31,7 +31,7 @@ class FastAPI(Starlette):
**extra: Dict[str, Any],
) -> None:
self._debug = debug
self.router = routing.APIRouter()
self.router: routing.APIRouter = routing.APIRouter()
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
self.error_middleware = ServerErrorMiddleware(
self.exception_middleware, debug=debug
@ -56,33 +56,41 @@ class FastAPI(Starlette):
if self.swagger_ui_url or self.redoc_url:
assert self.openapi_url, "The openapi_url is required for the docs"
self.openapi_schema: Optional[Dict[str, Any]] = None
self.setup()
def setup(self):
def openapi(self) -> Dict:
if not self.openapi_schema:
self.openapi_schema = get_openapi(
title=self.title,
version=self.version,
openapi_version=self.openapi_version,
description=self.description,
routes=self.routes,
)
return self.openapi_schema
def setup(self) -> None:
if self.openapi_url:
self.add_route(
self.openapi_url,
lambda req: JSONResponse(
get_openapi(
title=self.title,
version=self.version,
openapi_version=self.openapi_version,
description=self.description,
routes=self.routes,
)
),
lambda req: JSONResponse(self.openapi()),
include_in_schema=False,
)
if self.swagger_ui_url:
self.add_route(
self.swagger_ui_url,
lambda r: get_swagger_ui_html(openapi_url=self.openapi_url, title=self.title + " - Swagger UI"),
lambda r: get_swagger_ui_html(
openapi_url=self.openapi_url, title=self.title + " - Swagger UI"
),
include_in_schema=False,
)
if self.redoc_url:
self.add_route(
self.redoc_url,
lambda r: get_redoc_html(openapi_url=self.openapi_url, title=self.title + " - ReDoc"),
lambda r: get_redoc_html(
openapi_url=self.openapi_url, title=self.title + " - ReDoc"
),
include_in_schema=False,
)
self.add_exception_handler(HTTPException, http_exception)
@ -91,311 +99,322 @@ class FastAPI(Starlette):
self,
path: str,
endpoint: Callable,
methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
deprecated: bool = None,
name: str = None,
methods: List[str] = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> None:
self.router.add_api_route(
path,
endpoint=endpoint,
methods=methods,
name=name,
include_in_schema=include_in_schema,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
methods=methods,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
def api_route(
self,
path: str,
methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
deprecated: bool = None,
name: str = None,
methods: List[str] = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
def decorator(func: Callable) -> Callable:
self.router.add_api_route(
path,
func,
methods=methods,
name=name,
include_in_schema=include_in_schema,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
methods=methods,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
return func
return decorator
def include_router(self, router: "APIRouter", *, prefix=""):
def include_router(self, router: routing.APIRouter, *, prefix: str = "") -> None:
self.router.include_router(router, prefix=prefix)
def get(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.get(
path=path,
name=name,
include_in_schema=include_in_schema,
path,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
def put(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.put(
path=path,
name=name,
include_in_schema=include_in_schema,
path,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
def post(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.post(
path=path,
name=name,
include_in_schema=include_in_schema,
path,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
def delete(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.delete(
path=path,
name=name,
include_in_schema=include_in_schema,
path,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
def options(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.options(
path=path,
name=name,
include_in_schema=include_in_schema,
path,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
def head(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.head(
path=path,
name=name,
include_in_schema=include_in_schema,
path,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
def patch(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.patch(
path=path,
name=name,
include_in_schema=include_in_schema,
path,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
def trace(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.router.trace(
path=path,
name=name,
include_in_schema=include_in_schema,
path,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)

8
fastapi/dependencies/models.py

@ -1,14 +1,14 @@
from typing import Any, Callable, Dict, List, Sequence, Tuple
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from fastapi.security.base import SecurityBase
from pydantic import BaseConfig, Schema
from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import Field, Required
from pydantic.schema import get_annotation_from_schema
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from fastapi.security.base import SecurityBase
param_supported_types = (str, int, float, bool)

65
fastapi/dependencies/utils.py

@ -1,8 +1,14 @@
import asyncio
import inspect
from copy import deepcopy
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Type
from pydantic import BaseConfig, Schema, create_model
from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import Field, Required
from pydantic.schema import get_annotation_from_schema
from pydantic.utils import lenient_issubclass
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
@ -10,17 +16,11 @@ from fastapi import params
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.security.base import SecurityBase
from fastapi.utils import get_path_param_names
from pydantic import BaseConfig, Schema, create_model
from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import Field, Required
from pydantic.schema import get_annotation_from_schema
from pydantic.utils import lenient_issubclass
param_supported_types = (str, int, float, bool)
def get_sub_dependant(*, param: inspect.Parameter, path: str):
def get_sub_dependant(*, param: inspect.Parameter, path: str) -> Dependant:
depends: params.Depends = param.default
if depends.dependency:
dependency = depends.dependency
@ -36,7 +36,7 @@ def get_sub_dependant(*, param: inspect.Parameter, path: str):
return sub_dependant
def get_flat_dependant(dependant: Dependant):
def get_flat_dependant(dependant: Dependant) -> Dependant:
flat_dependant = Dependant(
path_params=dependant.path_params.copy(),
query_params=dependant.query_params.copy(),
@ -58,7 +58,7 @@ def get_flat_dependant(dependant: Dependant):
return flat_dependant
def get_dependant(*, path: str, call: Callable, name: str = None):
def get_dependant(*, path: str, call: Callable, name: str = None) -> Dependant:
path_param_names = get_path_param_names(path)
endpoint_signature = inspect.signature(call)
signature_params = endpoint_signature.parameters
@ -73,9 +73,10 @@ def get_dependant(*, path: str, call: Callable, name: str = None):
if (
(param.default == param.empty) or isinstance(param.default, params.Path)
) and (param_name in path_param_names):
assert lenient_issubclass(
param.annotation, param_supported_types
) or param.annotation == param.empty, f"Path params must be of type str, int, float or boot: {param}"
assert (
lenient_issubclass(param.annotation, param_supported_types)
or param.annotation == param.empty
), f"Path params must be of type str, int, float or boot: {param}"
param = signature_params[param_name]
add_param_to_fields(
param=param,
@ -109,9 +110,9 @@ def add_param_to_fields(
*,
param: inspect.Parameter,
dependant: Dependant,
default_schema=params.Param,
default_schema: Type[Schema] = params.Param,
force_type: params.ParamTypes = None,
):
) -> None:
default_value = Required
if not param.default == param.empty:
default_value = param.default
@ -125,15 +126,19 @@ def add_param_to_fields(
else:
schema = default_schema(default_value)
required = default_value == Required
annotation = Any
annotation: Type = Type[Any]
if not param.annotation == param.empty:
annotation = param.annotation
annotation = get_annotation_from_schema(annotation, schema)
if not schema.alias and getattr(schema, "alias_underscore_to_hyphen", None):
alias = param.name.replace("_", "-")
else:
alias = schema.alias or param.name
field = Field(
name=param.name,
type_=annotation,
default=None if required else default_value,
alias=schema.alias or param.name,
alias=alias,
required=required,
model_config=BaseConfig(),
class_validators=[],
@ -152,7 +157,7 @@ def add_param_to_fields(
dependant.cookie_params.append(field)
def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant) -> None:
default_value = Required
if not param.default == param.empty:
default_value = param.default
@ -176,7 +181,7 @@ def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
dependant.body_params.append(field)
def is_coroutine_callable(call: Callable = None):
def is_coroutine_callable(call: Callable = None) -> bool:
if not call:
return False
if inspect.isfunction(call):
@ -191,7 +196,7 @@ def is_coroutine_callable(call: Callable = None):
async def solve_dependencies(
*, request: Request, dependant: Dependant, body: Dict[str, Any] = None
):
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = []
for sub_dependant in dependant.dependencies:
@ -200,13 +205,13 @@ async def solve_dependencies(
)
if sub_errors:
return {}, errors
if sub_dependant.call and is_coroutine_callable(sub_dependant.call):
assert sub_dependant.call is not None, "sub_dependant.call must be a function"
if is_coroutine_callable(sub_dependant.call):
solved = await sub_dependant.call(**sub_values)
else:
solved = await run_in_threadpool(sub_dependant.call, **sub_values)
values[
sub_dependant.name
] = solved # type: ignore # Sub-dependants always have a name
assert sub_dependant.name is not None, "Subdependants always have a name"
values[sub_dependant.name] = solved
path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params
)
@ -236,7 +241,7 @@ async def solve_dependencies(
def request_params_to_args(
required_params: List[Field], received_params: Dict[str, Any]
required_params: Sequence[Field], received_params: Mapping[str, Any]
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
values = {}
errors = []
@ -250,9 +255,9 @@ def request_params_to_args(
else:
values[field.name] = deepcopy(field.default)
continue
v_, errors_ = field.validate(
value, values, loc=(field.schema.in_.value, field.alias)
)
schema: params.Param = field.schema
assert isinstance(schema, params.Param), "Params must be subclasses of Param"
v_, errors_ = field.validate(value, values, loc=(schema.in_.value, field.alias))
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
elif isinstance(errors_, list):
@ -294,7 +299,7 @@ async def request_body_to_args(
return values, errors
def get_body_field(*, dependant: Dependant, name: str):
def get_body_field(*, dependant: Dependant, name: str) -> Field:
flat_dependant = get_flat_dependant(dependant)
if not flat_dependant.body_params:
return None
@ -308,7 +313,7 @@ def get_body_field(*, dependant: Dependant, name: str):
BodyModel.__fields__[f.name] = f
required = any(True for f in flat_dependant.body_params if f.required)
if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params):
BodySchema = params.File
BodySchema: Type[params.Body] = params.File
elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params):
BodySchema = params.Form
else:

8
fastapi/encoders.py

@ -1,18 +1,18 @@
from enum import Enum
from types import GeneratorType
from typing import Set
from typing import Any, Set
from pydantic import BaseModel
from pydantic.json import pydantic_encoder
def jsonable_encoder(
obj,
obj: Any,
include: Set[str] = None,
exclude: Set[str] = set(),
by_alias: bool = False,
include_none=True,
):
include_none: bool = True,
) -> Any:
if isinstance(obj, BaseModel):
return jsonable_encoder(
obj.dict(include=include, exclude=exclude, by_alias=by_alias),

11
fastapi/openapi/docs.py

@ -1,6 +1,7 @@
from starlette.responses import HTMLResponse
def get_swagger_ui_html(*, openapi_url: str, title: str):
def get_swagger_ui_html(*, openapi_url: str, title: str) -> HTMLResponse:
return HTMLResponse(
"""
<! doctype html>
@ -35,12 +36,11 @@ def get_swagger_ui_html(*, openapi_url: str, title: str):
</script>
</body>
</html>
""",
media_type="text/html",
"""
)
def get_redoc_html(*, openapi_url: str, title: str):
def get_redoc_html(*, openapi_url: str, title: str) -> HTMLResponse:
return HTMLResponse(
"""
<!DOCTYPE html>
@ -73,6 +73,5 @@ def get_redoc_html(*, openapi_url: str, title: str):
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
</body>
</html>
""",
media_type="text/html",
"""
)

48
fastapi/openapi/models.py

@ -7,13 +7,13 @@ from pydantic.types import UrlStr
try:
import pydantic.types.EmailStr
from pydantic.types import EmailStr
from pydantic.types import EmailStr # type: ignore
except ImportError:
logging.warning(
"email-validator not installed, email fields will be treated as str"
)
class EmailStr(str):
class EmailStr(str): # type: ignore
pass
@ -50,7 +50,7 @@ class Server(BaseModel):
class Reference(BaseModel):
ref: str = PSchema(..., alias="$ref")
ref: str = PSchema(..., alias="$ref") # type: ignore
class Discriminator(BaseModel):
@ -72,28 +72,28 @@ class ExternalDocumentation(BaseModel):
class SchemaBase(BaseModel):
ref: Optional[str] = PSchema(None, alias="$ref")
ref: Optional[str] = PSchema(None, alias="$ref") # type: ignore
title: Optional[str] = None
multipleOf: Optional[float] = None
maximum: Optional[float] = None
exclusiveMaximum: Optional[float] = None
minimum: Optional[float] = None
exclusiveMinimum: Optional[float] = None
maxLength: Optional[int] = PSchema(None, gte=0)
minLength: Optional[int] = PSchema(None, gte=0)
maxLength: Optional[int] = PSchema(None, gte=0) # type: ignore
minLength: Optional[int] = PSchema(None, gte=0) # type: ignore
pattern: Optional[str] = None
maxItems: Optional[int] = PSchema(None, gte=0)
minItems: Optional[int] = PSchema(None, gte=0)
maxItems: Optional[int] = PSchema(None, gte=0) # type: ignore
minItems: Optional[int] = PSchema(None, gte=0) # type: ignore
uniqueItems: Optional[bool] = None
maxProperties: Optional[int] = PSchema(None, gte=0)
minProperties: Optional[int] = PSchema(None, gte=0)
maxProperties: Optional[int] = PSchema(None, gte=0) # type: ignore
minProperties: Optional[int] = PSchema(None, gte=0) # type: ignore
required: Optional[List[str]] = None
enum: Optional[List[str]] = None
type: Optional[str] = None
allOf: Optional[List[Any]] = None
oneOf: Optional[List[Any]] = None
anyOf: Optional[List[Any]] = None
not_: Optional[List[Any]] = PSchema(None, alias="not")
not_: Optional[List[Any]] = PSchema(None, alias="not") # type: ignore
items: Optional[Any] = None
properties: Optional[Dict[str, Any]] = None
additionalProperties: Optional[Union[bool, Any]] = None
@ -114,7 +114,7 @@ class Schema(SchemaBase):
allOf: Optional[List[SchemaBase]] = None
oneOf: Optional[List[SchemaBase]] = None
anyOf: Optional[List[SchemaBase]] = None
not_: Optional[List[SchemaBase]] = PSchema(None, alias="not")
not_: Optional[List[SchemaBase]] = PSchema(None, alias="not") # type: ignore
items: Optional[SchemaBase] = None
properties: Optional[Dict[str, SchemaBase]] = None
additionalProperties: Optional[Union[bool, SchemaBase]] = None
@ -144,7 +144,9 @@ class Encoding(BaseModel):
class MediaType(BaseModel):
schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
schema_: Optional[Union[Schema, Reference]] = PSchema(
None, alias="schema"
) # type: ignore
example: Optional[Any] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
encoding: Optional[Dict[str, Encoding]] = None
@ -158,7 +160,9 @@ class ParameterBase(BaseModel):
style: Optional[str] = None
explode: Optional[bool] = None
allowReserved: Optional[bool] = None
schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
schema_: Optional[Union[Schema, Reference]] = PSchema(
None, alias="schema"
) # type: ignore
example: Optional[Any] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
# Serialization rules for more complex scenarios
@ -167,7 +171,7 @@ class ParameterBase(BaseModel):
class Parameter(ParameterBase):
name: str
in_: ParameterInType = PSchema(..., alias="in")
in_: ParameterInType = PSchema(..., alias="in") # type: ignore
class Header(ParameterBase):
@ -222,7 +226,7 @@ class Operation(BaseModel):
class PathItem(BaseModel):
ref: Optional[str] = PSchema(None, alias="$ref")
ref: Optional[str] = PSchema(None, alias="$ref") # type: ignore
summary: Optional[str] = None
description: Optional[str] = None
get: Optional[Operation] = None
@ -250,7 +254,7 @@ class SecuritySchemeType(Enum):
class SecurityBase(BaseModel):
type_: SecuritySchemeType = PSchema(..., alias="type")
type_: SecuritySchemeType = PSchema(..., alias="type") # type: ignore
description: Optional[str] = None
@ -261,13 +265,13 @@ class APIKeyIn(Enum):
class APIKey(SecurityBase):
type_ = PSchema(SecuritySchemeType.apiKey, alias="type")
in_: APIKeyIn = PSchema(..., alias="in")
type_ = PSchema(SecuritySchemeType.apiKey, alias="type") # type: ignore
in_: APIKeyIn = PSchema(..., alias="in") # type: ignore
name: str
class HTTPBase(SecurityBase):
type_ = PSchema(SecuritySchemeType.http, alias="type")
type_ = PSchema(SecuritySchemeType.http, alias="type") # type: ignore
scheme: str
@ -306,12 +310,12 @@ class OAuthFlows(BaseModel):
class OAuth2(SecurityBase):
type_ = PSchema(SecuritySchemeType.oauth2, alias="type")
type_ = PSchema(SecuritySchemeType.oauth2, alias="type") # type: ignore
flows: OAuthFlows
class OpenIdConnect(SecurityBase):
type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type")
type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type") # type: ignore
openIdConnectUrl: str

98
fastapi/openapi/utils.py

@ -1,23 +1,21 @@
from typing import Any, Dict, Sequence, Type, List
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
from pydantic.fields import Field
from pydantic.schema import field_schema, get_model_name_map
from pydantic.schema import Schema, field_schema, get_model_name_map
from pydantic.utils import lenient_issubclass
from starlette.responses import HTMLResponse, JSONResponse
from starlette.routing import BaseRoute
from starlette.routing import BaseRoute, Route
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from fastapi import routing
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_flat_dependant
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY
from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX
from fastapi.openapi.models import OpenAPI
from fastapi.params import Body
from fastapi.params import Body, Param
from fastapi.utils import get_flat_models_from_routes, get_model_definitions
validation_error_definition = {
"title": "ValidationError",
"type": "object",
@ -42,7 +40,7 @@ validation_error_response_definition = {
}
def get_openapi_params(dependant: Dependant):
def get_openapi_params(dependant: Dependant) -> List[Field]:
flat_dependant = get_flat_dependant(dependant)
return (
flat_dependant.path_params
@ -52,7 +50,7 @@ def get_openapi_params(dependant: Dependant):
)
def get_openapi_security_definitions(flat_dependant: Dependant):
def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
security_definitions = {}
operation_security = []
for security_requirement in flat_dependant.security_requirements:
@ -61,59 +59,60 @@ def get_openapi_security_definitions(flat_dependant: Dependant):
by_alias=True,
include_none=False,
)
security_name = (
security_requirement.security_scheme.scheme_name
)
security_name = security_requirement.security_scheme.scheme_name
security_definitions[security_name] = security_definition
operation_security.append({security_name: security_requirement.scopes})
return security_definitions, operation_security
def get_openapi_operation_parameters(all_route_params: List[Field]):
def get_openapi_operation_parameters(
all_route_params: Sequence[Field]
) -> Tuple[Dict[str, Dict], List[Dict[str, Any]]]:
definitions: Dict[str, Dict] = {}
parameters = []
for param in all_route_params:
schema: Param = param.schema
if "ValidationError" not in definitions:
definitions["ValidationError"] = validation_error_definition
definitions["HTTPValidationError"] = validation_error_response_definition
parameter = {
"name": param.alias,
"in": param.schema.in_.value,
"in": schema.in_.value,
"required": param.required,
"schema": field_schema(param, model_name_map={})[0],
}
if param.schema.description:
parameter["description"] = param.schema.description
if param.schema.deprecated:
parameter["deprecated"] = param.schema.deprecated
if schema.description:
parameter["description"] = schema.description
if schema.deprecated:
parameter["deprecated"] = schema.deprecated
parameters.append(parameter)
return definitions, parameters
def get_openapi_operation_request_body(
*, body_field: Field, model_name_map: Dict[Type, str]
):
) -> Optional[Dict]:
if not body_field:
return None
assert isinstance(body_field, Field)
body_schema, _ = field_schema(
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
if isinstance(body_field.schema, Body):
request_media_type = body_field.schema.media_type
schema: Schema = body_field.schema
if isinstance(schema, Body):
request_media_type = schema.media_type
else:
# Includes not declared media types (Schema)
request_media_type = "application/json"
required = body_field.required
request_body_oai = {}
request_body_oai: Dict[str, Any] = {}
if required:
request_body_oai["required"] = required
request_body_oai["content"] = {request_media_type: {"schema": body_schema}}
return request_body_oai
def generate_operation_id(*, route: routing.APIRoute, method: str):
def generate_operation_id(*, route: routing.APIRoute, method: str) -> str:
if route.operation_id:
return route.operation_id
path: str = route.path
@ -123,12 +122,13 @@ def generate_operation_id(*, route: routing.APIRoute, method: str):
return operation_id
def generate_operation_summary(*, route: routing.APIRoute, method: str):
def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
if route.summary:
return route.summary
return method.title() + " " + route.name.replace("_", " ").title()
def get_openapi_operation_metadata(*, route: BaseRoute, method: str):
def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> Dict:
operation: Dict[str, Any] = {}
if route.tags:
operation["tags"] = route.tags
@ -141,12 +141,13 @@ def get_openapi_operation_metadata(*, route: BaseRoute, method: str):
return operation
def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
return None
def get_openapi_path(
*, route: routing.APIRoute, model_name_map: Dict[Type, str]
) -> Tuple[Dict, Dict, Dict]:
path = {}
security_schemes: Dict[str, Any] = {}
definitions: Dict[str, Any] = {}
assert route.methods is not None, "Methods must be a list"
for method in route.methods:
operation = get_openapi_operation_metadata(route=route, method=method)
parameters: List[Dict] = []
@ -172,10 +173,9 @@ def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
)
if request_body_oai:
operation["requestBody"] = request_body_oai
response_code = str(route.response_code)
status_code = str(route.status_code)
response_schema = {"type": "string"}
if lenient_issubclass(route.response_wrapper, JSONResponse):
response_media_type = "application/json"
if lenient_issubclass(route.content_type, JSONResponse):
if route.response_field:
response_schema, _ = field_schema(
route.response_field,
@ -184,16 +184,9 @@ def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
)
else:
response_schema = {}
elif lenient_issubclass(route.response_wrapper, HTMLResponse):
response_media_type = "text/html"
else:
response_media_type = "text/plain"
content = {response_media_type: {"schema": response_schema}}
content = {route.content_type.media_type: {"schema": response_schema}}
operation["responses"] = {
response_code: {
"description": route.response_description,
"content": content,
}
status_code: {"description": route.response_description, "content": content}
}
if all_route_params or route.body_field:
operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
@ -215,7 +208,7 @@ def get_openapi(
openapi_version: str = "3.0.2",
description: str = None,
routes: Sequence[BaseRoute]
):
) -> Dict:
info = {"title": title, "version": version}
if description:
info["description"] = description
@ -228,15 +221,18 @@ def get_openapi(
flat_models=flat_models, model_name_map=model_name_map
)
for route in routes:
result = get_openapi_path(route=route, model_name_map=model_name_map)
if result:
path, security_schemes, path_definitions = result
if path:
paths.setdefault(route.path, {}).update(path)
if security_schemes:
components.setdefault("securitySchemes", {}).update(security_schemes)
if path_definitions:
definitions.update(path_definitions)
if isinstance(route, routing.APIRoute):
result = get_openapi_path(route=route, model_name_map=model_name_map)
if result:
path, security_schemes, path_definitions = result
if path:
paths.setdefault(route.path, {}).update(path)
if security_schemes:
components.setdefault("securitySchemes", {}).update(
security_schemes
)
if path_definitions:
definitions.update(path_definitions)
if definitions:
components.setdefault("schemas", {}).update(definitions)
if components:

46
fastapi/params.py

@ -1,5 +1,5 @@
from enum import Enum
from typing import Sequence, Any, Dict
from typing import Any, Callable, Sequence
from pydantic import Schema
@ -16,7 +16,7 @@ class Param(Schema):
def __init__(
self,
default,
default: Any,
*,
deprecated: bool = None,
alias: str = None,
@ -29,7 +29,7 @@ class Param(Schema):
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: Dict[str, Any],
**extra: Any,
):
self.deprecated = deprecated
super().__init__(
@ -53,7 +53,7 @@ class Path(Param):
def __init__(
self,
default,
default: Any,
*,
deprecated: bool = None,
alias: str = None,
@ -66,7 +66,7 @@ class Path(Param):
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: Dict[str, Any],
**extra: Any,
):
self.description = description
self.deprecated = deprecated
@ -92,7 +92,7 @@ class Query(Param):
def __init__(
self,
default,
default: Any,
*,
deprecated: bool = None,
alias: str = None,
@ -105,7 +105,7 @@ class Query(Param):
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: Dict[str, Any],
**extra: Any,
):
self.description = description
self.deprecated = deprecated
@ -130,10 +130,11 @@ class Header(Param):
def __init__(
self,
default,
default: Any,
*,
deprecated: bool = None,
alias: str = None,
alias_underscore_to_hyphen: bool = True,
title: str = None,
description: str = None,
gt: float = None,
@ -143,10 +144,11 @@ class Header(Param):
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: Dict[str, Any],
**extra: Any,
):
self.description = description
self.deprecated = deprecated
self.alias_underscore_to_hyphen = alias_underscore_to_hyphen
super().__init__(
default,
alias=alias,
@ -168,7 +170,7 @@ class Cookie(Param):
def __init__(
self,
default,
default: Any,
*,
deprecated: bool = None,
alias: str = None,
@ -181,7 +183,7 @@ class Cookie(Param):
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: Dict[str, Any],
**extra: Any,
):
self.description = description
self.deprecated = deprecated
@ -204,9 +206,9 @@ class Cookie(Param):
class Body(Schema):
def __init__(
self,
default,
default: Any,
*,
embed=False,
embed: bool = False,
media_type: str = "application/json",
alias: str = None,
title: str = None,
@ -218,7 +220,7 @@ class Body(Schema):
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: Dict[str, Any],
**extra: Any,
):
self.embed = embed
self.media_type = media_type
@ -241,9 +243,9 @@ class Body(Schema):
class Form(Body):
def __init__(
self,
default,
default: Any,
*,
sub_key=False,
sub_key: bool = False,
media_type: str = "application/x-www-form-urlencoded",
alias: str = None,
title: str = None,
@ -255,7 +257,7 @@ class Form(Body):
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: Dict[str, Any],
**extra: Any,
):
super().__init__(
default,
@ -278,9 +280,9 @@ class Form(Body):
class File(Form):
def __init__(
self,
default,
default: Any,
*,
sub_key=False,
sub_key: bool = False,
media_type: str = "multipart/form-data",
alias: str = None,
title: str = None,
@ -292,7 +294,7 @@ class File(Form):
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: Dict[str, Any],
**extra: Any,
):
super().__init__(
default,
@ -313,11 +315,11 @@ class File(Form):
class Depends:
def __init__(self, dependency=None):
def __init__(self, dependency: Callable = None):
self.dependency = dependency
class Security(Depends):
def __init__(self, dependency=None, scopes: Sequence[str] = None):
def __init__(self, dependency: Callable = None, scopes: Sequence[str] = None):
self.scopes = scopes or []
super().__init__(dependency=dependency)

437
fastapi/routing.py

@ -1,12 +1,11 @@
import asyncio
import inspect
from typing import Callable, List, Type
from typing import Any, Callable, List, Optional, Type
from pydantic import BaseConfig, BaseModel, Schema
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.fields import Field
from pydantic.utils import lenient_issubclass
from starlette import routing
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
@ -22,7 +21,7 @@ from fastapi.dependencies.utils import get_body_field, get_dependant, solve_depe
from fastapi.encoders import jsonable_encoder
def serialize_response(*, field: Field = None, response):
def serialize_response(*, field: Field = None, response: Response) -> Any:
if field:
errors = []
value, errors_ = field.validate(response, {}, loc=("response",))
@ -40,11 +39,12 @@ def serialize_response(*, field: Field = None, response):
def get_app(
dependant: Dependant,
body_field: Field = None,
response_code: str = 200,
response_wrapper: Type[Response] = JSONResponse,
response_field: Type[Field] = None,
):
is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
status_code: int = 200,
content_type: Type[Response] = JSONResponse,
response_field: Field = None,
) -> Callable:
assert dependant.call is not None, "dependant.call must me a function"
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(body_field.schema, params.Form)
async def app(request: Request) -> Response:
@ -69,6 +69,7 @@ def get_app(
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
)
else:
assert dependant.call is not None, "dependant.call must me a function"
if is_coroutine:
raw_response = await dependant.call(**values)
else:
@ -76,32 +77,32 @@ def get_app(
if isinstance(raw_response, Response):
return raw_response
if isinstance(raw_response, BaseModel):
return response_wrapper(
content=jsonable_encoder(raw_response), status_code=response_code
return content_type(
content=jsonable_encoder(raw_response), status_code=status_code
)
errors = []
try:
return response_wrapper(
return content_type(
content=serialize_response(
field=response_field, response=raw_response
),
status_code=response_code,
status_code=status_code,
)
except Exception as e:
errors.append(e)
try:
response = dict(raw_response)
return response_wrapper(
return content_type(
content=serialize_response(field=response_field, response=response),
status_code=response_code,
status_code=status_code,
)
except Exception as e:
errors.append(e)
try:
response = vars(raw_response)
return response_wrapper(
return content_type(
content=serialize_response(field=response_field, response=response),
status_code=response_code,
status_code=status_code,
)
except Exception as e:
errors.append(e)
@ -116,43 +117,32 @@ class APIRoute(routing.Route):
path: str,
endpoint: Callable,
*,
methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
deprecated: bool = None,
name: str = None,
methods: List[str] = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> None:
assert path.startswith("/"), "Routed paths must always start with '/'"
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
self.include_in_schema = include_in_schema
self.tags = tags or []
self.summary = summary
self.description = description or self.endpoint.__doc__
self.operation_id = operation_id
self.deprecated = deprecated
self.body_field: Field = None
self.response_description = response_description
self.response_code = response_code
self.response_wrapper = response_wrapper
self.response_field = None
if response_type:
self.response_model = response_model
if self.response_model:
assert lenient_issubclass(
response_wrapper, JSONResponse
content_type, JSONResponse
), "To declare a type the response must be a JSON response"
self.response_type = response_type
response_name = "Response_" + self.name
self.response_field = Field(
self.response_field: Optional[Field] = Field(
name=response_name,
type_=self.response_type,
type_=self.response_model,
class_validators=[],
default=None,
required=False,
@ -160,25 +150,34 @@ class APIRoute(routing.Route):
schema=Schema(None),
)
else:
self.response_type = None
self.response_field = None
self.status_code = status_code
self.tags = tags or []
self.summary = summary
self.description = description or self.endpoint.__doc__
self.response_description = response_description
self.deprecated = deprecated
if methods is None:
methods = ["GET"]
self.methods = methods
self.operation_id = operation_id
self.include_in_schema = include_in_schema
self.content_type = content_type
self.path_regex, self.path_format, self.param_convertors = self.compile_path(
path
)
assert inspect.isfunction(endpoint) or inspect.ismethod(
endpoint
), f"An endpoint must be a function or method"
self.dependant = get_dependant(path=path, call=self.endpoint)
self.body_field = get_body_field(dependant=self.dependant, name=self.name)
self.app = request_response(
get_app(
dependant=self.dependant,
body_field=self.body_field,
response_code=self.response_code,
response_wrapper=self.response_wrapper,
status_code=self.status_code,
content_type=self.content_type,
response_field=self.response_field,
)
)
@ -189,75 +188,77 @@ class APIRouter(routing.Router):
self,
path: str,
endpoint: Callable,
methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
deprecated: bool = None,
name: str = None,
methods: List[str] = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> None:
route = APIRoute(
path,
endpoint=endpoint,
methods=methods,
name=name,
include_in_schema=include_in_schema,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
methods=methods,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
self.routes.append(route)
def api_route(
self,
path: str,
methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
deprecated: bool = None,
name: str = None,
methods: List[str] = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
def decorator(func: Callable) -> Callable:
self.add_api_route(
path,
func,
methods=methods,
name=name,
include_in_schema=include_in_schema,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
methods=methods,
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
return func
return decorator
def include_router(self, router: "APIRouter", *, prefix=""):
def include_router(self, router: "APIRouter", *, prefix: str = "") -> None:
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith(
@ -268,18 +269,18 @@ class APIRouter(routing.Router):
self.add_api_route(
prefix + route.path,
route.endpoint,
methods=route.methods,
name=route.name,
include_in_schema=route.include_in_schema,
tags=route.tags,
response_model=route.response_model,
status_code=route.status_code,
tags=route.tags or [],
summary=route.summary,
description=route.description,
operation_id=route.operation_id,
deprecated=route.deprecated,
response_type=route.response_type,
response_description=route.response_description,
response_code=route.response_code,
response_wrapper=route.response_wrapper,
deprecated=route.deprecated,
name=route.name,
methods=route.methods,
operation_id=route.operation_id,
include_in_schema=route.include_in_schema,
content_type=route.content_type,
)
elif isinstance(route, routing.Route):
self.add_route(
@ -293,247 +294,255 @@ class APIRouter(routing.Router):
def get(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route(
path=path,
methods=["GET"],
name=name,
include_in_schema=include_in_schema,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
methods=["GET"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
def put(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route(
path=path,
methods=["PUT"],
name=name,
include_in_schema=include_in_schema,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
methods=["PUT"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
def post(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route(
path=path,
methods=["POST"],
name=name,
include_in_schema=include_in_schema,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
methods=["POST"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
def delete(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route(
path=path,
methods=["DELETE"],
name=name,
include_in_schema=include_in_schema,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
methods=["DELETE"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
def options(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route(
path=path,
methods=["OPTIONS"],
name=name,
include_in_schema=include_in_schema,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
methods=["OPTIONS"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
def head(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route(
path=path,
methods=["HEAD"],
name=name,
include_in_schema=include_in_schema,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
methods=["HEAD"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
def patch(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route(
path=path,
methods=["PATCH"],
name=name,
include_in_schema=include_in_schema,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
methods=["PATCH"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)
def trace(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
*,
response_model: Type[BaseModel] = None,
status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
deprecated: bool = None,
name: str = None,
operation_id: str = None,
include_in_schema: bool = True,
content_type: Type[Response] = JSONResponse,
) -> Callable:
return self.api_route(
path=path,
methods=["TRACE"],
name=name,
include_in_schema=include_in_schema,
response_model=response_model,
status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
deprecated=deprecated,
name=name,
methods=["TRACE"],
operation_id=operation_id,
include_in_schema=include_in_schema,
content_type=content_type,
)

13
fastapi/security/api_key.py

@ -1,18 +1,19 @@
from starlette.requests import Request
from .base import SecurityBase
from fastapi.openapi.models import APIKeyIn, APIKey
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase
class APIKeyBase(SecurityBase):
pass
class APIKeyQuery(APIKeyBase):
class APIKeyQuery(APIKeyBase):
def __init__(self, *, name: str, scheme_name: str = None):
self.model = APIKey(in_=APIKeyIn.query, name=name)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request):
async def __call__(self, requests: Request) -> str:
return requests.query_params.get(self.model.name)
@ -21,7 +22,7 @@ class APIKeyHeader(APIKeyBase):
self.model = APIKey(in_=APIKeyIn.header, name=name)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request):
async def __call__(self, requests: Request) -> str:
return requests.headers.get(self.model.name)
@ -30,5 +31,5 @@ class APIKeyCookie(APIKeyBase):
self.model = APIKey(in_=APIKeyIn.cookie, name=name)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request):
async def __call__(self, requests: Request) -> str:
return requests.cookies.get(self.model.name)

6
fastapi/security/base.py

@ -1,6 +1,6 @@
from pydantic import BaseModel
from fastapi.openapi.models import SecurityBase as SecurityBaseModel
class SecurityBase:
pass
model: SecurityBaseModel
scheme_name: str

21
fastapi/security/http.py

@ -1,7 +1,10 @@
from starlette.requests import Request
from .base import SecurityBase
from fastapi.openapi.models import HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel
from fastapi.openapi.models import (
HTTPBase as HTTPBaseModel,
HTTPBearer as HTTPBearerModel,
)
from fastapi.security.base import SecurityBase
class HTTPBase(SecurityBase):
@ -9,7 +12,7 @@ class HTTPBase(SecurityBase):
self.model = HTTPBaseModel(scheme=scheme)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization")
@ -17,8 +20,8 @@ class HTTPBasic(HTTPBase):
def __init__(self, *, scheme_name: str = None):
self.model = HTTPBaseModel(scheme="basic")
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization")
@ -26,8 +29,8 @@ class HTTPBearer(HTTPBase):
def __init__(self, *, bearerFormat: str = None, scheme_name: str = None):
self.model = HTTPBearerModel(bearerFormat=bearerFormat)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization")
@ -35,6 +38,6 @@ class HTTPDigest(HTTPBase):
def __init__(self, *, scheme_name: str = None):
self.model = HTTPBaseModel(scheme="digest")
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization")

10
fastapi/security/oauth2.py

@ -1,13 +1,15 @@
from starlette.requests import Request
from .base import SecurityBase
from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel
from fastapi.security.base import SecurityBase
class OAuth2(SecurityBase):
def __init__(self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None):
def __init__(
self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None
):
self.model = OAuth2Model(flows=flows)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization")

6
fastapi/security/open_id_connect_url.py

@ -1,13 +1,13 @@
from starlette.requests import Request
from .base import SecurityBase
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase
class OpenIdConnect(SecurityBase):
def __init__(self, *, openIdConnectUrl: str, scheme_name: str = None):
self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization")

24
fastapi/utils.py

@ -1,20 +1,24 @@
import re
from typing import Dict, Sequence, Set, Type
from typing import Any, Dict, List, Sequence, Set, Type
from pydantic import BaseModel
from pydantic.fields import Field
from pydantic.schema import get_flat_models_from_fields, model_process_schema
from starlette.routing import BaseRoute
from fastapi import routing
from fastapi.openapi.constants import REF_PREFIX
from pydantic import BaseModel
from pydantic.fields import Field
from pydantic.schema import get_flat_models_from_fields, model_process_schema
def get_flat_models_from_routes(routes: Sequence[BaseRoute]):
body_fields_from_routes = []
responses_from_routes = []
def get_flat_models_from_routes(
routes: Sequence[Type[BaseRoute]]
) -> Set[Type[BaseModel]]:
body_fields_from_routes: List[Field] = []
responses_from_routes: List[Field] = []
for route in routes:
if route.include_in_schema and isinstance(route, routing.APIRoute):
if getattr(route, "include_in_schema", None) and isinstance(
route, routing.APIRoute
):
if route.body_field:
assert isinstance(
route.body_field, Field
@ -30,7 +34,7 @@ def get_flat_models_from_routes(routes: Sequence[BaseRoute]):
def get_model_definitions(
*, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
):
) -> Dict[str, Any]:
definitions: Dict[str, Dict] = {}
for model in flat_models:
m_schema, m_definitions = model_process_schema(
@ -42,5 +46,5 @@ def get_model_definitions(
return definitions
def get_path_param_names(path: str):
def get_path_param_names(path: str) -> Set[str]:
return {item.strip("{}") for item in re.findall("{[^}]*}", path)}

Loading…
Cancel
Save