From 9293795e99afbda07d2744f1eb7d23d2f0ea0154 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 8 Feb 2023 11:23:07 +0100 Subject: [PATCH] =?UTF-8?q?=E2=AC=86=EF=B8=8F=20Bump=20Starlette=20from=20?= =?UTF-8?q?0.22.0=20to=200.23.0=20(#5739)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sebastián Ramírez --- docs_src/app_testing/tutorial002.py | 2 +- fastapi/applications.py | 33 +++++++++++++++++++++++++ fastapi/routing.py | 37 +++++++++++++++++++++++++++++ pyproject.toml | 2 +- tests/test_route_scope.py | 4 ++-- 5 files changed, 74 insertions(+), 4 deletions(-) diff --git a/docs_src/app_testing/tutorial002.py b/docs_src/app_testing/tutorial002.py index b4a9c0586..71c898b3c 100644 --- a/docs_src/app_testing/tutorial002.py +++ b/docs_src/app_testing/tutorial002.py @@ -10,7 +10,7 @@ async def read_main(): return {"msg": "Hello World"} -@app.websocket_route("/ws") +@app.websocket("/ws") async def websocket(websocket: WebSocket): await websocket.accept() await websocket.send_json({"msg": "Hello WebSocket"}) diff --git a/fastapi/applications.py b/fastapi/applications.py index 36dc2605d..160d66301 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -35,6 +35,7 @@ from starlette.applications import Starlette from starlette.datastructures import State from starlette.exceptions import HTTPException from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.errors import ServerErrorMiddleware from starlette.middleware.exceptions import ExceptionMiddleware from starlette.requests import Request @@ -870,3 +871,35 @@ class FastAPI(Starlette): openapi_extra=openapi_extra, generate_unique_id_function=generate_unique_id_function, ) + + def websocket_route( + self, path: str, name: Union[str, None] = None + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + def decorator(func: DecoratedCallable) -> DecoratedCallable: + self.router.add_websocket_route(path, func, name=name) + return func + + return decorator + + def on_event( + self, event_type: str + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + return self.router.on_event(event_type) + + def middleware( + self, middleware_type: str + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + def decorator(func: DecoratedCallable) -> DecoratedCallable: + self.add_middleware(BaseHTTPMiddleware, dispatch=func) + return func + + return decorator + + def exception_handler( + self, exc_class_or_status_code: Union[int, Type[Exception]] + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + def decorator(func: DecoratedCallable) -> DecoratedCallable: + self.add_exception_handler(exc_class_or_status_code, func) + return func + + return decorator diff --git a/fastapi/routing.py b/fastapi/routing.py index f131fa903..7ab6275b6 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -522,6 +522,25 @@ class APIRouter(routing.Router): self.default_response_class = default_response_class self.generate_unique_id_function = generate_unique_id_function + def route( + self, + path: str, + methods: Optional[List[str]] = None, + name: Optional[str] = None, + include_in_schema: bool = True, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + def decorator(func: DecoratedCallable) -> DecoratedCallable: + self.add_route( + path, + func, + methods=methods, + name=name, + include_in_schema=include_in_schema, + ) + return func + + return decorator + def add_api_route( self, path: str, @@ -686,6 +705,15 @@ class APIRouter(routing.Router): return decorator + def websocket_route( + self, path: str, name: Union[str, None] = None + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + def decorator(func: DecoratedCallable) -> DecoratedCallable: + self.add_websocket_route(path, func, name=name) + return func + + return decorator + def include_router( self, router: "APIRouter", @@ -1247,3 +1275,12 @@ class APIRouter(routing.Router): openapi_extra=openapi_extra, generate_unique_id_function=generate_unique_id_function, ) + + def on_event( + self, event_type: str + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + def decorator(func: DecoratedCallable) -> DecoratedCallable: + self.add_event_handler(event_type, func) + return func + + return decorator diff --git a/pyproject.toml b/pyproject.toml index 7fb8078f9..4498f9432 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ classifiers = [ "Topic :: Internet :: WWW/HTTP", ] dependencies = [ - "starlette==0.22.0", + "starlette>=0.22.0,<=0.23.0", "pydantic >=1.6.2,!=1.7,!=1.7.1,!=1.7.2,!=1.7.3,!=1.8,!=1.8.1,<2.0.0", ] dynamic = ["version"] diff --git a/tests/test_route_scope.py b/tests/test_route_scope.py index a188e9a5f..2021c828f 100644 --- a/tests/test_route_scope.py +++ b/tests/test_route_scope.py @@ -46,5 +46,5 @@ def test_websocket(): def test_websocket_invalid_path_doesnt_match(): with pytest.raises(WebSocketDisconnect): - with client.websocket_connect("/itemsx/portal-gun") as websocket: - websocket.receive_json() + with client.websocket_connect("/itemsx/portal-gun"): + pass