diff --git a/docs/src/dependency_testing/tutorial001.py b/docs/src/dependency_testing/tutorial001.py new file mode 100644 index 000000000..2f234d396 --- /dev/null +++ b/docs/src/dependency_testing/tutorial001.py @@ -0,0 +1,55 @@ +from fastapi import Depends, FastAPI +from starlette.testclient import TestClient + +app = FastAPI() + + +async def common_parameters(q: str = None, skip: int = 0, limit: int = 100): + return {"q": q, "skip": skip, "limit": limit} + + +@app.get("/items/") +async def read_items(commons: dict = Depends(common_parameters)): + return {"message": "Hello Items!", "params": commons} + + +@app.get("/users/") +async def read_users(commons: dict = Depends(common_parameters)): + return {"message": "Hello Users!", "params": commons} + + +client = TestClient(app) + + +async def override_dependency(q: str = None): + return {"q": q, "skip": 5, "limit": 10} + + +app.dependency_overrides[common_parameters] = override_dependency + + +def test_override_in_items(): + response = client.get("/items/") + assert response.status_code == 200 + assert response.json() == { + "message": "Hello Items!", + "params": {"q": None, "skip": 5, "limit": 10}, + } + + +def test_override_in_items_with_q(): + response = client.get("/items/?q=foo") + assert response.status_code == 200 + assert response.json() == { + "message": "Hello Items!", + "params": {"q": "foo", "skip": 5, "limit": 10}, + } + + +def test_override_in_items_with_params(): + response = client.get("/items/?q=foo&skip=100&limit=200") + assert response.status_code == 200 + assert response.json() == { + "message": "Hello Items!", + "params": {"q": "foo", "skip": 5, "limit": 10}, + } diff --git a/docs/tutorial/testing-dependencies.md b/docs/tutorial/testing-dependencies.md new file mode 100644 index 000000000..50d50618c --- /dev/null +++ b/docs/tutorial/testing-dependencies.md @@ -0,0 +1,59 @@ +## Overriding dependencies during testing + +There are some scenarios where you might want to override a dependency during testing. + +You don't want the original dependency to run (nor any of the sub-dependencies it might have). + +Instead, you want to provide a different dependency that will be used only during tests (possibly only some specific tests), and will provide a value that can be used where the value of the original dependency was used. + +### Use cases: external service + +An example could be that you have an external authentication provider that you need to call. + +You send it a token and it returns an authenticated user. + +This provider might be charging you per request, and calling it might take some extra time than if you had a fixed mock user for tests. + +You probably want to test the external provider once, but not necessarily call it for every test that runs. + +In this case, you can override the dependency that calls that provider, and use a custom dependency that returns a mock user, only for your tests. + +### Use case: testing database + +Other example could be that you are using a specific database only for testing. + +Your normal dependency would return a database session. + +But then, after each test, you could want to rollback all the operations or remove data. + +Or you could want to alter the data before the tests run, etc. + +In this case, you could use a dependency override to return your *custom* database session instead of the one that would be used normally. + +### Use the `app.dependency_overrides` attribute + +For these cases, your **FastAPI** application has an attribute `app.dependency_overrides`, it is a simple `dict`. + +To override a dependency for testing, you put as a key the original dependency (a function), and as the value, your dependency override (another function). + +And then **FastAPI** will call that override instead of the original dependency. + +```Python hl_lines="24 25 28" +{!./src/dependency_testing/tutorial001.py!} +``` + +!!! tip + You can set a dependency override for a dependency used anywhere in your **FastAPI** application. + + The original dependency could be used in a *path operation function*, a *path operation decorator* (when you don't use the return value), a `.include_router()` call, etc. + + FastAPI will still be able to override it. + +Then you can reset your overrides (remove them) by setting `app.dependency_overrides` to be an empty `dict`: + +```Python +app.dependency_overrides = {} +``` + +!!! tip + If you want to override a dependency only during some tests, you can set the override at the beginning of the test (inside the test function) and reset it at the end (at the end of the test function). diff --git a/fastapi/applications.py b/fastapi/applications.py index 3c38b9d24..a1a2a5605 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -38,7 +38,9 @@ class FastAPI(Starlette): **extra: Dict[str, Any], ) -> None: self._debug = debug - self.router: routing.APIRouter = routing.APIRouter(routes) + self.router: routing.APIRouter = routing.APIRouter( + routes, dependency_overrides_provider=self + ) self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) self.error_middleware = ServerErrorMiddleware( self.exception_middleware, debug=debug @@ -53,6 +55,7 @@ class FastAPI(Starlette): self.redoc_url = redoc_url self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url self.extra = extra + self.dependency_overrides: Dict[Callable, Callable] = {} self.openapi_version = "3.0.2" diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 67eb094e8..33644d764 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -30,6 +30,7 @@ class Dependant: background_tasks_param_name: str = None, security_scopes_param_name: str = None, security_scopes: List[str] = None, + path: str = None, ) -> None: self.path_params = path_params or [] self.query_params = query_params or [] @@ -45,3 +46,5 @@ class Dependant: self.security_scopes_param_name = security_scopes_param_name self.name = name self.call = call + # Store the path to be able to re-generate a dependable from it in overrides + self.path = path diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 74ba61d81..2a64172ef 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -111,6 +111,7 @@ def get_flat_dependant(dependant: Dependant) -> Dependant: cookie_params=dependant.cookie_params.copy(), body_params=dependant.body_params.copy(), security_schemes=dependant.security_requirements.copy(), + path=dependant.path, ) for sub_dependant in dependant.dependencies: flat_sub = get_flat_dependant(sub_dependant) @@ -152,7 +153,7 @@ def get_dependant( path_param_names = get_path_param_names(path) endpoint_signature = inspect.signature(call) signature_params = endpoint_signature.parameters - dependant = Dependant(call=call, name=name) + dependant = Dependant(call=call, name=name, path=path) for param_name, param in signature_params.items(): if isinstance(param.default, params.Depends): sub_dependant = get_param_sub_dependant( @@ -284,26 +285,46 @@ async def solve_dependencies( dependant: Dependant, body: Dict[str, Any] = None, background_tasks: BackgroundTasks = None, + dependency_overrides_provider: Any = None, ) -> Tuple[Dict[str, Any], List[ErrorWrapper], Optional[BackgroundTasks]]: values: Dict[str, Any] = {} errors: List[ErrorWrapper] = [] + sub_dependant: Dependant for sub_dependant in dependant.dependencies: + call: Callable = sub_dependant.call # type: ignore + use_sub_dependant = sub_dependant + if ( + dependency_overrides_provider + and dependency_overrides_provider.dependency_overrides + ): + original_call: Callable = sub_dependant.call # type: ignore + call = getattr( + dependency_overrides_provider, "dependency_overrides", {} + ).get(original_call, original_call) + use_path: str = sub_dependant.path # type: ignore + use_sub_dependant = get_dependant( + path=use_path, + call=call, + name=sub_dependant.name, + security_scopes=sub_dependant.security_scopes, + ) + sub_values, sub_errors, background_tasks = await solve_dependencies( request=request, - dependant=sub_dependant, + dependant=use_sub_dependant, body=body, background_tasks=background_tasks, + dependency_overrides_provider=dependency_overrides_provider, ) if sub_errors: errors.extend(sub_errors) continue - assert sub_dependant.call is not None, "sub_dependant.call must be a function" - if is_coroutine_callable(sub_dependant.call): - solved = await sub_dependant.call(**sub_values) + if is_coroutine_callable(call): + solved = await call(**sub_values) else: - solved = await run_in_threadpool(sub_dependant.call, **sub_values) - if sub_dependant.name is not None: - values[sub_dependant.name] = solved + solved = await run_in_threadpool(call, **sub_values) + if use_sub_dependant.name is not None: + values[use_sub_dependant.name] = solved path_values, path_errors = request_params_to_args( dependant.path_params, request.path_params ) diff --git a/fastapi/routing.py b/fastapi/routing.py index 007194652..8526d8c04 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -30,6 +30,7 @@ from starlette.routing import ( websocket_session, ) from starlette.status import WS_1008_POLICY_VIOLATION +from starlette.types import ASGIApp from starlette.websockets import WebSocket @@ -80,6 +81,7 @@ def get_app( response_model_exclude: Set[str] = set(), response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, + dependency_overrides_provider: Any = None, ) -> Callable: assert dependant.call is not None, "dependant.call must be a function" is_coroutine = asyncio.iscoroutinefunction(dependant.call) @@ -101,7 +103,10 @@ def get_app( status_code=400, detail="There was an error parsing the body" ) from e values, errors, background_tasks = await solve_dependencies( - request=request, dependant=dependant, body=body + request=request, + dependant=dependant, + body=body, + dependency_overrides_provider=dependency_overrides_provider, ) if errors: raise RequestValidationError(errors) @@ -132,10 +137,14 @@ def get_app( return app -def get_websocket_app(dependant: Dependant) -> Callable: +def get_websocket_app( + dependant: Dependant, dependency_overrides_provider: Any = None +) -> Callable: async def app(websocket: WebSocket) -> None: values, errors, _ = await solve_dependencies( - request=websocket, dependant=dependant + request=websocket, + dependant=dependant, + dependency_overrides_provider=dependency_overrides_provider, ) if errors: await websocket.close(code=WS_1008_POLICY_VIOLATION) @@ -147,12 +156,24 @@ def get_websocket_app(dependant: Dependant) -> Callable: class APIWebSocketRoute(routing.WebSocketRoute): - def __init__(self, path: str, endpoint: Callable, *, name: str = None) -> None: + def __init__( + self, + path: str, + endpoint: Callable, + *, + name: str = None, + dependency_overrides_provider: Any = None, + ) -> None: self.path = path self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name self.dependant = get_dependant(path=path, call=self.endpoint) - self.app = websocket_session(get_websocket_app(dependant=self.dependant)) + self.app = websocket_session( + get_websocket_app( + dependant=self.dependant, + dependency_overrides_provider=dependency_overrides_provider, + ) + ) regex = "^" + path + "$" regex = re.sub("{([a-zA-Z_][a-zA-Z0-9_]*)}", r"(?P<\1>[^/]+)", regex) self.path_regex, self.path_format, self.param_convertors = compile_path(path) @@ -182,6 +203,7 @@ class APIRoute(routing.Route): response_model_skip_defaults: bool = False, include_in_schema: bool = True, response_class: Type[Response] = JSONResponse, + dependency_overrides_provider: Any = None, ) -> None: assert path.startswith("/"), "Routed paths must always start with '/'" self.path = path @@ -257,6 +279,7 @@ class APIRoute(routing.Route): get_parameterless_sub_dependant(depends=depends, path=self.path_format), ) self.body_field = get_body_field(dependant=self.dependant, name=self.name) + self.dependency_overrides_provider = dependency_overrides_provider self.app = request_response( get_app( dependant=self.dependant, @@ -268,11 +291,24 @@ class APIRoute(routing.Route): response_model_exclude=self.response_model_exclude, response_model_by_alias=self.response_model_by_alias, response_model_skip_defaults=self.response_model_skip_defaults, + dependency_overrides_provider=self.dependency_overrides_provider, ) ) class APIRouter(routing.Router): + def __init__( + self, + routes: List[routing.BaseRoute] = None, + redirect_slashes: bool = True, + default: ASGIApp = None, + dependency_overrides_provider: Any = None, + ) -> None: + super().__init__( + routes=routes, redirect_slashes=redirect_slashes, default=default + ) + self.dependency_overrides_provider = dependency_overrides_provider + def add_api_route( self, path: str, @@ -318,6 +354,7 @@ class APIRouter(routing.Router): include_in_schema=include_in_schema, response_class=response_class, name=name, + dependency_overrides_provider=self.dependency_overrides_provider, ) self.routes.append(route) diff --git a/mkdocs.yml b/mkdocs.yml index 6954eb044..c75581133 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -81,6 +81,7 @@ nav: - WebSockets: 'tutorial/websockets.md' - 'Events: startup - shutdown': 'tutorial/events.md' - Testing: 'tutorial/testing.md' + - Testing Dependencies with Overrides: 'tutorial/testing-dependencies.md' - Debugging: 'tutorial/debugging.md' - Extending OpenAPI: 'tutorial/extending-openapi.md' - Concurrency and async / await: 'async.md' diff --git a/tests/test_dependency_overrides.py b/tests/test_dependency_overrides.py new file mode 100644 index 000000000..7faf12ea5 --- /dev/null +++ b/tests/test_dependency_overrides.py @@ -0,0 +1,313 @@ +import pytest +from fastapi import APIRouter, Depends, FastAPI +from starlette.testclient import TestClient + +app = FastAPI() + +router = APIRouter() + + +async def common_parameters(q: str, skip: int = 0, limit: int = 100): + return {"q": q, "skip": skip, "limit": limit} + + +@app.get("/main-depends/") +async def main_depends(commons: dict = Depends(common_parameters)): + return {"in": "main-depends", "params": commons} + + +@app.get("/decorator-depends/", dependencies=[Depends(common_parameters)]) +async def decorator_depends(): + return {"in": "decorator-depends"} + + +@router.get("/router-depends/") +async def router_depends(commons: dict = Depends(common_parameters)): + return {"in": "router-depends", "params": commons} + + +@router.get("/router-decorator-depends/", dependencies=[Depends(common_parameters)]) +async def router_decorator_depends(): + return {"in": "router-decorator-depends"} + + +app.include_router(router) + +client = TestClient(app) + + +async def overrider_dependency_simple(q: str = None): + return {"q": q, "skip": 5, "limit": 10} + + +async def overrider_sub_dependency(k: str): + return {"k": k} + + +async def overrider_dependency_with_sub(msg: dict = Depends(overrider_sub_dependency)): + return msg + + +@pytest.mark.parametrize( + "url,status_code,expected", + [ + ( + "/main-depends/", + 422, + { + "detail": [ + { + "loc": ["query", "q"], + "msg": "field required", + "type": "value_error.missing", + } + ] + }, + ), + ( + "/main-depends/?q=foo", + 200, + {"in": "main-depends", "params": {"q": "foo", "skip": 0, "limit": 100}}, + ), + ( + "/main-depends/?q=foo&skip=100&limit=200", + 200, + {"in": "main-depends", "params": {"q": "foo", "skip": 100, "limit": 200}}, + ), + ( + "/decorator-depends/", + 422, + { + "detail": [ + { + "loc": ["query", "q"], + "msg": "field required", + "type": "value_error.missing", + } + ] + }, + ), + ("/decorator-depends/?q=foo", 200, {"in": "decorator-depends"}), + ( + "/decorator-depends/?q=foo&skip=100&limit=200", + 200, + {"in": "decorator-depends"}, + ), + ( + "/router-depends/", + 422, + { + "detail": [ + { + "loc": ["query", "q"], + "msg": "field required", + "type": "value_error.missing", + } + ] + }, + ), + ( + "/router-depends/?q=foo", + 200, + {"in": "router-depends", "params": {"q": "foo", "skip": 0, "limit": 100}}, + ), + ( + "/router-depends/?q=foo&skip=100&limit=200", + 200, + {"in": "router-depends", "params": {"q": "foo", "skip": 100, "limit": 200}}, + ), + ( + "/router-decorator-depends/", + 422, + { + "detail": [ + { + "loc": ["query", "q"], + "msg": "field required", + "type": "value_error.missing", + } + ] + }, + ), + ("/router-decorator-depends/?q=foo", 200, {"in": "router-decorator-depends"}), + ( + "/router-decorator-depends/?q=foo&skip=100&limit=200", + 200, + {"in": "router-decorator-depends"}, + ), + ], +) +def test_normal_app(url, status_code, expected): + response = client.get(url) + assert response.status_code == status_code + assert response.json() == expected + + +@pytest.mark.parametrize( + "url,status_code,expected", + [ + ( + "/main-depends/", + 200, + {"in": "main-depends", "params": {"q": None, "skip": 5, "limit": 10}}, + ), + ( + "/main-depends/?q=foo", + 200, + {"in": "main-depends", "params": {"q": "foo", "skip": 5, "limit": 10}}, + ), + ( + "/main-depends/?q=foo&skip=100&limit=200", + 200, + {"in": "main-depends", "params": {"q": "foo", "skip": 5, "limit": 10}}, + ), + ("/decorator-depends/", 200, {"in": "decorator-depends"}), + ( + "/router-depends/", + 200, + {"in": "router-depends", "params": {"q": None, "skip": 5, "limit": 10}}, + ), + ( + "/router-depends/?q=foo", + 200, + {"in": "router-depends", "params": {"q": "foo", "skip": 5, "limit": 10}}, + ), + ( + "/router-depends/?q=foo&skip=100&limit=200", + 200, + {"in": "router-depends", "params": {"q": "foo", "skip": 5, "limit": 10}}, + ), + ("/router-decorator-depends/", 200, {"in": "router-decorator-depends"}), + ], +) +def test_override_simple(url, status_code, expected): + app.dependency_overrides[common_parameters] = overrider_dependency_simple + response = client.get(url) + assert response.status_code == status_code + assert response.json() == expected + app.dependency_overrides = {} + + +@pytest.mark.parametrize( + "url,status_code,expected", + [ + ( + "/main-depends/", + 422, + { + "detail": [ + { + "loc": ["query", "k"], + "msg": "field required", + "type": "value_error.missing", + } + ] + }, + ), + ( + "/main-depends/?q=foo", + 422, + { + "detail": [ + { + "loc": ["query", "k"], + "msg": "field required", + "type": "value_error.missing", + } + ] + }, + ), + ("/main-depends/?k=bar", 200, {"in": "main-depends", "params": {"k": "bar"}}), + ( + "/decorator-depends/", + 422, + { + "detail": [ + { + "loc": ["query", "k"], + "msg": "field required", + "type": "value_error.missing", + } + ] + }, + ), + ( + "/decorator-depends/?q=foo", + 422, + { + "detail": [ + { + "loc": ["query", "k"], + "msg": "field required", + "type": "value_error.missing", + } + ] + }, + ), + ("/decorator-depends/?k=bar", 200, {"in": "decorator-depends"}), + ( + "/router-depends/", + 422, + { + "detail": [ + { + "loc": ["query", "k"], + "msg": "field required", + "type": "value_error.missing", + } + ] + }, + ), + ( + "/router-depends/?q=foo", + 422, + { + "detail": [ + { + "loc": ["query", "k"], + "msg": "field required", + "type": "value_error.missing", + } + ] + }, + ), + ( + "/router-depends/?k=bar", + 200, + {"in": "router-depends", "params": {"k": "bar"}}, + ), + ( + "/router-decorator-depends/", + 422, + { + "detail": [ + { + "loc": ["query", "k"], + "msg": "field required", + "type": "value_error.missing", + } + ] + }, + ), + ( + "/router-decorator-depends/?q=foo", + 422, + { + "detail": [ + { + "loc": ["query", "k"], + "msg": "field required", + "type": "value_error.missing", + } + ] + }, + ), + ("/router-decorator-depends/?k=bar", 200, {"in": "router-decorator-depends"}), + ], +) +def test_override_with_sub(url, status_code, expected): + app.dependency_overrides[common_parameters] = overrider_dependency_with_sub + response = client.get(url) + assert response.status_code == status_code + assert response.json() == expected + app.dependency_overrides = {} diff --git a/tests/test_tutorial/test_testing_dependencies/__init__.py b/tests/test_tutorial/test_testing_dependencies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_tutorial/test_testing_dependencies/test_tutorial001.py b/tests/test_tutorial/test_testing_dependencies/test_tutorial001.py new file mode 100644 index 000000000..093c6499a --- /dev/null +++ b/tests/test_tutorial/test_testing_dependencies/test_tutorial001.py @@ -0,0 +1,56 @@ +from dependency_testing.tutorial001 import ( + app, + client, + test_override_in_items, + test_override_in_items_with_params, + test_override_in_items_with_q, +) + + +def test_override_in_items_run(): + test_override_in_items() + + +def test_override_in_items_with_q_run(): + test_override_in_items_with_q() + + +def test_override_in_items_with_params_run(): + test_override_in_items_with_params() + + +def test_override_in_users(): + response = client.get("/users/") + assert response.status_code == 200 + assert response.json() == { + "message": "Hello Users!", + "params": {"q": None, "skip": 5, "limit": 10}, + } + + +def test_override_in_users_with_q(): + response = client.get("/users/?q=foo") + assert response.status_code == 200 + assert response.json() == { + "message": "Hello Users!", + "params": {"q": "foo", "skip": 5, "limit": 10}, + } + + +def test_override_in_users_with_params(): + response = client.get("/users/?q=foo&skip=100&limit=200") + assert response.status_code == 200 + assert response.json() == { + "message": "Hello Users!", + "params": {"q": "foo", "skip": 5, "limit": 10}, + } + + +def test_normal_app(): + app.dependency_overrides = None + response = client.get("/items/?q=foo&skip=100&limit=200") + assert response.status_code == 200 + assert response.json() == { + "message": "Hello Items!", + "params": {"q": "foo", "skip": 100, "limit": 200}, + }