Browse Source

add param to raise error on Unexpected terminator

pull/31/head
Vert 11 months ago
parent
commit
4d03a9cb50
  1. 2
      rcon/client.py
  2. 10
      rcon/source/async_rcon.py
  3. 20
      rcon/source/client.py
  4. 8
      rcon/source/proto.py

2
rcon/client.py

@ -39,7 +39,7 @@ class BaseClient:
return self._socket.__exit__(typ, value, traceback) return self._socket.__exit__(typ, value, traceback)
@property @property
def timeout(self) -> float: def timeout(self) -> float | None:
"""Return the socket timeout.""" """Return the socket timeout."""
return self._socket.gettimeout() return self._socket.gettimeout()

10
rcon/source/async_rcon.py

@ -23,12 +23,13 @@ async def communicate(
*, *,
frag_threshold: int = 4096, frag_threshold: int = 4096,
frag_detect_cmd: str = "", frag_detect_cmd: str = "",
raise_unexpected_terminator: bool = False,
) -> Packet: ) -> Packet:
"""Make an asynchronous request.""" """Make an asynchronous request."""
writer.write(bytes(packet)) writer.write(bytes(packet))
await writer.drain() await writer.drain()
response = await Packet.aread(reader) response = await Packet.aread(reader, raise_unexpected_terminator)
if len(response.payload) < frag_threshold: if len(response.payload) < frag_threshold:
return response return response
@ -36,7 +37,7 @@ async def communicate(
writer.write(bytes(Packet.make_command(frag_detect_cmd))) writer.write(bytes(Packet.make_command(frag_detect_cmd)))
await writer.drain() await writer.drain()
while (successor := await Packet.aread(reader)).id == response.id: while (successor := await Packet.aread(reader, raise_unexpected_terminator)).id == response.id:
response += successor response += successor
return response return response
@ -53,6 +54,7 @@ async def rcon(
frag_detect_cmd: str = "", frag_detect_cmd: str = "",
timeout: int | None = None, timeout: int | None = None,
enforce_id: bool = True, enforce_id: bool = True,
raise_unexpected_terminator: bool = False,
) -> str: ) -> str:
"""Run a command asynchronously.""" """Run a command asynchronously."""
@ -68,14 +70,14 @@ async def rcon(
# Wait for SERVERDATA_AUTH_RESPONSE according to: # Wait for SERVERDATA_AUTH_RESPONSE according to:
# https://developer.valvesoftware.com/wiki/Source_RCON_Protocol # https://developer.valvesoftware.com/wiki/Source_RCON_Protocol
while response.type != Type.SERVERDATA_AUTH_RESPONSE: while response.type != Type.SERVERDATA_AUTH_RESPONSE:
response = await Packet.aread(reader) response = await Packet.aread(reader, raise_unexpected_terminator)
if response.id == -1: if response.id == -1:
await close(writer) await close(writer)
raise WrongPassword() raise WrongPassword()
request = Packet.make_command(command, *arguments, encoding=encoding) request = Packet.make_command(command, *arguments, encoding=encoding)
response = await communicate(reader, writer, request) response = await communicate(reader, writer, request, raise_unexpected_terminator)
await close(writer) await close(writer)
if enforce_id and response.id != request.id: if enforce_id and response.id != request.id:

20
rcon/source/client.py

@ -6,7 +6,6 @@ from rcon.client import BaseClient
from rcon.exceptions import SessionTimeout, WrongPassword from rcon.exceptions import SessionTimeout, WrongPassword
from rcon.source.proto import Packet, Type from rcon.source.proto import Packet, Type
__all__ = ["Client"] __all__ = ["Client"]
@ -25,20 +24,22 @@ class Client(BaseClient, socket_type=SOCK_STREAM):
self.frag_threshold = frag_threshold self.frag_threshold = frag_threshold
self.frag_detect_cmd = frag_detect_cmd self.frag_detect_cmd = frag_detect_cmd
def communicate(self, packet: Packet) -> Packet: def communicate(
self, packet: Packet, raise_unexpected_terminator: bool = False
) -> Packet:
"""Send and receive a packet.""" """Send and receive a packet."""
self.send(packet) self.send(packet)
return self.read() return self.read(raise_unexpected_terminator)
def send(self, packet: Packet) -> None: def send(self, packet: Packet) -> None:
"""Send a packet to the server.""" """Send a packet to the server."""
with self._socket.makefile("wb") as file: with self._socket.makefile("wb") as file:
file.write(bytes(packet)) file.write(bytes(packet))
def read(self) -> Packet: def read(self, raise_unexpected_terminator: bool = False) -> Packet:
"""Read a packet from the server.""" """Read a packet from the server."""
with self._socket.makefile("rb") as file: with self._socket.makefile("rb") as file:
response = Packet.read(file) response = Packet.read(file, raise_unexpected_terminator)
if len(response.payload) < self.frag_threshold: if len(response.payload) < self.frag_threshold:
return response return response
@ -65,11 +66,16 @@ class Client(BaseClient, socket_type=SOCK_STREAM):
return True return True
def run( def run(
self, command: str, *args: str, encoding: str = "utf-8", enforce_id: bool = True self,
command: str,
*args: str,
encoding: str = "utf-8",
enforce_id: bool = True,
raise_unexpected_terminator: bool = False,
) -> str: ) -> str:
"""Run a command.""" """Run a command."""
request = Packet.make_command(command, *args, encoding=encoding) request = Packet.make_command(command, *args, encoding=encoding)
response = self.communicate(request) response = self.communicate(request, raise_unexpected_terminator)
if enforce_id and response.id != request.id: if enforce_id and response.id != request.id:
raise SessionTimeout("packet ID mismatch") raise SessionTimeout("packet ID mismatch")

8
rcon/source/proto.py

@ -118,7 +118,7 @@ class Packet(NamedTuple):
return size + payload return size + payload
@classmethod @classmethod
async def aread(cls, reader: StreamReader) -> Packet: async def aread(cls, reader: StreamReader, raise_unexpected_terminator: bool = False) -> Packet:
"""Read a packet from an asynchronous file-like object.""" """Read a packet from an asynchronous file-like object."""
LOGGER.debug("Reading packet asynchronously.") LOGGER.debug("Reading packet asynchronously.")
size = await LittleEndianSignedInt32.aread(reader) size = await LittleEndianSignedInt32.aread(reader)
@ -137,12 +137,14 @@ class Packet(NamedTuple):
LOGGER.debug(" => terminator: %s", terminator) LOGGER.debug(" => terminator: %s", terminator)
if terminator != TERMINATOR: if terminator != TERMINATOR:
if raise_unexpected_terminator:
raise ValueError("Unexpected terminator")
LOGGER.warning("Unexpected terminator: %s", terminator) LOGGER.warning("Unexpected terminator: %s", terminator)
return cls(id_, type_, payload, terminator) return cls(id_, type_, payload, terminator)
@classmethod @classmethod
def read(cls, file: IO) -> Packet: def read(cls, file: IO, raise_unexpected_terminator: bool = False) -> Packet:
"""Read a packet from a file-like object.""" """Read a packet from a file-like object."""
LOGGER.debug("Reading packet.") LOGGER.debug("Reading packet.")
size = LittleEndianSignedInt32.read(file) size = LittleEndianSignedInt32.read(file)
@ -161,6 +163,8 @@ class Packet(NamedTuple):
LOGGER.debug(" => terminator: %s", terminator) LOGGER.debug(" => terminator: %s", terminator)
if terminator != TERMINATOR: if terminator != TERMINATOR:
if raise_unexpected_terminator:
raise ValueError("Unexpected terminator")
LOGGER.warning("Unexpected terminator: %s", terminator) LOGGER.warning("Unexpected terminator: %s", terminator)
return cls(id_, type_, payload, terminator) return cls(id_, type_, payload, terminator)

Loading…
Cancel
Save