diff --git a/tests/test_depends_limiter.py b/tests/test_depends_limiter.py new file mode 100644 index 000000000..f7e015b7e --- /dev/null +++ b/tests/test_depends_limiter.py @@ -0,0 +1,125 @@ +from contextlib import asynccontextmanager +import threading + +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient +from typing_extensions import Annotated +import anyio +import sniffio + + +def get_borrowed_tokens(): + return { + lname: l.borrowed_tokens + for lname, l in sorted(limiters.items()) + } + +def limited_dep(): + # run this in the event loop thread: + tokens = anyio.from_thread.run_sync(get_borrowed_tokens) + yield tokens + + +def init_limiters(): + return { + 'default': None, # should be set in Lifespan handler + 'a': anyio.CapacityLimiter(5), + 'b': anyio.CapacityLimiter(3), + } + +# Note: +# initializing CapacityLimiters at module level before the event loop has started +# needs anyio >= 4.2.0; starlette currently requires anyio >= 3.4.0 +# see https://github.com/agronholm/anyio/pull/651 + +# The following is a temporary workaround for anyio < 4.2.0: +try: + limiters = init_limiters() +except sniffio.AsyncLibraryNotFoundError: + # not in an async context yet + async def _init_limiters(): + return init_limiters() + limiters = anyio.run(_init_limiters) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + limiters['default'] = anyio.to_thread.current_default_thread_limiter() + yield { + 'limiters': limiters, + } + +app = FastAPI(lifespan=lifespan) + + +@app.get("/") +async def root(borrowed_tokens: Annotated[dict, Depends(limited_dep)]): + return borrowed_tokens + +@app.get("/a") +async def a(borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters['a'])]): + return borrowed_tokens + +@app.get("/b") +async def b(borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters['b'])]): + return borrowed_tokens + + + +def test_depends_limiter(): + with TestClient(app) as client: + response = client.get("/") + assert response.status_code == 200, response.text + assert response.json() == {"a":0,"b":0,"default":1} + + response = client.get("/a") + assert response.status_code == 200, response.text + assert response.json() == {"a":1,"b":0,"default":0} + + response = client.get("/b") + assert response.status_code == 200, response.text + assert response.json() == {"a":0,"b":1,"default":0} + + +def test_openapi_schema(): + with TestClient(app) as client: + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json()['paths'] == { + '/': { + 'get': { + 'summary': 'Root', + 'operationId': 'root__get', + 'responses': { + '200': { + 'description': 'Successful Response', + 'content': {'application/json': {'schema': {}}} + } + } + } + }, + '/a': { + 'get': { + 'summary': 'A', + 'operationId': 'a_a_get', + 'responses': { + '200': { + 'description': 'Successful Response', + 'content': {'application/json': {'schema': {}}} + } + } + } + }, + '/b': { + 'get': { + 'summary': 'B', + 'operationId': 'b_b_get', + 'responses': { + '200': { + 'description': 'Successful Response', + 'content': {'application/json': {'schema': {}}} + } + } + } + } + }