From 04c8c2f6e74e2d198e1c4faae9430c2c6c52b14f Mon Sep 17 00:00:00 2001 From: Steve Myers <35839355+smyers119@users.noreply.github.com> Date: Sun, 3 Sep 2023 22:51:54 -0400 Subject: [PATCH] Change async_rcon to class By changing this to a class we clean up code duplication. Now we just declare the host, port, and password then call the method with just a command and arguments. tested to confirm working --- README.md | 10 ++- rcon/__init__.py | 2 +- rcon/source/__init__.py | 4 +- rcon/source/async_rcon.py | 156 +++++++++++++++++++++----------------- 4 files changed, 99 insertions(+), 73 deletions(-) diff --git a/README.md b/README.md index abd4cd6..cf84a00 100644 --- a/README.md +++ b/README.md @@ -35,13 +35,19 @@ If you prefer to use Source RCON in an asynchronous environment, you can use `rcon()`. ```python -from rcon.source import rcon +from rcon.source import Rcon -response = await rcon( +response = await Rcon.rcon( 'some_command', 'with', 'some', 'arguments', host='127.0.0.1', port=5000, passwd='mysecretpassword' ) print(response) + +or + +connect = Rcon('127.0.0.1', 12715, 'mysecretpassword') +response = await connect.rcon('some_command','with', 'some', 'arguments') +print(response) ``` ### BattlEye RCon diff --git a/rcon/__init__.py b/rcon/__init__.py index c2bd8b9..5943aa7 100644 --- a/rcon/__init__.py +++ b/rcon/__init__.py @@ -4,7 +4,7 @@ from typing import Any, Coroutine from warnings import warn from rcon.exceptions import EmptyResponse, SessionTimeout, WrongPassword -from rcon.source import rcon as _rcon +from rcon.source import Rcon as _rcon from rcon.source import Client as _Client diff --git a/rcon/source/__init__.py b/rcon/source/__init__.py index ed9a7f8..63c2303 100644 --- a/rcon/source/__init__.py +++ b/rcon/source/__init__.py @@ -1,7 +1,7 @@ """Source RCON implementation.""" -from rcon.source.async_rcon import rcon +from rcon.source.async_rcon import Rcon from rcon.source.client import Client -__all__ = ["Client", "rcon"] +__all__ = ["Client", "Rcon"] diff --git a/rcon/source/async_rcon.py b/rcon/source/async_rcon.py index a012f0f..60aee86 100644 --- a/rcon/source/async_rcon.py +++ b/rcon/source/async_rcon.py @@ -6,77 +6,97 @@ from rcon.exceptions import SessionTimeout, WrongPassword from rcon.source.proto import Packet, Type -__all__ = ["rcon"] - - -async def close(writer: StreamWriter) -> None: - """Close socket asynchronously.""" - - writer.close() - await writer.wait_closed() +__all__ = ["Rcon"] + + +class Rcon: + """ + Set's the variables needed for RCON connection + """ + def __init__(self, + host: str, + port: int = 27015, + password: str = '' + ): + + self.host = host + self.port = port + self.password = password + + async def close(self, writer: StreamWriter) -> None: + """Close socket asynchronously.""" + + writer.close() + await writer.wait_closed() + + async def communicate( + self, + reader: StreamReader, + writer: StreamWriter, + packet: Packet, + *, + frag_threshold: int = 4096, + frag_detect_cmd: str = "", + ) -> Packet: + """Make an asynchronous request.""" + + writer.write(bytes(packet)) + await writer.drain() + response = await Packet.aread(reader) + if len(response.payload) < frag_threshold: + return response -async def communicate( - reader: StreamReader, - writer: StreamWriter, - packet: Packet, - *, - frag_threshold: int = 4096, - frag_detect_cmd: str = "", -) -> Packet: - """Make an asynchronous request.""" + writer.write(bytes(Packet.make_command(frag_detect_cmd))) + await writer.drain() - writer.write(bytes(packet)) - await writer.drain() - response = await Packet.aread(reader) + while (successor := await Packet.aread(reader)).id == response.id: + response += successor - if len(response.payload) < frag_threshold: return response - writer.write(bytes(Packet.make_command(frag_detect_cmd))) - await writer.drain() - - while (successor := await Packet.aread(reader)).id == response.id: - response += successor - - return response - - -async def rcon( - command: str, - *arguments: str, - host: str, - port: int, - passwd: str, - encoding: str = "utf-8", - frag_threshold: int = 4096, - frag_detect_cmd: str = "", -) -> str: - """Run a command asynchronously.""" - - reader, writer = await open_connection(host, port) - response = await communicate( - reader, - writer, - Packet.make_login(passwd, encoding=encoding), - frag_threshold=frag_threshold, - frag_detect_cmd=frag_detect_cmd, - ) - - # Wait for SERVERDATA_AUTH_RESPONSE according to: - # https://developer.valvesoftware.com/wiki/Source_RCON_Protocol - while response.type != Type.SERVERDATA_AUTH_RESPONSE: - response = await Packet.aread(reader) - - if response.id == -1: - await close(writer) - raise WrongPassword() - - request = Packet.make_command(command, *arguments, encoding=encoding) - response = await communicate(reader, writer, request) - await close(writer) - - if response.id != request.id: - raise SessionTimeout() - - return response.payload.decode(encoding) + async def rcon( + self, + command: str, + *arguments: str, + host: str = None, + port: int = None, + passwd: str = None, + encoding: str = "utf-8", + frag_threshold: int = 4096, + frag_detect_cmd: str = "", + ) -> str: + """Run a command asynchronously.""" + try: + host = host or self.host + port = port or self.port + passwd = passwd or self.password + except AttributeError: + return print("Make sure you declare the host, port, or password.") + + reader, writer = await open_connection(host, port) + response = await self.communicate( + reader, + writer, + Packet.make_login(passwd, encoding=encoding), + frag_threshold=frag_threshold, + frag_detect_cmd=frag_detect_cmd, + ) + + # Wait for SERVERDATA_AUTH_RESPONSE according to: + # https://developer.valvesoftware.com/wiki/Source_RCON_Protocol + while response.type != Type.SERVERDATA_AUTH_RESPONSE: + response = await Packet.aread(reader) + + if response.id == -1: + await self.close(writer) + raise WrongPassword() + + request = Packet.make_command(command, *arguments, encoding=encoding) + response = await self.communicate(reader, writer, request) + await self.close(writer) + + if response.id != request.id: + raise SessionTimeout() + + return response.payload.decode(encoding)