From 9db8698748480825b7676cd6845c8248283b6cc2 Mon Sep 17 00:00:00 2001 From: apple502j <33279053+apple502j@users.noreply.github.com> Date: Sun, 22 Aug 2021 19:08:30 +0900 Subject: [PATCH] Typehint opus.py --- discord/opus.py | 97 +++++++++++++++++++++++++++++-------------------- 1 file changed, 58 insertions(+), 39 deletions(-) diff --git a/discord/opus.py b/discord/opus.py index ef95e8d59..97d437a36 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -22,7 +22,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import List, Tuple, TypedDict, Any +from __future__ import annotations + +from typing import List, Tuple, TypedDict, Any, TYPE_CHECKING, Callable, TypeVar, Literal, Optional, overload import array import ctypes @@ -33,7 +35,12 @@ import os.path import struct import sys -from .errors import DiscordException +from .errors import DiscordException, InvalidArgument + +if TYPE_CHECKING: + T = TypeVar('T') + BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full'] + SIGNAL_CTL = Literal['auto', 'voice', 'music'] class BandCtl(TypedDict): narrow: int @@ -104,13 +111,13 @@ signal_ctl: SignalCtl = { 'music': 3002, } -def _err_lt(result, func, args): +def _err_lt(result: int, func: Callable, args: List) -> int: if result < OK: _log.info('error has happened in %s', func.__name__) raise OpusError(result) return result -def _err_ne(result, func, args): +def _err_ne(result: T, func: Callable, args: List) -> T: ret = args[-1]._obj if ret.value != OK: _log.info('error has happened in %s', func.__name__) @@ -172,7 +179,7 @@ exported_functions: List[Tuple[Any, ...]] = [ [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt), ] -def libopus_loader(name): +def libopus_loader(name: str) -> Any: # create the library... lib = ctypes.cdll.LoadLibrary(name) @@ -196,7 +203,7 @@ def libopus_loader(name): return lib -def _load_default(): +def _load_default() -> bool: global _lib try: if sys.platform == 'win32': @@ -212,7 +219,7 @@ def _load_default(): return _lib is not None -def load_opus(name): +def load_opus(name: str) -> None: """Loads the libopus shared library for use with voice. If this function is not called then the library uses the function @@ -250,7 +257,7 @@ def load_opus(name): global _lib _lib = libopus_loader(name) -def is_loaded(): +def is_loaded() -> bool: """Function to check if opus lib is successfully loaded either via the :func:`ctypes.util.find_library` call of :func:`load_opus`. @@ -273,8 +280,8 @@ class OpusError(DiscordException): The error code returned. """ - def __init__(self, code): - self.code = code + def __init__(self, code: int): + self.code: int = code msg = _lib.opus_strerror(self.code).decode('utf-8') _log.info('"%s" has happened', msg) super().__init__(msg) @@ -300,92 +307,96 @@ class _OpusStruct: return _lib.opus_get_version_string().decode('utf-8') class Encoder(_OpusStruct): - def __init__(self, application=APPLICATION_AUDIO): + def __init__(self, application: int = APPLICATION_AUDIO): _OpusStruct.get_opus_version() - self.application = application - self._state = self._create_state() + self.application: int = application + self._state: EncoderStruct = self._create_state() self.set_bitrate(128) self.set_fec(True) self.set_expected_packet_loss_percent(0.15) self.set_bandwidth('full') self.set_signal_type('auto') - def __del__(self): + def __del__(self) -> None: if hasattr(self, '_state'): _lib.opus_encoder_destroy(self._state) - self._state = None + # This is a destructor, so it's okay to assign None + self._state = None # type: ignore - def _create_state(self): + def _create_state(self) -> EncoderStruct: ret = ctypes.c_int() return _lib.opus_encoder_create(self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret)) - def set_bitrate(self, kbps): + def set_bitrate(self, kbps: int) -> int: kbps = min(512, max(16, int(kbps))) _lib.opus_encoder_ctl(self._state, CTL_SET_BITRATE, kbps * 1024) return kbps - def set_bandwidth(self, req): + def set_bandwidth(self, req: BAND_CTL) -> None: if req not in band_ctl: raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}') k = band_ctl[req] _lib.opus_encoder_ctl(self._state, CTL_SET_BANDWIDTH, k) - def set_signal_type(self, req): + def set_signal_type(self, req: SIGNAL_CTL) -> None: if req not in signal_ctl: raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}') k = signal_ctl[req] _lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k) - def set_fec(self, enabled=True): + def set_fec(self, enabled: bool = True) -> None: _lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0) - def set_expected_packet_loss_percent(self, percentage): - _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) + def set_expected_packet_loss_percent(self, percentage: float) -> None: + _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore - def encode(self, pcm, frame_size): + def encode(self, pcm: bytes, frame_size: int) -> bytes: max_data_bytes = len(pcm) - pcm = ctypes.cast(pcm, c_int16_ptr) + # bytes can be used to reference pointer + pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore data = (ctypes.c_char * max_data_bytes)() - ret = _lib.opus_encode(self._state, pcm, frame_size, data, max_data_bytes) + ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes) - return array.array('b', data[:ret]).tobytes() + # array can be initialized with bytes but mypy doesn't know + return array.array('b', data[:ret]).tobytes() # type: ignore class Decoder(_OpusStruct): def __init__(self): _OpusStruct.get_opus_version() - self._state = self._create_state() + self._state: DecoderStruct = self._create_state() - def __del__(self): + def __del__(self) -> None: if hasattr(self, '_state'): _lib.opus_decoder_destroy(self._state) - self._state = None + # This is a destructor, so it's okay to assign None + self._state = None # type: ignore - def _create_state(self): + def _create_state(self) -> DecoderStruct: ret = ctypes.c_int() return _lib.opus_decoder_create(self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret)) @staticmethod - def packet_get_nb_frames(data): + def packet_get_nb_frames(data: bytes) -> int: """Gets the number of frames in an Opus packet""" return _lib.opus_packet_get_nb_frames(data, len(data)) @staticmethod - def packet_get_nb_channels(data): + def packet_get_nb_channels(data: bytes) -> int: """Gets the number of channels in an Opus packet""" return _lib.opus_packet_get_nb_channels(data) @classmethod - def packet_get_samples_per_frame(cls, data): + def packet_get_samples_per_frame(cls, data: bytes) -> int: """Gets the number of samples per frame from an Opus packet""" return _lib.opus_packet_get_samples_per_frame(data, cls.SAMPLING_RATE) - def _set_gain(self, adjustment): + def _set_gain(self, adjustment: int) -> int: """Configures decoder gain adjustment. Scales the decoded output by a factor specified in Q8 dB units. @@ -397,26 +408,34 @@ class Decoder(_OpusStruct): """ return _lib.opus_decoder_ctl(self._state, CTL_SET_GAIN, adjustment) - def set_gain(self, dB): + def set_gain(self, dB: float) -> int: """Sets the decoder gain in dB, from -128 to 128.""" dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8) return self._set_gain(dB_Q8) - def set_volume(self, mult): + def set_volume(self, mult: float) -> int: """Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc.""" return self.set_gain(20 * math.log10(mult)) # amplitude ratio - def _get_last_packet_duration(self): + def _get_last_packet_duration(self) -> int: """Gets the duration (in samples) of the last packet successfully decoded or concealed.""" ret = ctypes.c_int32() _lib.opus_decoder_ctl(self._state, CTL_LAST_PACKET_DURATION, ctypes.byref(ret)) return ret.value - def decode(self, data, *, fec=False): + @overload + def decode(self, data: bytes, *, fec: bool) -> bytes: + ... + + @overload + def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes: + ... + + def decode(self, data: Optional[bytes], *, fec: bool = False) -> bytes: if data is None and fec: - raise OpusError("Invalid arguments: FEC cannot be used with null data") + raise InvalidArgument("Invalid arguments: FEC cannot be used with null data") if data is None: frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME