Browse Source

Implement async requests

async
Gabriel Huber 5 years ago
parent
commit
f6120952a3
  1. 43
      a2s/a2sasync.py
  2. 11
      a2s/a2sstream.py
  3. 26
      a2s/info.py
  4. 69
      a2s/players.py
  5. 63
      a2s/rules.py

43
a2s/a2sasync.py

@ -11,7 +11,7 @@ HEADER_MULTI = b"\xFE\xFF\xFF\xFF"
logger = logging.getLogger("a2s") logger = logging.getLogger("a2s")
class A2SProtocol: class A2SProtocol(asyncio.DatagramProtocol):
def __init__(self): def __init__(self):
self.recv_queue = asyncio.Queue() self.recv_queue = asyncio.Queue()
self.error_event = asyncio.Event() self.error_event = asyncio.Event()
@ -47,23 +47,48 @@ class A2SProtocol:
self.error = exc self.error = exc
self.error_event.set() self.error_event.set()
def raise_on_error():
error = self.error
self.error = None
self.error_event.clear()
raise error
class A2SStreamAsync:
def __init__(self, transport, protocol, timeout):
self.transport = transport
self.protocol = protocol
self.timeout = timeout
def __del__(self):
self.close()
@classmethod
async def create(cls, address, timeout):
transport, protocol = await asyncio.create_datagram_endpoint(
lambda: A2SProtocol(), remote_addr=address)
return cls(transport, protocol, timeout)
def send(self, payload): def send(self, payload):
packet = HEADER_SIMPLE + payload packet = HEADER_SIMPLE + payload
self.transport.sendto(packet) self.transport.sendto(packet)
async def recv(self, timeout): async def recv(self):
queue_task = asyncio.create_task(self.recv_queue.get()) queue_task = asyncio.create_task(self.protocol.recv_queue.get())
error_task = asyncio.create_task(self.error_event.wait()) error_task = asyncio.create_task(self.protocol.error_event.wait())
done, pending = await asyncio.wait({queue_task, error_task}, done, pending = await asyncio.wait({queue_task, error_task},
timeout=timeout, return_when=FIRST_COMPLETED) timeout=self.timeout, return_when=FIRST_COMPLETED)
for task in pending: task.cancel() for task in pending: task.cancel()
if error_task in done: if error_task in done:
error = self.error self.protocol.raise_on_error()
self.error = None
self.error_event.clear()
raise error
if not done: if not done:
raise asyncio.TimeoutError() raise asyncio.TimeoutError()
return queue_task.result() return queue_task.result()
async def request(payload):
self.send(payload)
return await self.recv()
def close(self):
self.transport.close()

11
a2s/a2sstream.py

@ -45,12 +45,9 @@ class A2SStream:
raise BrokenMessageError( raise BrokenMessageError(
"Invalid packet header: " + repr(header)) "Invalid packet header: " + repr(header))
def request(payload):
self.send(payload)
return self.recv()
def close(self): def close(self):
self._socket.close() self._socket.close()
def request(address, data, timeout):
stream = A2SStream(address, timeout)
stream.send(data)
resp = stream.recv()
stream.close()
return resp

26
a2s/info.py

@ -3,7 +3,8 @@ import io
from a2s.exceptions import BrokenMessageError, BufferExhaustedError from a2s.exceptions import BrokenMessageError, BufferExhaustedError
from a2s.defaults import DEFAULT_TIMEOUT, DEFAULT_ENCODING from a2s.defaults import DEFAULT_TIMEOUT, DEFAULT_ENCODING
from a2s.a2sstream import request from a2s.a2sstream import A2SStream
from a2s.a2sasync import A2SStreamAsync
from a2s.byteio import ByteReader from a2s.byteio import ByteReader
from a2s.datacls import DataclsMeta from a2s.datacls import DataclsMeta
@ -244,10 +245,7 @@ def parse_goldsrc(reader):
return resp return resp
def info(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): def info_response(resp_data):
send_time = time.monotonic()
resp_data = request(address, b"\x54Source Engine Query\0", timeout)
recv_time = time.monotonic()
reader = ByteReader( reader = ByteReader(
io.BytesIO(resp_data), endian="<", encoding=encoding) io.BytesIO(resp_data), endian="<", encoding=encoding)
@ -262,3 +260,21 @@ def info(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING):
resp.ping = recv_time - send_time resp.ping = recv_time - send_time
return resp return resp
def info(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING):
conn = A2SStream(address, timeout)
send_time = time.monotonic()
resp_data = conn.request(b"\x54Source Engine Query\0")
recv_time = time.monotonic()
conn.close()
return info_response(resp_data)
async def info_async(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING):
conn = await A2SStreamAsync.create(address, timeout)
send_time = time.monotonic()
resp_data = await conn.request(b"\x54Source Engine Query\0")
recv_time = time.monotonic()
conn.close()
return info_response(resp_data)

69
a2s/players.py

@ -4,7 +4,8 @@ from typing import List
from a2s.exceptions import BrokenMessageError from a2s.exceptions import BrokenMessageError
from a2s.defaults import DEFAULT_TIMEOUT, DEFAULT_ENCODING, \ from a2s.defaults import DEFAULT_TIMEOUT, DEFAULT_ENCODING, \
DEFAULT_RETRIES DEFAULT_RETRIES
from a2s.a2sstream import request from a2s.a2sstream import A2SStream
from a2s.a2sasync import A2SStreamAsync
from a2s.byteio import ByteReader from a2s.byteio import ByteReader
from a2s.datacls import DataclsMeta from a2s.datacls import DataclsMeta
@ -26,13 +27,28 @@ class Player(metaclass=DataclsMeta):
"""Time the player has been connected to the server""" """Time the player has been connected to the server"""
duration: float duration: float
def players(address, timeout=DEFAULT_TIMEOUT, def players_response(reader):
encoding=DEFAULT_ENCODING): player_count = reader.read_uint8()
return players_impl(address, timeout, encoding) resp = [
Player(
index=reader.read_uint8(),
name=reader.read_cstring(),
score=reader.read_int32(),
duration=reader.read_float()
)
for player_num in range(player_count)
]
def players_impl(address, timeout, encoding, challenge=0, retries=0): return resp
resp_data = request(
address, b"\x55" + challenge.to_bytes(4, "little"), timeout) def players(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING):
conn = A2SStream(address, timeout)
reader = players_request(conn, encoding)
conn.close()
return players_response(reader)
def players_request(conn, encoding, challenge=0, retries=0):
resp_data = conn.request(b"\x55" + challenge.to_bytes(4, "little"))
reader = ByteReader( reader = ByteReader(
io.BytesIO(resp_data), endian="<", encoding=encoding) io.BytesIO(resp_data), endian="<", encoding=encoding)
@ -43,21 +59,36 @@ def players_impl(address, timeout, encoding, challenge=0, retries=0):
"Server keeps sending challenge responses") "Server keeps sending challenge responses")
challenge = reader.read_uint32() challenge = reader.read_uint32()
return players_impl( return players_impl(
address, timeout, encoding, challenge, retries + 1) conn, encoding, challenge, retries + 1)
if response_type != A2S_PLAYER_RESPONSE: if response_type != A2S_PLAYER_RESPONSE:
raise BrokenMessageError( raise BrokenMessageError(
"Invalid response type: " + str(response_type)) "Invalid response type: " + str(response_type))
player_count = reader.read_uint8() return reader
resp = [
Player(
index=reader.read_uint8(),
name=reader.read_cstring(),
score=reader.read_int32(),
duration=reader.read_float()
)
for player_num in range(player_count)
]
return resp async def players_async(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING):
conn = await A2SStream.create(address, timeout)
reader = await players_request_async(conn, encoding)
conn.close()
return players_response(reader)
async def players_request_async(conn, encoding, challenge=0, retries=0):
resp_data = await conn.request(b"\x55" + challenge.to_bytes(4, "little"))
reader = ByteReader(
io.BytesIO(resp_data), endian="<", encoding=encoding)
response_type = reader.read_uint8()
if response_type == A2S_CHALLENGE_RESPONSE:
if retries >= DEFAULT_RETRIES:
raise BrokenMessageError(
"Server keeps sending challenge responses")
challenge = reader.read_uint32()
return await players_impl(
conn, encoding, challenge, retries + 1)
if response_type != A2S_PLAYER_RESPONSE:
raise BrokenMessageError(
"Invalid response type: " + str(response_type))
return reader

63
a2s/rules.py

@ -3,7 +3,8 @@ import io
from a2s.exceptions import BrokenMessageError from a2s.exceptions import BrokenMessageError
from a2s.defaults import DEFAULT_TIMEOUT, DEFAULT_ENCODING, \ from a2s.defaults import DEFAULT_TIMEOUT, DEFAULT_ENCODING, \
DEFAULT_RETRIES DEFAULT_RETRIES
from a2s.a2sstream import request from a2s.a2sstream import A2SStream
from a2s.a2sasync import A2SStreamAsync
from a2s.byteio import ByteReader from a2s.byteio import ByteReader
@ -11,12 +12,24 @@ from a2s.byteio import ByteReader
A2S_RULES_RESPONSE = 0x45 A2S_RULES_RESPONSE = 0x45
A2S_CHALLENGE_RESPONSE = 0x41 A2S_CHALLENGE_RESPONSE = 0x41
def rules_response(reader):
rule_count = reader.read_int16()
# Have to use tuples to preserve evaluation order
resp = dict(
(reader.read_cstring(), reader.read_cstring())
for rule_num in range(rule_count)
)
return resp
def rules(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING): def rules(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING):
return rules_impl(address, timeout, encoding) conn = A2SStream(address, timeout)
reader = rules_request(conn, encoding)
conn.close()
return rules_response(reader)
def rules_impl(address, timeout, encoding, challenge=0, retries=0): def rules_request(conn, encoding, challenge=0, retries=0):
resp_data = request( resp_data = conn.request(b"\x56" + challenge.to_bytes(4, "little"))
address, b"\x56" + challenge.to_bytes(4, "little"), timeout)
reader = ByteReader( reader = ByteReader(
io.BytesIO(resp_data), endian="<", encoding=encoding) io.BytesIO(resp_data), endian="<", encoding=encoding)
@ -36,18 +49,40 @@ def rules_impl(address, timeout, encoding, challenge=0, retries=0):
raise BrokenMessageError( raise BrokenMessageError(
"Server keeps sending challenge responses") "Server keeps sending challenge responses")
challenge = reader.read_uint32() challenge = reader.read_uint32()
return rules_impl( return rules_request(
address, timeout, encoding, challenge, retries + 1) conn, encoding, challenge, retries + 1)
if response_type != A2S_RULES_RESPONSE: if response_type != A2S_RULES_RESPONSE:
raise BrokenMessageError( raise BrokenMessageError(
"Invalid response type: " + str(response_type)) "Invalid response type: " + str(response_type))
rule_count = reader.read_int16() return reader
# Have to use tuples to preserve evaluation order
resp = dict(
(reader.read_cstring(), reader.read_cstring())
for rule_num in range(rule_count)
)
return resp async def rules(address, timeout=DEFAULT_TIMEOUT, encoding=DEFAULT_ENCODING):
conn = await A2SStreamAsync.create(address, timeout)
reader = await rules_request_async(conn, encoding)
conn.close()
return rules_response(reader)
async def rules_request_async(conn, encoding, challenge=0, retries=0):
resp_data = conn.request(b"\x56" + challenge.to_bytes(4, "little"))
reader = ByteReader(
io.BytesIO(resp_data), endian="<", encoding=encoding)
if reader.peek(4) == b"\xFF\xFF\xFF\xFF":
reader.read(4)
response_type = reader.read_uint8()
if response_type == A2S_CHALLENGE_RESPONSE:
if retries >= DEFAULT_RETRIES:
raise BrokenMessageError(
"Server keeps sending challenge responses")
challenge = reader.read_uint32()
return await rules_request(
conn, encoding, challenge, retries + 1)
if response_type != A2S_RULES_RESPONSE:
raise BrokenMessageError(
"Invalid response type: " + str(response_type))
return reader

Loading…
Cancel
Save