From dd963511d699b463b416408a0ad705b3dda0d067 Mon Sep 17 00:00:00 2001 From: dmontagu <35119617+dmontagu@users.noreply.github.com> Date: Fri, 4 Oct 2019 14:35:20 -0700 Subject: [PATCH] :bug: Fix preserving route_class when calling include_router (#538) --- fastapi/routing.py | 5 +- tests/test_custom_route_class.py | 114 +++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 tests/test_custom_route_class.py diff --git a/fastapi/routing.py b/fastapi/routing.py index 8f61ea50c..b0902310c 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -348,8 +348,10 @@ class APIRouter(routing.Router): include_in_schema: bool = True, response_class: Type[Response] = None, name: str = None, + route_class_override: Optional[Type[APIRoute]] = None, ) -> None: - route = self.route_class( + route_class = route_class_override or self.route_class + route = route_class( path, endpoint=endpoint, response_model=response_model, @@ -487,6 +489,7 @@ class APIRouter(routing.Router): include_in_schema=route.include_in_schema, response_class=route.response_class or default_response_class, name=route.name, + route_class_override=type(route), ) elif isinstance(route, routing.Route): self.add_route( diff --git a/tests/test_custom_route_class.py b/tests/test_custom_route_class.py new file mode 100644 index 000000000..8bbf88ad3 --- /dev/null +++ b/tests/test_custom_route_class.py @@ -0,0 +1,114 @@ +import pytest +from fastapi import APIRouter, FastAPI +from fastapi.routing import APIRoute +from starlette.testclient import TestClient + +app = FastAPI() + + +class APIRouteA(APIRoute): + x_type = "A" + + +class APIRouteB(APIRoute): + x_type = "B" + + +class APIRouteC(APIRoute): + x_type = "C" + + +router_a = APIRouter(route_class=APIRouteA) +router_b = APIRouter(route_class=APIRouteB) +router_c = APIRouter(route_class=APIRouteC) + + +@router_a.get("/") +def get_a(): + return {"msg": "A"} + + +@router_b.get("/") +def get_b(): + return {"msg": "B"} + + +@router_c.get("/") +def get_c(): + return {"msg": "C"} + + +router_b.include_router(router=router_c, prefix="/c") +router_a.include_router(router=router_b, prefix="/b") +app.include_router(router=router_a, prefix="/a") + + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/a/": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Get A", + "operationId": "get_a_a__get", + } + }, + "/a/b/": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Get B", + "operationId": "get_b_a_b__get", + } + }, + "/a/b/c/": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Get C", + "operationId": "get_c_a_b_c__get", + } + }, + }, +} + + +@pytest.mark.parametrize( + "path,expected_status,expected_response", + [ + ("/a", 200, {"msg": "A"}), + ("/a/b", 200, {"msg": "B"}), + ("/a/b/c", 200, {"msg": "C"}), + ("/openapi.json", 200, openapi_schema), + ], +) +def test_get_path(path, expected_status, expected_response): + response = client.get(path) + assert response.status_code == expected_status + assert response.json() == expected_response + + +def test_route_classes(): + routes = {} + r: APIRoute + for r in app.router.routes: + routes[r.path] = r + assert routes["/a/"].x_type == "A" + assert routes["/a/b/"].x_type == "B" + assert routes["/a/b/c/"].x_type == "C"