diff --git a/a2s/a2s_async.py b/a2s/a2s_async.py index 04399d0..473fb48 100644 --- a/a2s/a2s_async.py +++ b/a2s/a2s_async.py @@ -4,45 +4,29 @@ import asyncio import io import logging import time -from typing import ( - TYPE_CHECKING, - Any, - List, - NoReturn, - Optional, - Tuple, - Type, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, List, NoReturn, Optional, Tuple, Type from a2s.a2s_fragment import A2SFragment, decode_fragment +from a2s.a2s_protocol import A2SProtocol from a2s.byteio import ByteReader from a2s.defaults import DEFAULT_RETRIES from a2s.exceptions import BrokenMessageError -from .info import InfoProtocol -from .players import PlayersProtocol -from .rules import RulesProtocol - if TYPE_CHECKING: from typing_extensions import Self HEADER_SIMPLE = b"\xFF\xFF\xFF\xFF" HEADER_MULTI = b"\xFE\xFF\xFF\xFF" A2S_CHALLENGE_RESPONSE = 0x41 -PROTOCOLS = Union[InfoProtocol, PlayersProtocol, RulesProtocol] logger: logging.Logger = logging.getLogger("a2s") -ProtocolT = TypeVar("ProtocolT", InfoProtocol, PlayersProtocol, RulesProtocol) - async def request_async( address: Tuple[str, int], timeout: float, encoding: str, - a2s_proto: Type[ProtocolT], + a2s_proto: Type[A2SProtocol], ) -> Any: conn = await A2SStreamAsync.create(address, timeout) response = await request_async_impl(conn, encoding, a2s_proto) @@ -53,7 +37,7 @@ async def request_async( async def request_async_impl( conn: A2SStreamAsync, encoding: str, - a2s_proto: Type[ProtocolT], + a2s_proto: Type[A2SProtocol], challenge: int = 0, retries: int = 0, ping: Optional[float] = None, @@ -86,7 +70,7 @@ async def request_async_impl( return a2s_proto.deserialize_response(reader, response_type, ping) -class A2SProtocol(asyncio.DatagramProtocol): +class A2SDatagramProtocol(asyncio.DatagramProtocol): def __init__(self) -> None: self.recv_queue: asyncio.Queue[bytes] = asyncio.Queue() self.error_event: asyncio.Event = asyncio.Event() @@ -142,11 +126,11 @@ class A2SStreamAsync: def __init__( self, transport: asyncio.DatagramTransport, - protocol: A2SProtocol, + protocol: A2SDatagramProtocol, timeout: float, ) -> None: self.transport: asyncio.DatagramTransport = transport - self.protocol: A2SProtocol = protocol + self.protocol: A2SDatagramProtocol = protocol self.timeout: float = timeout def __del__(self) -> None: @@ -156,7 +140,7 @@ class A2SStreamAsync: async def create(cls, address: Tuple[str, int], timeout: float) -> Self: loop = asyncio.get_running_loop() transport, protocol = await loop.create_datagram_endpoint( - lambda: A2SProtocol(), remote_addr=address + lambda: A2SDatagramProtocol(), remote_addr=address ) return cls(transport, protocol, timeout) diff --git a/a2s/a2s_protocol.py b/a2s/a2s_protocol.py new file mode 100644 index 0000000..bf5b5b9 --- /dev/null +++ b/a2s/a2s_protocol.py @@ -0,0 +1,22 @@ +__all__ = ("A2SProtocol",) + + +from typing import Any, Optional + +from .byteio import ByteReader + + +class A2SProtocol: + @staticmethod + def serialize_request(challenge: int) -> bytes: + raise NotImplemented + + @staticmethod + def validate_response_type(response_type: int) -> bool: + raise NotImplemented + + @staticmethod + def deserialize_response( + reader: ByteReader, response_type: int, ping: Optional[float] + ) -> Any: + raise NotImplemented diff --git a/a2s/a2s_sync.py b/a2s/a2s_sync.py index 9266885..dc2c2cc 100644 --- a/a2s/a2s_sync.py +++ b/a2s/a2s_sync.py @@ -4,29 +4,26 @@ import io import logging import socket import time -from typing import Any, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Optional, Tuple, Type from a2s.a2s_fragment import decode_fragment +from a2s.a2s_protocol import A2SProtocol from a2s.byteio import ByteReader from a2s.defaults import DEFAULT_RETRIES from a2s.exceptions import BrokenMessageError -from .info import InfoProtocol -from .players import PlayersProtocol -from .rules import RulesProtocol - HEADER_SIMPLE = b"\xFF\xFF\xFF\xFF" HEADER_MULTI = b"\xFE\xFF\xFF\xFF" A2S_CHALLENGE_RESPONSE = 0x41 -PROTOCOLS = Union[InfoProtocol, RulesProtocol, PlayersProtocol] logger: logging.Logger = logging.getLogger("a2s") -T = TypeVar("T", InfoProtocol, RulesProtocol, PlayersProtocol) - def request_sync( - address: Tuple[str, int], timeout: float, encoding: str, a2s_proto: Type[T] + address: Tuple[str, int], + timeout: float, + encoding: str, + a2s_proto: Type[A2SProtocol], ) -> Any: conn = A2SStream(address, timeout) response = request_sync_impl(conn, encoding, a2s_proto) @@ -37,7 +34,7 @@ def request_sync( def request_sync_impl( conn: A2SStream, encoding: str, - a2s_proto: Type[T], + a2s_proto: Type[A2SProtocol], challenge: int = 0, retries: int = 0, ping: Optional[float] = None, diff --git a/a2s/info.py b/a2s/info.py index e697583..c563e80 100644 --- a/a2s/info.py +++ b/a2s/info.py @@ -8,6 +8,7 @@ from a2s.datacls import DataclsMeta from a2s.defaults import DEFAULT_ENCODING, DEFAULT_TIMEOUT from a2s.exceptions import BufferExhaustedError +from .a2s_protocol import A2SProtocol from .byteio import ByteReader A2S_INFO_RESPONSE = 0x49 @@ -204,7 +205,7 @@ async def ainfo( return await request_async(address, timeout, encoding, InfoProtocol) -class InfoProtocol: +class InfoProtocol(A2SProtocol): @staticmethod def validate_response_type(response_type: int) -> bool: return response_type in (A2S_INFO_RESPONSE, A2S_INFO_RESPONSE_LEGACY) diff --git a/a2s/players.py b/a2s/players.py index 2ba3959..0b9e034 100644 --- a/a2s/players.py +++ b/a2s/players.py @@ -1,6 +1,7 @@ from typing import List, Optional, Tuple from a2s.a2s_async import request_async +from a2s.a2s_protocol import A2SProtocol from a2s.a2s_sync import request_sync from a2s.byteio import ByteReader from a2s.datacls import DataclsMeta @@ -45,7 +46,7 @@ async def aplayers( return await request_async(address, timeout, encoding, PlayersProtocol) -class PlayersProtocol: +class PlayersProtocol(A2SProtocol): @staticmethod def validate_response_type(response_type: int) -> bool: return response_type == A2S_PLAYER_RESPONSE diff --git a/a2s/rules.py b/a2s/rules.py index 5575c65..06e4c7b 100644 --- a/a2s/rules.py +++ b/a2s/rules.py @@ -1,6 +1,7 @@ from typing import Dict, Optional, Tuple, Union from a2s.a2s_async import request_async +from a2s.a2s_protocol import A2SProtocol from a2s.a2s_sync import request_sync from a2s.byteio import ByteReader from a2s.defaults import DEFAULT_ENCODING, DEFAULT_TIMEOUT @@ -29,7 +30,7 @@ async def arules( return await request_async(address, timeout, encoding, RulesProtocol) -class RulesProtocol: +class RulesProtocol(A2SProtocol): @staticmethod def validate_response_type(response_type: int) -> bool: return response_type == A2S_RULES_RESPONSE