diff --git a/a2s/a2sasync.py b/a2s/a2sasync.py index 9fb58bd..2356a7e 100644 --- a/a2s/a2sasync.py +++ b/a2s/a2sasync.py @@ -11,7 +11,7 @@ HEADER_MULTI = b"\xFE\xFF\xFF\xFF" logger = logging.getLogger("a2s") -class A2SProtocol: +class A2SProtocol(asyncio.DatagramProtocol): def __init__(self): self.recv_queue = asyncio.Queue() self.error_event = asyncio.Event() @@ -47,23 +47,48 @@ class A2SProtocol: self.error = exc self.error_event.set() + def raise_on_error(): + error = self.error + self.error = None + self.error_event.clear() + raise error + +class A2SStreamAsync: + def __init__(self, transport, protocol, timeout): + self.transport = transport + self.protocol = protocol + self.timeout = timeout + + def __del__(self): + self.close() + + @classmethod + async def create(cls, address, timeout): + transport, protocol = await asyncio.create_datagram_endpoint( + lambda: A2SProtocol(), remote_addr=address) + return cls(transport, protocol, timeout) + def send(self, payload): packet = HEADER_SIMPLE + payload self.transport.sendto(packet) - async def recv(self, timeout): - queue_task = asyncio.create_task(self.recv_queue.get()) - error_task = asyncio.create_task(self.error_event.wait()) + async def recv(self): + queue_task = asyncio.create_task(self.protocol.recv_queue.get()) + error_task = asyncio.create_task(self.protocol.error_event.wait()) done, pending = await asyncio.wait({queue_task, error_task}, - timeout=timeout, return_when=FIRST_COMPLETED) + timeout=self.timeout, return_when=FIRST_COMPLETED) for task in pending: task.cancel() if error_task in done: - error = self.error - self.error = None - self.error_event.clear() - raise error + self.protocol.raise_on_error() if not done: raise asyncio.TimeoutError() return queue_task.result() + + async def request(payload): + self.send(payload) + return await self.recv() + + def close(self): + self.transport.close() diff --git a/a2s/a2sstream.py b/a2s/a2sstream.py index abaacd5..313d9a0 100644 --- a/a2s/a2sstream.py +++ b/a2s/a2sstream.py @@ -45,12 +45,9 @@ class A2SStream: raise BrokenMessageError( "Invalid packet header: " + repr(header)) + def request(payload): + self.send(payload) + return self.recv() + def close(self): self._socket.close() - -def request(address, data, timeout): - stream = A2SStream(address, timeout) - stream.send(data) - resp = stream.recv() - stream.close() - return resp diff --git a/a2s/info.py b/a2s/info.py index 50c2e88..a810068 100644 --- a/a2s/info.py +++ b/a2s/info.py @@ -3,7 +3,8 @@ import io from a2s.exceptions import BrokenMessageError, BufferExhaustedError from a2s.defaults import DEFAULT_TIMEOUT, DEFAULT_ENCODING -from a2s.a2sstream import request +from a2s.a2sstream import A2SStream +from a2s.a2sasync import A2SStreamAsync from a2s.byteio import ByteReader from a2s.datacls import DataclsMeta @@ -244,10 +245,7 @@ def parse_goldsrc(reader): return resp -def info(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): - send_time = time.monotonic() - resp_data = request(address, b"\x54Source Engine Query\0", timeout) - recv_time = time.monotonic() +def info_response(resp_data): reader = ByteReader( io.BytesIO(resp_data), endian="<", encoding=encoding) @@ -262,3 +260,21 @@ def info(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): resp.ping = recv_time - send_time return resp + +def info(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): + conn = A2SStream(address, timeout) + send_time = time.monotonic() + resp_data = conn.request(b"\x54Source Engine Query\0") + recv_time = time.monotonic() + conn.close() + + return info_response(resp_data) + +async def info_async(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): + conn = await A2SStreamAsync.create(address, timeout) + send_time = time.monotonic() + resp_data = await conn.request(b"\x54Source Engine Query\0") + recv_time = time.monotonic() + conn.close() + + return info_response(resp_data) diff --git a/a2s/players.py b/a2s/players.py index 71e3460..49d0634 100644 --- a/a2s/players.py +++ b/a2s/players.py @@ -4,7 +4,8 @@ from typing import List from a2s.exceptions import BrokenMessageError from a2s.defaults import DEFAULT_TIMEOUT, DEFAULT_ENCODING, \ DEFAULT_RETRIES -from a2s.a2sstream import request +from a2s.a2sstream import A2SStream +from a2s.a2sasync import A2SStreamAsync from a2s.byteio import ByteReader from a2s.datacls import DataclsMeta @@ -26,13 +27,28 @@ class Player(metaclass=DataclsMeta): """Time the player has been connected to the server""" duration: float -def players(address, timeout=DEFAULT_TIMEOUT, - encoding=DEFAULT_ENCODING): - return players_impl(address, timeout, encoding) +def players_response(reader): + player_count = reader.read_uint8() + resp = [ + Player( + index=reader.read_uint8(), + name=reader.read_cstring(), + score=reader.read_int32(), + duration=reader.read_float() + ) + for player_num in range(player_count) + ] -def players_impl(address, timeout, encoding, challenge=0, retries=0): - resp_data = request( - address, b"\x55" + challenge.to_bytes(4, "little"), timeout) + return resp + +def players(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): + conn = A2SStream(address, timeout) + reader = players_request(conn, encoding) + conn.close() + return players_response(reader) + +def players_request(conn, encoding, challenge=0, retries=0): + resp_data = conn.request(b"\x55" + challenge.to_bytes(4, "little")) reader = ByteReader( io.BytesIO(resp_data), endian="<", encoding=encoding) @@ -43,21 +59,36 @@ def players_impl(address, timeout, encoding, challenge=0, retries=0): "Server keeps sending challenge responses") challenge = reader.read_uint32() return players_impl( - address, timeout, encoding, challenge, retries + 1) + conn, encoding, challenge, retries + 1) if response_type != A2S_PLAYER_RESPONSE: raise BrokenMessageError( "Invalid response type: " + str(response_type)) - player_count = reader.read_uint8() - resp = [ - Player( - index=reader.read_uint8(), - name=reader.read_cstring(), - score=reader.read_int32(), - duration=reader.read_float() - ) - for player_num in range(player_count) - ] + return reader - return resp +async def players_async(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): + conn = await A2SStream.create(address, timeout) + reader = await players_request_async(conn, encoding) + conn.close() + return players_response(reader) + +async def players_request_async(conn, encoding, challenge=0, retries=0): + resp_data = await conn.request(b"\x55" + challenge.to_bytes(4, "little")) + reader = ByteReader( + io.BytesIO(resp_data), endian="<", encoding=encoding) + + response_type = reader.read_uint8() + if response_type == A2S_CHALLENGE_RESPONSE: + if retries >= DEFAULT_RETRIES: + raise BrokenMessageError( + "Server keeps sending challenge responses") + challenge = reader.read_uint32() + return await players_impl( + conn, encoding, challenge, retries + 1) + + if response_type != A2S_PLAYER_RESPONSE: + raise BrokenMessageError( + "Invalid response type: " + str(response_type)) + + return reader diff --git a/a2s/rules.py b/a2s/rules.py index 8812f1f..fad10d5 100644 --- a/a2s/rules.py +++ b/a2s/rules.py @@ -3,7 +3,8 @@ import io from a2s.exceptions import BrokenMessageError from a2s.defaults import DEFAULT_TIMEOUT, DEFAULT_ENCODING, \ DEFAULT_RETRIES -from a2s.a2sstream import request +from a2s.a2sstream import A2SStream +from a2s.a2sasync import A2SStreamAsync from a2s.byteio import ByteReader @@ -11,12 +12,24 @@ from a2s.byteio import ByteReader A2S_RULES_RESPONSE = 0x45 A2S_CHALLENGE_RESPONSE = 0x41 +def rules_response(reader): + rule_count = reader.read_int16() + # Have to use tuples to preserve evaluation order + resp = dict( + (reader.read_cstring(), reader.read_cstring()) + for rule_num in range(rule_count) + ) + + return resp + def rules(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): - return rules_impl(address, timeout, encoding) + conn = A2SStream(address, timeout) + reader = rules_request(conn, encoding) + conn.close() + return rules_response(reader) -def rules_impl(address, timeout, encoding, challenge=0, retries=0): - resp_data = request( - address, b"\x56" + challenge.to_bytes(4, "little"), timeout) +def rules_request(conn, encoding, challenge=0, retries=0): + resp_data = conn.request(b"\x56" + challenge.to_bytes(4, "little")) reader = ByteReader( io.BytesIO(resp_data), endian="<", encoding=encoding) @@ -36,18 +49,40 @@ def rules_impl(address, timeout, encoding, challenge=0, retries=0): raise BrokenMessageError( "Server keeps sending challenge responses") challenge = reader.read_uint32() - return rules_impl( - address, timeout, encoding, challenge, retries + 1) + return rules_request( + conn, encoding, challenge, retries + 1) if response_type != A2S_RULES_RESPONSE: raise BrokenMessageError( "Invalid response type: " + str(response_type)) - rule_count = reader.read_int16() - # Have to use tuples to preserve evaluation order - resp = dict( - (reader.read_cstring(), reader.read_cstring()) - for rule_num in range(rule_count) - ) + return reader - return resp +async def rules(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): + conn = await A2SStreamAsync.create(address, timeout) + reader = await rules_request_async(conn, encoding) + conn.close() + return rules_response(reader) + +async def rules_request_async(conn, encoding, challenge=0, retries=0): + resp_data = conn.request(b"\x56" + challenge.to_bytes(4, "little")) + reader = ByteReader( + io.BytesIO(resp_data), endian="<", encoding=encoding) + + if reader.peek(4) == b"\xFF\xFF\xFF\xFF": + reader.read(4) + + response_type = reader.read_uint8() + if response_type == A2S_CHALLENGE_RESPONSE: + if retries >= DEFAULT_RETRIES: + raise BrokenMessageError( + "Server keeps sending challenge responses") + challenge = reader.read_uint32() + return await rules_request( + conn, encoding, challenge, retries + 1) + + if response_type != A2S_RULES_RESPONSE: + raise BrokenMessageError( + "Invalid response type: " + str(response_type)) + + return reader