From 031c69fa9e56f794aed5da9c952fa7bc669ca59d Mon Sep 17 00:00:00 2001 From: Synrom <max.leiwig@gmail.com> Date: Sun, 8 Sep 2024 13:36:26 +0200 Subject: [PATCH] Use helper method to avoid code duplication --- fastapi/routing.py | 38 ++++++++++++++------------------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 33e3b7434..140e6e75a 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -879,12 +879,9 @@ class APIRouter(routing.Router): name: Optional[str] = None, include_in_schema: bool = True, ) -> Callable[[DecoratedCallable], DecoratedCallable]: - if self.ignore_trailing_slash: - path = path.rstrip("/") - def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_route( - path, + self._normalize_path(path), func, methods=methods, name=name, @@ -928,8 +925,7 @@ class APIRouter(routing.Router): Callable[[APIRoute], str], DefaultPlaceholder ] = Default(generate_unique_id), ) -> None: - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(path) route_class = route_class_override or self.route_class responses = responses or {} combined_responses = {**self.responses, **responses} @@ -1046,8 +1042,7 @@ class APIRouter(routing.Router): endpoint: Callable[..., Any], name: Optional[str] = None, ) -> None: - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(path) super().add_websocket_route(path, endpoint, name) def add_api_websocket_route( @@ -1058,8 +1053,7 @@ class APIRouter(routing.Router): *, dependencies: Optional[Sequence[params.Depends]] = None, ) -> None: - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(path) current_dependencies = self.dependencies.copy() if dependencies: current_dependencies.extend(dependencies) @@ -1143,8 +1137,7 @@ class APIRouter(routing.Router): def websocket_route( self, path: str, name: Union[str, None] = None ) -> Callable[[DecoratedCallable], DecoratedCallable]: - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(path) def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_websocket_route(path, func, name=name) @@ -1304,9 +1297,7 @@ class APIRouter(routing.Router): responses = {} for route in router.routes: if isinstance(route, APIRoute): - path = route.path - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(route.path) combined_responses = {**responses, **route.responses} use_response_class = get_value_or_default( route.response_class, @@ -1366,9 +1357,7 @@ class APIRouter(routing.Router): generate_unique_id_function=current_generate_unique_id, ) elif isinstance(route, routing.Route): - path = route.path - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(route.path) methods = list(route.methods or []) self.add_route( prefix + path, @@ -1378,9 +1367,7 @@ class APIRouter(routing.Router): name=route.name, ) elif isinstance(route, APIWebSocketRoute): - path = route.path - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(route.path) current_dependencies = [] if dependencies: current_dependencies.extend(dependencies) @@ -1393,9 +1380,7 @@ class APIRouter(routing.Router): name=route.name, ) elif isinstance(route, routing.WebSocketRoute): - path = route.path - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(route.path) self.add_websocket_route(prefix + path, route.endpoint, name=route.name) for handler in router.on_startup: self.add_event_handler("startup", handler) @@ -1406,6 +1391,11 @@ class APIRouter(routing.Router): router.lifespan_context, ) + def _normalize_path(self, path: str) -> str: + if self.ignore_trailing_slash: + return path.rstrip("/") + return path + def get( self, path: Annotated[