You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

206 lines
5.7 KiB

from enum import Enum
from typing import Any, AsyncGenerator, Generator, TypeVar, Union
from typing_extensions import assert_never
from fastapi import APIRouter, FastAPI, WebSocket
from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
T = TypeVar("T")
class DependencyStyle(str, Enum):
SYNC_FUNCTION = "sync_function"
ASYNC_FUNCTION = "async_function"
SYNC_GENERATOR = "sync_generator"
ASYNC_GENERATOR = "async_generator"
class IntentionallyBadDependency(Exception):
pass
class DependencyFactory:
def __init__(
self,
dependency_style: DependencyStyle,
*,
should_error: bool = False,
value_offset: int = 0,
):
self.activation_times = 0
self.deactivation_times = 0
self.dependency_style = dependency_style
self._should_error = should_error
self._value_offset = value_offset
def get_dependency(self):
if self.dependency_style == DependencyStyle.SYNC_FUNCTION:
return self._synchronous_function_dependency
if self.dependency_style == DependencyStyle.SYNC_GENERATOR:
return self._synchronous_generator_dependency
if self.dependency_style == DependencyStyle.ASYNC_FUNCTION:
return self._asynchronous_function_dependency
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR:
return self._asynchronous_generator_dependency
assert_never(self.dependency_style)
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]:
self.activation_times += 1
if self._should_error:
raise IntentionallyBadDependency(self.activation_times)
yield self.activation_times + self._value_offset
self.deactivation_times += 1
def _synchronous_generator_dependency(self) -> Generator[T, None, None]:
self.activation_times += 1
if self._should_error:
raise IntentionallyBadDependency(self.activation_times)
yield self.activation_times + self._value_offset
self.deactivation_times += 1
async def _asynchronous_function_dependency(self) -> T:
self.activation_times += 1
if self._should_error:
raise IntentionallyBadDependency(self.activation_times)
return self.activation_times + self._value_offset
def _synchronous_function_dependency(self) -> T:
self.activation_times += 1
if self._should_error:
raise IntentionallyBadDependency(self.activation_times)
return self.activation_times + self._value_offset
def use_endpoint(client: TestClient, url: str) -> Any:
response = client.post(url)
response.raise_for_status()
return response.json()
def use_websocket(client: TestClient, url: str) -> Any:
with client.websocket_connect(url) as connection:
return connection.receive_json()
def create_endpoint_0_annotations(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(websocket: WebSocket) -> None:
await websocket.accept()
try:
await websocket.send_json(None)
except WebSocketDisconnect:
pass
else:
@router.post(path)
async def endpoint() -> None:
return None
def create_endpoint_1_annotation(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
annotation: Any,
expected_value: Any = None,
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(websocket: WebSocket, value: annotation) -> None:
if expected_value is not None:
assert value == expected_value
await websocket.accept()
try:
await websocket.send_json(value)
except WebSocketDisconnect:
pass
else:
@router.post(path)
async def endpoint(value: annotation) -> None:
if expected_value is not None:
assert value == expected_value
return value
def create_endpoint_2_annotations(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
annotation1: Any,
annotation2: Any,
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(
websocket: WebSocket,
value1: annotation1,
value2: annotation2,
) -> None:
await websocket.accept()
try:
await websocket.send_json([value1, value2])
except WebSocketDisconnect:
await websocket.close()
else:
@router.post(path)
async def endpoint(
value1: annotation1,
value2: annotation2,
) -> list[Any]:
return [value1, value2]
def create_endpoint_3_annotations(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
annotation1: Any,
annotation2: Any,
annotation3: Any,
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(
websocket: WebSocket,
value1: annotation1,
value2: annotation2,
value3: annotation3,
) -> None:
await websocket.accept()
try:
await websocket.send_json([value1, value2, value3])
except WebSocketDisconnect:
await websocket.close()
else:
@router.post(path)
async def endpoint(
value1: annotation1, value2: annotation2, value3: annotation3
) -> list[Any]:
return [value1, value2, value3]