Browse Source

Typehint opus.py

pull/7448/head
apple502j 4 years ago
committed by GitHub
parent
commit
9db8698748
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 97
      discord/opus.py

97
discord/opus.py

@ -22,7 +22,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from typing import List, Tuple, TypedDict, Any from __future__ import annotations
from typing import List, Tuple, TypedDict, Any, TYPE_CHECKING, Callable, TypeVar, Literal, Optional, overload
import array import array
import ctypes import ctypes
@ -33,7 +35,12 @@ import os.path
import struct import struct
import sys import sys
from .errors import DiscordException from .errors import DiscordException, InvalidArgument
if TYPE_CHECKING:
T = TypeVar('T')
BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full']
SIGNAL_CTL = Literal['auto', 'voice', 'music']
class BandCtl(TypedDict): class BandCtl(TypedDict):
narrow: int narrow: int
@ -104,13 +111,13 @@ signal_ctl: SignalCtl = {
'music': 3002, 'music': 3002,
} }
def _err_lt(result, func, args): def _err_lt(result: int, func: Callable, args: List) -> int:
if result < OK: if result < OK:
_log.info('error has happened in %s', func.__name__) _log.info('error has happened in %s', func.__name__)
raise OpusError(result) raise OpusError(result)
return result return result
def _err_ne(result, func, args): def _err_ne(result: T, func: Callable, args: List) -> T:
ret = args[-1]._obj ret = args[-1]._obj
if ret.value != OK: if ret.value != OK:
_log.info('error has happened in %s', func.__name__) _log.info('error has happened in %s', func.__name__)
@ -172,7 +179,7 @@ exported_functions: List[Tuple[Any, ...]] = [
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt), [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
] ]
def libopus_loader(name): def libopus_loader(name: str) -> Any:
# create the library... # create the library...
lib = ctypes.cdll.LoadLibrary(name) lib = ctypes.cdll.LoadLibrary(name)
@ -196,7 +203,7 @@ def libopus_loader(name):
return lib return lib
def _load_default(): def _load_default() -> bool:
global _lib global _lib
try: try:
if sys.platform == 'win32': if sys.platform == 'win32':
@ -212,7 +219,7 @@ def _load_default():
return _lib is not None return _lib is not None
def load_opus(name): def load_opus(name: str) -> None:
"""Loads the libopus shared library for use with voice. """Loads the libopus shared library for use with voice.
If this function is not called then the library uses the function If this function is not called then the library uses the function
@ -250,7 +257,7 @@ def load_opus(name):
global _lib global _lib
_lib = libopus_loader(name) _lib = libopus_loader(name)
def is_loaded(): def is_loaded() -> bool:
"""Function to check if opus lib is successfully loaded either """Function to check if opus lib is successfully loaded either
via the :func:`ctypes.util.find_library` call of :func:`load_opus`. via the :func:`ctypes.util.find_library` call of :func:`load_opus`.
@ -273,8 +280,8 @@ class OpusError(DiscordException):
The error code returned. The error code returned.
""" """
def __init__(self, code): def __init__(self, code: int):
self.code = code self.code: int = code
msg = _lib.opus_strerror(self.code).decode('utf-8') msg = _lib.opus_strerror(self.code).decode('utf-8')
_log.info('"%s" has happened', msg) _log.info('"%s" has happened', msg)
super().__init__(msg) super().__init__(msg)
@ -300,92 +307,96 @@ class _OpusStruct:
return _lib.opus_get_version_string().decode('utf-8') return _lib.opus_get_version_string().decode('utf-8')
class Encoder(_OpusStruct): class Encoder(_OpusStruct):
def __init__(self, application=APPLICATION_AUDIO): def __init__(self, application: int = APPLICATION_AUDIO):
_OpusStruct.get_opus_version() _OpusStruct.get_opus_version()
self.application = application self.application: int = application
self._state = self._create_state() self._state: EncoderStruct = self._create_state()
self.set_bitrate(128) self.set_bitrate(128)
self.set_fec(True) self.set_fec(True)
self.set_expected_packet_loss_percent(0.15) self.set_expected_packet_loss_percent(0.15)
self.set_bandwidth('full') self.set_bandwidth('full')
self.set_signal_type('auto') self.set_signal_type('auto')
def __del__(self): def __del__(self) -> None:
if hasattr(self, '_state'): if hasattr(self, '_state'):
_lib.opus_encoder_destroy(self._state) _lib.opus_encoder_destroy(self._state)
self._state = None # This is a destructor, so it's okay to assign None
self._state = None # type: ignore
def _create_state(self): def _create_state(self) -> EncoderStruct:
ret = ctypes.c_int() ret = ctypes.c_int()
return _lib.opus_encoder_create(self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret)) return _lib.opus_encoder_create(self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret))
def set_bitrate(self, kbps): def set_bitrate(self, kbps: int) -> int:
kbps = min(512, max(16, int(kbps))) kbps = min(512, max(16, int(kbps)))
_lib.opus_encoder_ctl(self._state, CTL_SET_BITRATE, kbps * 1024) _lib.opus_encoder_ctl(self._state, CTL_SET_BITRATE, kbps * 1024)
return kbps return kbps
def set_bandwidth(self, req): def set_bandwidth(self, req: BAND_CTL) -> None:
if req not in band_ctl: if req not in band_ctl:
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}') raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}')
k = band_ctl[req] k = band_ctl[req]
_lib.opus_encoder_ctl(self._state, CTL_SET_BANDWIDTH, k) _lib.opus_encoder_ctl(self._state, CTL_SET_BANDWIDTH, k)
def set_signal_type(self, req): def set_signal_type(self, req: SIGNAL_CTL) -> None:
if req not in signal_ctl: if req not in signal_ctl:
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}') raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}')
k = signal_ctl[req] k = signal_ctl[req]
_lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k) _lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k)
def set_fec(self, enabled=True): def set_fec(self, enabled: bool = True) -> None:
_lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0) _lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0)
def set_expected_packet_loss_percent(self, percentage): def set_expected_packet_loss_percent(self, percentage: float) -> None:
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore
def encode(self, pcm, frame_size): def encode(self, pcm: bytes, frame_size: int) -> bytes:
max_data_bytes = len(pcm) max_data_bytes = len(pcm)
pcm = ctypes.cast(pcm, c_int16_ptr) # bytes can be used to reference pointer
pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore
data = (ctypes.c_char * max_data_bytes)() data = (ctypes.c_char * max_data_bytes)()
ret = _lib.opus_encode(self._state, pcm, frame_size, data, max_data_bytes) ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes)
return array.array('b', data[:ret]).tobytes() # array can be initialized with bytes but mypy doesn't know
return array.array('b', data[:ret]).tobytes() # type: ignore
class Decoder(_OpusStruct): class Decoder(_OpusStruct):
def __init__(self): def __init__(self):
_OpusStruct.get_opus_version() _OpusStruct.get_opus_version()
self._state = self._create_state() self._state: DecoderStruct = self._create_state()
def __del__(self): def __del__(self) -> None:
if hasattr(self, '_state'): if hasattr(self, '_state'):
_lib.opus_decoder_destroy(self._state) _lib.opus_decoder_destroy(self._state)
self._state = None # This is a destructor, so it's okay to assign None
self._state = None # type: ignore
def _create_state(self): def _create_state(self) -> DecoderStruct:
ret = ctypes.c_int() ret = ctypes.c_int()
return _lib.opus_decoder_create(self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret)) return _lib.opus_decoder_create(self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret))
@staticmethod @staticmethod
def packet_get_nb_frames(data): def packet_get_nb_frames(data: bytes) -> int:
"""Gets the number of frames in an Opus packet""" """Gets the number of frames in an Opus packet"""
return _lib.opus_packet_get_nb_frames(data, len(data)) return _lib.opus_packet_get_nb_frames(data, len(data))
@staticmethod @staticmethod
def packet_get_nb_channels(data): def packet_get_nb_channels(data: bytes) -> int:
"""Gets the number of channels in an Opus packet""" """Gets the number of channels in an Opus packet"""
return _lib.opus_packet_get_nb_channels(data) return _lib.opus_packet_get_nb_channels(data)
@classmethod @classmethod
def packet_get_samples_per_frame(cls, data): def packet_get_samples_per_frame(cls, data: bytes) -> int:
"""Gets the number of samples per frame from an Opus packet""" """Gets the number of samples per frame from an Opus packet"""
return _lib.opus_packet_get_samples_per_frame(data, cls.SAMPLING_RATE) return _lib.opus_packet_get_samples_per_frame(data, cls.SAMPLING_RATE)
def _set_gain(self, adjustment): def _set_gain(self, adjustment: int) -> int:
"""Configures decoder gain adjustment. """Configures decoder gain adjustment.
Scales the decoded output by a factor specified in Q8 dB units. Scales the decoded output by a factor specified in Q8 dB units.
@ -397,26 +408,34 @@ class Decoder(_OpusStruct):
""" """
return _lib.opus_decoder_ctl(self._state, CTL_SET_GAIN, adjustment) return _lib.opus_decoder_ctl(self._state, CTL_SET_GAIN, adjustment)
def set_gain(self, dB): def set_gain(self, dB: float) -> int:
"""Sets the decoder gain in dB, from -128 to 128.""" """Sets the decoder gain in dB, from -128 to 128."""
dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8) dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8)
return self._set_gain(dB_Q8) return self._set_gain(dB_Q8)
def set_volume(self, mult): def set_volume(self, mult: float) -> int:
"""Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc.""" """Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc."""
return self.set_gain(20 * math.log10(mult)) # amplitude ratio return self.set_gain(20 * math.log10(mult)) # amplitude ratio
def _get_last_packet_duration(self): def _get_last_packet_duration(self) -> int:
"""Gets the duration (in samples) of the last packet successfully decoded or concealed.""" """Gets the duration (in samples) of the last packet successfully decoded or concealed."""
ret = ctypes.c_int32() ret = ctypes.c_int32()
_lib.opus_decoder_ctl(self._state, CTL_LAST_PACKET_DURATION, ctypes.byref(ret)) _lib.opus_decoder_ctl(self._state, CTL_LAST_PACKET_DURATION, ctypes.byref(ret))
return ret.value return ret.value
def decode(self, data, *, fec=False): @overload
def decode(self, data: bytes, *, fec: bool) -> bytes:
...
@overload
def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes:
...
def decode(self, data: Optional[bytes], *, fec: bool = False) -> bytes:
if data is None and fec: if data is None and fec:
raise OpusError("Invalid arguments: FEC cannot be used with null data") raise InvalidArgument("Invalid arguments: FEC cannot be used with null data")
if data is None: if data is None:
frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME

Loading…
Cancel
Save