From 4e8080f29013dcb6850518bb9011bf82c9081828 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 29 Feb 2020 21:28:23 +0100 Subject: [PATCH] :pushpin: Upgrade Starlette version (#1057) --- fastapi/applications.py | 23 +++++++--- fastapi/routing.py | 12 ++++- pyproject.toml | 2 +- tests/test_empty_router.py | 4 +- tests/test_router_events.py | 87 +++++++++++++++++++++++++++++++++++++ 5 files changed, 118 insertions(+), 10 deletions(-) create mode 100644 tests/test_router_events.py diff --git a/fastapi/applications.py b/fastapi/applications.py index ff2eb52c8..8270e54fd 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -18,8 +18,8 @@ from fastapi.params import Depends from fastapi.utils import warning_response_model_skip_defaults_deprecated from starlette.applications import Starlette from starlette.datastructures import State -from starlette.exceptions import ExceptionMiddleware, HTTPException -from starlette.middleware.errors import ServerErrorMiddleware +from starlette.exceptions import HTTPException +from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import HTMLResponse, JSONResponse, Response from starlette.routing import BaseRoute @@ -29,9 +29,9 @@ from starlette.types import Receive, Scope, Send class FastAPI(Starlette): def __init__( self, + *, debug: bool = False, routes: List[BaseRoute] = None, - template_directory: str = None, title: str = "FastAPI", description: str = "", version: str = "0.1.0", @@ -42,19 +42,28 @@ class FastAPI(Starlette): redoc_url: Optional[str] = "/redoc", swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect", swagger_ui_init_oauth: Optional[dict] = None, + middleware: Sequence[Middleware] = None, + exception_handlers: Dict[Union[int, Type[Exception]], Callable] = None, + on_startup: Sequence[Callable] = None, + on_shutdown: Sequence[Callable] = None, **extra: Dict[str, Any], ) -> None: self.default_response_class = default_response_class self._debug = debug self.state = State() self.router: routing.APIRouter = routing.APIRouter( - routes, dependency_overrides_provider=self + routes, + dependency_overrides_provider=self, + on_startup=on_startup, + on_shutdown=on_shutdown, ) - self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) - self.error_middleware = ServerErrorMiddleware( - self.exception_middleware, debug=debug + self.exception_handlers = ( + {} if exception_handlers is None else dict(exception_handlers) ) + self.user_middleware = [] if middleware is None else list(middleware) + self.middleware_stack = self.build_middleware_stack() + self.title = title self.description = description self.version = version diff --git a/fastapi/routing.py b/fastapi/routing.py index d5211c489..7f3dc4977 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -346,9 +346,15 @@ class APIRouter(routing.Router): dependency_overrides_provider: Any = None, route_class: Type[APIRoute] = APIRoute, default_response_class: Type[Response] = None, + on_startup: Sequence[Callable] = None, + on_shutdown: Sequence[Callable] = None, ) -> None: super().__init__( - routes=routes, redirect_slashes=redirect_slashes, default=default + routes=routes, + redirect_slashes=redirect_slashes, + default=default, + on_startup=on_startup, + on_shutdown=on_shutdown, ) self.dependency_overrides_provider = dependency_overrides_provider self.route_class = route_class @@ -552,6 +558,10 @@ class APIRouter(routing.Router): self.add_websocket_route( prefix + route.path, route.endpoint, name=route.name ) + for handler in router.on_startup: + self.add_event_handler("startup", handler) + for handler in router.on_shutdown: + self.add_event_handler("shutdown", handler) def get( self, diff --git a/pyproject.toml b/pyproject.toml index 7761a54c9..3805104ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ classifiers = [ "Topic :: Internet :: WWW/HTTP", ] requires = [ - "starlette >=0.12.9,<=0.12.9", + "starlette ==0.13.2", "pydantic >=0.32.2,<2.0.0" ] description-file = "README.md" diff --git a/tests/test_empty_router.py b/tests/test_empty_router.py index 57dd006fa..c38fae855 100644 --- a/tests/test_empty_router.py +++ b/tests/test_empty_router.py @@ -21,10 +21,12 @@ client = TestClient(app) def test_use_empty(): with client: response = client.get("/prefix") + assert response.status_code == 200 assert response.json() == ["OK"] response = client.get("/prefix/") - assert response.status_code == 404 + assert response.status_code == 200 + assert response.json() == ["OK"] def test_include_empty(): diff --git a/tests/test_router_events.py b/tests/test_router_events.py new file mode 100644 index 000000000..3a499b149 --- /dev/null +++ b/tests/test_router_events.py @@ -0,0 +1,87 @@ +from fastapi import APIRouter, FastAPI +from pydantic import BaseModel +from starlette.testclient import TestClient + + +class State(BaseModel): + app_startup: bool = False + app_shutdown: bool = False + router_startup: bool = False + router_shutdown: bool = False + sub_router_startup: bool = False + sub_router_shutdown: bool = False + + +state = State() + +app = FastAPI() + + +@app.on_event("startup") +def app_startup(): + state.app_startup = True + + +@app.on_event("shutdown") +def app_shutdown(): + state.app_shutdown = True + + +router = APIRouter() + + +@router.on_event("startup") +def router_startup(): + state.router_startup = True + + +@router.on_event("shutdown") +def router_shutdown(): + state.router_shutdown = True + + +sub_router = APIRouter() + + +@sub_router.on_event("startup") +def sub_router_startup(): + state.sub_router_startup = True + + +@sub_router.on_event("shutdown") +def sub_router_shutdown(): + state.sub_router_shutdown = True + + +@sub_router.get("/") +def main(): + return {"message": "Hello World"} + + +router.include_router(sub_router) +app.include_router(router) + + +def test_router_events(): + assert state.app_startup is False + assert state.router_startup is False + assert state.sub_router_startup is False + assert state.app_shutdown is False + assert state.router_shutdown is False + assert state.sub_router_shutdown is False + with TestClient(app) as client: + assert state.app_startup is True + assert state.router_startup is True + assert state.sub_router_startup is True + assert state.app_shutdown is False + assert state.router_shutdown is False + assert state.sub_router_shutdown is False + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"message": "Hello World"} + assert state.app_startup is True + assert state.router_startup is True + assert state.sub_router_startup is True + assert state.app_shutdown is True + assert state.router_shutdown is True + assert state.sub_router_shutdown is True