Browse Source

Add hack to read follow-up packets

pull/14/head
Richard Neumann 3 years ago
parent
commit
06a35dd013
  1. 32
      rcon/source/client.py
  2. 19
      rcon/source/proto.py

32
rcon/source/client.py

@ -23,7 +23,20 @@ class Client(BaseClient, socket_type=SOCK_STREAM):
def read(self) -> Packet: def read(self) -> Packet:
"""Read a packet.""" """Read a packet."""
with self._socket.makefile('rb') as file: with self._socket.makefile('rb') as file:
return Packet.read(file, max_pkg_size=self.max_pkg_size) packet = Packet.read(file)
if self.max_pkg_size and len(packet.payload) >= self.max_pkg_size:
return packet + self.read_followup_packet()
return packet
def read_followup_packet(self) -> Packet | None:
"""Reads a potential followup packet."""
with ChangedTimeout(self, 1) as client:
try:
return client.read()
except TimeoutError:
return None
def login(self, passwd: str, *, encoding: str = 'utf-8') -> bool: def login(self, passwd: str, *, encoding: str = 'utf-8') -> bool:
"""Perform a login.""" """Perform a login."""
@ -49,3 +62,20 @@ class Client(BaseClient, socket_type=SOCK_STREAM):
raise SessionTimeout() raise SessionTimeout()
return response.payload.decode(encoding) return response.payload.decode(encoding)
class ChangedTimeout:
"""Context manager to temporarily change a client's timeout."""
def __init__(self, client: Client, timeout: int | None):
self.client = client
self.timeout = timeout
self.original_timeout = None
def __enter__(self) -> Client:
self.original_timeout = self.client.timeout
self.client.timeout = self.timeout
return self.client
def __exit__(self, exc_type, exc_val, exc_tb):
self.client.timeout = self.original_timeout

19
rcon/source/proto.py

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
from asyncio import StreamReader from asyncio import StreamReader
from contextlib import suppress
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from logging import getLogger from logging import getLogger
@ -80,6 +79,17 @@ class Packet(NamedTuple):
payload: bytes payload: bytes
terminator: bytes = TERMINATOR terminator: bytes = TERMINATOR
def __add__(self, other: Packet):
return Packet(
self.id,
self.type,
self.payload + other.payload,
self.terminator
)
def __radd__(self, other: Packet):
return other.__add__(self)
def __bytes__(self): def __bytes__(self):
"""Return the packet as bytes with prepended length.""" """Return the packet as bytes with prepended length."""
payload = bytes(self.id) payload = bytes(self.id)
@ -104,7 +114,7 @@ class Packet(NamedTuple):
return cls(id_, type_, payload, terminator) return cls(id_, type_, payload, terminator)
@classmethod @classmethod
def read(cls, file: IO, *, max_pkg_size: int | None = None) -> Packet: def read(cls, file: IO) -> Packet:
"""Read a packet from a file-like object.""" """Read a packet from a file-like object."""
size = LittleEndianSignedInt32.read(file) size = LittleEndianSignedInt32.read(file)
id_ = LittleEndianSignedInt32.read(file) id_ = LittleEndianSignedInt32.read(file)
@ -115,11 +125,6 @@ class Packet(NamedTuple):
if terminator != TERMINATOR: if terminator != TERMINATOR:
LOGGER.warning('Unexpected terminator: %s', terminator) LOGGER.warning('Unexpected terminator: %s', terminator)
# Attempt to read following packets on large responses.
if size >= max_pkg_size:
with suppress(TimeoutError):
payload += cls.read(file, max_pkg_size=max_pkg_size).payload
return cls(id_, type_, payload, terminator) return cls(id_, type_, payload, terminator)
@classmethod @classmethod

Loading…
Cancel
Save