|
|
@ -4,7 +4,18 @@ import asyncio |
|
|
|
import io |
|
|
|
import logging |
|
|
|
import time |
|
|
|
from typing import TYPE_CHECKING, Dict, List, NoReturn, Optional, Tuple, Type, TypeVar, Union, overload |
|
|
|
from typing import ( |
|
|
|
TYPE_CHECKING, |
|
|
|
Dict, |
|
|
|
List, |
|
|
|
NoReturn, |
|
|
|
Optional, |
|
|
|
Tuple, |
|
|
|
Type, |
|
|
|
TypeVar, |
|
|
|
Union, |
|
|
|
overload, |
|
|
|
) |
|
|
|
|
|
|
|
from a2s.a2s_fragment import A2SFragment, decode_fragment |
|
|
|
from a2s.byteio import ByteReader |
|
|
@ -30,28 +41,42 @@ T = TypeVar("T", InfoProtocol, PlayersProtocol, RulesProtocol) |
|
|
|
|
|
|
|
@overload |
|
|
|
async def request_async( |
|
|
|
address: Tuple[str, int], timeout: float, encoding: str, a2s_proto: Type[InfoProtocol] |
|
|
|
address: Tuple[str, int], |
|
|
|
timeout: float, |
|
|
|
encoding: str, |
|
|
|
a2s_proto: Type[InfoProtocol], |
|
|
|
) -> Union[SourceInfo, GoldSrcInfo]: |
|
|
|
... |
|
|
|
|
|
|
|
|
|
|
|
@overload |
|
|
|
async def request_async( |
|
|
|
address: Tuple[str, int], timeout: float, encoding: str, a2s_proto: Type[PlayersProtocol] |
|
|
|
address: Tuple[str, int], |
|
|
|
timeout: float, |
|
|
|
encoding: str, |
|
|
|
a2s_proto: Type[PlayersProtocol], |
|
|
|
) -> List[Player]: |
|
|
|
... |
|
|
|
|
|
|
|
|
|
|
|
@overload |
|
|
|
async def request_async( |
|
|
|
address: Tuple[str, int], timeout: float, encoding: str, a2s_proto: Type[RulesProtocol] |
|
|
|
address: Tuple[str, int], |
|
|
|
timeout: float, |
|
|
|
encoding: str, |
|
|
|
a2s_proto: Type[RulesProtocol], |
|
|
|
) -> Dict[Union[str, bytes], Union[str, bytes]]: |
|
|
|
... |
|
|
|
|
|
|
|
|
|
|
|
async def request_async( |
|
|
|
address: Tuple[str, int], timeout: float, encoding: str, a2s_proto: Type[T] |
|
|
|
) -> Union[SourceInfo, GoldSrcInfo, List[Player], Dict[Union[str, bytes], Union[str, bytes]]]: |
|
|
|
) -> Union[ |
|
|
|
SourceInfo, |
|
|
|
GoldSrcInfo, |
|
|
|
List[Player], |
|
|
|
Dict[Union[str, bytes], Union[str, bytes]], |
|
|
|
]: |
|
|
|
conn = await A2SStreamAsync.create(address, timeout) |
|
|
|
response = await request_async_impl(conn, encoding, a2s_proto) |
|
|
|
conn.close() |
|
|
@ -101,7 +126,12 @@ async def request_async_impl( |
|
|
|
challenge: int = 0, |
|
|
|
retries: int = 0, |
|
|
|
ping: Optional[float] = None, |
|
|
|
) -> Union[SourceInfo, GoldSrcInfo, Dict[Union[str, bytes], Union[str, bytes]], List[Player]]: |
|
|
|
) -> Union[ |
|
|
|
SourceInfo, |
|
|
|
GoldSrcInfo, |
|
|
|
Dict[Union[str, bytes], Union[str, bytes]], |
|
|
|
List[Player], |
|
|
|
]: |
|
|
|
send_time = time.monotonic() |
|
|
|
resp_data = await conn.request(a2s_proto.serialize_request(challenge)) |
|
|
|
recv_time = time.monotonic() |
|
|
@ -114,12 +144,18 @@ async def request_async_impl( |
|
|
|
response_type = reader.read_uint8() |
|
|
|
if response_type == A2S_CHALLENGE_RESPONSE: |
|
|
|
if retries >= DEFAULT_RETRIES: |
|
|
|
raise BrokenMessageError("Server keeps sending challenge responses") |
|
|
|
raise BrokenMessageError( |
|
|
|
"Server keeps sending challenge responses" |
|
|
|
) |
|
|
|
challenge = reader.read_uint32() |
|
|
|
return await request_async_impl(conn, encoding, a2s_proto, challenge, retries + 1, ping) |
|
|
|
return await request_async_impl( |
|
|
|
conn, encoding, a2s_proto, challenge, retries + 1, ping |
|
|
|
) |
|
|
|
|
|
|
|
if not a2s_proto.validate_response_type(response_type): |
|
|
|
raise BrokenMessageError("Invalid response type: " + hex(response_type)) |
|
|
|
raise BrokenMessageError( |
|
|
|
"Invalid response type: " + hex(response_type) |
|
|
|
) |
|
|
|
|
|
|
|
return a2s_proto.deserialize_response(reader, response_type, ping) |
|
|
|
|
|
|
@ -153,15 +189,23 @@ class A2SProtocol(asyncio.DatagramProtocol): |
|
|
|
if len(self.fragment_buf) < self.fragment_buf[0].fragment_count: |
|
|
|
return # Wait for more packets to arrive |
|
|
|
self.fragment_buf.sort(key=lambda f: f.fragment_id) |
|
|
|
reassembled = b"".join(fragment.payload for fragment in self.fragment_buf) |
|
|
|
reassembled = b"".join( |
|
|
|
fragment.payload for fragment in self.fragment_buf |
|
|
|
) |
|
|
|
# Sometimes there's an additional header present |
|
|
|
if reassembled.startswith(b"\xFF\xFF\xFF\xFF"): |
|
|
|
reassembled = reassembled[4:] |
|
|
|
logger.debug("Received %s part packet with content: %r", len(self.fragment_buf), reassembled) |
|
|
|
logger.debug( |
|
|
|
"Received %s part packet with content: %r", |
|
|
|
len(self.fragment_buf), |
|
|
|
reassembled, |
|
|
|
) |
|
|
|
self.recv_queue.put_nowait(reassembled) |
|
|
|
self.fragment_buf = [] |
|
|
|
else: |
|
|
|
self.error = BrokenMessageError("Invalid packet header: " + repr(header)) |
|
|
|
self.error = BrokenMessageError( |
|
|
|
"Invalid packet header: " + repr(header) |
|
|
|
) |
|
|
|
self.error_event.set() |
|
|
|
|
|
|
|
def error_received(self, exc: Exception) -> None: |
|
|
@ -183,7 +227,12 @@ class A2SStreamAsync: |
|
|
|
"timeout", |
|
|
|
) |
|
|
|
|
|
|
|
def __init__(self, transport: asyncio.DatagramTransport, protocol: A2SProtocol, timeout: float) -> None: |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
transport: asyncio.DatagramTransport, |
|
|
|
protocol: A2SProtocol, |
|
|
|
timeout: float, |
|
|
|
) -> None: |
|
|
|
self.transport: asyncio.DatagramTransport = transport |
|
|
|
self.protocol: A2SProtocol = protocol |
|
|
|
self.timeout: float = timeout |
|
|
@ -194,7 +243,9 @@ class A2SStreamAsync: |
|
|
|
@classmethod |
|
|
|
async def create(cls, address: Tuple[str, int], timeout: float) -> Self: |
|
|
|
loop = asyncio.get_running_loop() |
|
|
|
transport, protocol = await loop.create_datagram_endpoint(lambda: A2SProtocol(), remote_addr=address) |
|
|
|
transport, protocol = await loop.create_datagram_endpoint( |
|
|
|
lambda: A2SProtocol(), remote_addr=address |
|
|
|
) |
|
|
|
return cls(transport, protocol, timeout) |
|
|
|
|
|
|
|
def send(self, payload: bytes) -> None: |
|
|
@ -206,7 +257,9 @@ class A2SStreamAsync: |
|
|
|
queue_task = asyncio.create_task(self.protocol.recv_queue.get()) |
|
|
|
error_task = asyncio.create_task(self.protocol.error_event.wait()) |
|
|
|
done, pending = await asyncio.wait( |
|
|
|
{queue_task, error_task}, timeout=self.timeout, return_when=asyncio.FIRST_COMPLETED |
|
|
|
{queue_task, error_task}, |
|
|
|
timeout=self.timeout, |
|
|
|
return_when=asyncio.FIRST_COMPLETED, |
|
|
|
) |
|
|
|
|
|
|
|
for task in pending: |
|
|
|