You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

130 lines
3.7 KiB

from contextlib import asynccontextmanager
import anyio
import sniffio
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from typing_extensions import Annotated
def get_borrowed_tokens():
return {lname: l.borrowed_tokens for lname, l in sorted(limiters.items())}
def limited_dep():
# run this in the event loop thread:
tokens = anyio.from_thread.run_sync(get_borrowed_tokens)
yield tokens
def init_limiters():
return {
"default": None, # should be set in Lifespan handler
"a": anyio.CapacityLimiter(5),
"b": anyio.CapacityLimiter(3),
}
# Note:
# initializing CapacityLimiters at module level before the event loop has started
# needs anyio >= 4.2.0; starlette currently requires anyio >= 3.4.0
# see https://github.com/agronholm/anyio/pull/651
# The following is a temporary workaround for anyio < 4.2.0:
try:
limiters = init_limiters()
except sniffio.AsyncLibraryNotFoundError:
# not in an async context yet
async def _init_limiters():
return init_limiters()
limiters = anyio.run(_init_limiters)
@asynccontextmanager
async def lifespan(app: FastAPI):
limiters["default"] = anyio.to_thread.current_default_thread_limiter()
yield {
"limiters": limiters,
}
app = FastAPI(lifespan=lifespan)
@app.get("/")
async def root(borrowed_tokens: Annotated[dict, Depends(limited_dep)]):
return borrowed_tokens
@app.get("/a")
async def a(
borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters["a"])],
):
return borrowed_tokens
@app.get("/b")
async def b(
borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters["b"])],
):
return borrowed_tokens
def test_depends_limiter():
with TestClient(app) as client:
response = client.get("/")
assert response.status_code == 200, response.text
assert response.json() == {"a": 0, "b": 0, "default": 1}
response = client.get("/a")
assert response.status_code == 200, response.text
assert response.json() == {"a": 1, "b": 0, "default": 0}
response = client.get("/b")
assert response.status_code == 200, response.text
assert response.json() == {"a": 0, "b": 1, "default": 0}
def test_openapi_schema():
with TestClient(app) as client:
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json()["paths"] == {
"/": {
"get": {
"summary": "Root",
"operationId": "root__get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
}
},
"/a": {
"get": {
"summary": "A",
"operationId": "a_a_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
}
},
"/b": {
"get": {
"summary": "B",
"operationId": "b_b_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
}
},
}