From 43150c13363cecd32a63d9b5c863a980741ba877 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20N=C3=B8rgaard?= Date: Sat, 31 Dec 2022 18:30:30 +0000 Subject: [PATCH] complete typing --- .github/workflows/coverage_and_lint.yml | 61 +++++++++++ .gitignore | 2 +- MANIFEST.in | 3 + a2s/__init__.py | 32 +++++- a2s/a2s_async.py | 132 +++++++++++++++-------- a2s/a2s_async.pyi | 48 +++++++++ a2s/a2s_fragment.py | 37 ++++--- a2s/a2s_sync.py | 65 ++++++----- a2s/a2s_sync.pyi | 48 +++++++++ a2s/byteio.py | 109 ++++++++++--------- a2s/datacls.py | 24 +++-- a2s/exceptions.py | 1 + a2s/info.py | 137 +++++++++++++----------- a2s/players.py | 26 ++--- a2s/py.typed | 0 a2s/rules.py | 27 +++-- pyproject.toml | 19 ++++ setup.py | 4 +- 18 files changed, 533 insertions(+), 242 deletions(-) create mode 100644 .github/workflows/coverage_and_lint.yml create mode 100644 MANIFEST.in create mode 100644 a2s/a2s_async.pyi create mode 100644 a2s/a2s_sync.pyi create mode 100644 a2s/py.typed create mode 100644 pyproject.toml diff --git a/.github/workflows/coverage_and_lint.yml b/.github/workflows/coverage_and_lint.yml new file mode 100644 index 0000000..ade407b --- /dev/null +++ b/.github/workflows/coverage_and_lint.yml @@ -0,0 +1,61 @@ +name: Type Coverage and Linting + +on: + push: + branches: + - master + pull_request: + branches: + - master + types: + - opened + - synchronize + +jobs: + job: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [ '3.7', '3.8', '3.9', '3.10', '3.11' ] + + name: "Type Coverage and Linting @ ${{ matrix.python-version }}" + steps: + - name: "Checkout Repository" + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: "Setup Python @ ${{ matrix.python-version }}" + uses: actions/setup-python@v3 + with: + python-version: "${{ matrix.python-version }}" + + - name: "Install Python deps @ ${{ matrix.python-version }}" + env: + PY_VER: "${{ matrix.python-version }}" + run: | + pip install -U . + + - uses: actions/setup-node@v3 + with: + node-version: "17" + - run: npm install --location=global pyright@latest + + - name: "Type Coverage @ ${{ matrix.python-version }}" + run: | + pyright + pyright --ignoreexternal --lib --verifytypes a2s + + - name: Lint + if: ${{ github.event_name != 'pull_request' }} + uses: github/super-linter/slim@v4 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + DEFAULT_BRANCH: master + VALIDATE_ALL_CODEBASE: false + VALIDATE_PYTHON_BLACK: true + VALIDATE_PYTHON_ISORT: true + LINTER_RULES_PATH: / + PYTHON_ISORT_CONFIG_FILE: pyproject.toml + PYTHON_BLACK_CONFIG_FILE: pyproject.toml diff --git a/.gitignore b/.gitignore index e32c3ef..d1a8619 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,4 @@ __pycache__ build dist *.egg-info - +.venv/ diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..c9e6c81 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +include README.md +include LICENSE +include a2s/py.typed diff --git a/a2s/__init__.py b/a2s/__init__.py index 0251452..c68a690 100644 --- a/a2s/__init__.py +++ b/a2s/__init__.py @@ -1,5 +1,29 @@ -from a2s.exceptions import BrokenMessageError, BufferExhaustedError +""" +MIT License -from a2s.info import info, ainfo, SourceInfo, GoldSrcInfo -from a2s.players import players, aplayers, Player -from a2s.rules import rules, arules +Copyright (c) 2020 Gabriel Huber + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from a2s.exceptions import BrokenMessageError as BrokenMessageError, BufferExhaustedError as BufferExhaustedError + +from a2s.info import info as info, ainfo as ainfo, SourceInfo as SourceInfo, GoldSrcInfo as GoldSrcInfo +from a2s.players import players as players, aplayers as aplayers, Player as Player +from a2s.rules import rules as rules, arules as arules diff --git a/a2s/a2s_async.py b/a2s/a2s_async.py index 5131c35..73ac321 100644 --- a/a2s/a2s_async.py +++ b/a2s/a2s_async.py @@ -1,29 +1,73 @@ +""" +MIT License + +Copyright (c) 2020 Gabriel Huber + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +from __future__ import annotations + import asyncio +import io import logging import time -import io +from typing import TYPE_CHECKING, Dict, List, NoReturn, Optional, Tuple, Type, TypeVar, Union -from a2s.exceptions import BrokenMessageError -from a2s.a2s_fragment import decode_fragment -from a2s.defaults import DEFAULT_RETRIES +from a2s.a2s_fragment import A2SFragment, decode_fragment from a2s.byteio import ByteReader +from a2s.defaults import DEFAULT_RETRIES +from a2s.exceptions import BrokenMessageError +from .info import GoldSrcInfo, InfoProtocol, SourceInfo +from .players import Player, PlayersProtocol +from .rules import RulesProtocol +if TYPE_CHECKING: + from typing_extensions import Self HEADER_SIMPLE = b"\xFF\xFF\xFF\xFF" HEADER_MULTI = b"\xFE\xFF\xFF\xFF" A2S_CHALLENGE_RESPONSE = 0x41 +PROTOCOLS = Union[InfoProtocol, PlayersProtocol, RulesProtocol] + +logger: logging.Logger = logging.getLogger("a2s") -logger = logging.getLogger("a2s") +T = TypeVar("T", bound=PROTOCOLS) -async def request_async(address, timeout, encoding, a2s_proto): +async def request_async( + address: Tuple[str, int], timeout: float, encoding: str, a2s_proto: Type[T] +) -> Union[SourceInfo, GoldSrcInfo, List[Player], Dict[str, str]]: conn = await A2SStreamAsync.create(address, timeout) response = await request_async_impl(conn, encoding, a2s_proto) conn.close() return response -async def request_async_impl(conn, encoding, a2s_proto, challenge=0, retries=0, ping=None): + +async def request_async_impl( + conn: A2SStreamAsync, + encoding: str, + a2s_proto: Type[T], + challenge: int = 0, + retries: int = 0, + ping: Optional[float] = None, +) -> Union[SourceInfo, GoldSrcInfo, Dict[str, str], List[Player]]: send_time = time.monotonic() resp_data = await conn.request(a2s_proto.serialize_request(challenge)) recv_time = time.monotonic() @@ -31,108 +75,104 @@ async def request_async_impl(conn, encoding, a2s_proto, challenge=0, retries=0, if retries == 0: ping = recv_time - send_time - reader = ByteReader( - io.BytesIO(resp_data), endian="<", encoding=encoding) + reader = ByteReader(io.BytesIO(resp_data), endian="<", encoding=encoding) response_type = reader.read_uint8() if response_type == A2S_CHALLENGE_RESPONSE: if retries >= DEFAULT_RETRIES: - raise BrokenMessageError( - "Server keeps sending challenge responses") + raise BrokenMessageError("Server keeps sending challenge responses") challenge = reader.read_uint32() - return await request_async_impl( - conn, encoding, a2s_proto, challenge, retries + 1, ping) + return await request_async_impl(conn, encoding, a2s_proto, challenge, retries + 1, ping) if not a2s_proto.validate_response_type(response_type): - raise BrokenMessageError( - "Invalid response type: " + hex(response_type)) + raise BrokenMessageError("Invalid response type: " + hex(response_type)) return a2s_proto.deserialize_response(reader, response_type, ping) class A2SProtocol(asyncio.DatagramProtocol): def __init__(self): - self.recv_queue = asyncio.Queue() - self.error_event = asyncio.Event() - self.error = None - self.fragment_buf = [] + self.recv_queue: asyncio.Queue[bytes] = asyncio.Queue() + self.error_event: asyncio.Event = asyncio.Event() + self.error: Optional[Exception] = None + self.fragment_buf: List[A2SFragment] = [] - def connection_made(self, transport): + def connection_made(self, transport: asyncio.DatagramTransport) -> None: self.transport = transport - def datagram_received(self, packet, addr): - header = packet[:4] - payload = packet[4:] + def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None: + header = data[:4] + payload = data[4:] if header == HEADER_SIMPLE: logger.debug("Received single packet: %r", payload) self.recv_queue.put_nowait(payload) elif header == HEADER_MULTI: self.fragment_buf.append(decode_fragment(payload)) if len(self.fragment_buf) < self.fragment_buf[0].fragment_count: - return # Wait for more packets to arrive + return # Wait for more packets to arrive self.fragment_buf.sort(key=lambda f: f.fragment_id) - reassembled = b"".join( - fragment.payload for fragment in self.fragment_buf) + reassembled = b"".join(fragment.payload for fragment in self.fragment_buf) # Sometimes there's an additional header present if reassembled.startswith(b"\xFF\xFF\xFF\xFF"): reassembled = reassembled[4:] - logger.debug("Received %s part packet with content: %r", - len(self.fragment_buf), reassembled) + logger.debug("Received %s part packet with content: %r", len(self.fragment_buf), reassembled) self.recv_queue.put_nowait(reassembled) self.fragment_buf = [] else: - self.error = BrokenMessageError( - "Invalid packet header: " + repr(header)) + self.error = BrokenMessageError("Invalid packet header: " + repr(header)) self.error_event.set() - def error_received(self, exc): + def error_received(self, exc: Exception) -> None: self.error = exc self.error_event.set() - def raise_on_error(self): - error = self.error + def raise_on_error(self) -> NoReturn: + assert self.error + error: Exception = self.error self.error = None self.error_event.clear() raise error + class A2SStreamAsync: - def __init__(self, transport, protocol, timeout): + def __init__(self, transport: asyncio.DatagramTransport, protocol: A2SProtocol, timeout: float) -> None: self.transport = transport self.protocol = protocol self.timeout = timeout - def __del__(self): + def __del__(self) -> None: self.close() @classmethod - async def create(cls, address, timeout): + async def create(cls, address: Tuple[str, int], timeout: float) -> Self: loop = asyncio.get_running_loop() - transport, protocol = await loop.create_datagram_endpoint( - lambda: A2SProtocol(), remote_addr=address) + transport, protocol = await loop.create_datagram_endpoint(lambda: A2SProtocol(), remote_addr=address) return cls(transport, protocol, timeout) - def send(self, payload): + def send(self, payload: bytes) -> None: logger.debug("Sending packet: %r", payload) packet = HEADER_SIMPLE + payload self.transport.sendto(packet) - async def recv(self): + async def recv(self) -> bytes: queue_task = asyncio.create_task(self.protocol.recv_queue.get()) error_task = asyncio.create_task(self.protocol.error_event.wait()) - done, pending = await asyncio.wait({queue_task, error_task}, - timeout=self.timeout, return_when=asyncio.FIRST_COMPLETED) + done, pending = await asyncio.wait( + {queue_task, error_task}, timeout=self.timeout, return_when=asyncio.FIRST_COMPLETED + ) - for task in pending: task.cancel() + for task in pending: + task.cancel() if error_task in done: - self.protocol.raise_on_error() + self.protocol.raise_on_error() if not done: raise asyncio.TimeoutError() return queue_task.result() - async def request(self, payload): + async def request(self, payload: bytes) -> bytes: self.send(payload) return await self.recv() - def close(self): + def close(self) -> None: self.transport.close() diff --git a/a2s/a2s_async.pyi b/a2s/a2s_async.pyi new file mode 100644 index 0000000..8091341 --- /dev/null +++ b/a2s/a2s_async.pyi @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union, overload + +from .a2s_async import A2SStreamAsync + +if TYPE_CHECKING: + from .info import GoldSrcInfo, InfoProtocol, SourceInfo + from .players import Player, PlayersProtocol + from .rules import RulesProtocol + +@overload +async def request_async( + address: Tuple[str, int], timeout: float, encoding: str, a2s_proto: Type[InfoProtocol] +) -> Union[SourceInfo, GoldSrcInfo]: ... +@overload +async def request_async( + address: Tuple[str, int], timeout: float, encoding: str, a2s_proto: Type[PlayersProtocol] +) -> List[Player]: ... +@overload +async def request_async( + address: Tuple[str, int], timeout: float, encoding: str, a2s_proto: Type[RulesProtocol] +) -> Dict[str, str]: ... +@overload +async def request_async_impl( + conn: A2SStreamAsync, + encoding: str, + a2s_proto: Type[InfoProtocol], + challenge: int = ..., + retries: int = ..., + ping: Optional[float] = ..., +) -> Union[SourceInfo, GoldSrcInfo]: ... +@overload +async def request_async_impl( + conn: A2SStreamAsync, + encoding: str, + a2s_proto: Type[PlayersProtocol], + challenge: int = ..., + retries: int = ..., + ping: Optional[float] = ..., +) -> List[Player]: ... +@overload +async def request_async_impl( + conn: A2SStreamAsync, + encoding: str, + a2s_proto: Type[RulesProtocol], + challenge: int = ..., + retries: int = ..., + ping: Optional[float] = ..., +) -> Dict[str, str]: ... diff --git a/a2s/a2s_fragment.py b/a2s/a2s_fragment.py index 38a40ba..9758333 100644 --- a/a2s/a2s_fragment.py +++ b/a2s/a2s_fragment.py @@ -4,30 +4,37 @@ import io from a2s.byteio import ByteReader - class A2SFragment: - def __init__(self, message_id, fragment_count, fragment_id, mtu, - decompressed_size=0, crc=0, payload=b""): - self.message_id = message_id - self.fragment_count = fragment_count - self.fragment_id = fragment_id - self.mtu = mtu - self.decompressed_size = decompressed_size - self.crc = crc - self.payload = payload + def __init__( + self, + message_id: int, + fragment_count: int, + fragment_id: int, + mtu: int, + decompressed_size: int = 0, + crc: int = 0, + payload: bytes = b"", + ) -> None: + self.message_id: int = message_id + self.fragment_count: int = fragment_count + self.fragment_id: int = fragment_id + self.mtu: int = mtu + self.decompressed_size: int = decompressed_size + self.crc: int = crc + self.payload: bytes = payload @property - def is_compressed(self): + def is_compressed(self) -> bool: return bool(self.message_id & (1 << 15)) -def decode_fragment(data): - reader = ByteReader( - io.BytesIO(data), endian="<", encoding="utf-8") + +def decode_fragment(data: bytes) -> A2SFragment: + reader = ByteReader(io.BytesIO(data), endian="<", encoding="utf-8") frag = A2SFragment( message_id=reader.read_uint32(), fragment_count=reader.read_uint8(), fragment_id=reader.read_uint8(), - mtu=reader.read_uint16() + mtu=reader.read_uint16(), ) if frag.is_compressed: frag.decompressed_size = reader.read_uint32() diff --git a/a2s/a2s_sync.py b/a2s/a2s_sync.py index 6e97e33..0950660 100644 --- a/a2s/a2s_sync.py +++ b/a2s/a2s_sync.py @@ -1,29 +1,42 @@ -import socket +from __future__ import annotations + +import io import logging +import socket import time -import io +from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union -from a2s.exceptions import BrokenMessageError from a2s.a2s_fragment import decode_fragment -from a2s.defaults import DEFAULT_RETRIES from a2s.byteio import ByteReader +from a2s.defaults import DEFAULT_RETRIES +from a2s.exceptions import BrokenMessageError - +from .info import GoldSrcInfo, InfoProtocol, SourceInfo +from .players import Player, PlayersProtocol +from .rules import RulesProtocol HEADER_SIMPLE = b"\xFF\xFF\xFF\xFF" HEADER_MULTI = b"\xFE\xFF\xFF\xFF" A2S_CHALLENGE_RESPONSE = 0x41 +PROTOCOLS = Union[InfoProtocol, RulesProtocol, PlayersProtocol] + +logger: logging.Logger = logging.getLogger("a2s") -logger = logging.getLogger("a2s") +T = TypeVar("T", InfoProtocol, RulesProtocol, PlayersProtocol) -def request_sync(address, timeout, encoding, a2s_proto): +def request_sync( + address: Tuple[str, int], timeout: float, encoding: str, a2s_proto: Type[T] +) -> Union[List[Player], GoldSrcInfo, SourceInfo, Dict[str, str]]: conn = A2SStream(address, timeout) - response = request_sync_impl(conn, encoding, a2s_proto) + response = request_sync_impl(conn, encoding, a2s_proto) # type: ignore conn.close() return response -def request_sync_impl(conn, encoding, a2s_proto, challenge=0, retries=0, ping=None): + +def request_sync_impl( + conn: A2SStream, encoding: str, a2s_proto: Type[T], challenge: int = 0, retries: int = 0, ping: Optional[float] = None +) -> Union[SourceInfo, GoldSrcInfo, Dict[str, str], List[Player]]: send_time = time.monotonic() resp_data = conn.request(a2s_proto.serialize_request(challenge)) recv_time = time.monotonic() @@ -31,40 +44,36 @@ def request_sync_impl(conn, encoding, a2s_proto, challenge=0, retries=0, ping=No if retries == 0: ping = recv_time - send_time - reader = ByteReader( - io.BytesIO(resp_data), endian="<", encoding=encoding) + reader = ByteReader(io.BytesIO(resp_data), endian="<", encoding=encoding) response_type = reader.read_uint8() if response_type == A2S_CHALLENGE_RESPONSE: if retries >= DEFAULT_RETRIES: - raise BrokenMessageError( - "Server keeps sending challenge responses") + raise BrokenMessageError("Server keeps sending challenge responses") challenge = reader.read_uint32() - return request_sync_impl( - conn, encoding, a2s_proto, challenge, retries + 1, ping) + return request_sync_impl(conn, encoding, a2s_proto, challenge, retries + 1, ping) if not a2s_proto.validate_response_type(response_type): - raise BrokenMessageError( - "Invalid response type: " + hex(response_type)) + raise BrokenMessageError("Invalid response type: " + hex(response_type)) return a2s_proto.deserialize_response(reader, response_type, ping) class A2SStream: - def __init__(self, address, timeout): - self.address = address - self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + def __init__(self, address: Tuple[str, int], timeout: float) -> None: + self.address: Tuple[str, int] = address + self._socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self._socket.settimeout(timeout) - def __del__(self): + def __del__(self) -> None: self.close() - def send(self, data): + def send(self, data: bytes) -> None: logger.debug("Sending packet: %r", data) packet = HEADER_SIMPLE + data self._socket.sendto(packet, self.address) - def recv(self): + def recv(self) -> bytes: packet = self._socket.recv(65535) header = packet[:4] data = packet[4:] @@ -81,16 +90,14 @@ class A2SStream: # Sometimes there's an additional header present if reassembled.startswith(b"\xFF\xFF\xFF\xFF"): reassembled = reassembled[4:] - logger.debug("Received %s part packet with content: %r", - len(fragments), reassembled) + logger.debug("Received %s part packet with content: %r", len(fragments), reassembled) return reassembled else: - raise BrokenMessageError( - "Invalid packet header: " + repr(header)) + raise BrokenMessageError("Invalid packet header: " + repr(header)) - def request(self, payload): + def request(self, payload: bytes) -> bytes: self.send(payload) return self.recv() - def close(self): + def close(self) -> None: self._socket.close() diff --git a/a2s/a2s_sync.pyi b/a2s/a2s_sync.pyi new file mode 100644 index 0000000..a2f48ea --- /dev/null +++ b/a2s/a2s_sync.pyi @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union, overload + +from .a2s_sync import A2SStream + +if TYPE_CHECKING: + from .info import GoldSrcInfo, InfoProtocol, SourceInfo + from .players import Player, PlayersProtocol + from .rules import RulesProtocol + +@overload +def request_sync( + address: Tuple[str, int], timeout: float, encoding: str, a2s_proto: Type[InfoProtocol] +) -> Union[SourceInfo, GoldSrcInfo]: ... +@overload +def request_sync( + address: Tuple[str, int], timeout: float, encoding: str, a2s_proto: Type[PlayersProtocol] +) -> List[Player]: ... +@overload +def request_sync( + address: Tuple[str, int], timeout: float, encoding: str, a2s_proto: Type[RulesProtocol] +) -> Dict[str, str]: ... +@overload +def request_sync_impl( + conn: A2SStream, + encoding: str, + a2s_proto: Type[InfoProtocol], + challenge: int = ..., + retries: int = ..., + ping: Optional[float] = ..., +) -> Union[SourceInfo, GoldSrcInfo]: ... +@overload +def request_sync_impl( + conn: A2SStream, + encoding: str, + a2s_proto: Type[PlayersProtocol], + challenge: int = ..., + retries: int = ..., + ping: Optional[float] = ..., +) -> List[Player]: ... +@overload +def request_sync_impl( + conn: A2SStream, + encoding: str, + a2s_proto: Type[RulesProtocol], + challenge: int = ..., + retries: int = ..., + ping: Optional[float] = ..., +) -> Dict[str, str]: ... diff --git a/a2s/byteio.py b/a2s/byteio.py index b40aaba..25e3a95 100644 --- a/a2s/byteio.py +++ b/a2s/byteio.py @@ -1,80 +1,92 @@ -import struct +from __future__ import annotations + import io +import struct +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union from a2s.exceptions import BufferExhaustedError +from .defaults import DEFAULT_ENCODING + +if TYPE_CHECKING: + from typing_extensions import Literal -class ByteReader(): - def __init__(self, stream, endian="=", encoding=None): - self.stream = stream - self.endian = endian - self.encoding = encoding +STRUCT_OPTIONS = Literal[ + "x", "c", "b", "B", "?", "h", "H", "i", "I", "l", "L", "q", "Q", "n", "N", "e", "f", "d", "s", "p", "P" +] - def read(self, size=-1): + +class ByteReader: + def __init__(self, stream: io.BytesIO, endian: str = "=", encoding: Optional[str] = None) -> None: + self.stream: io.BytesIO = stream + self.endian: str = endian + self.encoding: Optional[str] = encoding + + def read(self, size: int = -1) -> bytes: data = self.stream.read(size) if size > -1 and len(data) != size: raise BufferExhaustedError() return data - def peek(self, size=-1): + def peek(self, size: int = -1) -> bytes: cur_pos = self.stream.tell() data = self.stream.read(size) self.stream.seek(cur_pos, io.SEEK_SET) return data - def unpack(self, fmt): - fmt = self.endian + fmt + def unpack(self, fmt: STRUCT_OPTIONS) -> Tuple[Any, ...]: + new_fmt = self.endian + fmt fmt_size = struct.calcsize(fmt) - return struct.unpack(fmt, self.read(fmt_size)) + return struct.unpack(new_fmt, self.read(fmt_size)) - def unpack_one(self, fmt): + def unpack_one(self, fmt: STRUCT_OPTIONS) -> Any: values = self.unpack(fmt) assert len(values) == 1 return values[0] - def read_int8(self): + def read_int8(self) -> int: return self.unpack_one("b") - def read_uint8(self): + def read_uint8(self) -> int: return self.unpack_one("B") - def read_int16(self): + def read_int16(self) -> int: return self.unpack_one("h") - def read_uint16(self): + def read_uint16(self) -> int: return self.unpack_one("H") - def read_int32(self): + def read_int32(self) -> int: return self.unpack_one("l") - def read_uint32(self): + def read_uint32(self) -> int: return self.unpack_one("L") - def read_int64(self): + def read_int64(self) -> int: return self.unpack_one("q") - def read_uint64(self): + def read_uint64(self) -> int: return self.unpack_one("Q") - def read_float(self): + def read_float(self) -> float: return self.unpack_one("f") - def read_double(self): + def read_double(self) -> float: return self.unpack_one("d") - def read_bool(self): + def read_bool(self) -> bool: return bool(self.unpack_one("b")) - def read_char(self): + def read_char(self) -> str: char = self.unpack_one("c") if self.encoding is not None: return char.decode(self.encoding, errors="replace") else: - return char + return char.decode(DEFAULT_ENCODING, errors="replace") - def read_cstring(self, charsize=1): + def read_cstring(self, charsize: int = 1) -> str: string = b"" while True: c = self.read(charsize) @@ -86,64 +98,65 @@ class ByteReader(): if self.encoding is not None: return string.decode(self.encoding, errors="replace") else: - return string + return string.decode(DEFAULT_ENCODING, errors="replace") -class ByteWriter(): - def __init__(self, stream, endian="=", encoding=None): - self.stream = stream - self.endian = endian - self.encoding = encoding +class ByteWriter: + def __init__(self, stream: io.BytesIO, endian: str = "=", encoding: Optional[str] = None) -> None: + self.stream: io.BytesIO = stream + self.endian: str = endian + self.encoding: Optional[str] = encoding - def write(self, *args): + def write(self, *args: bytes) -> int: return self.stream.write(*args) - def pack(self, fmt, *values): + def pack(self, fmt: str, *values: Any) -> int: fmt = self.endian + fmt - fmt_size = struct.calcsize(fmt) return self.stream.write(struct.pack(fmt, *values)) - def write_int8(self, val): + def write_int8(self, val: int) -> None: self.pack("b", val) - def write_uint8(self, val): + def write_uint8(self, val: int) -> None: self.pack("B", val) - def write_int16(self, val): + def write_int16(self, val: int) -> None: self.pack("h", val) - def write_uint16(self, val): + def write_uint16(self, val: int) -> None: self.pack("H", val) - def write_int32(self, val): + def write_int32(self, val: int) -> None: self.pack("l", val) - def write_uint32(self, val): + def write_uint32(self, val: int) -> None: self.pack("L", val) - def write_int64(self, val): + def write_int64(self, val: int) -> None: self.pack("q", val) - def write_uint64(self, val): + def write_uint64(self, val: int) -> None: self.pack("Q", val) - def write_float(self, val): + def write_float(self, val: float) -> None: self.pack("f", val) - def write_double(self, val): + def write_double(self, val: float) -> None: self.pack("d", val) - def write_bool(self, val): + def write_bool(self, val: bool) -> None: self.pack("b", val) - def write_char(self, val): + def write_char(self, val: str) -> None: if self.encoding is not None: self.pack("c", val.encode(self.encoding)) else: self.pack("c", val) - def write_cstring(self, val): + def write_cstring(self, val: Union[str, bytes]) -> None: if self.encoding is not None: + assert isinstance(val, str) self.write(val.encode(self.encoding) + b"\x00") else: + assert isinstance(val, bytes) self.write(val + b"\x00") diff --git a/a2s/datacls.py b/a2s/datacls.py index e8152cb..bc62b73 100644 --- a/a2s/datacls.py +++ b/a2s/datacls.py @@ -5,29 +5,37 @@ Check out the official documentation to see what this is trying to achieve: https://docs.python.org/3/library/dataclasses.html """ +from __future__ import annotations -import collections +from collections import OrderedDict import copy +from typing import Any, Generator, Tuple, TYPE_CHECKING, Dict + +if TYPE_CHECKING: + from typing_extensions import Self + class DataclsBase: - def __init__(self, **kwargs): + _defaults: "OrderedDict[str, Any]" + + def __init__(self, **kwargs: Any) -> None: for name, value in self._defaults.items(): if name in kwargs: value = kwargs[name] setattr(self, name, copy.copy(value)) - def __iter__(self): + def __iter__(self) -> Generator[Tuple[str, Any], None, None]: for name in self.__annotations__: yield (name, getattr(self, name)) - def __repr__(self): + def __repr__(self) -> str: return "{}({})".format( self.__class__.__name__, ", ".join(name + "=" + repr(value) for name, value in self)) class DataclsMeta(type): - def __new__(cls, name, bases, prop): - values = collections.OrderedDict() + def __new__(cls, name: str, bases: Tuple[type, ...], prop: Dict[str, Any]) -> Self: + values: OrderedDict[str, Any] = OrderedDict() for member_name in prop["__annotations__"].keys(): # Check if member has a default value set as class variable if member_name in prop: @@ -43,5 +51,5 @@ class DataclsMeta(type): bases = (DataclsBase, *bases) return super().__new__(cls, name, bases, prop) - def __prepare__(self, *args, **kwargs): - return collections.OrderedDict() + def __prepare__(self, *args: Any, **kwargs: Any) -> OrderedDict[str, Any]: # type: ignore # this is custom overriden + return OrderedDict() diff --git a/a2s/exceptions.py b/a2s/exceptions.py index e12d0c3..fad37c5 100644 --- a/a2s/exceptions.py +++ b/a2s/exceptions.py @@ -1,5 +1,6 @@ class BrokenMessageError(Exception): pass + class BufferExhaustedError(BrokenMessageError): pass diff --git a/a2s/info.py b/a2s/info.py index dd1dd02..c0487a3 100644 --- a/a2s/info.py +++ b/a2s/info.py @@ -1,204 +1,214 @@ -import io +from __future__ import annotations + +from typing import Optional, Tuple, Union -from a2s.exceptions import BrokenMessageError, BufferExhaustedError -from a2s.defaults import DEFAULT_TIMEOUT, DEFAULT_ENCODING -from a2s.a2s_sync import request_sync from a2s.a2s_async import request_async -from a2s.byteio import ByteReader +from a2s.a2s_sync import request_sync from a2s.datacls import DataclsMeta +from a2s.defaults import DEFAULT_ENCODING, DEFAULT_TIMEOUT +from a2s.exceptions import BufferExhaustedError - +from .byteio import ByteReader A2S_INFO_RESPONSE = 0x49 A2S_INFO_RESPONSE_LEGACY = 0x6D class SourceInfo(metaclass=DataclsMeta): - """Protocol version used by the server""" + protocol: int + """Protocol version used by the server""" - """Display name of the server""" server_name: str + """Display name of the server""" - """The currently loaded map""" map_name: str + """The currently loaded map""" - """Name of the game directory""" folder: str + """Name of the game directory""" - """Name of the game""" game: str + """Name of the game""" - """App ID of the game required to connect""" app_id: int + """App ID of the game required to connect""" - """Number of players currently connected""" player_count: int + """Number of players currently connected""" - """Number of player slots available""" max_players: int + """Number of player slots available""" - """Number of bots on the server""" bot_count: int + """Number of bots on the server""" + server_type: str """Type of the server: 'd': Dedicated server 'l': Non-dedicated server 'p': SourceTV relay (proxy)""" - server_type: str + platform: str """Operating system of the server 'l', 'w', 'm' for Linux, Windows, macOS""" - platform: str - """Server requires a password to connect""" password_protected: bool + """Server requires a password to connect""" - """Server has VAC enabled""" vac_enabled: bool + """Server has VAC enabled""" - """Version of the server software""" version: str + """Version of the server software""" # Optional: + edf: int = 0 """Extra data field, used to indicate if extra values are included in the response""" - edf: int = 0 - """Port of the game server.""" port: int + """Port of the game server.""" - """Steam ID of the server""" steam_id: int + """Steam ID of the server""" - """Port of the SourceTV server""" stv_port: int + """Port of the SourceTV server""" - """Name of the SourceTV server""" stv_name: str + """Name of the SourceTV server""" - """Tags that describe the gamemode being played""" keywords: str + """Tags that describe the gamemode being played""" - """Game ID for games that have an app ID too high for 16bit.""" game_id: int + """Game ID for games that have an app ID too high for 16bit.""" # Client determined values: - """Round-trip delay time for the request in seconds""" ping: float + """Round-trip delay time for the request in seconds""" @property - def has_port(self): + def has_port(self) -> bool: return bool(self.edf & 0x80) @property - def has_steam_id(self): + def has_steam_id(self) -> bool: return bool(self.edf & 0x10) @property - def has_stv(self): + def has_stv(self) -> bool: return bool(self.edf & 0x40) @property - def has_keywords(self): + def has_keywords(self) -> bool: return bool(self.edf & 0x20) @property - def has_game_id(self): + def has_game_id(self) -> bool: return bool(self.edf & 0x01) + class GoldSrcInfo(metaclass=DataclsMeta): - """IP Address and port of the server""" address: str + """IP Address and port of the server""" - """Display name of the server""" server_name: str + """Display name of the server""" - """The currently loaded map""" map_name: str + """The currently loaded map""" - """Name of the game directory""" folder: str + """Name of the game directory""" - """Name of the game""" game: str + """Name of the game""" - """Number of players currently connected""" player_count: int + """Number of players currently connected""" - """Number of player slots available""" max_players: int + """Number of player slots available""" - """Protocol version used by the server""" protocol: int + """Protocol version used by the server""" + server_type: str """Type of the server: 'd': Dedicated server 'l': Non-dedicated server 'p': SourceTV relay (proxy)""" - server_type: str + platform: str """Operating system of the server 'l', 'w' for Linux and Windows""" - platform: str - """Server requires a password to connect""" password_protected: bool + """Server requires a password to connect""" """Server is running a Half-Life mod instead of the base game""" is_mod: bool - """Server has VAC enabled""" vac_enabled: bool + """Server has VAC enabled""" - """Number of bots on the server""" bot_count: int + """Number of bots on the server""" # Optional: - """URL to the mod website""" mod_website: str + """URL to the mod website""" - """URL to download the mod""" mod_download: str + """URL to download the mod""" - """Version of the mod installed on the server""" mod_version: int + """Version of the mod installed on the server""" - """Size in bytes of the mod""" mod_size: int + """Size in bytes of the mod""" - """Mod supports multiplayer only""" multiplayer_only: bool = False + """Mod supports multiplayer only""" + uses_custom_dll: bool = True """Mod uses a custom DLL""" - uses_hl_dll: bool = True # Client determined values: - """Round-trip delay time for the request in seconds""" ping: float + """Round-trip delay time for the request in seconds""" -def info(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): +def info( + address: Tuple[str, int], timeout: float = DEFAULT_TIMEOUT, encoding: str = DEFAULT_ENCODING +) -> Union[SourceInfo, GoldSrcInfo]: return request_sync(address, timeout, encoding, InfoProtocol) -async def ainfo(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): + +async def ainfo( + address: Tuple[str, int], timeout: float = DEFAULT_TIMEOUT, encoding: str = DEFAULT_ENCODING +) -> Union[SourceInfo, GoldSrcInfo]: return await request_async(address, timeout, encoding, InfoProtocol) class InfoProtocol: @staticmethod - def validate_response_type(response_type): + def validate_response_type(response_type: int) -> bool: return response_type in (A2S_INFO_RESPONSE, A2S_INFO_RESPONSE_LEGACY) @staticmethod - def serialize_request(challenge): + def serialize_request(challenge: int) -> bytes: if challenge: return b"\x54Source Engine Query\0" + challenge.to_bytes(4, "little") else: return b"\x54Source Engine Query\0" @staticmethod - def deserialize_response(reader, response_type, ping): + def deserialize_response( + reader: ByteReader, response_type: int, ping: Optional[float] + ) -> Union[SourceInfo, GoldSrcInfo]: if response_type == A2S_INFO_RESPONSE: resp = parse_source(reader) elif response_type == A2S_INFO_RESPONSE_LEGACY: @@ -206,10 +216,12 @@ class InfoProtocol: else: raise Exception(str(response_type)) + assert ping resp.ping = ping return resp -def parse_source(reader): + +def parse_source(reader: ByteReader) -> SourceInfo: resp = SourceInfo() resp.protocol = reader.read_uint8() resp.server_name = reader.read_cstring() @@ -222,7 +234,7 @@ def parse_source(reader): resp.bot_count = reader.read_uint8() resp.server_type = reader.read_char().lower() resp.platform = reader.read_char().lower() - if resp.platform == "o": # Deprecated mac value + if resp.platform == "o": # Deprecated mac value resp.platform = "m" resp.password_protected = reader.read_bool() resp.vac_enabled = reader.read_bool() @@ -247,7 +259,8 @@ def parse_source(reader): return resp -def parse_goldsrc(reader): + +def parse_goldsrc(reader: ByteReader) -> GoldSrcInfo: resp = GoldSrcInfo() resp.address = reader.read_cstring() resp.server_name = reader.read_cstring() @@ -266,7 +279,7 @@ def parse_goldsrc(reader): if resp.is_mod and len(reader.peek()) > 2: resp.mod_website = reader.read_cstring() resp.mod_download = reader.read_cstring() - reader.read(1) # Skip a NULL byte + reader.read(1) # Skip a NULL byte resp.mod_version = reader.read_uint32() resp.mod_size = reader.read_uint32() resp.multiplayer_only = reader.read_bool() diff --git a/a2s/players.py b/a2s/players.py index 1100cf1..8bb23fd 100644 --- a/a2s/players.py +++ b/a2s/players.py @@ -1,18 +1,17 @@ -import io +from typing import List, Optional, Tuple -from a2s.defaults import DEFAULT_TIMEOUT, DEFAULT_ENCODING -from a2s.a2s_sync import request_sync from a2s.a2s_async import request_async +from a2s.a2s_sync import request_sync from a2s.byteio import ByteReader from a2s.datacls import DataclsMeta - - +from a2s.defaults import DEFAULT_ENCODING, DEFAULT_TIMEOUT A2S_PLAYER_RESPONSE = 0x44 class Player(metaclass=DataclsMeta): """Apparently an entry index, but seems to be always 0""" + index: int """Name of the player""" @@ -25,32 +24,35 @@ class Player(metaclass=DataclsMeta): duration: float -def players(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): +def players(address: Tuple[str, int], timeout: float = DEFAULT_TIMEOUT, encoding: str = DEFAULT_ENCODING) -> List[Player]: return request_sync(address, timeout, encoding, PlayersProtocol) -async def aplayers(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): + +async def aplayers( + address: Tuple[str, int], timeout: float = DEFAULT_TIMEOUT, encoding: str = DEFAULT_ENCODING +) -> List[Player]: return await request_async(address, timeout, encoding, PlayersProtocol) class PlayersProtocol: @staticmethod - def validate_response_type(response_type): + def validate_response_type(response_type: int) -> bool: return response_type == A2S_PLAYER_RESPONSE @staticmethod - def serialize_request(challenge): + def serialize_request(challenge: int) -> bytes: return b"\x55" + challenge.to_bytes(4, "little") @staticmethod - def deserialize_response(reader, response_type, ping): + def deserialize_response(reader: ByteReader, response_type: int, ping: Optional[float]) -> List[Player]: player_count = reader.read_uint8() resp = [ Player( index=reader.read_uint8(), name=reader.read_cstring(), score=reader.read_int32(), - duration=reader.read_float() + duration=reader.read_float(), ) - for player_num in range(player_count) + for _ in range(player_count) ] return resp diff --git a/a2s/py.typed b/a2s/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/a2s/rules.py b/a2s/rules.py index 1224f72..e2a7e09 100644 --- a/a2s/rules.py +++ b/a2s/rules.py @@ -1,38 +1,35 @@ -import io +from typing import Dict, Optional, Tuple -from a2s.defaults import DEFAULT_TIMEOUT, DEFAULT_ENCODING -from a2s.a2s_sync import request_sync from a2s.a2s_async import request_async +from a2s.a2s_sync import request_sync from a2s.byteio import ByteReader -from a2s.datacls import DataclsMeta - - +from a2s.defaults import DEFAULT_ENCODING, DEFAULT_TIMEOUT A2S_RULES_RESPONSE = 0x45 -def rules(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): +def rules(address: Tuple[str, int], timeout: float = DEFAULT_TIMEOUT, encoding: str = DEFAULT_ENCODING) -> Dict[str, str]: return request_sync(address, timeout, encoding, RulesProtocol) -async def arules(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): + +async def arules( + address: Tuple[str, int], timeout: float = DEFAULT_TIMEOUT, encoding: str = DEFAULT_ENCODING +) -> Dict[str, str]: return await request_async(address, timeout, encoding, RulesProtocol) class RulesProtocol: @staticmethod - def validate_response_type(response_type): + def validate_response_type(response_type: int) -> bool: return response_type == A2S_RULES_RESPONSE @staticmethod - def serialize_request(challenge): + def serialize_request(challenge: int) -> bytes: return b"\x56" + challenge.to_bytes(4, "little") @staticmethod - def deserialize_response(reader, response_type, ping): + def deserialize_response(reader: ByteReader, response_type: int, ping: Optional[float]) -> Dict[str, str]: rule_count = reader.read_int16() # Have to use tuples to preserve evaluation order - resp = dict( - (reader.read_cstring(), reader.read_cstring()) - for rule_num in range(rule_count) - ) + resp = dict((reader.read_cstring(), reader.read_cstring()) for _ in range(rule_count)) return resp diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ef0b3c5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,19 @@ +[tool.black] +line-length = 125 +target-version = ["py37"] + +[tool.isort] +profile = "black" +combine_as_imports = true +combine_star = true +line_length = 125 + +[tool.pyright] +include = ["a2s/**/*.py"] +useLibraryCodeForTypes = true +typeCheckingMode = "strict" +pythonVersion = "3.7" + +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 7fb2fb0..e2b6aab 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ setuptools.setup( "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Operating System :: OS Independent", - "Topic :: Games/Entertainment" + "Topic :: Games/Entertainment", ], - python_requires=">=3.7" + python_requires=">=3.7", )