Browse Source

Merge remote-tracking branch 'github/master'

pull/7/head
Richard Neumann 3 years ago
parent
commit
43f474b977
  1. 8
      rcon/async_rcon.py
  2. 11
      rcon/client.py
  3. 28
      rcon/proto.py
  4. 2
      tests/test_proto.py

8
rcon/async_rcon.py

@ -19,11 +19,11 @@ async def communicate(reader: IO, writer: IO, packet: Packet) -> Packet:
async def rcon(command: str, *arguments: str, host: str, port: int, async def rcon(command: str, *arguments: str, host: str, port: int,
passwd: str) -> str: passwd: str, encoding: str = 'utf-8') -> str:
"""Runs a command asynchronously.""" """Runs a command asynchronously."""
reader, writer = await open_connection(host, port) reader, writer = await open_connection(host, port)
login = Packet.make_login(passwd) login = Packet.make_login(passwd, encoding=encoding)
response = await communicate(reader, writer, login) response = await communicate(reader, writer, login)
# Wait for SERVERDATA_AUTH_RESPONSE according to: # Wait for SERVERDATA_AUTH_RESPONSE according to:
@ -34,10 +34,10 @@ async def rcon(command: str, *arguments: str, host: str, port: int,
if response.id == -1: if response.id == -1:
raise WrongPassword() raise WrongPassword()
request = Packet.make_command(command, *arguments) request = Packet.make_command(command, *arguments, encoding=encoding)
response = await communicate(reader, writer, request) response = await communicate(reader, writer, request)
if response.id != request.id: if response.id != request.id:
raise RequestIdMismatch(request.id, response.id) raise RequestIdMismatch(request.id, response.id)
return response.payload return response.payload.decode(encoding)

11
rcon/client.py

@ -70,9 +70,10 @@ class Client:
with self._socket.makefile('rb') as file: with self._socket.makefile('rb') as file:
return Packet.read(file) return Packet.read(file)
def login(self, passwd: str) -> bool: def login(self, passwd: str, *, encoding: str = 'utf-8') -> bool:
"""Performs a login.""" """Performs a login."""
response = self.communicate(Packet.make_login(passwd)) request = Packet.make_login(passwd, encoding=encoding)
response = self.communicate(request)
# Wait for SERVERDATA_AUTH_RESPONSE according to: # Wait for SERVERDATA_AUTH_RESPONSE according to:
# https://developer.valvesoftware.com/wiki/Source_RCON_Protocol # https://developer.valvesoftware.com/wiki/Source_RCON_Protocol
@ -84,12 +85,12 @@ class Client:
return True return True
def run(self, command: str, *arguments: str, raw: bool = False) -> str: def run(self, command: str, *args: str, encoding: str = 'utf-8') -> str:
"""Runs a command.""" """Runs a command."""
request = Packet.make_command(command, *arguments) request = Packet.make_command(command, *args, encoding=encoding)
response = self.communicate(request) response = self.communicate(request)
if response.id != request.id: if response.id != request.id:
raise RequestIdMismatch(request.id, response.id) raise RequestIdMismatch(request.id, response.id)
return response if raw else response.payload return response.payload.decode(encoding)

28
rcon/proto.py

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from enum import Enum from enum import Enum
from functools import partial
from logging import getLogger from logging import getLogger
from random import randint from random import randint
from typing import IO, NamedTuple from typing import IO, NamedTuple
@ -11,7 +12,7 @@ __all__ = ['LittleEndianSignedInt32', 'Type', 'Packet', 'random_request_id']
LOGGER = getLogger(__file__) LOGGER = getLogger(__file__)
TERMINATOR = '\x00\x00' TERMINATOR = b'\x00\x00'
def random_request_id() -> LittleEndianSignedInt32: def random_request_id() -> LittleEndianSignedInt32:
@ -80,15 +81,15 @@ class Packet(NamedTuple):
id: LittleEndianSignedInt32 id: LittleEndianSignedInt32
type: Type type: Type
payload: str payload: bytes
terminator: str = TERMINATOR terminator: bytes = TERMINATOR
def __bytes__(self): def __bytes__(self):
"""Returns the packet as bytes with prepended length.""" """Returns the packet as bytes with prepended length."""
payload = bytes(self.id) payload = bytes(self.id)
payload += bytes(self.type) payload += bytes(self.type)
payload += self.payload.encode() payload += self.payload
payload += self.terminator.encode() payload += self.terminator
size = bytes(LittleEndianSignedInt32(len(payload))) size = bytes(LittleEndianSignedInt32(len(payload)))
return size + payload return size + payload
@ -101,10 +102,10 @@ class Packet(NamedTuple):
payload = await file.read(size - 10) payload = await file.read(size - 10)
terminator = await file.read(2) terminator = await file.read(2)
if (terminator := terminator.decode()) != TERMINATOR: if terminator != TERMINATOR:
LOGGER.warning('Unexpected terminator: %s', terminator) LOGGER.warning('Unexpected terminator: %s', terminator)
return cls(id_, type_, payload.decode(), terminator) return cls(id_, type_, payload, terminator)
@classmethod @classmethod
def read(cls, file: IO) -> Packet: def read(cls, file: IO) -> Packet:
@ -112,8 +113,8 @@ class Packet(NamedTuple):
size = LittleEndianSignedInt32.read(file) size = LittleEndianSignedInt32.read(file)
id_ = LittleEndianSignedInt32.read(file) id_ = LittleEndianSignedInt32.read(file)
type_ = Type.read(file) type_ = Type.read(file)
payload = file.read(size - 10).decode() payload = file.read(size - 10)
terminator = file.read(2).decode() terminator = file.read(2)
if terminator != TERMINATOR: if terminator != TERMINATOR:
LOGGER.warning('Unexpected terminator: %s', terminator) LOGGER.warning('Unexpected terminator: %s', terminator)
@ -121,12 +122,13 @@ class Packet(NamedTuple):
return cls(id_, type_, payload, terminator) return cls(id_, type_, payload, terminator)
@classmethod @classmethod
def make_command(cls, *args: str) -> Packet: def make_command(cls, *args: str, encoding: str = 'utf-8') -> Packet:
"""Creates a command packet.""" """Creates a command packet."""
return cls(random_request_id(), Type.SERVERDATA_EXECCOMMAND, return cls(random_request_id(), Type.SERVERDATA_EXECCOMMAND,
' '.join(args)) b' '.join(map(partial(str.encode, encoding=encoding), args)))
@classmethod @classmethod
def make_login(cls, passwd: str) -> Packet: def make_login(cls, passwd: bytes, *, encoding: str = 'utf-8') -> Packet:
"""Creates a login packet.""" """Creates a login packet."""
return cls(random_request_id(), Type.SERVERDATA_AUTH, passwd) return cls(random_request_id(), Type.SERVERDATA_AUTH,
passwd.encode(encoding))

2
tests/test_proto.py

@ -137,7 +137,7 @@ class TestPacket(TestCase):
self.packet = Packet( self.packet = Packet(
random_request_id(), random_request_id(),
Type.SERVERDATA_EXECCOMMAND, Type.SERVERDATA_EXECCOMMAND,
'Lorem ipsum sit amet...' 'Lorem ipsum sit amet...'.encode()
) )
def test_bytes_rw(self): def test_bytes_rw(self):

Loading…
Cancel
Save