Browse Source

Use helper method to avoid code duplication

pull/12145/head
Synrom 7 months ago
parent
commit
031c69fa9e
  1. 38
      fastapi/routing.py

38
fastapi/routing.py

@ -879,12 +879,9 @@ class APIRouter(routing.Router):
name: Optional[str] = None,
include_in_schema: bool = True,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
if self.ignore_trailing_slash:
path = path.rstrip("/")
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_route(
path,
self._normalize_path(path),
func,
methods=methods,
name=name,
@ -928,8 +925,7 @@ class APIRouter(routing.Router):
Callable[[APIRoute], str], DefaultPlaceholder
] = Default(generate_unique_id),
) -> None:
if self.ignore_trailing_slash:
path = path.rstrip("/")
path = self._normalize_path(path)
route_class = route_class_override or self.route_class
responses = responses or {}
combined_responses = {**self.responses, **responses}
@ -1046,8 +1042,7 @@ class APIRouter(routing.Router):
endpoint: Callable[..., Any],
name: Optional[str] = None,
) -> None:
if self.ignore_trailing_slash:
path = path.rstrip("/")
path = self._normalize_path(path)
super().add_websocket_route(path, endpoint, name)
def add_api_websocket_route(
@ -1058,8 +1053,7 @@ class APIRouter(routing.Router):
*,
dependencies: Optional[Sequence[params.Depends]] = None,
) -> None:
if self.ignore_trailing_slash:
path = path.rstrip("/")
path = self._normalize_path(path)
current_dependencies = self.dependencies.copy()
if dependencies:
current_dependencies.extend(dependencies)
@ -1143,8 +1137,7 @@ class APIRouter(routing.Router):
def websocket_route(
self, path: str, name: Union[str, None] = None
) -> Callable[[DecoratedCallable], DecoratedCallable]:
if self.ignore_trailing_slash:
path = path.rstrip("/")
path = self._normalize_path(path)
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_websocket_route(path, func, name=name)
@ -1304,9 +1297,7 @@ class APIRouter(routing.Router):
responses = {}
for route in router.routes:
if isinstance(route, APIRoute):
path = route.path
if self.ignore_trailing_slash:
path = path.rstrip("/")
path = self._normalize_path(route.path)
combined_responses = {**responses, **route.responses}
use_response_class = get_value_or_default(
route.response_class,
@ -1366,9 +1357,7 @@ class APIRouter(routing.Router):
generate_unique_id_function=current_generate_unique_id,
)
elif isinstance(route, routing.Route):
path = route.path
if self.ignore_trailing_slash:
path = path.rstrip("/")
path = self._normalize_path(route.path)
methods = list(route.methods or [])
self.add_route(
prefix + path,
@ -1378,9 +1367,7 @@ class APIRouter(routing.Router):
name=route.name,
)
elif isinstance(route, APIWebSocketRoute):
path = route.path
if self.ignore_trailing_slash:
path = path.rstrip("/")
path = self._normalize_path(route.path)
current_dependencies = []
if dependencies:
current_dependencies.extend(dependencies)
@ -1393,9 +1380,7 @@ class APIRouter(routing.Router):
name=route.name,
)
elif isinstance(route, routing.WebSocketRoute):
path = route.path
if self.ignore_trailing_slash:
path = path.rstrip("/")
path = self._normalize_path(route.path)
self.add_websocket_route(prefix + path, route.endpoint, name=route.name)
for handler in router.on_startup:
self.add_event_handler("startup", handler)
@ -1406,6 +1391,11 @@ class APIRouter(routing.Router):
router.lifespan_context,
)
def _normalize_path(self, path: str) -> str:
if self.ignore_trailing_slash:
return path.rstrip("/")
return path
def get(
self,
path: Annotated[

Loading…
Cancel
Save