From 4d03a9cb50e0072a0270408bb2cc254722437ca5 Mon Sep 17 00:00:00 2001 From: Vert Date: Sun, 12 May 2024 11:38:37 -0400 Subject: [PATCH] add param to raise error on Unexpected terminator --- rcon/client.py | 2 +- rcon/source/async_rcon.py | 10 ++++++---- rcon/source/client.py | 20 +++++++++++++------- rcon/source/proto.py | 8 ++++++-- 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/rcon/client.py b/rcon/client.py index 6296f5a..73846df 100644 --- a/rcon/client.py +++ b/rcon/client.py @@ -39,7 +39,7 @@ class BaseClient: return self._socket.__exit__(typ, value, traceback) @property - def timeout(self) -> float: + def timeout(self) -> float | None: """Return the socket timeout.""" return self._socket.gettimeout() diff --git a/rcon/source/async_rcon.py b/rcon/source/async_rcon.py index cfa9428..095873a 100644 --- a/rcon/source/async_rcon.py +++ b/rcon/source/async_rcon.py @@ -23,12 +23,13 @@ async def communicate( *, frag_threshold: int = 4096, frag_detect_cmd: str = "", + raise_unexpected_terminator: bool = False, ) -> Packet: """Make an asynchronous request.""" writer.write(bytes(packet)) await writer.drain() - response = await Packet.aread(reader) + response = await Packet.aread(reader, raise_unexpected_terminator) if len(response.payload) < frag_threshold: return response @@ -36,7 +37,7 @@ async def communicate( writer.write(bytes(Packet.make_command(frag_detect_cmd))) await writer.drain() - while (successor := await Packet.aread(reader)).id == response.id: + while (successor := await Packet.aread(reader, raise_unexpected_terminator)).id == response.id: response += successor return response @@ -53,6 +54,7 @@ async def rcon( frag_detect_cmd: str = "", timeout: int | None = None, enforce_id: bool = True, + raise_unexpected_terminator: bool = False, ) -> str: """Run a command asynchronously.""" @@ -68,14 +70,14 @@ async def rcon( # Wait for SERVERDATA_AUTH_RESPONSE according to: # https://developer.valvesoftware.com/wiki/Source_RCON_Protocol while response.type != Type.SERVERDATA_AUTH_RESPONSE: - response = await Packet.aread(reader) + response = await Packet.aread(reader, raise_unexpected_terminator) if response.id == -1: await close(writer) raise WrongPassword() request = Packet.make_command(command, *arguments, encoding=encoding) - response = await communicate(reader, writer, request) + response = await communicate(reader, writer, request, raise_unexpected_terminator) await close(writer) if enforce_id and response.id != request.id: diff --git a/rcon/source/client.py b/rcon/source/client.py index 44a2129..3ec7d14 100644 --- a/rcon/source/client.py +++ b/rcon/source/client.py @@ -6,7 +6,6 @@ from rcon.client import BaseClient from rcon.exceptions import SessionTimeout, WrongPassword from rcon.source.proto import Packet, Type - __all__ = ["Client"] @@ -25,20 +24,22 @@ class Client(BaseClient, socket_type=SOCK_STREAM): self.frag_threshold = frag_threshold self.frag_detect_cmd = frag_detect_cmd - def communicate(self, packet: Packet) -> Packet: + def communicate( + self, packet: Packet, raise_unexpected_terminator: bool = False + ) -> Packet: """Send and receive a packet.""" self.send(packet) - return self.read() + return self.read(raise_unexpected_terminator) def send(self, packet: Packet) -> None: """Send a packet to the server.""" with self._socket.makefile("wb") as file: file.write(bytes(packet)) - def read(self) -> Packet: + def read(self, raise_unexpected_terminator: bool = False) -> Packet: """Read a packet from the server.""" with self._socket.makefile("rb") as file: - response = Packet.read(file) + response = Packet.read(file, raise_unexpected_terminator) if len(response.payload) < self.frag_threshold: return response @@ -65,11 +66,16 @@ class Client(BaseClient, socket_type=SOCK_STREAM): return True def run( - self, command: str, *args: str, encoding: str = "utf-8", enforce_id: bool = True + self, + command: str, + *args: str, + encoding: str = "utf-8", + enforce_id: bool = True, + raise_unexpected_terminator: bool = False, ) -> str: """Run a command.""" request = Packet.make_command(command, *args, encoding=encoding) - response = self.communicate(request) + response = self.communicate(request, raise_unexpected_terminator) if enforce_id and response.id != request.id: raise SessionTimeout("packet ID mismatch") diff --git a/rcon/source/proto.py b/rcon/source/proto.py index fa2312f..dc0c1d2 100644 --- a/rcon/source/proto.py +++ b/rcon/source/proto.py @@ -118,7 +118,7 @@ class Packet(NamedTuple): return size + payload @classmethod - async def aread(cls, reader: StreamReader) -> Packet: + async def aread(cls, reader: StreamReader, raise_unexpected_terminator: bool = False) -> Packet: """Read a packet from an asynchronous file-like object.""" LOGGER.debug("Reading packet asynchronously.") size = await LittleEndianSignedInt32.aread(reader) @@ -137,12 +137,14 @@ class Packet(NamedTuple): LOGGER.debug(" => terminator: %s", terminator) if terminator != TERMINATOR: + if raise_unexpected_terminator: + raise ValueError("Unexpected terminator") LOGGER.warning("Unexpected terminator: %s", terminator) return cls(id_, type_, payload, terminator) @classmethod - def read(cls, file: IO) -> Packet: + def read(cls, file: IO, raise_unexpected_terminator: bool = False) -> Packet: """Read a packet from a file-like object.""" LOGGER.debug("Reading packet.") size = LittleEndianSignedInt32.read(file) @@ -161,6 +163,8 @@ class Packet(NamedTuple): LOGGER.debug(" => terminator: %s", terminator) if terminator != TERMINATOR: + if raise_unexpected_terminator: + raise ValueError("Unexpected terminator") LOGGER.warning("Unexpected terminator: %s", terminator) return cls(id_, type_, payload, terminator)