Browse Source

Add param to raise error on Unexpected terminator (#31)

* add param to raise error on Unexpected terminator
* add new exceptions
pull/33/head
Alex 11 months ago
committed by GitHub
parent
commit
b5ca7c2186
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 4
      rcon/exceptions.py
  2. 10
      rcon/source/async_rcon.py
  3. 20
      rcon/source/client.py
  4. 12
      rcon/source/proto.py

4
rcon/exceptions.py

@ -27,3 +27,7 @@ class UserAbort(Exception):
class WrongPassword(Exception): class WrongPassword(Exception):
"""Indicates a wrong password.""" """Indicates a wrong password."""
class UnexpectedTerminator(Exception):
"""Indicates an unexpected terminator in the response."""

10
rcon/source/async_rcon.py

@ -23,12 +23,13 @@ async def communicate(
*, *,
frag_threshold: int = 4096, frag_threshold: int = 4096,
frag_detect_cmd: str = "", frag_detect_cmd: str = "",
raise_unexpected_terminator: bool = False,
) -> Packet: ) -> Packet:
"""Make an asynchronous request.""" """Make an asynchronous request."""
writer.write(bytes(packet)) writer.write(bytes(packet))
await writer.drain() await writer.drain()
response = await Packet.aread(reader) response = await Packet.aread(reader, raise_unexpected_terminator)
if len(response.payload) < frag_threshold: if len(response.payload) < frag_threshold:
return response return response
@ -36,7 +37,7 @@ async def communicate(
writer.write(bytes(Packet.make_command(frag_detect_cmd))) writer.write(bytes(Packet.make_command(frag_detect_cmd)))
await writer.drain() await writer.drain()
while (successor := await Packet.aread(reader)).id == response.id: while (successor := await Packet.aread(reader, raise_unexpected_terminator)).id == response.id:
response += successor response += successor
return response return response
@ -53,6 +54,7 @@ async def rcon(
frag_detect_cmd: str = "", frag_detect_cmd: str = "",
timeout: int | None = None, timeout: int | None = None,
enforce_id: bool = True, enforce_id: bool = True,
raise_unexpected_terminator: bool = False,
) -> str: ) -> str:
"""Run a command asynchronously.""" """Run a command asynchronously."""
@ -68,14 +70,14 @@ async def rcon(
# Wait for SERVERDATA_AUTH_RESPONSE according to: # Wait for SERVERDATA_AUTH_RESPONSE according to:
# https://developer.valvesoftware.com/wiki/Source_RCON_Protocol # https://developer.valvesoftware.com/wiki/Source_RCON_Protocol
while response.type != Type.SERVERDATA_AUTH_RESPONSE: while response.type != Type.SERVERDATA_AUTH_RESPONSE:
response = await Packet.aread(reader) response = await Packet.aread(reader, raise_unexpected_terminator)
if response.id == -1: if response.id == -1:
await close(writer) await close(writer)
raise WrongPassword() raise WrongPassword()
request = Packet.make_command(command, *arguments, encoding=encoding) request = Packet.make_command(command, *arguments, encoding=encoding)
response = await communicate(reader, writer, request) response = await communicate(reader, writer, request, raise_unexpected_terminator)
await close(writer) await close(writer)
if enforce_id and response.id != request.id: if enforce_id and response.id != request.id:

20
rcon/source/client.py

@ -6,7 +6,6 @@ from rcon.client import BaseClient
from rcon.exceptions import SessionTimeout, WrongPassword from rcon.exceptions import SessionTimeout, WrongPassword
from rcon.source.proto import Packet, Type from rcon.source.proto import Packet, Type
__all__ = ["Client"] __all__ = ["Client"]
@ -25,20 +24,22 @@ class Client(BaseClient, socket_type=SOCK_STREAM):
self.frag_threshold = frag_threshold self.frag_threshold = frag_threshold
self.frag_detect_cmd = frag_detect_cmd self.frag_detect_cmd = frag_detect_cmd
def communicate(self, packet: Packet) -> Packet: def communicate(
self, packet: Packet, raise_unexpected_terminator: bool = False
) -> Packet:
"""Send and receive a packet.""" """Send and receive a packet."""
self.send(packet) self.send(packet)
return self.read() return self.read(raise_unexpected_terminator)
def send(self, packet: Packet) -> None: def send(self, packet: Packet) -> None:
"""Send a packet to the server.""" """Send a packet to the server."""
with self._socket.makefile("wb") as file: with self._socket.makefile("wb") as file:
file.write(bytes(packet)) file.write(bytes(packet))
def read(self) -> Packet: def read(self, raise_unexpected_terminator: bool = False) -> Packet:
"""Read a packet from the server.""" """Read a packet from the server."""
with self._socket.makefile("rb") as file: with self._socket.makefile("rb") as file:
response = Packet.read(file) response = Packet.read(file, raise_unexpected_terminator)
if len(response.payload) < self.frag_threshold: if len(response.payload) < self.frag_threshold:
return response return response
@ -65,11 +66,16 @@ class Client(BaseClient, socket_type=SOCK_STREAM):
return True return True
def run( def run(
self, command: str, *args: str, encoding: str = "utf-8", enforce_id: bool = True self,
command: str,
*args: str,
encoding: str = "utf-8",
enforce_id: bool = True,
raise_unexpected_terminator: bool = False,
) -> str: ) -> str:
"""Run a command.""" """Run a command."""
request = Packet.make_command(command, *args, encoding=encoding) request = Packet.make_command(command, *args, encoding=encoding)
response = self.communicate(request) response = self.communicate(request, raise_unexpected_terminator)
if enforce_id and response.id != request.id: if enforce_id and response.id != request.id:
raise SessionTimeout("packet ID mismatch") raise SessionTimeout("packet ID mismatch")

12
rcon/source/proto.py

@ -8,7 +8,7 @@ from logging import getLogger
from random import randint from random import randint
from typing import IO, NamedTuple from typing import IO, NamedTuple
from rcon.exceptions import EmptyResponse from rcon.exceptions import EmptyResponse, UnexpectedTerminator
__all__ = ["LittleEndianSignedInt32", "Type", "Packet", "random_request_id"] __all__ = ["LittleEndianSignedInt32", "Type", "Packet", "random_request_id"]
@ -118,7 +118,9 @@ class Packet(NamedTuple):
return size + payload return size + payload
@classmethod @classmethod
async def aread(cls, reader: StreamReader) -> Packet: async def aread(
cls, reader: StreamReader, raise_unexpected_terminator: bool = False
) -> Packet:
"""Read a packet from an asynchronous file-like object.""" """Read a packet from an asynchronous file-like object."""
LOGGER.debug("Reading packet asynchronously.") LOGGER.debug("Reading packet asynchronously.")
size = await LittleEndianSignedInt32.aread(reader) size = await LittleEndianSignedInt32.aread(reader)
@ -137,12 +139,14 @@ class Packet(NamedTuple):
LOGGER.debug(" => terminator: %s", terminator) LOGGER.debug(" => terminator: %s", terminator)
if terminator != TERMINATOR: if terminator != TERMINATOR:
if raise_unexpected_terminator:
raise UnexpectedTerminator(terminator)
LOGGER.warning("Unexpected terminator: %s", terminator) LOGGER.warning("Unexpected terminator: %s", terminator)
return cls(id_, type_, payload, terminator) return cls(id_, type_, payload, terminator)
@classmethod @classmethod
def read(cls, file: IO) -> Packet: def read(cls, file: IO, raise_unexpected_terminator: bool = False) -> Packet:
"""Read a packet from a file-like object.""" """Read a packet from a file-like object."""
LOGGER.debug("Reading packet.") LOGGER.debug("Reading packet.")
size = LittleEndianSignedInt32.read(file) size = LittleEndianSignedInt32.read(file)
@ -161,6 +165,8 @@ class Packet(NamedTuple):
LOGGER.debug(" => terminator: %s", terminator) LOGGER.debug(" => terminator: %s", terminator)
if terminator != TERMINATOR: if terminator != TERMINATOR:
if raise_unexpected_terminator:
raise UnexpectedTerminator(terminator)
LOGGER.warning("Unexpected terminator: %s", terminator) LOGGER.warning("Unexpected terminator: %s", terminator)
return cls(id_, type_, payload, terminator) return cls(id_, type_, payload, terminator)

Loading…
Cancel
Save