You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

125 lines
3.7 KiB

from contextlib import asynccontextmanager
import threading
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from typing_extensions import Annotated
import anyio
import sniffio
def get_borrowed_tokens():
return {
lname: l.borrowed_tokens
for lname, l in sorted(limiters.items())
}
def limited_dep():
# run this in the event loop thread:
tokens = anyio.from_thread.run_sync(get_borrowed_tokens)
yield tokens
def init_limiters():
return {
'default': None, # should be set in Lifespan handler
'a': anyio.CapacityLimiter(5),
'b': anyio.CapacityLimiter(3),
}
# Note:
# initializing CapacityLimiters at module level before the event loop has started
# needs anyio >= 4.2.0; starlette currently requires anyio >= 3.4.0
# see https://github.com/agronholm/anyio/pull/651
# The following is a temporary workaround for anyio < 4.2.0:
try:
limiters = init_limiters()
except sniffio.AsyncLibraryNotFoundError:
# not in an async context yet
async def _init_limiters():
return init_limiters()
limiters = anyio.run(_init_limiters)
@asynccontextmanager
async def lifespan(app: FastAPI):
limiters['default'] = anyio.to_thread.current_default_thread_limiter()
yield {
'limiters': limiters,
}
app = FastAPI(lifespan=lifespan)
@app.get("/")
async def root(borrowed_tokens: Annotated[dict, Depends(limited_dep)]):
return borrowed_tokens
@app.get("/a")
async def a(borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters['a'])]):
return borrowed_tokens
@app.get("/b")
async def b(borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters['b'])]):
return borrowed_tokens
def test_depends_limiter():
with TestClient(app) as client:
response = client.get("/")
assert response.status_code == 200, response.text
assert response.json() == {"a":0,"b":0,"default":1}
response = client.get("/a")
assert response.status_code == 200, response.text
assert response.json() == {"a":1,"b":0,"default":0}
response = client.get("/b")
assert response.status_code == 200, response.text
assert response.json() == {"a":0,"b":1,"default":0}
def test_openapi_schema():
with TestClient(app) as client:
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json()['paths'] == {
'/': {
'get': {
'summary': 'Root',
'operationId': 'root__get',
'responses': {
'200': {
'description': 'Successful Response',
'content': {'application/json': {'schema': {}}}
}
}
}
},
'/a': {
'get': {
'summary': 'A',
'operationId': 'a_a_get',
'responses': {
'200': {
'description': 'Successful Response',
'content': {'application/json': {'schema': {}}}
}
}
}
},
'/b': {
'get': {
'summary': 'B',
'operationId': 'b_b_get',
'responses': {
'200': {
'description': 'Successful Response',
'content': {'application/json': {'schema': {}}}
}
}
}
}
}