diff --git a/rcon/async_rcon.py b/rcon/async_rcon.py index af3de82..ba5a02d 100644 --- a/rcon/async_rcon.py +++ b/rcon/async_rcon.py @@ -19,11 +19,11 @@ async def communicate(reader: IO, writer: IO, packet: Packet) -> Packet: async def rcon(command: str, *arguments: str, host: str, port: int, - passwd: str) -> str: + passwd: str, encoding: str = 'utf-8') -> str: """Runs a command asynchronously.""" reader, writer = await open_connection(host, port) - login = Packet.make_login(passwd) + login = Packet.make_login(passwd, encoding=encoding) response = await communicate(reader, writer, login) # Wait for SERVERDATA_AUTH_RESPONSE according to: @@ -34,10 +34,10 @@ async def rcon(command: str, *arguments: str, host: str, port: int, if response.id == -1: raise WrongPassword() - request = Packet.make_command(command, *arguments) + request = Packet.make_command(command, *arguments, encoding=encoding) response = await communicate(reader, writer, request) if response.id != request.id: raise RequestIdMismatch(request.id, response.id) - return response.payload + return response.payload.decode(encoding) diff --git a/rcon/client.py b/rcon/client.py index 5433932..4510cd8 100644 --- a/rcon/client.py +++ b/rcon/client.py @@ -70,9 +70,10 @@ class Client: with self._socket.makefile('rb') as file: return Packet.read(file) - def login(self, passwd: str) -> bool: + def login(self, passwd: str, *, encoding: str = 'utf-8') -> bool: """Performs a login.""" - response = self.communicate(Packet.make_login(passwd)) + request = Packet.make_login(passwd, encoding=encoding) + response = self.communicate(request) # Wait for SERVERDATA_AUTH_RESPONSE according to: # https://developer.valvesoftware.com/wiki/Source_RCON_Protocol @@ -84,12 +85,12 @@ class Client: return True - def run(self, command: str, *arguments: str, raw: bool = False) -> str: + def run(self, command: str, *args: str, encoding: str = 'utf-8') -> str: """Runs a command.""" - request = Packet.make_command(command, *arguments) + request = Packet.make_command(command, *args, encoding=encoding) response = self.communicate(request) if response.id != request.id: raise RequestIdMismatch(request.id, response.id) - return response if raw else response.payload + return response.payload.decode(encoding) diff --git a/rcon/proto.py b/rcon/proto.py index cb653d1..44913f9 100644 --- a/rcon/proto.py +++ b/rcon/proto.py @@ -2,6 +2,7 @@ from __future__ import annotations from enum import Enum +from functools import partial from logging import getLogger from random import randint from typing import IO, NamedTuple @@ -11,7 +12,7 @@ __all__ = ['LittleEndianSignedInt32', 'Type', 'Packet', 'random_request_id'] LOGGER = getLogger(__file__) -TERMINATOR = '\x00\x00' +TERMINATOR = b'\x00\x00' def random_request_id() -> LittleEndianSignedInt32: @@ -80,15 +81,15 @@ class Packet(NamedTuple): id: LittleEndianSignedInt32 type: Type - payload: str - terminator: str = TERMINATOR + payload: bytes + terminator: bytes = TERMINATOR def __bytes__(self): """Returns the packet as bytes with prepended length.""" payload = bytes(self.id) payload += bytes(self.type) - payload += self.payload.encode() - payload += self.terminator.encode() + payload += self.payload + payload += self.terminator size = bytes(LittleEndianSignedInt32(len(payload))) return size + payload @@ -101,10 +102,10 @@ class Packet(NamedTuple): payload = await file.read(size - 10) terminator = await file.read(2) - if (terminator := terminator.decode()) != TERMINATOR: + if terminator != TERMINATOR: LOGGER.warning('Unexpected terminator: %s', terminator) - return cls(id_, type_, payload.decode(), terminator) + return cls(id_, type_, payload, terminator) @classmethod def read(cls, file: IO) -> Packet: @@ -112,8 +113,8 @@ class Packet(NamedTuple): size = LittleEndianSignedInt32.read(file) id_ = LittleEndianSignedInt32.read(file) type_ = Type.read(file) - payload = file.read(size - 10).decode() - terminator = file.read(2).decode() + payload = file.read(size - 10) + terminator = file.read(2) if terminator != TERMINATOR: LOGGER.warning('Unexpected terminator: %s', terminator) @@ -121,12 +122,13 @@ class Packet(NamedTuple): return cls(id_, type_, payload, terminator) @classmethod - def make_command(cls, *args: str) -> Packet: + def make_command(cls, *args: str, encoding: str = 'utf-8') -> Packet: """Creates a command packet.""" return cls(random_request_id(), Type.SERVERDATA_EXECCOMMAND, - ' '.join(args)) + b' '.join(map(partial(str.encode, encoding=encoding), args))) @classmethod - def make_login(cls, passwd: str) -> Packet: + def make_login(cls, passwd: bytes, *, encoding: str = 'utf-8') -> Packet: """Creates a login packet.""" - return cls(random_request_id(), Type.SERVERDATA_AUTH, passwd) + return cls(random_request_id(), Type.SERVERDATA_AUTH, + passwd.encode(encoding)) diff --git a/tests/test_proto.py b/tests/test_proto.py index 8031222..1e4df3b 100644 --- a/tests/test_proto.py +++ b/tests/test_proto.py @@ -137,7 +137,7 @@ class TestPacket(TestCase): self.packet = Packet( random_request_id(), Type.SERVERDATA_EXECCOMMAND, - 'Lorem ipsum sit amet...' + 'Lorem ipsum sit amet...'.encode() ) def test_bytes_rw(self):