Browse Source

add param to raise error on Unexpected terminator

pull/31/head
Vert 11 months ago
parent
commit
4d03a9cb50
  1. 2
      rcon/client.py
  2. 10
      rcon/source/async_rcon.py
  3. 20
      rcon/source/client.py
  4. 8
      rcon/source/proto.py

2
rcon/client.py

@ -39,7 +39,7 @@ class BaseClient:
return self._socket.__exit__(typ, value, traceback)
@property
def timeout(self) -> float:
def timeout(self) -> float | None:
"""Return the socket timeout."""
return self._socket.gettimeout()

10
rcon/source/async_rcon.py

@ -23,12 +23,13 @@ async def communicate(
*,
frag_threshold: int = 4096,
frag_detect_cmd: str = "",
raise_unexpected_terminator: bool = False,
) -> Packet:
"""Make an asynchronous request."""
writer.write(bytes(packet))
await writer.drain()
response = await Packet.aread(reader)
response = await Packet.aread(reader, raise_unexpected_terminator)
if len(response.payload) < frag_threshold:
return response
@ -36,7 +37,7 @@ async def communicate(
writer.write(bytes(Packet.make_command(frag_detect_cmd)))
await writer.drain()
while (successor := await Packet.aread(reader)).id == response.id:
while (successor := await Packet.aread(reader, raise_unexpected_terminator)).id == response.id:
response += successor
return response
@ -53,6 +54,7 @@ async def rcon(
frag_detect_cmd: str = "",
timeout: int | None = None,
enforce_id: bool = True,
raise_unexpected_terminator: bool = False,
) -> str:
"""Run a command asynchronously."""
@ -68,14 +70,14 @@ async def rcon(
# Wait for SERVERDATA_AUTH_RESPONSE according to:
# https://developer.valvesoftware.com/wiki/Source_RCON_Protocol
while response.type != Type.SERVERDATA_AUTH_RESPONSE:
response = await Packet.aread(reader)
response = await Packet.aread(reader, raise_unexpected_terminator)
if response.id == -1:
await close(writer)
raise WrongPassword()
request = Packet.make_command(command, *arguments, encoding=encoding)
response = await communicate(reader, writer, request)
response = await communicate(reader, writer, request, raise_unexpected_terminator)
await close(writer)
if enforce_id and response.id != request.id:

20
rcon/source/client.py

@ -6,7 +6,6 @@ from rcon.client import BaseClient
from rcon.exceptions import SessionTimeout, WrongPassword
from rcon.source.proto import Packet, Type
__all__ = ["Client"]
@ -25,20 +24,22 @@ class Client(BaseClient, socket_type=SOCK_STREAM):
self.frag_threshold = frag_threshold
self.frag_detect_cmd = frag_detect_cmd
def communicate(self, packet: Packet) -> Packet:
def communicate(
self, packet: Packet, raise_unexpected_terminator: bool = False
) -> Packet:
"""Send and receive a packet."""
self.send(packet)
return self.read()
return self.read(raise_unexpected_terminator)
def send(self, packet: Packet) -> None:
"""Send a packet to the server."""
with self._socket.makefile("wb") as file:
file.write(bytes(packet))
def read(self) -> Packet:
def read(self, raise_unexpected_terminator: bool = False) -> Packet:
"""Read a packet from the server."""
with self._socket.makefile("rb") as file:
response = Packet.read(file)
response = Packet.read(file, raise_unexpected_terminator)
if len(response.payload) < self.frag_threshold:
return response
@ -65,11 +66,16 @@ class Client(BaseClient, socket_type=SOCK_STREAM):
return True
def run(
self, command: str, *args: str, encoding: str = "utf-8", enforce_id: bool = True
self,
command: str,
*args: str,
encoding: str = "utf-8",
enforce_id: bool = True,
raise_unexpected_terminator: bool = False,
) -> str:
"""Run a command."""
request = Packet.make_command(command, *args, encoding=encoding)
response = self.communicate(request)
response = self.communicate(request, raise_unexpected_terminator)
if enforce_id and response.id != request.id:
raise SessionTimeout("packet ID mismatch")

8
rcon/source/proto.py

@ -118,7 +118,7 @@ class Packet(NamedTuple):
return size + payload
@classmethod
async def aread(cls, reader: StreamReader) -> Packet:
async def aread(cls, reader: StreamReader, raise_unexpected_terminator: bool = False) -> Packet:
"""Read a packet from an asynchronous file-like object."""
LOGGER.debug("Reading packet asynchronously.")
size = await LittleEndianSignedInt32.aread(reader)
@ -137,12 +137,14 @@ class Packet(NamedTuple):
LOGGER.debug(" => terminator: %s", terminator)
if terminator != TERMINATOR:
if raise_unexpected_terminator:
raise ValueError("Unexpected terminator")
LOGGER.warning("Unexpected terminator: %s", terminator)
return cls(id_, type_, payload, terminator)
@classmethod
def read(cls, file: IO) -> Packet:
def read(cls, file: IO, raise_unexpected_terminator: bool = False) -> Packet:
"""Read a packet from a file-like object."""
LOGGER.debug("Reading packet.")
size = LittleEndianSignedInt32.read(file)
@ -161,6 +163,8 @@ class Packet(NamedTuple):
LOGGER.debug(" => terminator: %s", terminator)
if terminator != TERMINATOR:
if raise_unexpected_terminator:
raise ValueError("Unexpected terminator")
LOGGER.warning("Unexpected terminator: %s", terminator)
return cls(id_, type_, payload, terminator)

Loading…
Cancel
Save