Browse Source

switch to a base type for the a2s protocols

pull/43/head
Alex Nørgaard 2 years ago
parent
commit
a8278f2b6d
No known key found for this signature in database GPG Key ID: 94D54F6A3604E97
  1. 32
      a2s/a2s_async.py
  2. 22
      a2s/a2s_protocol.py
  3. 17
      a2s/a2s_sync.py
  4. 3
      a2s/info.py
  5. 3
      a2s/players.py
  6. 3
      a2s/rules.py

32
a2s/a2s_async.py

@ -4,45 +4,29 @@ import asyncio
import io
import logging
import time
from typing import (
TYPE_CHECKING,
Any,
List,
NoReturn,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from typing import TYPE_CHECKING, Any, List, NoReturn, Optional, Tuple, Type
from a2s.a2s_fragment import A2SFragment, decode_fragment
from a2s.a2s_protocol import A2SProtocol
from a2s.byteio import ByteReader
from a2s.defaults import DEFAULT_RETRIES
from a2s.exceptions import BrokenMessageError
from .info import InfoProtocol
from .players import PlayersProtocol
from .rules import RulesProtocol
if TYPE_CHECKING:
from typing_extensions import Self
HEADER_SIMPLE = b"\xFF\xFF\xFF\xFF"
HEADER_MULTI = b"\xFE\xFF\xFF\xFF"
A2S_CHALLENGE_RESPONSE = 0x41
PROTOCOLS = Union[InfoProtocol, PlayersProtocol, RulesProtocol]
logger: logging.Logger = logging.getLogger("a2s")
ProtocolT = TypeVar("ProtocolT", InfoProtocol, PlayersProtocol, RulesProtocol)
async def request_async(
address: Tuple[str, int],
timeout: float,
encoding: str,
a2s_proto: Type[ProtocolT],
a2s_proto: Type[A2SProtocol],
) -> Any:
conn = await A2SStreamAsync.create(address, timeout)
response = await request_async_impl(conn, encoding, a2s_proto)
@ -53,7 +37,7 @@ async def request_async(
async def request_async_impl(
conn: A2SStreamAsync,
encoding: str,
a2s_proto: Type[ProtocolT],
a2s_proto: Type[A2SProtocol],
challenge: int = 0,
retries: int = 0,
ping: Optional[float] = None,
@ -86,7 +70,7 @@ async def request_async_impl(
return a2s_proto.deserialize_response(reader, response_type, ping)
class A2SProtocol(asyncio.DatagramProtocol):
class A2SDatagramProtocol(asyncio.DatagramProtocol):
def __init__(self) -> None:
self.recv_queue: asyncio.Queue[bytes] = asyncio.Queue()
self.error_event: asyncio.Event = asyncio.Event()
@ -142,11 +126,11 @@ class A2SStreamAsync:
def __init__(
self,
transport: asyncio.DatagramTransport,
protocol: A2SProtocol,
protocol: A2SDatagramProtocol,
timeout: float,
) -> None:
self.transport: asyncio.DatagramTransport = transport
self.protocol: A2SProtocol = protocol
self.protocol: A2SDatagramProtocol = protocol
self.timeout: float = timeout
def __del__(self) -> None:
@ -156,7 +140,7 @@ class A2SStreamAsync:
async def create(cls, address: Tuple[str, int], timeout: float) -> Self:
loop = asyncio.get_running_loop()
transport, protocol = await loop.create_datagram_endpoint(
lambda: A2SProtocol(), remote_addr=address
lambda: A2SDatagramProtocol(), remote_addr=address
)
return cls(transport, protocol, timeout)

22
a2s/a2s_protocol.py

@ -0,0 +1,22 @@
__all__ = ("A2SProtocol",)
from typing import Any, Optional
from .byteio import ByteReader
class A2SProtocol:
@staticmethod
def serialize_request(challenge: int) -> bytes:
raise NotImplemented
@staticmethod
def validate_response_type(response_type: int) -> bool:
raise NotImplemented
@staticmethod
def deserialize_response(
reader: ByteReader, response_type: int, ping: Optional[float]
) -> Any:
raise NotImplemented

17
a2s/a2s_sync.py

@ -4,29 +4,26 @@ import io
import logging
import socket
import time
from typing import Any, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Optional, Tuple, Type
from a2s.a2s_fragment import decode_fragment
from a2s.a2s_protocol import A2SProtocol
from a2s.byteio import ByteReader
from a2s.defaults import DEFAULT_RETRIES
from a2s.exceptions import BrokenMessageError
from .info import InfoProtocol
from .players import PlayersProtocol
from .rules import RulesProtocol
HEADER_SIMPLE = b"\xFF\xFF\xFF\xFF"
HEADER_MULTI = b"\xFE\xFF\xFF\xFF"
A2S_CHALLENGE_RESPONSE = 0x41
PROTOCOLS = Union[InfoProtocol, RulesProtocol, PlayersProtocol]
logger: logging.Logger = logging.getLogger("a2s")
T = TypeVar("T", InfoProtocol, RulesProtocol, PlayersProtocol)
def request_sync(
address: Tuple[str, int], timeout: float, encoding: str, a2s_proto: Type[T]
address: Tuple[str, int],
timeout: float,
encoding: str,
a2s_proto: Type[A2SProtocol],
) -> Any:
conn = A2SStream(address, timeout)
response = request_sync_impl(conn, encoding, a2s_proto)
@ -37,7 +34,7 @@ def request_sync(
def request_sync_impl(
conn: A2SStream,
encoding: str,
a2s_proto: Type[T],
a2s_proto: Type[A2SProtocol],
challenge: int = 0,
retries: int = 0,
ping: Optional[float] = None,

3
a2s/info.py

@ -8,6 +8,7 @@ from a2s.datacls import DataclsMeta
from a2s.defaults import DEFAULT_ENCODING, DEFAULT_TIMEOUT
from a2s.exceptions import BufferExhaustedError
from .a2s_protocol import A2SProtocol
from .byteio import ByteReader
A2S_INFO_RESPONSE = 0x49
@ -204,7 +205,7 @@ async def ainfo(
return await request_async(address, timeout, encoding, InfoProtocol)
class InfoProtocol:
class InfoProtocol(A2SProtocol):
@staticmethod
def validate_response_type(response_type: int) -> bool:
return response_type in (A2S_INFO_RESPONSE, A2S_INFO_RESPONSE_LEGACY)

3
a2s/players.py

@ -1,6 +1,7 @@
from typing import List, Optional, Tuple
from a2s.a2s_async import request_async
from a2s.a2s_protocol import A2SProtocol
from a2s.a2s_sync import request_sync
from a2s.byteio import ByteReader
from a2s.datacls import DataclsMeta
@ -45,7 +46,7 @@ async def aplayers(
return await request_async(address, timeout, encoding, PlayersProtocol)
class PlayersProtocol:
class PlayersProtocol(A2SProtocol):
@staticmethod
def validate_response_type(response_type: int) -> bool:
return response_type == A2S_PLAYER_RESPONSE

3
a2s/rules.py

@ -1,6 +1,7 @@
from typing import Dict, Optional, Tuple, Union
from a2s.a2s_async import request_async
from a2s.a2s_protocol import A2SProtocol
from a2s.a2s_sync import request_sync
from a2s.byteio import ByteReader
from a2s.defaults import DEFAULT_ENCODING, DEFAULT_TIMEOUT
@ -29,7 +30,7 @@ async def arules(
return await request_async(address, timeout, encoding, RulesProtocol)
class RulesProtocol:
class RulesProtocol(A2SProtocol):
@staticmethod
def validate_response_type(response_type: int) -> bool:
return response_type == A2S_RULES_RESPONSE

Loading…
Cancel
Save