diff --git a/a2s/a2sasync.py b/a2s/a2sasync.py new file mode 100644 index 0000000..9fb58bd --- /dev/null +++ b/a2s/a2sasync.py @@ -0,0 +1,69 @@ +import asyncio +import logging + +from a2s.exceptions import BrokenMessageError +from a2s.a2sfragment import decode_fragment + + + +HEADER_SIMPLE = b"\xFF\xFF\xFF\xFF" +HEADER_MULTI = b"\xFE\xFF\xFF\xFF" + +logger = logging.getLogger("a2s") + +class A2SProtocol: + def __init__(self): + self.recv_queue = asyncio.Queue() + self.error_event = asyncio.Event() + self.error = None + self.fragment_buf = [] + + def connection_made(self, transport): + self.transport = transport + + def datagram_received(self, packet, addr): + header = packet[:4] + payload = packet[4:] + if header == HEADER_SIMPLE: + logger.debug("Received single packet: %r", payload) + self.recv_queue.put_nowait(payload) + elif header == HEADER_MULTI: + self.fragment_buf.append(decode_fragment(payload)) + if len(self.fragment_buf) < self.fragment_buf[0].fragment_count: + return # Wait for more packets to arrive + self.fragment_buf.sort(key=lambda f: f.fragment_id) + reassembled = b"".join( + fragment.payload for fragment in self.fragment_buf) + logger.debug("Received %s part packet with content: %r", + len(fragments), reassembled) + self.recv_queue.put_nowait(reassembled) + self.fragment_buf = [] + else: + self.error = BrokenMessageError( + "Invalid packet header: " + repr(header)) + self.error_event.set() + + def error_received(self, exc): + self.error = exc + self.error_event.set() + + def send(self, payload): + packet = HEADER_SIMPLE + payload + self.transport.sendto(packet) + + async def recv(self, timeout): + queue_task = asyncio.create_task(self.recv_queue.get()) + error_task = asyncio.create_task(self.error_event.wait()) + done, pending = await asyncio.wait({queue_task, error_task}, + timeout=timeout, return_when=FIRST_COMPLETED) + + for task in pending: task.cancel() + if error_task in done: + error = self.error + self.error = None + self.error_event.clear() + raise error + if not done: + raise asyncio.TimeoutError() + + return queue_task.result() diff --git a/a2s/a2sfragment.py b/a2s/a2sfragment.py new file mode 100644 index 0000000..38a40ba --- /dev/null +++ b/a2s/a2sfragment.py @@ -0,0 +1,39 @@ +import bz2 +import io + +from a2s.byteio import ByteReader + + + +class A2SFragment: + def __init__(self, message_id, fragment_count, fragment_id, mtu, + decompressed_size=0, crc=0, payload=b""): + self.message_id = message_id + self.fragment_count = fragment_count + self.fragment_id = fragment_id + self.mtu = mtu + self.decompressed_size = decompressed_size + self.crc = crc + self.payload = payload + + @property + def is_compressed(self): + return bool(self.message_id & (1 << 15)) + +def decode_fragment(data): + reader = ByteReader( + io.BytesIO(data), endian="<", encoding="utf-8") + frag = A2SFragment( + message_id=reader.read_uint32(), + fragment_count=reader.read_uint8(), + fragment_id=reader.read_uint8(), + mtu=reader.read_uint16() + ) + if frag.is_compressed: + frag.decompressed_size = reader.read_uint32() + frag.crc = reader.read_uint32() + frag.payload = bz2.decompress(reader.read()) + else: + frag.payload = reader.read() + + return frag diff --git a/a2s/a2sstream.py b/a2s/a2sstream.py index 83ef472..abaacd5 100644 --- a/a2s/a2sstream.py +++ b/a2s/a2sstream.py @@ -1,10 +1,8 @@ import socket -import bz2 -import io import logging from a2s.exceptions import BrokenMessageError -from a2s.byteio import ByteReader +from a2s.a2sfragment import decode_fragment @@ -13,39 +11,6 @@ HEADER_MULTI = b"\xFE\xFF\xFF\xFF" logger = logging.getLogger("a2s") -class A2SFragment: - def __init__(self, message_id, fragment_count, fragment_id, mtu, - decompressed_size=0, crc=0, payload=b""): - self.message_id = message_id - self.fragment_count = fragment_count - self.fragment_id = fragment_id - self.mtu = mtu - self.decompressed_size = decompressed_size - self.crc = crc - self.payload = payload - - @property - def is_compressed(self): - return bool(self.message_id & (1 << 15)) - -def decode_fragment(data): - reader = ByteReader( - io.BytesIO(data), endian="<", encoding="utf-8") - frag = A2SFragment( - message_id=reader.read_uint32(), - fragment_count=reader.read_uint8(), - fragment_id=reader.read_uint8(), - mtu=reader.read_uint16() - ) - if frag.is_compressed: - frag.decompressed_size = reader.read_uint32() - frag.crc = reader.read_uint32() - frag.payload = bz2.decompress(reader.read()) - else: - frag.payload = reader.read() - - return frag - class A2SStream: def __init__(self, address, timeout): self.address = address