Browse Source

Merge remote-tracking branch 'github/master'

pull/7/head
Richard Neumann 3 years ago
parent
commit
43f474b977
  1. 8
      rcon/async_rcon.py
  2. 11
      rcon/client.py
  3. 28
      rcon/proto.py
  4. 2
      tests/test_proto.py

8
rcon/async_rcon.py

@ -19,11 +19,11 @@ async def communicate(reader: IO, writer: IO, packet: Packet) -> Packet:
async def rcon(command: str, *arguments: str, host: str, port: int,
passwd: str) -> str:
passwd: str, encoding: str = 'utf-8') -> str:
"""Runs a command asynchronously."""
reader, writer = await open_connection(host, port)
login = Packet.make_login(passwd)
login = Packet.make_login(passwd, encoding=encoding)
response = await communicate(reader, writer, login)
# Wait for SERVERDATA_AUTH_RESPONSE according to:
@ -34,10 +34,10 @@ async def rcon(command: str, *arguments: str, host: str, port: int,
if response.id == -1:
raise WrongPassword()
request = Packet.make_command(command, *arguments)
request = Packet.make_command(command, *arguments, encoding=encoding)
response = await communicate(reader, writer, request)
if response.id != request.id:
raise RequestIdMismatch(request.id, response.id)
return response.payload
return response.payload.decode(encoding)

11
rcon/client.py

@ -70,9 +70,10 @@ class Client:
with self._socket.makefile('rb') as file:
return Packet.read(file)
def login(self, passwd: str) -> bool:
def login(self, passwd: str, *, encoding: str = 'utf-8') -> bool:
"""Performs a login."""
response = self.communicate(Packet.make_login(passwd))
request = Packet.make_login(passwd, encoding=encoding)
response = self.communicate(request)
# Wait for SERVERDATA_AUTH_RESPONSE according to:
# https://developer.valvesoftware.com/wiki/Source_RCON_Protocol
@ -84,12 +85,12 @@ class Client:
return True
def run(self, command: str, *arguments: str, raw: bool = False) -> str:
def run(self, command: str, *args: str, encoding: str = 'utf-8') -> str:
"""Runs a command."""
request = Packet.make_command(command, *arguments)
request = Packet.make_command(command, *args, encoding=encoding)
response = self.communicate(request)
if response.id != request.id:
raise RequestIdMismatch(request.id, response.id)
return response if raw else response.payload
return response.payload.decode(encoding)

28
rcon/proto.py

@ -2,6 +2,7 @@
from __future__ import annotations
from enum import Enum
from functools import partial
from logging import getLogger
from random import randint
from typing import IO, NamedTuple
@ -11,7 +12,7 @@ __all__ = ['LittleEndianSignedInt32', 'Type', 'Packet', 'random_request_id']
LOGGER = getLogger(__file__)
TERMINATOR = '\x00\x00'
TERMINATOR = b'\x00\x00'
def random_request_id() -> LittleEndianSignedInt32:
@ -80,15 +81,15 @@ class Packet(NamedTuple):
id: LittleEndianSignedInt32
type: Type
payload: str
terminator: str = TERMINATOR
payload: bytes
terminator: bytes = TERMINATOR
def __bytes__(self):
"""Returns the packet as bytes with prepended length."""
payload = bytes(self.id)
payload += bytes(self.type)
payload += self.payload.encode()
payload += self.terminator.encode()
payload += self.payload
payload += self.terminator
size = bytes(LittleEndianSignedInt32(len(payload)))
return size + payload
@ -101,10 +102,10 @@ class Packet(NamedTuple):
payload = await file.read(size - 10)
terminator = await file.read(2)
if (terminator := terminator.decode()) != TERMINATOR:
if terminator != TERMINATOR:
LOGGER.warning('Unexpected terminator: %s', terminator)
return cls(id_, type_, payload.decode(), terminator)
return cls(id_, type_, payload, terminator)
@classmethod
def read(cls, file: IO) -> Packet:
@ -112,8 +113,8 @@ class Packet(NamedTuple):
size = LittleEndianSignedInt32.read(file)
id_ = LittleEndianSignedInt32.read(file)
type_ = Type.read(file)
payload = file.read(size - 10).decode()
terminator = file.read(2).decode()
payload = file.read(size - 10)
terminator = file.read(2)
if terminator != TERMINATOR:
LOGGER.warning('Unexpected terminator: %s', terminator)
@ -121,12 +122,13 @@ class Packet(NamedTuple):
return cls(id_, type_, payload, terminator)
@classmethod
def make_command(cls, *args: str) -> Packet:
def make_command(cls, *args: str, encoding: str = 'utf-8') -> Packet:
"""Creates a command packet."""
return cls(random_request_id(), Type.SERVERDATA_EXECCOMMAND,
' '.join(args))
b' '.join(map(partial(str.encode, encoding=encoding), args)))
@classmethod
def make_login(cls, passwd: str) -> Packet:
def make_login(cls, passwd: bytes, *, encoding: str = 'utf-8') -> Packet:
"""Creates a login packet."""
return cls(random_request_id(), Type.SERVERDATA_AUTH, passwd)
return cls(random_request_id(), Type.SERVERDATA_AUTH,
passwd.encode(encoding))

2
tests/test_proto.py

@ -137,7 +137,7 @@ class TestPacket(TestCase):
self.packet = Packet(
random_request_id(),
Type.SERVERDATA_EXECCOMMAND,
'Lorem ipsum sit amet...'
'Lorem ipsum sit amet...'.encode()
)
def test_bytes_rw(self):

Loading…
Cancel
Save