diff --git a/rcon/source/client.py b/rcon/source/client.py index 4596fc8..dccc817 100644 --- a/rcon/source/client.py +++ b/rcon/source/client.py @@ -1,7 +1,6 @@ """Synchronous client.""" from socket import SOCK_STREAM -from typing import IO from rcon.client import BaseClient from rcon.exceptions import SessionTimeout, WrongPassword @@ -14,33 +13,45 @@ __all__ = ['Client'] class Client(BaseClient, socket_type=SOCK_STREAM): """An RCON client.""" + _frag_detect: str | None = None + + def __init_subclass__( + cls, + *args, + frag_detect: str | None = None, + **kwargs + ): + """Set an optional fragmentation command + in order to detect fragmented packets. + + See: https://wiki.vg/RCON#Fragmentation + """ + super().__init_subclass__(*args, **kwargs) + + if frag_detect is not None: + cls._frag_detect = frag_detect + def communicate(self, packet: Packet) -> Packet: """Send and receive a packet.""" with self._socket.makefile('wb') as file: file.write(bytes(packet)) + if self._frag_detect is not None: + with self._socket.makefile('wb') as file: + file.write(bytes(Packet.make_command(self._frag_detect))) + return self.read() def read(self) -> Packet: """Read a packet.""" with self._socket.makefile('rb') as file: - return self._read_from_file(file) - - def _read_from_file(self, file: IO) -> Packet: - """Read a packet from the given file.""" - packet = Packet.read(file) + packet = Packet.read(file) - if self._max_packet_size is None: - return packet + if self._frag_detect is not None: + while (successor := Packet.read(file)).id == packet.id: + packet += successor - if len(packet.payload) < self._max_packet_size: - return packet - - with ChangedTimeout(self, 1): - try: - return packet + Packet.read(file) - except TimeoutError: - return packet + return packet def login(self, passwd: str, *, encoding: str = 'utf-8') -> bool: """Perform a login.""" @@ -66,19 +77,3 @@ class Client(BaseClient, socket_type=SOCK_STREAM): raise SessionTimeout() return response.payload.decode(encoding) - - -class ChangedTimeout: - """Context manager to temporarily change a client's timeout.""" - - def __init__(self, client: Client, timeout: int | None): - self.client = client - self.timeout = timeout - self.original_timeout = None - - def __enter__(self) -> None: - self.original_timeout = self.client.timeout - self.client.timeout = self.timeout - - def __exit__(self, exc_type, exc_val, exc_tb): - self.client.timeout = self.original_timeout