diff --git a/rcon/proto.py b/rcon/proto.py index 7056bf7..273dde6 100644 --- a/rcon/proto.py +++ b/rcon/proto.py @@ -5,7 +5,7 @@ from enum import Enum from logging import getLogger from random import randint from socket import SOCK_STREAM, socket -from typing import NamedTuple, Optional +from typing import IO, NamedTuple, Optional from rcon.exceptions import RequestIdMismatch from rcon.exceptions import WrongPassword @@ -48,18 +48,18 @@ class LittleEndianSignedInt32(int): return self.to_bytes(4, 'little', signed=True) @classmethod - def from_bytes(cls, bytes_: bytes) -> LittleEndianSignedInt32: + def read(cls, file: IO) -> LittleEndianSignedInt32: """Creates the integer from the given bytes.""" - return super().from_bytes(bytes_, 'little', signed=True) + return super().from_bytes(file.read(4), 'little', signed=True) class Type(Enum): """RCON packet types.""" - SERVERDATA_AUTH = 3 - SERVERDATA_AUTH_RESPONSE = 2 - SERVERDATA_EXECCOMMAND = 2 - SERVERDATA_RESPONSE_VALUE = 0 + SERVERDATA_AUTH = LittleEndianSignedInt32(3) + SERVERDATA_AUTH_RESPONSE = LittleEndianSignedInt32(2) + SERVERDATA_EXECCOMMAND = LittleEndianSignedInt32(2) + SERVERDATA_RESPONSE_VALUE = LittleEndianSignedInt32(0) def __int__(self): """Returns the actual integer value.""" @@ -70,9 +70,9 @@ class Type(Enum): return int(self).to_bytes(4, 'little', signed=True) @classmethod - def from_bytes(cls, bytes_: bytes) -> Type: + def read(cls, file: IO) -> Type: """Creates a type from the given bytes.""" - return cls(int.from_bytes(bytes_, 'little', signed=True)) + return cls(LittleEndianSignedInt32.read(file)) class Packet(NamedTuple): @@ -89,29 +89,30 @@ class Packet(NamedTuple): payload += bytes(self.type) payload += self.payload.encode() payload += self.terminator.encode() - size = len(payload).to_bytes(4, 'little', signed=True) + size = bytes(LittleEndianSignedInt32(len(payload))) return size + payload @classmethod - def from_bytes(cls, bytes_: bytes) -> Packet: - """Creates a packet from the respective bytes.""" - id_ = LittleEndianSignedInt32.from_bytes(bytes_[:4]) - type_ = Type.from_bytes(bytes_[4:8]) - payload = bytes_[8:-2].decode() - - if (terminator := bytes_[-2:].decode()) != TERMINATOR: + def read(cls, file: IO) -> Packet: + """Reads a packet from a file-like object.""" + size = LittleEndianSignedInt32.read(file) + id_ = LittleEndianSignedInt32.read(file) + type_ = Type.read(file) + payload = file.read(size - 10).decode() + + if (terminator := file.read(2).decode()) != TERMINATOR: LOGGER.warning('Unexpected terminator: %s', terminator) return cls(id_, type_, payload, terminator) @classmethod - def from_args(cls, *args: str) -> Packet: + def make_command(cls, *args: str) -> Packet: """Creates a command packet.""" return cls(random_request_id(), Type.SERVERDATA_EXECCOMMAND, ' '.join(args)) @classmethod - def from_login(cls, passwd: str) -> Packet: + def make_login(cls, passwd: str) -> Packet: """Creates a login packet.""" return cls(random_request_id(), Type.SERVERDATA_AUTH, passwd) @@ -119,17 +120,17 @@ class Packet(NamedTuple): class Client: """An RCON client.""" - __slots__ = ('_socket', 'host', 'port', 'timeout', 'passwd') + __slots__ = ('host', 'port', 'timeout', 'passwd', '_socket') def __init__(self, host: str, port: int, *, timeout: Optional[float] = None, passwd: Optional[str] = None): """Initializes the base client with the SOCK_STREAM socket type.""" - self._socket = socket(type=SOCK_STREAM) self.host = host self.port = port self.timeout = timeout self.passwd = passwd + self._socket = socket(type=SOCK_STREAM) def __enter__(self): """Attempts an auto-login if a password is set.""" @@ -152,34 +153,28 @@ class Client: file.write(bytes(packet)) with self._socket.makefile('rb') as file: - header = file.read(4) - length = int.from_bytes(header, 'little') - payload = file.read(length) - - response = Packet.from_bytes(payload) + response = Packet.read(file) if response.id == packet.id: return response raise RequestIdMismatch(packet.id, response.id) - def login(self, passwd: str) -> bool: + def login(self, passwd: str) -> Packet: """Performs a login.""" - packet = Packet.from_login(passwd) + packet = Packet.make_login(passwd) try: - self.communicate(packet) + return self.communicate(packet) except RequestIdMismatch as mismatch: if mismatch.received == -1: raise WrongPassword() from None raise - return True - def run(self, command: str, *arguments: str, raw: bool = False) -> str: """Runs a command.""" - packet = Packet.from_args(command, *arguments) + packet = Packet.make_command(command, *arguments) try: response = self.communicate(packet)