Browse Source

switch to a base type for the a2s protocols

pull/43/head
Alex Nørgaard 2 years ago
parent
commit
a8278f2b6d
No known key found for this signature in database GPG Key ID: 94D54F6A3604E97
  1. 32
      a2s/a2s_async.py
  2. 22
      a2s/a2s_protocol.py
  3. 17
      a2s/a2s_sync.py
  4. 3
      a2s/info.py
  5. 3
      a2s/players.py
  6. 3
      a2s/rules.py

32
a2s/a2s_async.py

@ -4,45 +4,29 @@ import asyncio
import io import io
import logging import logging
import time import time
from typing import ( from typing import TYPE_CHECKING, Any, List, NoReturn, Optional, Tuple, Type
TYPE_CHECKING,
Any,
List,
NoReturn,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from a2s.a2s_fragment import A2SFragment, decode_fragment from a2s.a2s_fragment import A2SFragment, decode_fragment
from a2s.a2s_protocol import A2SProtocol
from a2s.byteio import ByteReader from a2s.byteio import ByteReader
from a2s.defaults import DEFAULT_RETRIES from a2s.defaults import DEFAULT_RETRIES
from a2s.exceptions import BrokenMessageError from a2s.exceptions import BrokenMessageError
from .info import InfoProtocol
from .players import PlayersProtocol
from .rules import RulesProtocol
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self
HEADER_SIMPLE = b"\xFF\xFF\xFF\xFF" HEADER_SIMPLE = b"\xFF\xFF\xFF\xFF"
HEADER_MULTI = b"\xFE\xFF\xFF\xFF" HEADER_MULTI = b"\xFE\xFF\xFF\xFF"
A2S_CHALLENGE_RESPONSE = 0x41 A2S_CHALLENGE_RESPONSE = 0x41
PROTOCOLS = Union[InfoProtocol, PlayersProtocol, RulesProtocol]
logger: logging.Logger = logging.getLogger("a2s") logger: logging.Logger = logging.getLogger("a2s")
ProtocolT = TypeVar("ProtocolT", InfoProtocol, PlayersProtocol, RulesProtocol)
async def request_async( async def request_async(
address: Tuple[str, int], address: Tuple[str, int],
timeout: float, timeout: float,
encoding: str, encoding: str,
a2s_proto: Type[ProtocolT], a2s_proto: Type[A2SProtocol],
) -> Any: ) -> Any:
conn = await A2SStreamAsync.create(address, timeout) conn = await A2SStreamAsync.create(address, timeout)
response = await request_async_impl(conn, encoding, a2s_proto) response = await request_async_impl(conn, encoding, a2s_proto)
@ -53,7 +37,7 @@ async def request_async(
async def request_async_impl( async def request_async_impl(
conn: A2SStreamAsync, conn: A2SStreamAsync,
encoding: str, encoding: str,
a2s_proto: Type[ProtocolT], a2s_proto: Type[A2SProtocol],
challenge: int = 0, challenge: int = 0,
retries: int = 0, retries: int = 0,
ping: Optional[float] = None, ping: Optional[float] = None,
@ -86,7 +70,7 @@ async def request_async_impl(
return a2s_proto.deserialize_response(reader, response_type, ping) return a2s_proto.deserialize_response(reader, response_type, ping)
class A2SProtocol(asyncio.DatagramProtocol): class A2SDatagramProtocol(asyncio.DatagramProtocol):
def __init__(self) -> None: def __init__(self) -> None:
self.recv_queue: asyncio.Queue[bytes] = asyncio.Queue() self.recv_queue: asyncio.Queue[bytes] = asyncio.Queue()
self.error_event: asyncio.Event = asyncio.Event() self.error_event: asyncio.Event = asyncio.Event()
@ -142,11 +126,11 @@ class A2SStreamAsync:
def __init__( def __init__(
self, self,
transport: asyncio.DatagramTransport, transport: asyncio.DatagramTransport,
protocol: A2SProtocol, protocol: A2SDatagramProtocol,
timeout: float, timeout: float,
) -> None: ) -> None:
self.transport: asyncio.DatagramTransport = transport self.transport: asyncio.DatagramTransport = transport
self.protocol: A2SProtocol = protocol self.protocol: A2SDatagramProtocol = protocol
self.timeout: float = timeout self.timeout: float = timeout
def __del__(self) -> None: def __del__(self) -> None:
@ -156,7 +140,7 @@ class A2SStreamAsync:
async def create(cls, address: Tuple[str, int], timeout: float) -> Self: async def create(cls, address: Tuple[str, int], timeout: float) -> Self:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
transport, protocol = await loop.create_datagram_endpoint( transport, protocol = await loop.create_datagram_endpoint(
lambda: A2SProtocol(), remote_addr=address lambda: A2SDatagramProtocol(), remote_addr=address
) )
return cls(transport, protocol, timeout) return cls(transport, protocol, timeout)

22
a2s/a2s_protocol.py

@ -0,0 +1,22 @@
__all__ = ("A2SProtocol",)
from typing import Any, Optional
from .byteio import ByteReader
class A2SProtocol:
@staticmethod
def serialize_request(challenge: int) -> bytes:
raise NotImplemented
@staticmethod
def validate_response_type(response_type: int) -> bool:
raise NotImplemented
@staticmethod
def deserialize_response(
reader: ByteReader, response_type: int, ping: Optional[float]
) -> Any:
raise NotImplemented

17
a2s/a2s_sync.py

@ -4,29 +4,26 @@ import io
import logging import logging
import socket import socket
import time import time
from typing import Any, Optional, Tuple, Type, TypeVar, Union from typing import Any, Optional, Tuple, Type
from a2s.a2s_fragment import decode_fragment from a2s.a2s_fragment import decode_fragment
from a2s.a2s_protocol import A2SProtocol
from a2s.byteio import ByteReader from a2s.byteio import ByteReader
from a2s.defaults import DEFAULT_RETRIES from a2s.defaults import DEFAULT_RETRIES
from a2s.exceptions import BrokenMessageError from a2s.exceptions import BrokenMessageError
from .info import InfoProtocol
from .players import PlayersProtocol
from .rules import RulesProtocol
HEADER_SIMPLE = b"\xFF\xFF\xFF\xFF" HEADER_SIMPLE = b"\xFF\xFF\xFF\xFF"
HEADER_MULTI = b"\xFE\xFF\xFF\xFF" HEADER_MULTI = b"\xFE\xFF\xFF\xFF"
A2S_CHALLENGE_RESPONSE = 0x41 A2S_CHALLENGE_RESPONSE = 0x41
PROTOCOLS = Union[InfoProtocol, RulesProtocol, PlayersProtocol]
logger: logging.Logger = logging.getLogger("a2s") logger: logging.Logger = logging.getLogger("a2s")
T = TypeVar("T", InfoProtocol, RulesProtocol, PlayersProtocol)
def request_sync( def request_sync(
address: Tuple[str, int], timeout: float, encoding: str, a2s_proto: Type[T] address: Tuple[str, int],
timeout: float,
encoding: str,
a2s_proto: Type[A2SProtocol],
) -> Any: ) -> Any:
conn = A2SStream(address, timeout) conn = A2SStream(address, timeout)
response = request_sync_impl(conn, encoding, a2s_proto) response = request_sync_impl(conn, encoding, a2s_proto)
@ -37,7 +34,7 @@ def request_sync(
def request_sync_impl( def request_sync_impl(
conn: A2SStream, conn: A2SStream,
encoding: str, encoding: str,
a2s_proto: Type[T], a2s_proto: Type[A2SProtocol],
challenge: int = 0, challenge: int = 0,
retries: int = 0, retries: int = 0,
ping: Optional[float] = None, ping: Optional[float] = None,

3
a2s/info.py

@ -8,6 +8,7 @@ from a2s.datacls import DataclsMeta
from a2s.defaults import DEFAULT_ENCODING, DEFAULT_TIMEOUT from a2s.defaults import DEFAULT_ENCODING, DEFAULT_TIMEOUT
from a2s.exceptions import BufferExhaustedError from a2s.exceptions import BufferExhaustedError
from .a2s_protocol import A2SProtocol
from .byteio import ByteReader from .byteio import ByteReader
A2S_INFO_RESPONSE = 0x49 A2S_INFO_RESPONSE = 0x49
@ -204,7 +205,7 @@ async def ainfo(
return await request_async(address, timeout, encoding, InfoProtocol) return await request_async(address, timeout, encoding, InfoProtocol)
class InfoProtocol: class InfoProtocol(A2SProtocol):
@staticmethod @staticmethod
def validate_response_type(response_type: int) -> bool: def validate_response_type(response_type: int) -> bool:
return response_type in (A2S_INFO_RESPONSE, A2S_INFO_RESPONSE_LEGACY) return response_type in (A2S_INFO_RESPONSE, A2S_INFO_RESPONSE_LEGACY)

3
a2s/players.py

@ -1,6 +1,7 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from a2s.a2s_async import request_async from a2s.a2s_async import request_async
from a2s.a2s_protocol import A2SProtocol
from a2s.a2s_sync import request_sync from a2s.a2s_sync import request_sync
from a2s.byteio import ByteReader from a2s.byteio import ByteReader
from a2s.datacls import DataclsMeta from a2s.datacls import DataclsMeta
@ -45,7 +46,7 @@ async def aplayers(
return await request_async(address, timeout, encoding, PlayersProtocol) return await request_async(address, timeout, encoding, PlayersProtocol)
class PlayersProtocol: class PlayersProtocol(A2SProtocol):
@staticmethod @staticmethod
def validate_response_type(response_type: int) -> bool: def validate_response_type(response_type: int) -> bool:
return response_type == A2S_PLAYER_RESPONSE return response_type == A2S_PLAYER_RESPONSE

3
a2s/rules.py

@ -1,6 +1,7 @@
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
from a2s.a2s_async import request_async from a2s.a2s_async import request_async
from a2s.a2s_protocol import A2SProtocol
from a2s.a2s_sync import request_sync from a2s.a2s_sync import request_sync
from a2s.byteio import ByteReader from a2s.byteio import ByteReader
from a2s.defaults import DEFAULT_ENCODING, DEFAULT_TIMEOUT from a2s.defaults import DEFAULT_ENCODING, DEFAULT_TIMEOUT
@ -29,7 +30,7 @@ async def arules(
return await request_async(address, timeout, encoding, RulesProtocol) return await request_async(address, timeout, encoding, RulesProtocol)
class RulesProtocol: class RulesProtocol(A2SProtocol):
@staticmethod @staticmethod
def validate_response_type(response_type: int) -> bool: def validate_response_type(response_type: int) -> bool:
return response_type == A2S_RULES_RESPONSE return response_type == A2S_RULES_RESPONSE

Loading…
Cancel
Save