diff --git a/.gitignore b/.gitignore index e32c3ef..d0b4e09 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ __pycache__ build dist *.egg-info - +venv +.venv diff --git a/a2s/datacls.py b/a2s/datacls.py deleted file mode 100644 index e8152cb..0000000 --- a/a2s/datacls.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -Cheap dataclasses module backport - -Check out the official documentation to see what this is trying to -achieve: -https://docs.python.org/3/library/dataclasses.html -""" - -import collections -import copy - -class DataclsBase: - def __init__(self, **kwargs): - for name, value in self._defaults.items(): - if name in kwargs: - value = kwargs[name] - setattr(self, name, copy.copy(value)) - - def __iter__(self): - for name in self.__annotations__: - yield (name, getattr(self, name)) - - def __repr__(self): - return "{}({})".format( - self.__class__.__name__, - ", ".join(name + "=" + repr(value) for name, value in self)) - -class DataclsMeta(type): - def __new__(cls, name, bases, prop): - values = collections.OrderedDict() - for member_name in prop["__annotations__"].keys(): - # Check if member has a default value set as class variable - if member_name in prop: - # Store default value and remove the class variable - values[member_name] = prop[member_name] - del prop[member_name] - else: - # Set None as the default value - values[member_name] = None - - prop["__slots__"] = list(values.keys()) - prop["_defaults"] = values - bases = (DataclsBase, *bases) - return super().__new__(cls, name, bases, prop) - - def __prepare__(self, *args, **kwargs): - return collections.OrderedDict() diff --git a/a2s/info.py b/a2s/info.py index 2fd3d5a..55087b7 100644 --- a/a2s/info.py +++ b/a2s/info.py @@ -1,11 +1,12 @@ import io +from dataclasses import dataclass +from typing import Optional, Generic, TypeVar, overload from a2s.exceptions import BrokenMessageError, BufferExhaustedError from a2s.defaults import DEFAULT_TIMEOUT, DEFAULT_ENCODING from a2s.a2s_sync import request_sync from a2s.a2s_async import request_async from a2s.byteio import ByteReader -from a2s.datacls import DataclsMeta @@ -13,20 +14,23 @@ A2S_INFO_RESPONSE = 0x49 A2S_INFO_RESPONSE_LEGACY = 0x6D -class SourceInfo(metaclass=DataclsMeta): +StrType = TypeVar("StrType", str, bytes) # str (default) or bytes if encoding=None is used + +@dataclass +class SourceInfo(Generic[StrType]): protocol: int """Protocol version used by the server""" - server_name: str + server_name: StrType """Display name of the server""" - map_name: str + map_name: StrType """The currently loaded map""" - folder: str + folder: StrType """Name of the game directory""" - game: str + game: StrType """Name of the game""" app_id: int @@ -41,13 +45,13 @@ class SourceInfo(metaclass=DataclsMeta): bot_count: int """Number of bots on the server""" - server_type: str + server_type: StrType """Type of the server: 'd': Dedicated server 'l': Non-dedicated server 'p': SourceTV relay (proxy)""" - platform: str + platform: StrType """Operating system of the server 'l', 'w', 'm' for Linux, Windows, macOS""" @@ -57,36 +61,34 @@ class SourceInfo(metaclass=DataclsMeta): vac_enabled: bool """Server has VAC enabled""" - version: str + version: StrType """Version of the server software""" - # Optional: - edf: int = 0 - """Extra data field, used to indicate if extra values are - included in the response""" + edf: int + """Extra data field, used to indicate if extra values are included in the response""" - port: int + ping: float + """Round-trip time for the request in seconds, not actually sent by the server""" + + # Optional: + port: Optional[int] = None """Port of the game server.""" - steam_id: int + steam_id: Optional[int] = None """Steam ID of the server""" - stv_port: int + stv_port: Optional[int] = None """Port of the SourceTV server""" - stv_name: str + stv_name: Optional[StrType] = None """Name of the SourceTV server""" - keywords: str + keywords: Optional[StrType] = None """Tags that describe the gamemode being played""" - game_id: int + game_id: Optional[int] = None """Game ID for games that have an app ID too high for 16bit.""" - # Client determined values: - ping: float - """Round-trip delay time for the request in seconds""" - @property def has_port(self): return bool(self.edf & 0x80) @@ -107,20 +109,21 @@ class SourceInfo(metaclass=DataclsMeta): def has_game_id(self): return bool(self.edf & 0x01) -class GoldSrcInfo(metaclass=DataclsMeta): - address: str +@dataclass +class GoldSrcInfo(Generic[StrType]): + address: StrType """IP Address and port of the server""" - server_name: str + server_name: StrType """Display name of the server""" - map_name: str + map_name: StrType """The currently loaded map""" - folder: str + folder: StrType """Name of the game directory""" - game: str + game: StrType """Name of the game""" player_count: int @@ -132,13 +135,13 @@ class GoldSrcInfo(metaclass=DataclsMeta): protocol: int """Protocol version used by the server""" - server_type: str + server_type: StrType """Type of the server: 'd': Dedicated server 'l': Non-dedicated server 'p': SourceTV relay (proxy)""" - platform: str + platform: StrType """Operating system of the server 'l', 'w' for Linux and Windows""" @@ -154,34 +157,62 @@ class GoldSrcInfo(metaclass=DataclsMeta): bot_count: int """Number of bots on the server""" + ping: float + """Round-trip time for the request in seconds, not actually sent by the server""" + # Optional: - mod_website: str + mod_website: Optional[StrType] """URL to the mod website""" - mod_download: str + mod_download: Optional[StrType] """URL to download the mod""" - mod_version: int + mod_version: Optional[int] """Version of the mod installed on the server""" - mod_size: int + mod_size: Optional[int] """Size in bytes of the mod""" - multiplayer_only: bool = False + multiplayer_only: Optional[bool] """Mod supports multiplayer only""" - uses_hl_dll: bool = True + uses_custom_dll: Optional[bool] """Mod uses a custom DLL""" - # Client determined values: - ping: float - """Round-trip delay time for the request in seconds""" + @property + def uses_hl_dll(self) -> Optional[bool]: + """Compatibility alias, because it got renamed""" + return self.uses_custom_dll + +@overload +def info(address: tuple[str, int], timeout: float, encoding: str) -> SourceInfo[str] | GoldSrcInfo[str]: + ... -def info(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): +@overload +def info(address: tuple[str, int], timeout: float, encoding: None) -> SourceInfo[bytes] | GoldSrcInfo[bytes]: + ... + +def info( + address: tuple[str, int], + timeout: float = DEFAULT_TIMEOUT, + encoding: str | None = DEFAULT_ENCODING +) -> SourceInfo[str] | SourceInfo[bytes] | GoldSrcInfo[str] | GoldSrcInfo[bytes]: return request_sync(address, timeout, encoding, InfoProtocol) -async def ainfo(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): +@overload +async def ainfo(address: tuple[str, int], timeout: float, encoding: str) -> SourceInfo[str] | GoldSrcInfo[str]: + ... + +@overload +async def ainfo(address: tuple[str, int], timeout: float, encoding: None) -> SourceInfo[bytes] | GoldSrcInfo[bytes]: + ... + +async def ainfo( + address: tuple[str, int], + timeout: float = DEFAULT_TIMEOUT, + encoding: str | None = DEFAULT_ENCODING +) -> SourceInfo[str] | SourceInfo[bytes] | GoldSrcInfo[str] | GoldSrcInfo[bytes]: return await request_async(address, timeout, encoding, InfoProtocol) @@ -200,39 +231,41 @@ class InfoProtocol: @staticmethod def deserialize_response(reader, response_type, ping): if response_type == A2S_INFO_RESPONSE: - resp = parse_source(reader) + resp = parse_source(reader, ping) elif response_type == A2S_INFO_RESPONSE_LEGACY: - resp = parse_goldsrc(reader) + resp = parse_goldsrc(reader, ping) else: raise Exception(str(response_type)) - resp.ping = ping return resp -def parse_source(reader): - resp = SourceInfo() - resp.protocol = reader.read_uint8() - resp.server_name = reader.read_cstring() - resp.map_name = reader.read_cstring() - resp.folder = reader.read_cstring() - resp.game = reader.read_cstring() - resp.app_id = reader.read_uint16() - resp.player_count = reader.read_uint8() - resp.max_players = reader.read_uint8() - resp.bot_count = reader.read_uint8() - resp.server_type = reader.read_char().lower() - resp.platform = reader.read_char().lower() - if resp.platform == "o": # Deprecated mac value - resp.platform = "m" - resp.password_protected = reader.read_bool() - resp.vac_enabled = reader.read_bool() - resp.version = reader.read_cstring() +def parse_source(reader, ping): + protocol = reader.read_uint8() + server_name = reader.read_cstring() + map_name = reader.read_cstring() + folder = reader.read_cstring() + game = reader.read_cstring() + app_id = reader.read_uint16() + player_count = reader.read_uint8() + max_players = reader.read_uint8() + bot_count = reader.read_uint8() + server_type = reader.read_char().lower() + platform = reader.read_char().lower() + if platform == "o": # Deprecated mac value + platform = "m" + password_protected = reader.read_bool() + vac_enabled = reader.read_bool() + version = reader.read_cstring() try: - resp.edf = reader.read_uint8() + edf = reader.read_uint8() except BufferExhaustedError: - pass + edf = 0 + resp = SourceInfo( + protocol, server_name, map_name, folder, game, app_id, player_count, max_players, + bot_count, server_type, platform, password_protected, vac_enabled, version, edf, ping + ) if resp.has_port: resp.port = reader.read_uint16() if resp.has_steam_id: @@ -247,32 +280,42 @@ def parse_source(reader): return resp -def parse_goldsrc(reader): - resp = GoldSrcInfo() - resp.address = reader.read_cstring() - resp.server_name = reader.read_cstring() - resp.map_name = reader.read_cstring() - resp.folder = reader.read_cstring() - resp.game = reader.read_cstring() - resp.player_count = reader.read_uint8() - resp.max_players = reader.read_uint8() - resp.protocol = reader.read_uint8() - resp.server_type = reader.read_char() - resp.platform = reader.read_char() - resp.password_protected = reader.read_bool() - resp.is_mod = reader.read_bool() +def parse_goldsrc(reader, ping): + address = reader.read_cstring() + server_name = reader.read_cstring() + map_name = reader.read_cstring() + folder = reader.read_cstring() + game = reader.read_cstring() + player_count = reader.read_uint8() + max_players = reader.read_uint8() + protocol = reader.read_uint8() + server_type = reader.read_char() + platform = reader.read_char() + password_protected = reader.read_bool() + is_mod = reader.read_bool() # Some games don't send this section - if resp.is_mod and len(reader.peek()) > 2: - resp.mod_website = reader.read_cstring() - resp.mod_download = reader.read_cstring() + if is_mod and len(reader.peek()) > 2: + mod_website = reader.read_cstring() + mod_download = reader.read_cstring() reader.read(1) # Skip a NULL byte - resp.mod_version = reader.read_uint32() - resp.mod_size = reader.read_uint32() - resp.multiplayer_only = reader.read_bool() - resp.uses_custom_dll = reader.read_bool() - - resp.vac_enabled = reader.read_bool() - resp.bot_count = reader.read_uint8() - - return resp + mod_version = reader.read_uint32() + mod_size = reader.read_uint32() + multiplayer_only = reader.read_bool() + uses_custom_dll = reader.read_bool() + else: + mod_website = None + mod_download = None + mod_version = None + mod_size = None + multiplayer_only = None + uses_custom_dll = None + + vac_enabled = reader.read_bool() + bot_count = reader.read_uint8() + + return GoldSrcInfo( + address, server_name, map_name, folder, game, player_count, max_players, protocol, + server_type, platform, password_protected, is_mod, vac_enabled, bot_count, mod_website, + mod_download, mod_version, mod_size, multiplayer_only, uses_custom_dll, ping + ) diff --git a/a2s/players.py b/a2s/players.py index f95289a..1c302e5 100644 --- a/a2s/players.py +++ b/a2s/players.py @@ -1,21 +1,25 @@ import io +from dataclasses import dataclass +from typing import Generic, TypeVar, overload from a2s.defaults import DEFAULT_TIMEOUT, DEFAULT_ENCODING from a2s.a2s_sync import request_sync from a2s.a2s_async import request_async from a2s.byteio import ByteReader -from a2s.datacls import DataclsMeta A2S_PLAYER_RESPONSE = 0x44 -class Player(metaclass=DataclsMeta): +StrType = TypeVar("StrType", str, bytes) # str (default) or bytes if encoding=None is used + +@dataclass +class Player(Generic[StrType]): index: int """Apparently an entry index, but seems to be always 0""" - name: str + name: StrType """Name of the player""" score: int @@ -25,10 +29,34 @@ class Player(metaclass=DataclsMeta): """Time the player has been connected to the server""" -def players(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): +@overload +def players(address: tuple[str, int], timeout: float, encoding: str) -> list[Player[str]]: + ... + +@overload +def players(address: tuple[str, int], timeout: float, encoding: None) -> list[Player[bytes]]: + ... + +def players( + address: tuple[str, int], + timeout: float = DEFAULT_TIMEOUT, + encoding: str | None = DEFAULT_ENCODING +) -> list[Player[str]] | list[Player[bytes]]: return request_sync(address, timeout, encoding, PlayersProtocol) -async def aplayers(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): +@overload +async def aplayers(address: tuple[str, int], timeout: float, encoding: str) -> list[Player[str]]: + ... + +@overload +async def aplayers(address: tuple[str, int], timeout: float, encoding: None) -> list[Player[bytes]]: + ... + +async def aplayers( + address: tuple[str, int], + timeout: float = DEFAULT_TIMEOUT, + encoding: str | None = DEFAULT_ENCODING +) -> list[Player[str]] | list[Player[bytes]]: return await request_async(address, timeout, encoding, PlayersProtocol) diff --git a/a2s/rules.py b/a2s/rules.py index 1224f72..74a2423 100644 --- a/a2s/rules.py +++ b/a2s/rules.py @@ -1,20 +1,44 @@ import io +from typing import overload from a2s.defaults import DEFAULT_TIMEOUT, DEFAULT_ENCODING from a2s.a2s_sync import request_sync from a2s.a2s_async import request_async from a2s.byteio import ByteReader -from a2s.datacls import DataclsMeta A2S_RULES_RESPONSE = 0x45 -def rules(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): +@overload +def rules(address: tuple[str, int], timeout: float, encoding: str) -> dict[str, str]: + ... + +@overload +def rules(address: tuple[str, int], timeout: float, encoding: None) -> dict[bytes, bytes]: + ... + +def rules( + address: tuple[str, int], + timeout: float = DEFAULT_TIMEOUT, + encoding: str | None = DEFAULT_ENCODING +) -> dict[str, str] | dict[bytes, bytes]: return request_sync(address, timeout, encoding, RulesProtocol) -async def arules(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): +@overload +async def arules(address: tuple[str, int], timeout: float, encoding: str) -> dict[str, str]: + ... + +@overload +async def arules(address: tuple[str, int], timeout: float, encoding: None) -> dict[bytes, bytes]: + ... + +async def arules( + address: tuple[str, int], + timeout: float = DEFAULT_TIMEOUT, + encoding: str | None = DEFAULT_ENCODING +) -> dict[str, str] | dict[bytes, bytes]: return await request_async(address, timeout, encoding, RulesProtocol)