|
|
@ -9,12 +9,14 @@ from typing import ( |
|
|
|
Callable, |
|
|
|
Coroutine, |
|
|
|
Dict, |
|
|
|
Iterator, |
|
|
|
List, |
|
|
|
Optional, |
|
|
|
Sequence, |
|
|
|
Set, |
|
|
|
Tuple, |
|
|
|
Type, |
|
|
|
TypeVar, |
|
|
|
Union, |
|
|
|
) |
|
|
|
|
|
|
@ -57,6 +59,10 @@ from starlette.status import WS_1008_POLICY_VIOLATION |
|
|
|
from starlette.types import ASGIApp, Scope |
|
|
|
from starlette.websockets import WebSocket |
|
|
|
|
|
|
|
APIRouteType = TypeVar("APIRouteType", bound="APIRoute") |
|
|
|
APIRouterType = TypeVar("APIRouterType", bound="APIRouter") |
|
|
|
APIMountType = TypeVar("APIMountType", bound="APIMount") |
|
|
|
|
|
|
|
|
|
|
|
def _prepare_response_content( |
|
|
|
res: Any, |
|
|
@ -338,13 +344,13 @@ class APIRoute(routing.Route): |
|
|
|
generate_unique_id_function: Union[ |
|
|
|
Callable[["APIRoute"], str], DefaultPlaceholder |
|
|
|
] = Default(generate_unique_id), |
|
|
|
router: Optional["APIRouter"] = None, |
|
|
|
) -> None: |
|
|
|
self.path = path |
|
|
|
self.endpoint = endpoint |
|
|
|
self.response_model = response_model |
|
|
|
self.summary = summary |
|
|
|
self.response_description = response_description |
|
|
|
self.deprecated = deprecated |
|
|
|
self.operation_id = operation_id |
|
|
|
self.response_model_include = response_model_include |
|
|
|
self.response_model_exclude = response_model_exclude |
|
|
@ -352,34 +358,128 @@ class APIRoute(routing.Route): |
|
|
|
self.response_model_exclude_unset = response_model_exclude_unset |
|
|
|
self.response_model_exclude_defaults = response_model_exclude_defaults |
|
|
|
self.response_model_exclude_none = response_model_exclude_none |
|
|
|
self.include_in_schema = include_in_schema |
|
|
|
self.response_class = response_class |
|
|
|
self.dependency_overrides_provider = dependency_overrides_provider |
|
|
|
self.callbacks = callbacks |
|
|
|
self.openapi_extra = openapi_extra |
|
|
|
self.generate_unique_id_function = generate_unique_id_function |
|
|
|
self.tags = tags or [] |
|
|
|
self.responses = responses or {} |
|
|
|
self.router = router |
|
|
|
|
|
|
|
self.name = get_name(endpoint) if name is None else name |
|
|
|
self.path_regex, self.path_format, self.param_convertors = compile_path(path) |
|
|
|
# normalize enums e.g. http.HTTPStatus |
|
|
|
if isinstance(status_code, IntEnum): |
|
|
|
status_code = int(status_code) |
|
|
|
self.status_code = status_code |
|
|
|
if methods is None: |
|
|
|
methods = ["GET"] |
|
|
|
self.methods: Set[str] = set([method.upper() for method in methods]) |
|
|
|
if isinstance(generate_unique_id_function, DefaultPlaceholder): |
|
|
|
current_generate_unique_id: Callable[ |
|
|
|
|
|
|
|
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") |
|
|
|
# if a "form feed" character (page break) is found in the description text, |
|
|
|
# truncate description text to the content preceding the first "form feed" |
|
|
|
self.description = self.description.split("\f")[0] |
|
|
|
|
|
|
|
assert callable(endpoint), "An endpoint must be a callable" |
|
|
|
|
|
|
|
self.path_regex, self.path_format, self.param_convertors = compile_path( |
|
|
|
self.path |
|
|
|
) |
|
|
|
|
|
|
|
# Attributes set in route used to compute resolved attributes |
|
|
|
self._route_deprecated = deprecated |
|
|
|
self._route_include_in_schema = include_in_schema |
|
|
|
self._route_response_class = response_class |
|
|
|
self._route_callbacks = callbacks |
|
|
|
self._route_generate_unique_id_function = generate_unique_id_function |
|
|
|
self._route_tags = tags or [] |
|
|
|
self._route_responses = responses or {} |
|
|
|
if dependencies: |
|
|
|
self._route_dependencies = dependencies |
|
|
|
else: |
|
|
|
self._route_dependencies = [] |
|
|
|
|
|
|
|
self.setup() |
|
|
|
|
|
|
|
def setup(self) -> None: |
|
|
|
# setup full path |
|
|
|
self._route_full_path = self.path |
|
|
|
if self.router: |
|
|
|
self._route_full_path = self.router._router_full_path + self.path |
|
|
|
|
|
|
|
# setup dependencies |
|
|
|
self.dependencies: List[params.Depends] = [] |
|
|
|
if self.router: |
|
|
|
self.dependencies.extend(self.router.dependencies) |
|
|
|
self.dependencies.extend(self._route_dependencies) |
|
|
|
|
|
|
|
# setup generate_unique_id |
|
|
|
generate_unique_id_functions: List[ |
|
|
|
Union[Callable[[APIRoute], str], DefaultPlaceholder] |
|
|
|
] = [self._route_generate_unique_id_function] |
|
|
|
if self.router: |
|
|
|
generate_unique_id_functions.append(self.router.generate_unique_id_function) |
|
|
|
current_generate_unique_id_function = get_value_or_default( |
|
|
|
*generate_unique_id_functions |
|
|
|
) |
|
|
|
self.generate_unique_id_function: Union[ |
|
|
|
Callable[[APIRoute], str], DefaultPlaceholder |
|
|
|
] = current_generate_unique_id_function |
|
|
|
|
|
|
|
# setup responses |
|
|
|
responses: Dict[Union[int, str], Dict[str, Any]] = {} |
|
|
|
if self.router: |
|
|
|
responses.update(self.router.responses) |
|
|
|
responses.update(self._route_responses) |
|
|
|
self.responses: Dict[Union[int, str], Dict[str, Any]] = responses |
|
|
|
|
|
|
|
# setup default_response_class |
|
|
|
default_response_classes: List[Union[Type[Response], DefaultPlaceholder]] = [ |
|
|
|
self._route_response_class |
|
|
|
] |
|
|
|
if self.router: |
|
|
|
default_response_classes.append(self.router.default_response_class) |
|
|
|
current_default_response_class = get_value_or_default(*default_response_classes) |
|
|
|
self.response_class: Union[ |
|
|
|
Type[Response], DefaultPlaceholder |
|
|
|
] = current_default_response_class |
|
|
|
|
|
|
|
# setup tags |
|
|
|
self.tags: List[Union[str, Enum]] = [] |
|
|
|
if self.router: |
|
|
|
self.tags.extend(self.router.tags) |
|
|
|
self.tags.extend(self._route_tags) |
|
|
|
|
|
|
|
# setup callbacks |
|
|
|
callbacks: List[BaseRoute] = [] |
|
|
|
if self.router: |
|
|
|
callbacks.extend(self.router.callbacks) |
|
|
|
if self._route_callbacks: |
|
|
|
callbacks.extend(self._route_callbacks) |
|
|
|
self.callbacks = callbacks |
|
|
|
|
|
|
|
# setup deprecated |
|
|
|
self.deprecated = self._route_deprecated |
|
|
|
if self.router: |
|
|
|
self.deprecated = self._route_deprecated or self.router.deprecated |
|
|
|
|
|
|
|
# setup include_in_schema |
|
|
|
self.include_in_schema = self._route_include_in_schema |
|
|
|
if self.router: |
|
|
|
self.include_in_schema = ( |
|
|
|
self._route_include_in_schema and self.router.include_in_schema |
|
|
|
) |
|
|
|
|
|
|
|
_, self._route_full_path_format, _ = compile_path(self._route_full_path) |
|
|
|
|
|
|
|
if isinstance(self.generate_unique_id_function, DefaultPlaceholder): |
|
|
|
resolved_generate_unique_id: Callable[ |
|
|
|
["APIRoute"], str |
|
|
|
] = generate_unique_id_function.value |
|
|
|
] = self.generate_unique_id_function.value |
|
|
|
else: |
|
|
|
current_generate_unique_id = generate_unique_id_function |
|
|
|
self.unique_id = self.operation_id or current_generate_unique_id(self) |
|
|
|
# normalize enums e.g. http.HTTPStatus |
|
|
|
if isinstance(status_code, IntEnum): |
|
|
|
status_code = int(status_code) |
|
|
|
self.status_code = status_code |
|
|
|
resolved_generate_unique_id = self.generate_unique_id_function |
|
|
|
self.unique_id = self.operation_id or resolved_generate_unique_id(self) |
|
|
|
|
|
|
|
if self.response_model: |
|
|
|
assert ( |
|
|
|
status_code not in STATUS_CODES_WITH_NO_BODY |
|
|
|
), f"Status code {status_code} must not have a response body" |
|
|
|
self.status_code not in STATUS_CODES_WITH_NO_BODY |
|
|
|
), f"Status code {self.status_code} must not have a response body" |
|
|
|
response_name = "Response_" + self.unique_id |
|
|
|
self.response_field = create_response_field( |
|
|
|
name=response_name, type_=self.response_model |
|
|
@ -397,14 +497,7 @@ class APIRoute(routing.Route): |
|
|
|
else: |
|
|
|
self.response_field = None # type: ignore |
|
|
|
self.secure_cloned_response_field = None |
|
|
|
if dependencies: |
|
|
|
self.dependencies = list(dependencies) |
|
|
|
else: |
|
|
|
self.dependencies = [] |
|
|
|
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") |
|
|
|
# if a "form feed" character (page break) is found in the description text, |
|
|
|
# truncate description text to the content preceding the first "form feed" |
|
|
|
self.description = self.description.split("\f")[0] |
|
|
|
|
|
|
|
response_fields = {} |
|
|
|
for additional_status_code, response in self.responses.items(): |
|
|
|
assert isinstance(response, dict), "An additional response must be a dict" |
|
|
@ -421,16 +514,50 @@ class APIRoute(routing.Route): |
|
|
|
else: |
|
|
|
self.response_fields = {} |
|
|
|
|
|
|
|
assert callable(endpoint), "An endpoint must be a callable" |
|
|
|
self.dependant = get_dependant(path=self.path_format, call=self.endpoint) |
|
|
|
self.dependant = get_dependant( |
|
|
|
path=self._route_full_path_format, call=self.endpoint |
|
|
|
) |
|
|
|
for depends in self.dependencies[::-1]: |
|
|
|
self.dependant.dependencies.insert( |
|
|
|
0, |
|
|
|
get_parameterless_sub_dependant(depends=depends, path=self.path_format), |
|
|
|
get_parameterless_sub_dependant( |
|
|
|
depends=depends, path=self._route_full_path_format |
|
|
|
), |
|
|
|
) |
|
|
|
self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id) |
|
|
|
self.app = request_response(self.get_route_handler()) |
|
|
|
|
|
|
|
def copy(self: APIRouteType) -> APIRouteType: |
|
|
|
return type(self)( |
|
|
|
path=self.path, |
|
|
|
endpoint=self.endpoint, |
|
|
|
response_model=self.response_model, |
|
|
|
status_code=self.status_code, |
|
|
|
tags=self._route_tags, |
|
|
|
dependencies=self._route_dependencies, |
|
|
|
summary=self.summary, |
|
|
|
description=self.description, |
|
|
|
response_description=self.response_description, |
|
|
|
responses=self._route_responses, |
|
|
|
deprecated=self._route_deprecated, |
|
|
|
name=self.name, |
|
|
|
methods=self.methods, |
|
|
|
operation_id=self.operation_id, |
|
|
|
response_model_include=self.response_model_include, |
|
|
|
response_model_exclude=self.response_model_exclude, |
|
|
|
response_model_by_alias=self.response_model_by_alias, |
|
|
|
response_model_exclude_unset=self.response_model_exclude_unset, |
|
|
|
response_model_exclude_defaults=self.response_model_exclude_defaults, |
|
|
|
response_model_exclude_none=self.response_model_exclude_none, |
|
|
|
include_in_schema=self._route_include_in_schema, |
|
|
|
response_class=self._route_response_class, |
|
|
|
dependency_overrides_provider=self.dependency_overrides_provider, |
|
|
|
callbacks=self._route_callbacks, |
|
|
|
openapi_extra=self.openapi_extra, |
|
|
|
generate_unique_id_function=self._route_generate_unique_id_function, |
|
|
|
router=self.router, |
|
|
|
) |
|
|
|
|
|
|
|
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: |
|
|
|
return get_request_handler( |
|
|
|
dependant=self.dependant, |
|
|
@ -476,6 +603,7 @@ class APIRouter(routing.Router): |
|
|
|
generate_unique_id_function: Callable[[APIRoute], str] = Default( |
|
|
|
generate_unique_id |
|
|
|
), |
|
|
|
parent_router: Optional["APIRouter"] = None, |
|
|
|
) -> None: |
|
|
|
super().__init__( |
|
|
|
routes=routes, # type: ignore # in Starlette |
|
|
@ -490,16 +618,151 @@ class APIRouter(routing.Router): |
|
|
|
"/" |
|
|
|
), "A path prefix must not end with '/', as the routes will start with '/'" |
|
|
|
self.prefix = prefix |
|
|
|
self.tags: List[Union[str, Enum]] = tags or [] |
|
|
|
self.dependencies = list(dependencies or []) or [] |
|
|
|
self.deprecated = deprecated |
|
|
|
self.include_in_schema = include_in_schema |
|
|
|
self.responses = responses or {} |
|
|
|
self.callbacks = callbacks or [] |
|
|
|
self.dependency_overrides_provider = dependency_overrides_provider |
|
|
|
self.route_class = route_class |
|
|
|
self.default_response_class = default_response_class |
|
|
|
self.generate_unique_id_function = generate_unique_id_function |
|
|
|
|
|
|
|
self.parent_router = parent_router |
|
|
|
|
|
|
|
# Attributes set in router used to compute resolved attributes |
|
|
|
self._router_dependencies = list(dependencies or []) or [] |
|
|
|
self._router_generate_unique_id_function = generate_unique_id_function |
|
|
|
self._router_responses = responses or {} |
|
|
|
self._router_default_response_class = default_response_class |
|
|
|
self._router_tags: List[Union[str, Enum]] = tags or [] |
|
|
|
self._router_callbacks = callbacks or [] |
|
|
|
self._router_deprecated = deprecated |
|
|
|
self._router_include_in_schema = include_in_schema |
|
|
|
self._router_has_empty_route = False |
|
|
|
self._router_has_root_route = False |
|
|
|
self.setup() |
|
|
|
|
|
|
|
def setup(self) -> None: |
|
|
|
# setup full path |
|
|
|
self._router_full_path = self.prefix |
|
|
|
if self.parent_router: |
|
|
|
self._router_full_path = self.parent_router._router_full_path + self.prefix |
|
|
|
# setup dependencies |
|
|
|
self.dependencies: List[params.Depends] = [] |
|
|
|
if self.parent_router: |
|
|
|
self.dependencies.extend(self.parent_router.dependencies) |
|
|
|
self.dependencies.extend(self._router_dependencies) |
|
|
|
|
|
|
|
# setup generate_unique_id |
|
|
|
generate_unique_id_functions: List[ |
|
|
|
Union[Callable[[APIRoute], str], DefaultPlaceholder] |
|
|
|
] = [self._router_generate_unique_id_function] |
|
|
|
if self.parent_router: |
|
|
|
generate_unique_id_functions.append( |
|
|
|
self.parent_router.generate_unique_id_function |
|
|
|
) |
|
|
|
current_generate_unique_id_function = get_value_or_default( |
|
|
|
*generate_unique_id_functions |
|
|
|
) |
|
|
|
self.generate_unique_id_function: Union[ |
|
|
|
Callable[[APIRoute], str], DefaultPlaceholder |
|
|
|
] = current_generate_unique_id_function |
|
|
|
|
|
|
|
# setup responses |
|
|
|
responses: Dict[Union[int, str], Dict[str, Any]] = {} |
|
|
|
if self.parent_router: |
|
|
|
responses.update(self.parent_router.responses) |
|
|
|
responses.update(self._router_responses) |
|
|
|
self.responses: Dict[Union[int, str], Dict[str, Any]] = responses |
|
|
|
|
|
|
|
# setup default_response_class |
|
|
|
default_response_classes: List[Union[Type[Response], DefaultPlaceholder]] = [ |
|
|
|
self._router_default_response_class |
|
|
|
] |
|
|
|
if self.parent_router: |
|
|
|
default_response_classes.append(self.parent_router.default_response_class) |
|
|
|
current_default_response_class = get_value_or_default(*default_response_classes) |
|
|
|
self.default_response_class: Union[ |
|
|
|
Type[Response], DefaultPlaceholder |
|
|
|
] = current_default_response_class |
|
|
|
|
|
|
|
# setup tags |
|
|
|
self.tags: List[Union[str, Enum]] = [] |
|
|
|
if self.parent_router: |
|
|
|
self.tags.extend(self.parent_router.tags) |
|
|
|
self.tags.extend(self._router_tags) |
|
|
|
|
|
|
|
# setup callbacks |
|
|
|
self.callbacks: List[BaseRoute] = [] |
|
|
|
if self.parent_router: |
|
|
|
self.callbacks.extend(self.parent_router.callbacks) |
|
|
|
self.callbacks.extend(self._router_callbacks) |
|
|
|
|
|
|
|
# setup deprecated |
|
|
|
self.deprecated = self._router_deprecated |
|
|
|
if self.parent_router: |
|
|
|
self.deprecated = self._router_deprecated or self.parent_router.deprecated |
|
|
|
|
|
|
|
# setup include_in_schema |
|
|
|
self.include_in_schema = self._router_include_in_schema |
|
|
|
if self.parent_router: |
|
|
|
self.include_in_schema = ( |
|
|
|
self._router_include_in_schema and self.parent_router.include_in_schema |
|
|
|
) |
|
|
|
|
|
|
|
# setup routes |
|
|
|
for route in self.routes: |
|
|
|
if isinstance(route, APIRoute): |
|
|
|
route.router = self |
|
|
|
route.setup() |
|
|
|
elif isinstance(route, APIMount): |
|
|
|
route.parent_router = self |
|
|
|
route.setup() |
|
|
|
|
|
|
|
def copy(self: APIRouterType) -> APIRouterType: |
|
|
|
routes: List[routing.BaseRoute] = [] |
|
|
|
for route in self.routes: |
|
|
|
if isinstance(route, APIRoute): |
|
|
|
routes.append(route.copy()) |
|
|
|
elif isinstance(route, APIMount): |
|
|
|
routes.append(route.copy()) |
|
|
|
else: |
|
|
|
routes.append(route) |
|
|
|
copied_router = type(self)( |
|
|
|
prefix=self.prefix, |
|
|
|
tags=self._router_tags, |
|
|
|
dependencies=self._router_dependencies, |
|
|
|
default_response_class=self._router_default_response_class, |
|
|
|
responses=self._router_responses, |
|
|
|
callbacks=self._router_callbacks, |
|
|
|
routes=routes, |
|
|
|
redirect_slashes=self.redirect_slashes, |
|
|
|
default=self.default, |
|
|
|
dependency_overrides_provider=self.dependency_overrides_provider, |
|
|
|
route_class=self.route_class, |
|
|
|
on_startup=self.on_startup, |
|
|
|
on_shutdown=self.on_shutdown, |
|
|
|
deprecated=self._router_deprecated, |
|
|
|
include_in_schema=self._router_include_in_schema, |
|
|
|
generate_unique_id_function=self._router_generate_unique_id_function, |
|
|
|
parent_router=self.parent_router, |
|
|
|
) |
|
|
|
copied_router._router_has_empty_route = self._router_has_empty_route |
|
|
|
copied_router._router_has_root_route = self._router_has_root_route |
|
|
|
for route in copied_router.routes: |
|
|
|
if isinstance(route, APIRoute): |
|
|
|
route.router = copied_router |
|
|
|
route.setup() |
|
|
|
elif isinstance(route, Mount): |
|
|
|
if isinstance(route.app, APIRouter): |
|
|
|
route.app.setup() |
|
|
|
return copied_router |
|
|
|
|
|
|
|
def iter_all_routes(self) -> Iterator[routing.BaseRoute]: |
|
|
|
for route in self.routes: |
|
|
|
if isinstance(route, Mount): |
|
|
|
if isinstance(route.app, APIRouter): |
|
|
|
yield from route.app.iter_all_routes() |
|
|
|
else: |
|
|
|
yield route |
|
|
|
|
|
|
|
def api_mount(self, router: "APIRouter", name: Optional[str] = None) -> None: |
|
|
|
route = APIMount(router=router, name=name, parent_router=self) |
|
|
|
self.routes.append(route) |
|
|
|
|
|
|
|
def add_api_route( |
|
|
|
self, |
|
|
@ -537,34 +800,18 @@ class APIRouter(routing.Router): |
|
|
|
) -> None: |
|
|
|
route_class = route_class_override or self.route_class |
|
|
|
responses = responses or {} |
|
|
|
combined_responses = {**self.responses, **responses} |
|
|
|
current_response_class = get_value_or_default( |
|
|
|
response_class, self.default_response_class |
|
|
|
) |
|
|
|
current_tags = self.tags.copy() |
|
|
|
if tags: |
|
|
|
current_tags.extend(tags) |
|
|
|
current_dependencies = self.dependencies.copy() |
|
|
|
if dependencies: |
|
|
|
current_dependencies.extend(dependencies) |
|
|
|
current_callbacks = self.callbacks.copy() |
|
|
|
if callbacks: |
|
|
|
current_callbacks.extend(callbacks) |
|
|
|
current_generate_unique_id = get_value_or_default( |
|
|
|
generate_unique_id_function, self.generate_unique_id_function |
|
|
|
) |
|
|
|
route = route_class( |
|
|
|
self.prefix + path, |
|
|
|
path, |
|
|
|
endpoint=endpoint, |
|
|
|
response_model=response_model, |
|
|
|
status_code=status_code, |
|
|
|
tags=current_tags, |
|
|
|
dependencies=current_dependencies, |
|
|
|
tags=tags, |
|
|
|
dependencies=dependencies, |
|
|
|
summary=summary, |
|
|
|
description=description, |
|
|
|
response_description=response_description, |
|
|
|
responses=combined_responses, |
|
|
|
deprecated=deprecated or self.deprecated, |
|
|
|
responses=responses, |
|
|
|
deprecated=deprecated, |
|
|
|
methods=methods, |
|
|
|
operation_id=operation_id, |
|
|
|
response_model_include=response_model_include, |
|
|
@ -573,15 +820,20 @@ class APIRouter(routing.Router): |
|
|
|
response_model_exclude_unset=response_model_exclude_unset, |
|
|
|
response_model_exclude_defaults=response_model_exclude_defaults, |
|
|
|
response_model_exclude_none=response_model_exclude_none, |
|
|
|
include_in_schema=include_in_schema and self.include_in_schema, |
|
|
|
response_class=current_response_class, |
|
|
|
include_in_schema=include_in_schema, |
|
|
|
response_class=response_class, |
|
|
|
name=name, |
|
|
|
dependency_overrides_provider=self.dependency_overrides_provider, |
|
|
|
callbacks=current_callbacks, |
|
|
|
callbacks=callbacks, |
|
|
|
openapi_extra=openapi_extra, |
|
|
|
generate_unique_id_function=current_generate_unique_id, |
|
|
|
generate_unique_id_function=generate_unique_id_function, |
|
|
|
router=self, |
|
|
|
) |
|
|
|
self.routes.append(route) |
|
|
|
if not path: |
|
|
|
self._router_has_empty_route = True |
|
|
|
if path == "/": |
|
|
|
self._router_has_root_route = True |
|
|
|
|
|
|
|
def api_route( |
|
|
|
self, |
|
|
@ -680,103 +932,197 @@ class APIRouter(routing.Router): |
|
|
|
generate_unique_id_function: Callable[[APIRoute], str] = Default( |
|
|
|
generate_unique_id |
|
|
|
), |
|
|
|
copy_flat_routes: Optional[bool] = None, |
|
|
|
) -> None: |
|
|
|
if prefix: |
|
|
|
assert prefix.startswith("/"), "A path prefix must start with '/'" |
|
|
|
assert not prefix.endswith( |
|
|
|
"/" |
|
|
|
), "A path prefix must not end with '/', as the routes will start with '/'" |
|
|
|
resolved_copy_flat_routes = copy_flat_routes |
|
|
|
if resolved_copy_flat_routes is None: |
|
|
|
resolved_copy_flat_routes = not (prefix or router.prefix) |
|
|
|
if not resolved_copy_flat_routes: |
|
|
|
included_router = router.copy() |
|
|
|
if ( |
|
|
|
prefix |
|
|
|
or tags |
|
|
|
or dependencies |
|
|
|
or not isinstance(default_response_class, DefaultPlaceholder) |
|
|
|
or responses |
|
|
|
or callbacks |
|
|
|
or deprecated is not None |
|
|
|
or include_in_schema is not True |
|
|
|
or not isinstance(generate_unique_id_function, DefaultPlaceholder) |
|
|
|
): |
|
|
|
current_router = type(self)( |
|
|
|
prefix=prefix, |
|
|
|
tags=tags, |
|
|
|
dependencies=dependencies, |
|
|
|
default_response_class=default_response_class, |
|
|
|
responses=responses, |
|
|
|
callbacks=callbacks, |
|
|
|
deprecated=deprecated, |
|
|
|
include_in_schema=include_in_schema, |
|
|
|
generate_unique_id_function=generate_unique_id_function, |
|
|
|
parent_router=self, |
|
|
|
) |
|
|
|
# current_router.api_mount(included_router) |
|
|
|
current_router.include_router(included_router) |
|
|
|
if included_router._router_has_empty_route and not self.prefix: |
|
|
|
current_router._router_has_empty_route = True |
|
|
|
current_router._router_has_root_route = ( |
|
|
|
included_router._router_has_root_route |
|
|
|
) |
|
|
|
self.api_mount(current_router) |
|
|
|
included_router.parent_router = current_router |
|
|
|
else: |
|
|
|
self.api_mount(included_router) |
|
|
|
included_router.parent_router = self |
|
|
|
|
|
|
|
included_router.setup() |
|
|
|
else: |
|
|
|
for r in router.routes: |
|
|
|
path = getattr(r, "path") |
|
|
|
name = getattr(r, "name", "unknown") |
|
|
|
if path is not None and not path: |
|
|
|
raise Exception( |
|
|
|
f"Prefix and path cannot be both empty (path operation: {name})" |
|
|
|
# TODO: remove this and its test, as a subrouter can mount another |
|
|
|
# subrouter (done automatically of other things are overwritten) and both |
|
|
|
# can omit a prefix, this would error out |
|
|
|
# for r in router.routes: |
|
|
|
# path = getattr(r, "path") |
|
|
|
# name = getattr(r, "name", "unknown") |
|
|
|
# if path is not None and not path: |
|
|
|
# raise Exception( |
|
|
|
# f"Prefix and path cannot be both empty (path operation: {name})" |
|
|
|
# ) |
|
|
|
if responses is None: |
|
|
|
responses = {} |
|
|
|
for route in router.routes: |
|
|
|
if isinstance(route, APIRoute): |
|
|
|
combined_responses = {} |
|
|
|
if route.router: |
|
|
|
combined_responses.update(route.router.responses) |
|
|
|
combined_responses.update(responses) |
|
|
|
combined_responses.update(route.responses) |
|
|
|
|
|
|
|
response_classes: List[ |
|
|
|
Union[Type[Response], DefaultPlaceholder] |
|
|
|
] = [] |
|
|
|
if route.router: |
|
|
|
response_classes.append(route.router.default_response_class) |
|
|
|
response_classes.extend( |
|
|
|
[ |
|
|
|
route.response_class, |
|
|
|
router.default_response_class, |
|
|
|
default_response_class, |
|
|
|
self.default_response_class, |
|
|
|
] |
|
|
|
) |
|
|
|
if responses is None: |
|
|
|
responses = {} |
|
|
|
for route in router.routes: |
|
|
|
if isinstance(route, APIRoute): |
|
|
|
combined_responses = {**responses, **route.responses} |
|
|
|
use_response_class = get_value_or_default( |
|
|
|
route.response_class, |
|
|
|
router.default_response_class, |
|
|
|
default_response_class, |
|
|
|
self.default_response_class, |
|
|
|
) |
|
|
|
current_tags = [] |
|
|
|
if tags: |
|
|
|
current_tags.extend(tags) |
|
|
|
if route.tags: |
|
|
|
current_tags.extend(route.tags) |
|
|
|
current_dependencies: List[params.Depends] = [] |
|
|
|
if dependencies: |
|
|
|
current_dependencies.extend(dependencies) |
|
|
|
if route.dependencies: |
|
|
|
current_dependencies.extend(route.dependencies) |
|
|
|
current_callbacks = [] |
|
|
|
if callbacks: |
|
|
|
current_callbacks.extend(callbacks) |
|
|
|
if route.callbacks: |
|
|
|
current_callbacks.extend(route.callbacks) |
|
|
|
current_generate_unique_id = get_value_or_default( |
|
|
|
route.generate_unique_id_function, |
|
|
|
router.generate_unique_id_function, |
|
|
|
generate_unique_id_function, |
|
|
|
self.generate_unique_id_function, |
|
|
|
) |
|
|
|
self.add_api_route( |
|
|
|
prefix + route.path, |
|
|
|
route.endpoint, |
|
|
|
response_model=route.response_model, |
|
|
|
status_code=route.status_code, |
|
|
|
tags=current_tags, |
|
|
|
dependencies=current_dependencies, |
|
|
|
summary=route.summary, |
|
|
|
description=route.description, |
|
|
|
response_description=route.response_description, |
|
|
|
responses=combined_responses, |
|
|
|
deprecated=route.deprecated or deprecated or self.deprecated, |
|
|
|
methods=route.methods, |
|
|
|
operation_id=route.operation_id, |
|
|
|
response_model_include=route.response_model_include, |
|
|
|
response_model_exclude=route.response_model_exclude, |
|
|
|
response_model_by_alias=route.response_model_by_alias, |
|
|
|
response_model_exclude_unset=route.response_model_exclude_unset, |
|
|
|
response_model_exclude_defaults=route.response_model_exclude_defaults, |
|
|
|
response_model_exclude_none=route.response_model_exclude_none, |
|
|
|
include_in_schema=route.include_in_schema |
|
|
|
and self.include_in_schema |
|
|
|
and include_in_schema, |
|
|
|
response_class=use_response_class, |
|
|
|
name=route.name, |
|
|
|
route_class_override=type(route), |
|
|
|
callbacks=current_callbacks, |
|
|
|
openapi_extra=route.openapi_extra, |
|
|
|
generate_unique_id_function=current_generate_unique_id, |
|
|
|
) |
|
|
|
elif isinstance(route, routing.Route): |
|
|
|
methods = list(route.methods or []) # type: ignore # in Starlette |
|
|
|
self.add_route( |
|
|
|
prefix + route.path, |
|
|
|
route.endpoint, |
|
|
|
methods=methods, |
|
|
|
include_in_schema=route.include_in_schema, |
|
|
|
name=route.name, |
|
|
|
) |
|
|
|
elif isinstance(route, APIWebSocketRoute): |
|
|
|
self.add_api_websocket_route( |
|
|
|
prefix + route.path, route.endpoint, name=route.name |
|
|
|
) |
|
|
|
elif isinstance(route, routing.WebSocketRoute): |
|
|
|
self.add_websocket_route( |
|
|
|
prefix + route.path, route.endpoint, name=route.name |
|
|
|
) |
|
|
|
for handler in router.on_startup: |
|
|
|
self.add_event_handler("startup", handler) |
|
|
|
for handler in router.on_shutdown: |
|
|
|
self.add_event_handler("shutdown", handler) |
|
|
|
use_response_class = get_value_or_default(*response_classes) |
|
|
|
current_tags = [] |
|
|
|
if route.router: |
|
|
|
current_tags.extend(route.router.tags) |
|
|
|
if tags: |
|
|
|
current_tags.extend(tags) |
|
|
|
if route.tags: |
|
|
|
current_tags.extend(route.tags) |
|
|
|
current_dependencies: List[params.Depends] = [] |
|
|
|
if route.router: |
|
|
|
current_dependencies.extend(route.router.dependencies) |
|
|
|
if dependencies: |
|
|
|
current_dependencies.extend(dependencies) |
|
|
|
if route.dependencies: |
|
|
|
current_dependencies.extend(route.dependencies) |
|
|
|
current_callbacks = [] |
|
|
|
if route.router: |
|
|
|
current_callbacks.extend(route.router.callbacks) |
|
|
|
if callbacks: |
|
|
|
current_callbacks.extend(callbacks) |
|
|
|
if route.callbacks: |
|
|
|
current_callbacks.extend(route.callbacks) |
|
|
|
|
|
|
|
generate_unique_id_functions: List[ |
|
|
|
Union[Callable[[APIRoute], str], DefaultPlaceholder] |
|
|
|
] = [] |
|
|
|
if route.router: |
|
|
|
generate_unique_id_functions.append( |
|
|
|
route.router.generate_unique_id_function |
|
|
|
) |
|
|
|
generate_unique_id_functions.extend( |
|
|
|
[ |
|
|
|
route.generate_unique_id_function, |
|
|
|
router.generate_unique_id_function, |
|
|
|
generate_unique_id_function, |
|
|
|
self.generate_unique_id_function, |
|
|
|
] |
|
|
|
) |
|
|
|
current_generate_unique_id_function = get_value_or_default( |
|
|
|
*generate_unique_id_functions |
|
|
|
) |
|
|
|
path = prefix + route.path |
|
|
|
if route.router: |
|
|
|
path = prefix + route.router.prefix + path |
|
|
|
self.add_api_route( |
|
|
|
path, |
|
|
|
route.endpoint, |
|
|
|
response_model=route.response_model, |
|
|
|
status_code=route.status_code, |
|
|
|
tags=current_tags, |
|
|
|
dependencies=current_dependencies, |
|
|
|
summary=route.summary, |
|
|
|
description=route.description, |
|
|
|
response_description=route.response_description, |
|
|
|
responses=combined_responses, |
|
|
|
deprecated=route.deprecated or deprecated or self.deprecated, |
|
|
|
methods=route.methods, |
|
|
|
operation_id=route.operation_id, |
|
|
|
response_model_include=route.response_model_include, |
|
|
|
response_model_exclude=route.response_model_exclude, |
|
|
|
response_model_by_alias=route.response_model_by_alias, |
|
|
|
response_model_exclude_unset=route.response_model_exclude_unset, |
|
|
|
response_model_exclude_defaults=route.response_model_exclude_defaults, |
|
|
|
response_model_exclude_none=route.response_model_exclude_none, |
|
|
|
include_in_schema=route.include_in_schema |
|
|
|
and self.include_in_schema |
|
|
|
and include_in_schema, |
|
|
|
response_class=use_response_class, |
|
|
|
name=route.name, |
|
|
|
route_class_override=type(route), |
|
|
|
callbacks=current_callbacks, |
|
|
|
openapi_extra=route.openapi_extra, |
|
|
|
generate_unique_id_function=current_generate_unique_id_function, |
|
|
|
) |
|
|
|
elif isinstance(route, APIMount): |
|
|
|
self.include_router( |
|
|
|
route.app, |
|
|
|
prefix=prefix, |
|
|
|
tags=tags, |
|
|
|
dependencies=dependencies, |
|
|
|
default_response_class=default_response_class, |
|
|
|
responses=responses, |
|
|
|
callbacks=callbacks, |
|
|
|
deprecated=deprecated, |
|
|
|
include_in_schema=include_in_schema, |
|
|
|
generate_unique_id_function=generate_unique_id_function, |
|
|
|
) |
|
|
|
elif isinstance(route, routing.Route): |
|
|
|
methods = list(route.methods or []) # type: ignore # in Starlette |
|
|
|
self.add_route( |
|
|
|
prefix + route.path, |
|
|
|
route.endpoint, |
|
|
|
methods=methods, |
|
|
|
include_in_schema=route.include_in_schema, |
|
|
|
name=route.name, |
|
|
|
) |
|
|
|
elif isinstance(route, APIWebSocketRoute): |
|
|
|
self.add_api_websocket_route( |
|
|
|
prefix + route.path, route.endpoint, name=route.name |
|
|
|
) |
|
|
|
elif isinstance(route, routing.WebSocketRoute): |
|
|
|
self.add_websocket_route( |
|
|
|
prefix + route.path, route.endpoint, name=route.name |
|
|
|
) |
|
|
|
for handler in router.on_startup: |
|
|
|
self.add_event_handler("startup", handler) |
|
|
|
for handler in router.on_shutdown: |
|
|
|
self.add_event_handler("shutdown", handler) |
|
|
|
|
|
|
|
def get( |
|
|
|
self, |
|
|
@ -1226,3 +1572,100 @@ class APIRouter(routing.Router): |
|
|
|
openapi_extra=openapi_extra, |
|
|
|
generate_unique_id_function=generate_unique_id_function, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class APIMount(routing.Mount): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
router: APIRouter, |
|
|
|
*, |
|
|
|
name: Optional[str] = None, |
|
|
|
parent_router: Optional[APIRouter] = None, |
|
|
|
) -> None: |
|
|
|
self.name = name # type: ignore # in Starlette |
|
|
|
self.parent_router = parent_router |
|
|
|
self.router = router |
|
|
|
|
|
|
|
self.setup() |
|
|
|
|
|
|
|
def setup(self) -> None: |
|
|
|
self.app: APIRouter = self.router.copy() |
|
|
|
if self.parent_router: |
|
|
|
self.app.parent_router = self.parent_router |
|
|
|
self.app.setup() |
|
|
|
self.path = self.app.prefix |
|
|
|
self.path_regex, self.path_format, self.param_convertors = compile_path( |
|
|
|
self.path + "/{path:path}" |
|
|
|
) |
|
|
|
|
|
|
|
# Add custom additional root without trailing slash for compatibility with |
|
|
|
# include_router and possibly app migrations |
|
|
|
# Ref: https://github.com/tiangolo/fastapi/issues/414 |
|
|
|
( |
|
|
|
self._root_path_regex, |
|
|
|
self._root_path_format, |
|
|
|
self._root_param_convertors, |
|
|
|
) = compile_path(self.path) |
|
|
|
( |
|
|
|
self._root_path_regex_trailing, |
|
|
|
self._root_path_format_trailing, |
|
|
|
self._root_param_convertors_trailing, |
|
|
|
) = compile_path(self.path + "/") |
|
|
|
|
|
|
|
def copy(self: APIMountType) -> APIMountType: |
|
|
|
return type(self)( |
|
|
|
router=self.router.copy(), |
|
|
|
name=self.name, |
|
|
|
parent_router=self.parent_router, |
|
|
|
) |
|
|
|
|
|
|
|
def matches(self, scope: Scope) -> Tuple[Match, Scope]: |
|
|
|
if scope["type"] in ("http", "websocket"): |
|
|
|
path = scope["path"] |
|
|
|
if self.app._router_has_empty_route: |
|
|
|
# Custom logic to support paths without trailing slash |
|
|
|
# Ref: https://github.com/tiangolo/fastapi/issues/414 |
|
|
|
# This mixes the code in |
|
|
|
# starlette.routing.Route.matches() and starlette.routing.Mount.matches() |
|
|
|
match = self._root_path_regex.match(path) |
|
|
|
if match: |
|
|
|
matched_params = match.groupdict() |
|
|
|
for key, value in matched_params.items(): |
|
|
|
matched_params[key] = self.param_convertors[key].convert(value) |
|
|
|
path_params = dict(scope.get("path_params", {})) |
|
|
|
path_params.update(matched_params) |
|
|
|
root_path = scope.get("root_path", "") |
|
|
|
child_scope = { |
|
|
|
"path_params": path_params, |
|
|
|
"app_root_path": scope.get("app_root_path", root_path), |
|
|
|
"root_path": root_path, |
|
|
|
"path": "", |
|
|
|
"endpoint": self.app, |
|
|
|
} |
|
|
|
return Match.FULL, child_scope |
|
|
|
if not self.app._router_has_root_route: |
|
|
|
match = self._root_path_regex_trailing.match(path) |
|
|
|
if match: |
|
|
|
return Match.NONE, {} |
|
|
|
# End of custom logic |
|
|
|
# Duplicated code from Starlette |
|
|
|
match = self.path_regex.match(path) |
|
|
|
if match: |
|
|
|
matched_params = match.groupdict() |
|
|
|
for key, value in matched_params.items(): |
|
|
|
matched_params[key] = self.param_convertors[key].convert(value) |
|
|
|
remaining_path = "/" + matched_params.pop("path") |
|
|
|
matched_path = path[: -len(remaining_path)] |
|
|
|
path_params = dict(scope.get("path_params", {})) |
|
|
|
path_params.update(matched_params) |
|
|
|
root_path = scope.get("root_path", "") |
|
|
|
child_scope = { |
|
|
|
"path_params": path_params, |
|
|
|
"app_root_path": scope.get("app_root_path", root_path), |
|
|
|
"root_path": root_path + matched_path, |
|
|
|
"path": remaining_path, |
|
|
|
"endpoint": self.app, |
|
|
|
} |
|
|
|
return Match.FULL, child_scope |
|
|
|
return Match.NONE, {} |
|
|
|
# End of duplicated code from Starlette |
|
|
|