From 06a35dd0132b5c60cdd9335c08afe168af3fe1bd Mon Sep 17 00:00:00 2001 From: Richard Neumann Date: Sun, 14 Aug 2022 23:43:21 +0200 Subject: [PATCH] Add hack to read follow-up packets --- rcon/source/client.py | 32 +++++++++++++++++++++++++++++++- rcon/source/proto.py | 19 ++++++++++++------- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/rcon/source/client.py b/rcon/source/client.py index 7e992ea..5d58b24 100644 --- a/rcon/source/client.py +++ b/rcon/source/client.py @@ -23,7 +23,20 @@ class Client(BaseClient, socket_type=SOCK_STREAM): def read(self) -> Packet: """Read a packet.""" with self._socket.makefile('rb') as file: - return Packet.read(file, max_pkg_size=self.max_pkg_size) + packet = Packet.read(file) + + if self.max_pkg_size and len(packet.payload) >= self.max_pkg_size: + return packet + self.read_followup_packet() + + return packet + + def read_followup_packet(self) -> Packet | None: + """Reads a potential followup packet.""" + with ChangedTimeout(self, 1) as client: + try: + return client.read() + except TimeoutError: + return None def login(self, passwd: str, *, encoding: str = 'utf-8') -> bool: """Perform a login.""" @@ -49,3 +62,20 @@ class Client(BaseClient, socket_type=SOCK_STREAM): raise SessionTimeout() return response.payload.decode(encoding) + + +class ChangedTimeout: + """Context manager to temporarily change a client's timeout.""" + + def __init__(self, client: Client, timeout: int | None): + self.client = client + self.timeout = timeout + self.original_timeout = None + + def __enter__(self) -> Client: + self.original_timeout = self.client.timeout + self.client.timeout = self.timeout + return self.client + + def __exit__(self, exc_type, exc_val, exc_tb): + self.client.timeout = self.original_timeout diff --git a/rcon/source/proto.py b/rcon/source/proto.py index f14567e..537fd66 100644 --- a/rcon/source/proto.py +++ b/rcon/source/proto.py @@ -2,7 +2,6 @@ from __future__ import annotations from asyncio import StreamReader -from contextlib import suppress from enum import Enum from functools import partial from logging import getLogger @@ -80,6 +79,17 @@ class Packet(NamedTuple): payload: bytes terminator: bytes = TERMINATOR + def __add__(self, other: Packet): + return Packet( + self.id, + self.type, + self.payload + other.payload, + self.terminator + ) + + def __radd__(self, other: Packet): + return other.__add__(self) + def __bytes__(self): """Return the packet as bytes with prepended length.""" payload = bytes(self.id) @@ -104,7 +114,7 @@ class Packet(NamedTuple): return cls(id_, type_, payload, terminator) @classmethod - def read(cls, file: IO, *, max_pkg_size: int | None = None) -> Packet: + def read(cls, file: IO) -> Packet: """Read a packet from a file-like object.""" size = LittleEndianSignedInt32.read(file) id_ = LittleEndianSignedInt32.read(file) @@ -115,11 +125,6 @@ class Packet(NamedTuple): if terminator != TERMINATOR: LOGGER.warning('Unexpected terminator: %s', terminator) - # Attempt to read following packets on large responses. - if size >= max_pkg_size: - with suppress(TimeoutError): - payload += cls.read(file, max_pkg_size=max_pkg_size).payload - return cls(id_, type_, payload, terminator) @classmethod