pythonasyncioapiasyncfastapiframeworkjsonjson-schemaopenapiopenapi3pydanticpython-typespython3redocreststarletteswaggerswagger-uiuvicornweb
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
120 lines
3.7 KiB
120 lines
3.7 KiB
import inspect
|
|
import logging
|
|
from collections.abc import Callable
|
|
from functools import partial
|
|
from typing import Any, Generic, TypeVar
|
|
|
|
from fastapi import APIRouter
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class CBV(Generic[T]):
|
|
def __init__(self, cls: type[T], router: APIRouter):
|
|
self.logger = logging.getLogger(self.__class__.__name__)
|
|
self.router = router
|
|
self.cls = cls
|
|
|
|
def __call__(self, *args: Any, **kwargs: Any) -> T:
|
|
self.instance = self.cls(*args, **kwargs)
|
|
for name, status_code in {
|
|
"head": 200,
|
|
"get": 200,
|
|
"post": 201,
|
|
"put": 204,
|
|
"delete": 204,
|
|
"patch": 200,
|
|
"options": 200,
|
|
"trace": 200,
|
|
"connect": 200,
|
|
}.items():
|
|
if hasattr(self.instance, name):
|
|
method = getattr(self.instance, name)
|
|
self.router.add_api_route(
|
|
path="",
|
|
endpoint=method,
|
|
status_code=status_code,
|
|
methods=[name.upper()],
|
|
summary=f"{name.upper()} {self.router.prefix}",
|
|
)
|
|
|
|
return self.instance
|
|
|
|
def __dir__(self) -> list[str]:
|
|
return dir(self.cls)
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
return getattr(self.cls, name)
|
|
|
|
|
|
class CBR(Generic[T]):
|
|
def __init__(self, cls: type[T], router: APIRouter):
|
|
self.logger = logging.getLogger(self.__class__.__name__)
|
|
self.router = router
|
|
self.cls = cls
|
|
|
|
def __call__(self, *args: Any, **kwargs: Any) -> T:
|
|
self.instance = self.cls(*args, **kwargs)
|
|
|
|
for _name, endpoint in inspect.getmembers(
|
|
self.instance, lambda x: inspect.ismethod(x) or inspect.isfunction(x)
|
|
):
|
|
if cbx_router := endpoint.__annotations__.get("cbx_router"):
|
|
for router in cbx_router:
|
|
self.router.add_api_route(
|
|
path=router["path"],
|
|
endpoint=endpoint,
|
|
methods=[router["method"]],
|
|
**router["kwargs"],
|
|
)
|
|
return self.instance
|
|
|
|
def __dir__(self) -> list[str]:
|
|
return dir(self.cls)
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
return getattr(self.cls, name)
|
|
|
|
|
|
class cbv(Generic[T]):
|
|
def __init__(self, router: APIRouter):
|
|
self.router = router
|
|
|
|
def __call__(self, cls: type[T]) -> CBV[T]:
|
|
return CBV(cls, self.router)
|
|
|
|
|
|
class cbr(Generic[T]):
|
|
class method:
|
|
def __init__(self, method: str, path: str, **kwargs: Any):
|
|
self.method = method
|
|
self.path = path
|
|
self.kwargs = kwargs
|
|
|
|
def __call__(self, endpoint: Callable[..., Any]) -> Callable[..., Any]:
|
|
if "cbx_router" in endpoint.__annotations__:
|
|
endpoint.__annotations__["cbx_router"].append(
|
|
{"method": self.method, "path": self.path, "kwargs": self.kwargs}
|
|
)
|
|
else:
|
|
endpoint.__annotations__.setdefault(
|
|
"cbx_router",
|
|
[{"method": self.method, "path": self.path, "kwargs": self.kwargs}],
|
|
)
|
|
return endpoint
|
|
|
|
head = partial(method, "HEAD")
|
|
get = partial(method, "GET")
|
|
post = partial(method, "POST")
|
|
put = partial(method, "PUT")
|
|
delete = partial(method, "DELETE")
|
|
patch = partial(method, "PATCH")
|
|
options = partial(method, "OPTIONS")
|
|
trace = partial(method, "TRACE")
|
|
connect = partial(method, "CONNECT")
|
|
|
|
def __init__(self, router: APIRouter):
|
|
self.router = router
|
|
|
|
def __call__(self, cls: type[T]) -> CBR[T]:
|
|
return CBR(cls, self.router)
|
|
|