You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1673 lines
70 KiB

import asyncio
import dataclasses
import email.message
import inspect
import json
from enum import Enum, IntEnum
from typing import (
Any,
Callable,
Coroutine,
Dict,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
)
from fastapi import params
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import (
get_body_field,
get_dependant,
get_parameterless_sub_dependant,
solve_dependencies,
)
from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY
from fastapi.types import DecoratedCallable
from fastapi.utils import (
create_cloned_field,
create_response_field,
generate_unique_id,
get_value_or_default,
)
from pydantic import BaseModel
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.fields import ModelField, Undefined
from starlette import routing
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import BaseRoute, Match
from starlette.routing import Mount as Mount # noqa
from starlette.routing import (
compile_path,
get_name,
request_response,
websocket_session,
)
from starlette.status import WS_1008_POLICY_VIOLATION
from starlette.types import ASGIApp, Scope
from starlette.websockets import WebSocket
APIRouteType = TypeVar("APIRouteType", bound="APIRoute")
APIRouterType = TypeVar("APIRouterType", bound="APIRouter")
APIMountType = TypeVar("APIMountType", bound="APIMount")
def _prepare_response_content(
res: Any,
*,
exclude_unset: bool,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Any:
if isinstance(res, BaseModel):
read_with_orm_mode = getattr(res.__config__, "read_with_orm_mode", None)
if read_with_orm_mode:
# Let from_orm extract the data from this model instead of converting
# it now to a dict.
# Otherwise there's no way to extract lazy data that requires attribute
# access instead of dict iteration, e.g. lazy relationships.
return res
return res.dict(
by_alias=True,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
elif isinstance(res, list):
return [
_prepare_response_content(
item,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
for item in res
]
elif isinstance(res, dict):
return {
k: _prepare_response_content(
v,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
for k, v in res.items()
}
elif dataclasses.is_dataclass(res):
return dataclasses.asdict(res)
return res
async def serialize_response(
*,
field: Optional[ModelField] = None,
response_content: Any,
include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
is_coroutine: bool = True,
) -> Any:
if field:
errors = []
response_content = _prepare_response_content(
response_content,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
if is_coroutine:
value, errors_ = field.validate(response_content, {}, loc=("response",))
else:
value, errors_ = await run_in_threadpool(
field.validate, response_content, {}, loc=("response",)
)
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
elif isinstance(errors_, list):
errors.extend(errors_)
if errors:
raise ValidationError(errors, field.type_)
return jsonable_encoder(
value,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
else:
return jsonable_encoder(response_content)
async def run_endpoint_function(
*, dependant: Dependant, values: Dict[str, Any], is_coroutine: bool
) -> Any:
# Only called by get_request_handler. Has been split into its own function to
# facilitate profiling endpoints, since inner functions are harder to profile.
assert dependant.call is not None, "dependant.call must be a function"
if is_coroutine:
return await dependant.call(**values)
else:
return await run_in_threadpool(dependant.call, **values)
def get_request_handler(
dependant: Dependant,
body_field: Optional[ModelField] = None,
status_code: Optional[int] = None,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse),
response_field: Optional[ModelField] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
dependency_overrides_provider: Optional[Any] = None,
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(body_field.field_info, params.Form)
if isinstance(response_class, DefaultPlaceholder):
actual_response_class: Type[Response] = response_class.value
else:
actual_response_class = response_class
async def app(request: Request) -> Response:
try:
body: Any = None
if body_field:
if is_body_form:
body = await request.form()
else:
body_bytes = await request.body()
if body_bytes:
json_body: Any = Undefined
content_type_value = request.headers.get("content-type")
if not content_type_value:
json_body = await request.json()
else:
message = email.message.Message()
message["content-type"] = content_type_value
if message.get_content_maintype() == "application":
subtype = message.get_content_subtype()
if subtype == "json" or subtype.endswith("+json"):
json_body = await request.json()
if json_body != Undefined:
body = json_body
else:
body = body_bytes
except json.JSONDecodeError as e:
raise RequestValidationError([ErrorWrapper(e, ("body", e.pos))], body=e.doc)
except Exception as e:
raise HTTPException(
status_code=400, detail="There was an error parsing the body"
) from e
solved_result = await solve_dependencies(
request=request,
dependant=dependant,
body=body,
dependency_overrides_provider=dependency_overrides_provider,
)
values, errors, background_tasks, sub_response, _ = solved_result
if errors:
raise RequestValidationError(errors, body=body)
else:
raw_response = await run_endpoint_function(
dependant=dependant, values=values, is_coroutine=is_coroutine
)
if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = background_tasks
return raw_response
response_data = await serialize_response(
field=response_field,
response_content=raw_response,
include=response_model_include,
exclude=response_model_exclude,
by_alias=response_model_by_alias,
exclude_unset=response_model_exclude_unset,
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine,
)
response_args: Dict[str, Any] = {"background": background_tasks}
# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
if status_code is not None:
response_args["status_code"] = status_code
response = actual_response_class(response_data, **response_args)
response.headers.raw.extend(sub_response.headers.raw)
if sub_response.status_code:
response.status_code = sub_response.status_code
return response
return app
def get_websocket_app(
dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
async def app(websocket: WebSocket) -> None:
solved_result = await solve_dependencies(
request=websocket,
dependant=dependant,
dependency_overrides_provider=dependency_overrides_provider,
)
values, errors, _, _2, _3 = solved_result
if errors:
await websocket.close(code=WS_1008_POLICY_VIOLATION)
raise WebSocketRequestValidationError(errors)
assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**values)
return app
class APIWebSocketRoute(routing.WebSocketRoute):
def __init__(
self,
path: str,
endpoint: Callable[..., Any],
*,
name: Optional[str] = None,
dependency_overrides_provider: Optional[Any] = None,
) -> None:
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
self.dependant = get_dependant(path=path, call=self.endpoint)
self.app = websocket_session(
get_websocket_app(
dependant=self.dependant,
dependency_overrides_provider=dependency_overrides_provider,
)
)
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
match, child_scope = super().matches(scope)
if match != Match.NONE:
child_scope["route"] = self
return match, child_scope
class APIRoute(routing.Route):
_route_full_path_format: str # only for mypy
def __init__(
self,
path: str,
endpoint: Callable[..., Any],
*,
response_model: Optional[Type[Any]] = None,
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
name: Optional[str] = None,
methods: Optional[Union[Set[str], List[str]]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(
JSONResponse
),
dependency_overrides_provider: Optional[Any] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Union[
Callable[["APIRoute"], str], DefaultPlaceholder
] = Default(generate_unique_id),
router: Optional["APIRouter"] = None,
) -> None:
self.path = path
self.endpoint = endpoint
self.response_model = response_model
self.summary = summary
self.response_description = response_description
self.operation_id = operation_id
self.response_model_include = response_model_include
self.response_model_exclude = response_model_exclude
self.response_model_by_alias = response_model_by_alias
self.response_model_exclude_unset = response_model_exclude_unset
self.response_model_exclude_defaults = response_model_exclude_defaults
self.response_model_exclude_none = response_model_exclude_none
self.dependency_overrides_provider = dependency_overrides_provider
self.openapi_extra = openapi_extra
self.router = router
self.name = get_name(endpoint) if name is None else name
# normalize enums e.g. http.HTTPStatus
if isinstance(status_code, IntEnum):
status_code = int(status_code)
self.status_code = status_code
if methods is None:
methods = ["GET"]
self.methods: Set[str] = set([method.upper() for method in methods])
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text,
# truncate description text to the content preceding the first "form feed"
self.description = self.description.split("\f")[0]
assert callable(endpoint), "An endpoint must be a callable"
self.path_regex, self.path_format, self.param_convertors = compile_path(
self.path
)
# Attributes set in route used to compute resolved attributes
self._route_deprecated = deprecated
self._route_include_in_schema = include_in_schema
self._route_response_class = response_class
self._route_callbacks = callbacks
self._route_generate_unique_id_function = generate_unique_id_function
self._route_tags = tags or []
self._route_responses = responses or {}
if dependencies:
self._route_dependencies = dependencies
else:
self._route_dependencies = []
self.setup()
def setup(self) -> None:
# setup full path
self._route_full_path = self.path
if self.router:
self._route_full_path = self.router._router_full_path + self.path
# setup dependencies
self.dependencies: List[params.Depends] = []
if self.router:
self.dependencies.extend(self.router.dependencies)
self.dependencies.extend(self._route_dependencies)
# setup generate_unique_id
generate_unique_id_functions: List[
Union[Callable[[APIRoute], str], DefaultPlaceholder]
] = [self._route_generate_unique_id_function]
if self.router:
generate_unique_id_functions.append(self.router.generate_unique_id_function)
current_generate_unique_id_function = get_value_or_default(
*generate_unique_id_functions
)
self.generate_unique_id_function: Union[
Callable[[APIRoute], str], DefaultPlaceholder
] = current_generate_unique_id_function
# setup responses
responses: Dict[Union[int, str], Dict[str, Any]] = {}
if self.router:
responses.update(self.router.responses)
responses.update(self._route_responses)
self.responses: Dict[Union[int, str], Dict[str, Any]] = responses
# setup default_response_class
default_response_classes: List[Union[Type[Response], DefaultPlaceholder]] = [
self._route_response_class
]
if self.router:
default_response_classes.append(self.router.default_response_class)
current_default_response_class = get_value_or_default(*default_response_classes)
self.response_class: Union[
Type[Response], DefaultPlaceholder
] = current_default_response_class
# setup tags
self.tags: List[Union[str, Enum]] = []
if self.router:
self.tags.extend(self.router.tags)
self.tags.extend(self._route_tags)
# setup callbacks
callbacks: List[BaseRoute] = []
if self.router:
callbacks.extend(self.router.callbacks)
if self._route_callbacks:
callbacks.extend(self._route_callbacks)
self.callbacks = callbacks
# setup deprecated
self.deprecated = self._route_deprecated
if self.router:
self.deprecated = self._route_deprecated or self.router.deprecated
# setup include_in_schema
self.include_in_schema = self._route_include_in_schema
if self.router:
self.include_in_schema = (
self._route_include_in_schema and self.router.include_in_schema
)
_, self._route_full_path_format, _ = compile_path(self._route_full_path)
if isinstance(self.generate_unique_id_function, DefaultPlaceholder):
resolved_generate_unique_id: Callable[
["APIRoute"], str
] = self.generate_unique_id_function.value
else:
resolved_generate_unique_id = self.generate_unique_id_function
self.unique_id = self.operation_id or resolved_generate_unique_id(self)
if self.response_model:
assert (
self.status_code not in STATUS_CODES_WITH_NO_BODY
), f"Status code {self.status_code} must not have a response body"
response_name = "Response_" + self.unique_id
self.response_field = create_response_field(
name=response_name, type_=self.response_model
)
# Create a clone of the field, so that a Pydantic submodel is not returned
# as is just because it's an instance of a subclass of a more limited class
# e.g. UserInDB (containing hashed_password) could be a subclass of User
# that doesn't have the hashed_password. But because it's a subclass, it
# would pass the validation and be returned as is.
# By being a new field, no inheritance will be passed as is. A new model
# will be always created.
self.secure_cloned_response_field: Optional[
ModelField
] = create_cloned_field(self.response_field)
else:
self.response_field = None # type: ignore
self.secure_cloned_response_field = None
response_fields = {}
for additional_status_code, response in self.responses.items():
assert isinstance(response, dict), "An additional response must be a dict"
model = response.get("model")
if model:
assert (
additional_status_code not in STATUS_CODES_WITH_NO_BODY
), f"Status code {additional_status_code} must not have a response body"
response_name = f"Response_{additional_status_code}_{self.unique_id}"
response_field = create_response_field(name=response_name, type_=model)
response_fields[additional_status_code] = response_field
if response_fields:
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
else:
self.response_fields = {}
self.dependant = get_dependant(
path=self._route_full_path_format, call=self.endpoint
)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(
depends=depends, path=self._route_full_path_format
),
)
self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
self.app = request_response(self.get_route_handler())
def copy(self: APIRouteType) -> APIRouteType:
return type(self)(
path=self.path,
endpoint=self.endpoint,
response_model=self.response_model,
status_code=self.status_code,
tags=self._route_tags,
dependencies=self._route_dependencies,
summary=self.summary,
description=self.description,
response_description=self.response_description,
responses=self._route_responses,
deprecated=self._route_deprecated,
name=self.name,
methods=self.methods,
operation_id=self.operation_id,
response_model_include=self.response_model_include,
response_model_exclude=self.response_model_exclude,
response_model_by_alias=self.response_model_by_alias,
response_model_exclude_unset=self.response_model_exclude_unset,
response_model_exclude_defaults=self.response_model_exclude_defaults,
response_model_exclude_none=self.response_model_exclude_none,
include_in_schema=self._route_include_in_schema,
response_class=self._route_response_class,
dependency_overrides_provider=self.dependency_overrides_provider,
callbacks=self._route_callbacks,
openapi_extra=self.openapi_extra,
generate_unique_id_function=self._route_generate_unique_id_function,
router=self.router,
)
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
return get_request_handler(
dependant=self.dependant,
body_field=self.body_field,
status_code=self.status_code,
response_class=self.response_class,
response_field=self.secure_cloned_response_field,
response_model_include=self.response_model_include,
response_model_exclude=self.response_model_exclude,
response_model_by_alias=self.response_model_by_alias,
response_model_exclude_unset=self.response_model_exclude_unset,
response_model_exclude_defaults=self.response_model_exclude_defaults,
response_model_exclude_none=self.response_model_exclude_none,
dependency_overrides_provider=self.dependency_overrides_provider,
)
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
match, child_scope = super().matches(scope)
if match != Match.NONE:
child_scope["route"] = self
return match, child_scope
class APIRouter(routing.Router):
def __init__(
self,
*,
prefix: str = "",
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
default_response_class: Type[Response] = Default(JSONResponse),
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
callbacks: Optional[List[BaseRoute]] = None,
routes: Optional[List[routing.BaseRoute]] = None,
redirect_slashes: bool = True,
default: Optional[ASGIApp] = None,
dependency_overrides_provider: Optional[Any] = None,
route_class: Type[APIRoute] = APIRoute,
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
parent_router: Optional["APIRouter"] = None,
) -> None:
super().__init__(
routes=routes, # type: ignore # in Starlette
redirect_slashes=redirect_slashes,
default=default, # type: ignore # in Starlette
on_startup=on_startup, # type: ignore # in Starlette
on_shutdown=on_shutdown, # type: ignore # in Starlette
)
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith(
"/"
), "A path prefix must not end with '/', as the routes will start with '/'"
self.prefix = prefix
self.dependency_overrides_provider = dependency_overrides_provider
self.route_class = route_class
self.parent_router = parent_router
# Attributes set in router used to compute resolved attributes
self._router_dependencies = list(dependencies or []) or []
self._router_generate_unique_id_function = generate_unique_id_function
self._router_responses = responses or {}
self._router_default_response_class = default_response_class
self._router_tags: List[Union[str, Enum]] = tags or []
self._router_callbacks = callbacks or []
self._router_deprecated = deprecated
self._router_include_in_schema = include_in_schema
self._router_has_empty_route = False
self._router_has_root_route = False
self.setup()
def setup(self) -> None:
# setup full path
self._router_full_path = self.prefix
if self.parent_router:
self._router_full_path = self.parent_router._router_full_path + self.prefix
# setup dependencies
self.dependencies: List[params.Depends] = []
if self.parent_router:
self.dependencies.extend(self.parent_router.dependencies)
self.dependencies.extend(self._router_dependencies)
# setup generate_unique_id
generate_unique_id_functions: List[
Union[Callable[[APIRoute], str], DefaultPlaceholder]
] = [self._router_generate_unique_id_function]
if self.parent_router:
generate_unique_id_functions.append(
self.parent_router.generate_unique_id_function
)
current_generate_unique_id_function = get_value_or_default(
*generate_unique_id_functions
)
self.generate_unique_id_function: Union[
Callable[[APIRoute], str], DefaultPlaceholder
] = current_generate_unique_id_function
# setup responses
responses: Dict[Union[int, str], Dict[str, Any]] = {}
if self.parent_router:
responses.update(self.parent_router.responses)
responses.update(self._router_responses)
self.responses: Dict[Union[int, str], Dict[str, Any]] = responses
# setup default_response_class
default_response_classes: List[Union[Type[Response], DefaultPlaceholder]] = [
self._router_default_response_class
]
if self.parent_router:
default_response_classes.append(self.parent_router.default_response_class)
current_default_response_class = get_value_or_default(*default_response_classes)
self.default_response_class: Union[
Type[Response], DefaultPlaceholder
] = current_default_response_class
# setup tags
self.tags: List[Union[str, Enum]] = []
if self.parent_router:
self.tags.extend(self.parent_router.tags)
self.tags.extend(self._router_tags)
# setup callbacks
self.callbacks: List[BaseRoute] = []
if self.parent_router:
self.callbacks.extend(self.parent_router.callbacks)
self.callbacks.extend(self._router_callbacks)
# setup deprecated
self.deprecated = self._router_deprecated
if self.parent_router:
self.deprecated = self._router_deprecated or self.parent_router.deprecated
# setup include_in_schema
self.include_in_schema = self._router_include_in_schema
if self.parent_router:
self.include_in_schema = (
self._router_include_in_schema and self.parent_router.include_in_schema
)
# setup routes
for route in self.routes:
if isinstance(route, APIRoute):
route.router = self
route.setup()
elif isinstance(route, APIMount):
route.parent_router = self
route.setup()
def copy(self: APIRouterType) -> APIRouterType:
routes: List[routing.BaseRoute] = []
for route in self.routes:
if isinstance(route, APIRoute):
routes.append(route.copy())
elif isinstance(route, APIMount):
routes.append(route.copy())
else:
routes.append(route)
copied_router = type(self)(
prefix=self.prefix,
tags=self._router_tags,
dependencies=self._router_dependencies,
default_response_class=self._router_default_response_class,
responses=self._router_responses,
callbacks=self._router_callbacks,
routes=routes,
redirect_slashes=self.redirect_slashes,
default=self.default,
dependency_overrides_provider=self.dependency_overrides_provider,
route_class=self.route_class,
on_startup=self.on_startup,
on_shutdown=self.on_shutdown,
deprecated=self._router_deprecated,
include_in_schema=self._router_include_in_schema,
generate_unique_id_function=self._router_generate_unique_id_function,
parent_router=self.parent_router,
)
copied_router._router_has_empty_route = self._router_has_empty_route
copied_router._router_has_root_route = self._router_has_root_route
for route in copied_router.routes:
if isinstance(route, APIRoute):
route.router = copied_router
route.setup()
elif isinstance(route, Mount):
if isinstance(route.app, APIRouter):
route.app.setup()
return copied_router
def iter_all_routes(self) -> Iterator[routing.BaseRoute]:
for route in self.routes:
if isinstance(route, Mount):
if isinstance(route.app, APIRouter):
yield from route.app.iter_all_routes()
else:
yield route
def api_mount(self, router: "APIRouter", name: Optional[str] = None) -> None:
route = APIMount(router=router, name=name, parent_router=self)
self.routes.append(route)
def add_api_route(
self,
path: str,
endpoint: Callable[..., Any],
*,
response_model: Optional[Type[Any]] = None,
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
methods: Optional[Union[Set[str], List[str]]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(
JSONResponse
),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Union[
Callable[[APIRoute], str], DefaultPlaceholder
] = Default(generate_unique_id),
) -> None:
route_class = route_class_override or self.route_class
responses = responses or {}
route = route_class(
path,
endpoint=endpoint,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
methods=methods,
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
dependency_overrides_provider=self.dependency_overrides_provider,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
router=self,
)
self.routes.append(route)
if not path:
self._router_has_empty_route = True
if path == "/":
self._router_has_root_route = True
def api_route(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
methods: Optional[List[str]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_route(
path,
func,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
methods=methods,
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
)
return func
return decorator
def add_api_websocket_route(
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
) -> None:
route = APIWebSocketRoute(
path,
endpoint=endpoint,
name=name,
dependency_overrides_provider=self.dependency_overrides_provider,
)
self.routes.append(route)
def websocket(
self, path: str, name: Optional[str] = None
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route(path, func, name=name)
return func
return decorator
def include_router(
self,
router: "APIRouter",
*,
prefix: str = "",
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
default_response_class: Type[Response] = Default(JSONResponse),
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
callbacks: Optional[List[BaseRoute]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
copy_flat_routes: Optional[bool] = None,
) -> None:
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith(
"/"
), "A path prefix must not end with '/', as the routes will start with '/'"
resolved_copy_flat_routes = copy_flat_routes
if resolved_copy_flat_routes is None:
resolved_copy_flat_routes = not (prefix or router.prefix)
if not resolved_copy_flat_routes:
included_router = router.copy()
if (
prefix
or tags
or dependencies
or not isinstance(default_response_class, DefaultPlaceholder)
or responses
or callbacks
or deprecated is not None
or include_in_schema is not True
or not isinstance(generate_unique_id_function, DefaultPlaceholder)
):
current_router = type(self)(
prefix=prefix,
tags=tags,
dependencies=dependencies,
default_response_class=default_response_class,
responses=responses,
callbacks=callbacks,
deprecated=deprecated,
include_in_schema=include_in_schema,
generate_unique_id_function=generate_unique_id_function,
parent_router=self,
)
# current_router.api_mount(included_router)
current_router.include_router(included_router)
if included_router._router_has_empty_route and not self.prefix:
current_router._router_has_empty_route = True
current_router._router_has_root_route = (
included_router._router_has_root_route
)
self.api_mount(current_router)
included_router.parent_router = current_router
else:
self.api_mount(included_router)
included_router.parent_router = self
included_router.setup()
else:
# TODO: remove this and its test, as a subrouter can mount another
# subrouter (done automatically of other things are overwritten) and both
# can omit a prefix, this would error out
# for r in router.routes:
# path = getattr(r, "path")
# name = getattr(r, "name", "unknown")
# if path is not None and not path:
# raise Exception(
# f"Prefix and path cannot be both empty (path operation: {name})"
# )
if responses is None:
responses = {}
for route in router.routes:
if isinstance(route, APIRoute):
combined_responses = {}
if route.router:
combined_responses.update(route.router.responses)
combined_responses.update(responses)
combined_responses.update(route.responses)
response_classes: List[
Union[Type[Response], DefaultPlaceholder]
] = []
if route.router:
response_classes.append(route.router.default_response_class)
response_classes.extend(
[
route.response_class,
router.default_response_class,
default_response_class,
self.default_response_class,
]
)
use_response_class = get_value_or_default(*response_classes)
current_tags = []
if route.router:
current_tags.extend(route.router.tags)
if tags:
current_tags.extend(tags)
if route.tags:
current_tags.extend(route.tags)
current_dependencies: List[params.Depends] = []
if route.router:
current_dependencies.extend(route.router.dependencies)
if dependencies:
current_dependencies.extend(dependencies)
if route.dependencies:
current_dependencies.extend(route.dependencies)
current_callbacks = []
if route.router:
current_callbacks.extend(route.router.callbacks)
if callbacks:
current_callbacks.extend(callbacks)
if route.callbacks:
current_callbacks.extend(route.callbacks)
generate_unique_id_functions: List[
Union[Callable[[APIRoute], str], DefaultPlaceholder]
] = []
if route.router:
generate_unique_id_functions.append(
route.router.generate_unique_id_function
)
generate_unique_id_functions.extend(
[
route.generate_unique_id_function,
router.generate_unique_id_function,
generate_unique_id_function,
self.generate_unique_id_function,
]
)
current_generate_unique_id_function = get_value_or_default(
*generate_unique_id_functions
)
path = prefix + route.path
if route.router:
path = prefix + route.router.prefix + path
self.add_api_route(
path,
route.endpoint,
response_model=route.response_model,
status_code=route.status_code,
tags=current_tags,
dependencies=current_dependencies,
summary=route.summary,
description=route.description,
response_description=route.response_description,
responses=combined_responses,
deprecated=route.deprecated or deprecated or self.deprecated,
methods=route.methods,
operation_id=route.operation_id,
response_model_include=route.response_model_include,
response_model_exclude=route.response_model_exclude,
response_model_by_alias=route.response_model_by_alias,
response_model_exclude_unset=route.response_model_exclude_unset,
response_model_exclude_defaults=route.response_model_exclude_defaults,
response_model_exclude_none=route.response_model_exclude_none,
include_in_schema=route.include_in_schema
and self.include_in_schema
and include_in_schema,
response_class=use_response_class,
name=route.name,
route_class_override=type(route),
callbacks=current_callbacks,
openapi_extra=route.openapi_extra,
generate_unique_id_function=current_generate_unique_id_function,
)
elif isinstance(route, APIMount):
self.include_router(
route.app,
prefix=prefix,
tags=tags,
dependencies=dependencies,
default_response_class=default_response_class,
responses=responses,
callbacks=callbacks,
deprecated=deprecated,
include_in_schema=include_in_schema,
generate_unique_id_function=generate_unique_id_function,
)
elif isinstance(route, routing.Route):
methods = list(route.methods or []) # type: ignore # in Starlette
self.add_route(
prefix + route.path,
route.endpoint,
methods=methods,
include_in_schema=route.include_in_schema,
name=route.name,
)
elif isinstance(route, APIWebSocketRoute):
self.add_api_websocket_route(
prefix + route.path, route.endpoint, name=route.name
)
elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route(
prefix + route.path, route.endpoint, name=route.name
)
for handler in router.on_startup:
self.add_event_handler("startup", handler)
for handler in router.on_shutdown:
self.add_event_handler("shutdown", handler)
def get(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
methods=["GET"],
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
)
def put(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
methods=["PUT"],
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
)
def post(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
methods=["POST"],
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
)
def delete(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
methods=["DELETE"],
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
)
def options(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
methods=["OPTIONS"],
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
)
def head(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
methods=["HEAD"],
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
)
def patch(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
methods=["PATCH"],
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
)
def trace(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.api_route(
path=path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
methods=["TRACE"],
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
)
class APIMount(routing.Mount):
def __init__(
self,
router: APIRouter,
*,
name: Optional[str] = None,
parent_router: Optional[APIRouter] = None,
) -> None:
self.name = name # type: ignore # in Starlette
self.parent_router = parent_router
self.router = router
self.setup()
def setup(self) -> None:
self.app: APIRouter = self.router.copy()
if self.parent_router:
self.app.parent_router = self.parent_router
self.app.setup()
self.path = self.app.prefix
self.path_regex, self.path_format, self.param_convertors = compile_path(
self.path + "/{path:path}"
)
# Add custom additional root without trailing slash for compatibility with
# include_router and possibly app migrations
# Ref: https://github.com/tiangolo/fastapi/issues/414
(
self._root_path_regex,
self._root_path_format,
self._root_param_convertors,
) = compile_path(self.path)
(
self._root_path_regex_trailing,
self._root_path_format_trailing,
self._root_param_convertors_trailing,
) = compile_path(self.path + "/")
def copy(self: APIMountType) -> APIMountType:
return type(self)(
router=self.router.copy(),
name=self.name,
parent_router=self.parent_router,
)
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
if scope["type"] in ("http", "websocket"):
path = scope["path"]
if self.app._router_has_empty_route:
# Custom logic to support paths without trailing slash
# Ref: https://github.com/tiangolo/fastapi/issues/414
# This mixes the code in
# starlette.routing.Route.matches() and starlette.routing.Mount.matches()
match = self._root_path_regex.match(path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
root_path = scope.get("root_path", "")
child_scope = {
"path_params": path_params,
"app_root_path": scope.get("app_root_path", root_path),
"root_path": root_path,
"path": "",
"endpoint": self.app,
}
return Match.FULL, child_scope
if not self.app._router_has_root_route:
match = self._root_path_regex_trailing.match(path)
if match:
return Match.NONE, {}
# End of custom logic
# Duplicated code from Starlette
match = self.path_regex.match(path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
remaining_path = "/" + matched_params.pop("path")
matched_path = path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
root_path = scope.get("root_path", "")
child_scope = {
"path_params": path_params,
"app_root_path": scope.get("app_root_path", root_path),
"root_path": root_path + matched_path,
"path": remaining_path,
"endpoint": self.app,
}
return Match.FULL, child_scope
return Match.NONE, {}
# End of duplicated code from Starlette