diff --git a/fastapi/applications.py b/fastapi/applications.py index 132a94c9a..1f60d0285 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -137,6 +137,10 @@ class FastAPI(Starlette): self.middleware_stack: ASGIApp = self.build_middleware_stack() self.setup() + @property + def routes(self) -> List[BaseRoute]: + return list(self.router.iter_all_routes()) + def build_middleware_stack(self) -> ASGIApp: # Duplicate/override from Starlette to add AsyncExitStackMiddleware # inside of ExceptionMiddleware, inside of custom user middlewares diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 58a748d04..bba0c4bcb 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -152,7 +152,7 @@ def generate_operation_id( ) if route.operation_id: return route.operation_id - path: str = route.path_format + path: str = route._route_full_path_format return generate_operation_id_for_path(name=route.name, path=path, method=method) @@ -243,7 +243,7 @@ def get_openapi_path( model_name_map=model_name_map, operation_ids=operation_ids, ) - callbacks[callback.name] = {callback.path: cb_path} + callbacks[callback.name] = {callback._route_full_path: cb_path} operation["callbacks"] = callbacks if route.status_code is not None: status_code = str(route.status_code) @@ -422,7 +422,7 @@ def get_openapi( if result: path, security_schemes, path_definitions = result if path: - paths.setdefault(route.path_format, {}).update(path) + paths.setdefault(route._route_full_path_format, {}).update(path) if security_schemes: components.setdefault("securitySchemes", {}).update( security_schemes diff --git a/fastapi/routing.py b/fastapi/routing.py index 0f416ac42..8f27656d8 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -9,12 +9,14 @@ from typing import ( Callable, Coroutine, Dict, + Iterator, List, Optional, Sequence, Set, Tuple, Type, + TypeVar, Union, ) @@ -57,6 +59,10 @@ from starlette.status import WS_1008_POLICY_VIOLATION from starlette.types import ASGIApp, Scope from starlette.websockets import WebSocket +APIRouteType = TypeVar("APIRouteType", bound="APIRoute") +APIRouterType = TypeVar("APIRouterType", bound="APIRouter") +APIMountType = TypeVar("APIMountType", bound="APIMount") + def _prepare_response_content( res: Any, @@ -338,13 +344,13 @@ class APIRoute(routing.Route): generate_unique_id_function: Union[ Callable[["APIRoute"], str], DefaultPlaceholder ] = Default(generate_unique_id), + router: Optional["APIRouter"] = None, ) -> None: self.path = path self.endpoint = endpoint self.response_model = response_model self.summary = summary self.response_description = response_description - self.deprecated = deprecated self.operation_id = operation_id self.response_model_include = response_model_include self.response_model_exclude = response_model_exclude @@ -352,34 +358,128 @@ class APIRoute(routing.Route): self.response_model_exclude_unset = response_model_exclude_unset self.response_model_exclude_defaults = response_model_exclude_defaults self.response_model_exclude_none = response_model_exclude_none - self.include_in_schema = include_in_schema - self.response_class = response_class self.dependency_overrides_provider = dependency_overrides_provider - self.callbacks = callbacks self.openapi_extra = openapi_extra - self.generate_unique_id_function = generate_unique_id_function - self.tags = tags or [] - self.responses = responses or {} + self.router = router + self.name = get_name(endpoint) if name is None else name - self.path_regex, self.path_format, self.param_convertors = compile_path(path) + # normalize enums e.g. http.HTTPStatus + if isinstance(status_code, IntEnum): + status_code = int(status_code) + self.status_code = status_code if methods is None: methods = ["GET"] self.methods: Set[str] = set([method.upper() for method in methods]) - if isinstance(generate_unique_id_function, DefaultPlaceholder): - current_generate_unique_id: Callable[ + + self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") + # if a "form feed" character (page break) is found in the description text, + # truncate description text to the content preceding the first "form feed" + self.description = self.description.split("\f")[0] + + assert callable(endpoint), "An endpoint must be a callable" + + self.path_regex, self.path_format, self.param_convertors = compile_path( + self.path + ) + + # Attributes set in route used to compute resolved attributes + self._route_deprecated = deprecated + self._route_include_in_schema = include_in_schema + self._route_response_class = response_class + self._route_callbacks = callbacks + self._route_generate_unique_id_function = generate_unique_id_function + self._route_tags = tags or [] + self._route_responses = responses or {} + if dependencies: + self._route_dependencies = dependencies + else: + self._route_dependencies = [] + + self.setup() + + def setup(self) -> None: + # setup full path + self._route_full_path = self.path + if self.router: + self._route_full_path = self.router._router_full_path + self.path + + # setup dependencies + self.dependencies: List[params.Depends] = [] + if self.router: + self.dependencies.extend(self.router.dependencies) + self.dependencies.extend(self._route_dependencies) + + # setup generate_unique_id + generate_unique_id_functions: List[ + Union[Callable[[APIRoute], str], DefaultPlaceholder] + ] = [self._route_generate_unique_id_function] + if self.router: + generate_unique_id_functions.append(self.router.generate_unique_id_function) + current_generate_unique_id_function = get_value_or_default( + *generate_unique_id_functions + ) + self.generate_unique_id_function: Union[ + Callable[[APIRoute], str], DefaultPlaceholder + ] = current_generate_unique_id_function + + # setup responses + responses: Dict[Union[int, str], Dict[str, Any]] = {} + if self.router: + responses.update(self.router.responses) + responses.update(self._route_responses) + self.responses: Dict[Union[int, str], Dict[str, Any]] = responses + + # setup default_response_class + default_response_classes: List[Union[Type[Response], DefaultPlaceholder]] = [ + self._route_response_class + ] + if self.router: + default_response_classes.append(self.router.default_response_class) + current_default_response_class = get_value_or_default(*default_response_classes) + self.response_class: Union[ + Type[Response], DefaultPlaceholder + ] = current_default_response_class + + # setup tags + self.tags: List[Union[str, Enum]] = [] + if self.router: + self.tags.extend(self.router.tags) + self.tags.extend(self._route_tags) + + # setup callbacks + callbacks: List[BaseRoute] = [] + if self.router: + callbacks.extend(self.router.callbacks) + if self._route_callbacks: + callbacks.extend(self._route_callbacks) + self.callbacks = callbacks + + # setup deprecated + self.deprecated = self._route_deprecated + if self.router: + self.deprecated = self._route_deprecated or self.router.deprecated + + # setup include_in_schema + self.include_in_schema = self._route_include_in_schema + if self.router: + self.include_in_schema = ( + self._route_include_in_schema and self.router.include_in_schema + ) + + _, self._route_full_path_format, _ = compile_path(self._route_full_path) + + if isinstance(self.generate_unique_id_function, DefaultPlaceholder): + resolved_generate_unique_id: Callable[ ["APIRoute"], str - ] = generate_unique_id_function.value + ] = self.generate_unique_id_function.value else: - current_generate_unique_id = generate_unique_id_function - self.unique_id = self.operation_id or current_generate_unique_id(self) - # normalize enums e.g. http.HTTPStatus - if isinstance(status_code, IntEnum): - status_code = int(status_code) - self.status_code = status_code + resolved_generate_unique_id = self.generate_unique_id_function + self.unique_id = self.operation_id or resolved_generate_unique_id(self) + if self.response_model: assert ( - status_code not in STATUS_CODES_WITH_NO_BODY - ), f"Status code {status_code} must not have a response body" + self.status_code not in STATUS_CODES_WITH_NO_BODY + ), f"Status code {self.status_code} must not have a response body" response_name = "Response_" + self.unique_id self.response_field = create_response_field( name=response_name, type_=self.response_model @@ -397,14 +497,7 @@ class APIRoute(routing.Route): else: self.response_field = None # type: ignore self.secure_cloned_response_field = None - if dependencies: - self.dependencies = list(dependencies) - else: - self.dependencies = [] - self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") - # if a "form feed" character (page break) is found in the description text, - # truncate description text to the content preceding the first "form feed" - self.description = self.description.split("\f")[0] + response_fields = {} for additional_status_code, response in self.responses.items(): assert isinstance(response, dict), "An additional response must be a dict" @@ -421,16 +514,50 @@ class APIRoute(routing.Route): else: self.response_fields = {} - assert callable(endpoint), "An endpoint must be a callable" - self.dependant = get_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_dependant( + path=self._route_full_path_format, call=self.endpoint + ) for depends in self.dependencies[::-1]: self.dependant.dependencies.insert( 0, - get_parameterless_sub_dependant(depends=depends, path=self.path_format), + get_parameterless_sub_dependant( + depends=depends, path=self._route_full_path_format + ), ) self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id) self.app = request_response(self.get_route_handler()) + def copy(self: APIRouteType) -> APIRouteType: + return type(self)( + path=self.path, + endpoint=self.endpoint, + response_model=self.response_model, + status_code=self.status_code, + tags=self._route_tags, + dependencies=self._route_dependencies, + summary=self.summary, + description=self.description, + response_description=self.response_description, + responses=self._route_responses, + deprecated=self._route_deprecated, + name=self.name, + methods=self.methods, + operation_id=self.operation_id, + response_model_include=self.response_model_include, + response_model_exclude=self.response_model_exclude, + response_model_by_alias=self.response_model_by_alias, + response_model_exclude_unset=self.response_model_exclude_unset, + response_model_exclude_defaults=self.response_model_exclude_defaults, + response_model_exclude_none=self.response_model_exclude_none, + include_in_schema=self._route_include_in_schema, + response_class=self._route_response_class, + dependency_overrides_provider=self.dependency_overrides_provider, + callbacks=self._route_callbacks, + openapi_extra=self.openapi_extra, + generate_unique_id_function=self._route_generate_unique_id_function, + router=self.router, + ) + def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: return get_request_handler( dependant=self.dependant, @@ -476,6 +603,7 @@ class APIRouter(routing.Router): generate_unique_id_function: Callable[[APIRoute], str] = Default( generate_unique_id ), + parent_router: Optional["APIRouter"] = None, ) -> None: super().__init__( routes=routes, # type: ignore # in Starlette @@ -490,16 +618,151 @@ class APIRouter(routing.Router): "/" ), "A path prefix must not end with '/', as the routes will start with '/'" self.prefix = prefix - self.tags: List[Union[str, Enum]] = tags or [] - self.dependencies = list(dependencies or []) or [] - self.deprecated = deprecated - self.include_in_schema = include_in_schema - self.responses = responses or {} - self.callbacks = callbacks or [] self.dependency_overrides_provider = dependency_overrides_provider self.route_class = route_class - self.default_response_class = default_response_class - self.generate_unique_id_function = generate_unique_id_function + + self.parent_router = parent_router + + # Attributes set in router used to compute resolved attributes + self._router_dependencies = list(dependencies or []) or [] + self._router_generate_unique_id_function = generate_unique_id_function + self._router_responses = responses or {} + self._router_default_response_class = default_response_class + self._router_tags: List[Union[str, Enum]] = tags or [] + self._router_callbacks = callbacks or [] + self._router_deprecated = deprecated + self._router_include_in_schema = include_in_schema + self._router_has_empty_route = False + self._router_has_root_route = False + self.setup() + + def setup(self) -> None: + # setup full path + self._router_full_path = self.prefix + if self.parent_router: + self._router_full_path = self.parent_router._router_full_path + self.prefix + # setup dependencies + self.dependencies: List[params.Depends] = [] + if self.parent_router: + self.dependencies.extend(self.parent_router.dependencies) + self.dependencies.extend(self._router_dependencies) + + # setup generate_unique_id + generate_unique_id_functions: List[ + Union[Callable[[APIRoute], str], DefaultPlaceholder] + ] = [self._router_generate_unique_id_function] + if self.parent_router: + generate_unique_id_functions.append( + self.parent_router.generate_unique_id_function + ) + current_generate_unique_id_function = get_value_or_default( + *generate_unique_id_functions + ) + self.generate_unique_id_function: Union[ + Callable[[APIRoute], str], DefaultPlaceholder + ] = current_generate_unique_id_function + + # setup responses + responses: Dict[Union[int, str], Dict[str, Any]] = {} + if self.parent_router: + responses.update(self.parent_router.responses) + responses.update(self._router_responses) + self.responses: Dict[Union[int, str], Dict[str, Any]] = responses + + # setup default_response_class + default_response_classes: List[Union[Type[Response], DefaultPlaceholder]] = [ + self._router_default_response_class + ] + if self.parent_router: + default_response_classes.append(self.parent_router.default_response_class) + current_default_response_class = get_value_or_default(*default_response_classes) + self.default_response_class: Union[ + Type[Response], DefaultPlaceholder + ] = current_default_response_class + + # setup tags + self.tags: List[Union[str, Enum]] = [] + if self.parent_router: + self.tags.extend(self.parent_router.tags) + self.tags.extend(self._router_tags) + + # setup callbacks + self.callbacks: List[BaseRoute] = [] + if self.parent_router: + self.callbacks.extend(self.parent_router.callbacks) + self.callbacks.extend(self._router_callbacks) + + # setup deprecated + self.deprecated = self._router_deprecated + if self.parent_router: + self.deprecated = self._router_deprecated or self.parent_router.deprecated + + # setup include_in_schema + self.include_in_schema = self._router_include_in_schema + if self.parent_router: + self.include_in_schema = ( + self._router_include_in_schema and self.parent_router.include_in_schema + ) + + # setup routes + for route in self.routes: + if isinstance(route, APIRoute): + route.router = self + route.setup() + elif isinstance(route, APIMount): + route.parent_router = self + route.setup() + + def copy(self: APIRouterType) -> APIRouterType: + routes: List[routing.BaseRoute] = [] + for route in self.routes: + if isinstance(route, APIRoute): + routes.append(route.copy()) + elif isinstance(route, APIMount): + routes.append(route.copy()) + else: + routes.append(route) + copied_router = type(self)( + prefix=self.prefix, + tags=self._router_tags, + dependencies=self._router_dependencies, + default_response_class=self._router_default_response_class, + responses=self._router_responses, + callbacks=self._router_callbacks, + routes=routes, + redirect_slashes=self.redirect_slashes, + default=self.default, + dependency_overrides_provider=self.dependency_overrides_provider, + route_class=self.route_class, + on_startup=self.on_startup, + on_shutdown=self.on_shutdown, + deprecated=self._router_deprecated, + include_in_schema=self._router_include_in_schema, + generate_unique_id_function=self._router_generate_unique_id_function, + parent_router=self.parent_router, + ) + copied_router._router_has_empty_route = self._router_has_empty_route + copied_router._router_has_root_route = self._router_has_root_route + for route in copied_router.routes: + if isinstance(route, APIRoute): + route.router = copied_router + route.setup() + elif isinstance(route, Mount): + if isinstance(route.app, APIRouter): + route.app.setup() + return copied_router + + def iter_all_routes(self) -> Iterator[routing.BaseRoute]: + for route in self.routes: + if isinstance(route, Mount): + if isinstance(route.app, APIRouter): + yield from route.app.iter_all_routes() + else: + yield route + + def api_mount(self, router: "APIRouter", name: Optional[str] = None) -> None: + route = APIMount(router=router, name=name, parent_router=self) + self.routes.append(route) def add_api_route( self, @@ -537,34 +800,18 @@ class APIRouter(routing.Router): ) -> None: route_class = route_class_override or self.route_class responses = responses or {} - combined_responses = {**self.responses, **responses} - current_response_class = get_value_or_default( - response_class, self.default_response_class - ) - current_tags = self.tags.copy() - if tags: - current_tags.extend(tags) - current_dependencies = self.dependencies.copy() - if dependencies: - current_dependencies.extend(dependencies) - current_callbacks = self.callbacks.copy() - if callbacks: - current_callbacks.extend(callbacks) - current_generate_unique_id = get_value_or_default( - generate_unique_id_function, self.generate_unique_id_function - ) route = route_class( - self.prefix + path, + path, endpoint=endpoint, response_model=response_model, status_code=status_code, - tags=current_tags, - dependencies=current_dependencies, + tags=tags, + dependencies=dependencies, summary=summary, description=description, response_description=response_description, - responses=combined_responses, - deprecated=deprecated or self.deprecated, + responses=responses, + deprecated=deprecated, methods=methods, operation_id=operation_id, response_model_include=response_model_include, @@ -573,15 +820,20 @@ class APIRouter(routing.Router): response_model_exclude_unset=response_model_exclude_unset, response_model_exclude_defaults=response_model_exclude_defaults, response_model_exclude_none=response_model_exclude_none, - include_in_schema=include_in_schema and self.include_in_schema, - response_class=current_response_class, + include_in_schema=include_in_schema, + response_class=response_class, name=name, dependency_overrides_provider=self.dependency_overrides_provider, - callbacks=current_callbacks, + callbacks=callbacks, openapi_extra=openapi_extra, - generate_unique_id_function=current_generate_unique_id, + generate_unique_id_function=generate_unique_id_function, + router=self, ) self.routes.append(route) + if not path: + self._router_has_empty_route = True + if path == "/": + self._router_has_root_route = True def api_route( self, @@ -680,103 +932,197 @@ class APIRouter(routing.Router): generate_unique_id_function: Callable[[APIRoute], str] = Default( generate_unique_id ), + copy_flat_routes: Optional[bool] = None, ) -> None: if prefix: assert prefix.startswith("/"), "A path prefix must start with '/'" assert not prefix.endswith( "/" ), "A path prefix must not end with '/', as the routes will start with '/'" + resolved_copy_flat_routes = copy_flat_routes + if resolved_copy_flat_routes is None: + resolved_copy_flat_routes = not (prefix or router.prefix) + if not resolved_copy_flat_routes: + included_router = router.copy() + if ( + prefix + or tags + or dependencies + or not isinstance(default_response_class, DefaultPlaceholder) + or responses + or callbacks + or deprecated is not None + or include_in_schema is not True + or not isinstance(generate_unique_id_function, DefaultPlaceholder) + ): + current_router = type(self)( + prefix=prefix, + tags=tags, + dependencies=dependencies, + default_response_class=default_response_class, + responses=responses, + callbacks=callbacks, + deprecated=deprecated, + include_in_schema=include_in_schema, + generate_unique_id_function=generate_unique_id_function, + parent_router=self, + ) + # current_router.api_mount(included_router) + current_router.include_router(included_router) + if included_router._router_has_empty_route and not self.prefix: + current_router._router_has_empty_route = True + current_router._router_has_root_route = ( + included_router._router_has_root_route + ) + self.api_mount(current_router) + included_router.parent_router = current_router + else: + self.api_mount(included_router) + included_router.parent_router = self + + included_router.setup() else: - for r in router.routes: - path = getattr(r, "path") - name = getattr(r, "name", "unknown") - if path is not None and not path: - raise Exception( - f"Prefix and path cannot be both empty (path operation: {name})" + # TODO: remove this and its test, as a subrouter can mount another + # subrouter (done automatically of other things are overwritten) and both + # can omit a prefix, this would error out + # for r in router.routes: + # path = getattr(r, "path") + # name = getattr(r, "name", "unknown") + # if path is not None and not path: + # raise Exception( + # f"Prefix and path cannot be both empty (path operation: {name})" + # ) + if responses is None: + responses = {} + for route in router.routes: + if isinstance(route, APIRoute): + combined_responses = {} + if route.router: + combined_responses.update(route.router.responses) + combined_responses.update(responses) + combined_responses.update(route.responses) + + response_classes: List[ + Union[Type[Response], DefaultPlaceholder] + ] = [] + if route.router: + response_classes.append(route.router.default_response_class) + response_classes.extend( + [ + route.response_class, + router.default_response_class, + default_response_class, + self.default_response_class, + ] ) - if responses is None: - responses = {} - for route in router.routes: - if isinstance(route, APIRoute): - combined_responses = {**responses, **route.responses} - use_response_class = get_value_or_default( - route.response_class, - router.default_response_class, - default_response_class, - self.default_response_class, - ) - current_tags = [] - if tags: - current_tags.extend(tags) - if route.tags: - current_tags.extend(route.tags) - current_dependencies: List[params.Depends] = [] - if dependencies: - current_dependencies.extend(dependencies) - if route.dependencies: - current_dependencies.extend(route.dependencies) - current_callbacks = [] - if callbacks: - current_callbacks.extend(callbacks) - if route.callbacks: - current_callbacks.extend(route.callbacks) - current_generate_unique_id = get_value_or_default( - route.generate_unique_id_function, - router.generate_unique_id_function, - generate_unique_id_function, - self.generate_unique_id_function, - ) - self.add_api_route( - prefix + route.path, - route.endpoint, - response_model=route.response_model, - status_code=route.status_code, - tags=current_tags, - dependencies=current_dependencies, - summary=route.summary, - description=route.description, - response_description=route.response_description, - responses=combined_responses, - deprecated=route.deprecated or deprecated or self.deprecated, - methods=route.methods, - operation_id=route.operation_id, - response_model_include=route.response_model_include, - response_model_exclude=route.response_model_exclude, - response_model_by_alias=route.response_model_by_alias, - response_model_exclude_unset=route.response_model_exclude_unset, - response_model_exclude_defaults=route.response_model_exclude_defaults, - response_model_exclude_none=route.response_model_exclude_none, - include_in_schema=route.include_in_schema - and self.include_in_schema - and include_in_schema, - response_class=use_response_class, - name=route.name, - route_class_override=type(route), - callbacks=current_callbacks, - openapi_extra=route.openapi_extra, - generate_unique_id_function=current_generate_unique_id, - ) - elif isinstance(route, routing.Route): - methods = list(route.methods or []) # type: ignore # in Starlette - self.add_route( - prefix + route.path, - route.endpoint, - methods=methods, - include_in_schema=route.include_in_schema, - name=route.name, - ) - elif isinstance(route, APIWebSocketRoute): - self.add_api_websocket_route( - prefix + route.path, route.endpoint, name=route.name - ) - elif isinstance(route, routing.WebSocketRoute): - self.add_websocket_route( - prefix + route.path, route.endpoint, name=route.name - ) - for handler in router.on_startup: - self.add_event_handler("startup", handler) - for handler in router.on_shutdown: - self.add_event_handler("shutdown", handler) + use_response_class = get_value_or_default(*response_classes) + current_tags = [] + if route.router: + current_tags.extend(route.router.tags) + if tags: + current_tags.extend(tags) + if route.tags: + current_tags.extend(route.tags) + current_dependencies: List[params.Depends] = [] + if route.router: + current_dependencies.extend(route.router.dependencies) + if dependencies: + current_dependencies.extend(dependencies) + if route.dependencies: + current_dependencies.extend(route.dependencies) + current_callbacks = [] + if route.router: + current_callbacks.extend(route.router.callbacks) + if callbacks: + current_callbacks.extend(callbacks) + if route.callbacks: + current_callbacks.extend(route.callbacks) + + generate_unique_id_functions: List[ + Union[Callable[[APIRoute], str], DefaultPlaceholder] + ] = [] + if route.router: + generate_unique_id_functions.append( + route.router.generate_unique_id_function + ) + generate_unique_id_functions.extend( + [ + route.generate_unique_id_function, + router.generate_unique_id_function, + generate_unique_id_function, + self.generate_unique_id_function, + ] + ) + current_generate_unique_id_function = get_value_or_default( + *generate_unique_id_functions + ) + path = prefix + route.path + if route.router: + path = prefix + route.router.prefix + path + self.add_api_route( + path, + route.endpoint, + response_model=route.response_model, + status_code=route.status_code, + tags=current_tags, + dependencies=current_dependencies, + summary=route.summary, + description=route.description, + response_description=route.response_description, + responses=combined_responses, + deprecated=route.deprecated or deprecated or self.deprecated, + methods=route.methods, + operation_id=route.operation_id, + response_model_include=route.response_model_include, + response_model_exclude=route.response_model_exclude, + response_model_by_alias=route.response_model_by_alias, + response_model_exclude_unset=route.response_model_exclude_unset, + response_model_exclude_defaults=route.response_model_exclude_defaults, + response_model_exclude_none=route.response_model_exclude_none, + include_in_schema=route.include_in_schema + and self.include_in_schema + and include_in_schema, + response_class=use_response_class, + name=route.name, + route_class_override=type(route), + callbacks=current_callbacks, + openapi_extra=route.openapi_extra, + generate_unique_id_function=current_generate_unique_id_function, + ) + elif isinstance(route, APIMount): + self.include_router( + route.app, + prefix=prefix, + tags=tags, + dependencies=dependencies, + default_response_class=default_response_class, + responses=responses, + callbacks=callbacks, + deprecated=deprecated, + include_in_schema=include_in_schema, + generate_unique_id_function=generate_unique_id_function, + ) + elif isinstance(route, routing.Route): + methods = list(route.methods or []) # type: ignore # in Starlette + self.add_route( + prefix + route.path, + route.endpoint, + methods=methods, + include_in_schema=route.include_in_schema, + name=route.name, + ) + elif isinstance(route, APIWebSocketRoute): + self.add_api_websocket_route( + prefix + route.path, route.endpoint, name=route.name + ) + elif isinstance(route, routing.WebSocketRoute): + self.add_websocket_route( + prefix + route.path, route.endpoint, name=route.name + ) + for handler in router.on_startup: + self.add_event_handler("startup", handler) + for handler in router.on_shutdown: + self.add_event_handler("shutdown", handler) def get( self, @@ -1226,3 +1572,100 @@ class APIRouter(routing.Router): openapi_extra=openapi_extra, generate_unique_id_function=generate_unique_id_function, ) + + +class APIMount(routing.Mount): + def __init__( + self, + router: APIRouter, + *, + name: Optional[str] = None, + parent_router: Optional[APIRouter] = None, + ) -> None: + self.name = name # type: ignore # in Starlette + self.parent_router = parent_router + self.router = router + + self.setup() + + def setup(self) -> None: + self.app: APIRouter = self.router.copy() + if self.parent_router: + self.app.parent_router = self.parent_router + self.app.setup() + self.path = self.app.prefix + self.path_regex, self.path_format, self.param_convertors = compile_path( + self.path + "/{path:path}" + ) + + # Add custom additional root without trailing slash for compatibility with + # include_router and possibly app migrations + # Ref: https://github.com/tiangolo/fastapi/issues/414 + ( + self._root_path_regex, + self._root_path_format, + self._root_param_convertors, + ) = compile_path(self.path) + ( + self._root_path_regex_trailing, + self._root_path_format_trailing, + self._root_param_convertors_trailing, + ) = compile_path(self.path + "/") + + def copy(self: APIMountType) -> APIMountType: + return type(self)( + router=self.router.copy(), + name=self.name, + parent_router=self.parent_router, + ) + + def matches(self, scope: Scope) -> Tuple[Match, Scope]: + if scope["type"] in ("http", "websocket"): + path = scope["path"] + if self.app._router_has_empty_route: + # Custom logic to support paths without trailing slash + # Ref: https://github.com/tiangolo/fastapi/issues/414 + # This mixes the code in + # starlette.routing.Route.matches() and starlette.routing.Mount.matches() + match = self._root_path_regex.match(path) + if match: + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key].convert(value) + path_params = dict(scope.get("path_params", {})) + path_params.update(matched_params) + root_path = scope.get("root_path", "") + child_scope = { + "path_params": path_params, + "app_root_path": scope.get("app_root_path", root_path), + "root_path": root_path, + "path": "", + "endpoint": self.app, + } + return Match.FULL, child_scope + if not self.app._router_has_root_route: + match = self._root_path_regex_trailing.match(path) + if match: + return Match.NONE, {} + # End of custom logic + # Duplicated code from Starlette + match = self.path_regex.match(path) + if match: + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key].convert(value) + remaining_path = "/" + matched_params.pop("path") + matched_path = path[: -len(remaining_path)] + path_params = dict(scope.get("path_params", {})) + path_params.update(matched_params) + root_path = scope.get("root_path", "") + child_scope = { + "path_params": path_params, + "app_root_path": scope.get("app_root_path", root_path), + "root_path": root_path + matched_path, + "path": remaining_path, + "endpoint": self.app, + } + return Match.FULL, child_scope + return Match.NONE, {} + # End of duplicated code from Starlette diff --git a/fastapi/utils.py b/fastapi/utils.py index b9301499a..9f832a6d4 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -139,7 +139,7 @@ def generate_operation_id_for_path( def generate_unique_id(route: "APIRoute") -> str: - operation_id = route.name + route.path_format + operation_id = route.name + route._route_full_path_format operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id) assert route.methods operation_id = operation_id + "_" + list(route.methods)[0].lower() diff --git a/tests/test_custom_route_class.py b/tests/test_custom_route_class.py index 1a9ea7199..7d8b5f141 100644 --- a/tests/test_custom_route_class.py +++ b/tests/test_custom_route_class.py @@ -107,9 +107,9 @@ def test_get_path(path, expected_status, expected_response): def test_route_classes(): routes = {} - for r in app.router.routes: - assert isinstance(r, Route) - routes[r.path] = r + for r in app.router.iter_all_routes(): + if isinstance(r, APIRoute): + routes[r._route_full_path_format] = r assert getattr(routes["/a/"], "x_type") == "A" assert getattr(routes["/a/b/"], "x_type") == "B" assert getattr(routes["/a/b/c/"], "x_type") == "C"