You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

120 lines
3.7 KiB

import inspect
import logging
from collections.abc import Callable
from functools import partial
from typing import Any, Generic, TypeVar
from fastapi import APIRouter
T = TypeVar("T")
class CBV(Generic[T]):
def __init__(self, cls: type[T], router: APIRouter):
self.logger = logging.getLogger(self.__class__.__name__)
self.router = router
self.cls = cls
def __call__(self, *args: Any, **kwargs: Any) -> T:
self.instance = self.cls(*args, **kwargs)
for name, status_code in {
"head": 200,
"get": 200,
"post": 201,
"put": 204,
"delete": 204,
"patch": 200,
"options": 200,
"trace": 200,
"connect": 200,
}.items():
if hasattr(self.instance, name):
method = getattr(self.instance, name)
self.router.add_api_route(
path="",
endpoint=method,
status_code=status_code,
methods=[name.upper()],
summary=f"{name.upper()} {self.router.prefix}",
)
return self.instance
def __dir__(self) -> list[str]:
return dir(self.cls)
def __getattr__(self, name: str) -> Any:
return getattr(self.cls, name)
class CBR(Generic[T]):
def __init__(self, cls: type[T], router: APIRouter):
self.logger = logging.getLogger(self.__class__.__name__)
self.router = router
self.cls = cls
def __call__(self, *args: Any, **kwargs: Any) -> T:
self.instance = self.cls(*args, **kwargs)
for _name, endpoint in inspect.getmembers(
self.instance, lambda x: inspect.ismethod(x) or inspect.isfunction(x)
):
if cbx_router := endpoint.__annotations__.get("cbx_router"):
for router in cbx_router:
self.router.add_api_route(
path=router["path"],
endpoint=endpoint,
methods=[router["method"]],
**router["kwargs"],
)
return self.instance
def __dir__(self) -> list[str]:
return dir(self.cls)
def __getattr__(self, name: str) -> Any:
return getattr(self.cls, name)
class cbv(Generic[T]):
def __init__(self, router: APIRouter):
self.router = router
def __call__(self, cls: type[T]) -> CBV[T]:
return CBV(cls, self.router)
class cbr(Generic[T]):
class method:
def __init__(self, method: str, path: str, **kwargs: Any):
self.method = method
self.path = path
self.kwargs = kwargs
def __call__(self, endpoint: Callable[..., Any]) -> Callable[..., Any]:
if "cbx_router" in endpoint.__annotations__:
endpoint.__annotations__["cbx_router"].append(
{"method": self.method, "path": self.path, "kwargs": self.kwargs}
)
else:
endpoint.__annotations__.setdefault(
"cbx_router",
[{"method": self.method, "path": self.path, "kwargs": self.kwargs}],
)
return endpoint
head = partial(method, "HEAD")
get = partial(method, "GET")
post = partial(method, "POST")
put = partial(method, "PUT")
delete = partial(method, "DELETE")
patch = partial(method, "PATCH")
options = partial(method, "OPTIONS")
trace = partial(method, "TRACE")
connect = partial(method, "CONNECT")
def __init__(self, router: APIRouter):
self.router = router
def __call__(self, cls: type[T]) -> CBR[T]:
return CBR(cls, self.router)