From ab5906e63bde7e042edbb3811a38ee6ae56646f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jeremy=20Lain=C3=A9?= Date: Sat, 2 Nov 2019 22:33:21 +0100 Subject: [PATCH] Add more type hints --- .gitignore | 4 +- aiortc/codecs/__init__.py | 56 ++++++------ aiortc/codecs/_opus.pyi | 4 + aiortc/codecs/_vpx.pyi | 4 + aiortc/rtcdatachannel.py | 109 +++++++++++------------ aiortc/rtcdtlstransport.py | 24 +++--- aiortc/rtcicetransport.py | 59 ++++++------- aiortc/rtcpeerconnection.py | 67 +++++++++----- aiortc/rtcrtpreceiver.py | 37 ++++---- aiortc/rtcrtpsender.py | 24 +++--- aiortc/rtcrtptransceiver.py | 56 ++++++++---- aiortc/rtcsctptransport.py | 116 +++++++++++++------------ aiortc/rtp.py | 163 ++++++++++++++++++----------------- aiortc/sdp.py | 38 ++++---- stubs/audioop.pyi | 8 +- stubs/pyee.pyi | 6 +- tests/test_rtcrtpreceiver.py | 4 +- tests/test_rtcrtpsender.py | 4 +- tests/utils.py | 2 +- 19 files changed, 435 insertions(+), 350 deletions(-) create mode 100644 aiortc/codecs/_opus.pyi create mode 100644 aiortc/codecs/_vpx.pyi diff --git a/.gitignore b/.gitignore index da58c8625..785971d98 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,11 @@ *.egg-info *.pyc +*.so .coverage .eggs .idea +.mypy_cache .vscode -/aiortc/codecs/_opus.* -/aiortc/codecs/_vpx.* /build /dist /docs/_build diff --git a/aiortc/codecs/__init__.py b/aiortc/codecs/__init__.py index d7eda538c..ee730920f 100644 --- a/aiortc/codecs/__init__.py +++ b/aiortc/codecs/__init__.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from ..rtcrtpparameters import ( RTCRtcpFeedback, @@ -108,34 +108,36 @@ def depayload(codec: RTCRtpCodecParameters, payload: bytes) -> bytes: return payload -def get_capabilities(kind): - if kind in CODECS: - codecs = [] - rtx_added = False - for params in CODECS[kind]: - if not is_rtx(params): - codecs.append( - RTCRtpCodecCapability( - mimeType=params.mimeType, - clockRate=params.clockRate, - channels=params.channels, - parameters=params.parameters, - ) +def get_capabilities(kind: str) -> RTCRtpCapabilities: + if kind not in CODECS: + raise ValueError("cannot get capabilities for unknown media %s" % kind) + + codecs = [] + rtx_added = False + for params in CODECS[kind]: + if not is_rtx(params): + codecs.append( + RTCRtpCodecCapability( + mimeType=params.mimeType, + clockRate=params.clockRate, + channels=params.channels, + parameters=params.parameters, ) - elif not rtx_added: - # There will only be a single entry in codecs[] for retransmission - # via RTX, with sdpFmtpLine not present. - codecs.append( - RTCRtpCodecCapability( - mimeType=params.mimeType, clockRate=params.clockRate - ) + ) + elif not rtx_added: + # There will only be a single entry in codecs[] for retransmission + # via RTX, with sdpFmtpLine not present. + codecs.append( + RTCRtpCodecCapability( + mimeType=params.mimeType, clockRate=params.clockRate ) - rtx_added = True + ) + rtx_added = True - headerExtensions = [] - for params in HEADER_EXTENSIONS[kind]: - headerExtensions.append(RTCRtpHeaderExtensionCapability(uri=params.uri)) - return RTCRtpCapabilities(codecs=codecs, headerExtensions=headerExtensions) + headerExtensions = [] + for extension in HEADER_EXTENSIONS[kind]: + headerExtensions.append(RTCRtpHeaderExtensionCapability(uri=extension.uri)) + return RTCRtpCapabilities(codecs=codecs, headerExtensions=headerExtensions) def get_decoder(codec: RTCRtpCodecParameters): @@ -168,7 +170,7 @@ def get_encoder(codec: RTCRtpCodecParameters): return Vp8Encoder() -def is_rtx(codec: RTCRtpCodecParameters) -> bool: +def is_rtx(codec: Union[RTCRtpCodecCapability, RTCRtpCodecParameters]) -> bool: return codec.name.lower() == "rtx" diff --git a/aiortc/codecs/_opus.pyi b/aiortc/codecs/_opus.pyi new file mode 100644 index 000000000..9991e617b --- /dev/null +++ b/aiortc/codecs/_opus.pyi @@ -0,0 +1,4 @@ +from typing import Any + +ffi: Any +lib: Any diff --git a/aiortc/codecs/_vpx.pyi b/aiortc/codecs/_vpx.pyi new file mode 100644 index 000000000..9991e617b --- /dev/null +++ b/aiortc/codecs/_vpx.pyi @@ -0,0 +1,4 @@ +from typing import Any + +ffi: Any +lib: Any diff --git a/aiortc/rtcdatachannel.py b/aiortc/rtcdatachannel.py index 18366a11b..43578cc76 100644 --- a/aiortc/rtcdatachannel.py +++ b/aiortc/rtcdatachannel.py @@ -1,4 +1,5 @@ import logging +from typing import Optional, Union import attr from pyee import AsyncIOEventEmitter @@ -8,6 +9,41 @@ logger = logging.getLogger("datachannel") +@attr.s +class RTCDataChannelParameters: + """ + The :class:`RTCDataChannelParameters` dictionary describes the + configuration of an :class:`RTCDataChannel`. + """ + + label = attr.ib(default="") # type: str + "A name describing the data channel." + + maxPacketLifeTime = attr.ib(default=None) # type: Optional[int] + "The maximum time in milliseconds during which transmissions are attempted." + + maxRetransmits = attr.ib(default=None) # type: Optional[int] + "The maximum number of retransmissions that are attempted." + + ordered = attr.ib(default=True) # type: bool + "Whether the data channel guarantees in-order delivery of messages." + + protocol = attr.ib(default="") # type: str + "The name of the subprotocol in use." + + negotiated = attr.ib(default=False) # type: bool + """ + Whether data channel will be negotiated out of-band, where both sides + create data channel with an agreed-upon ID.""" + + id = attr.ib(default=None) # type: Optional[int] + """ + An numeric ID for the channel; permitted values are 0-65534. + If you don't include this option, the user agent will select an ID for you. + Must be set when negotiating out-of-band. + """ + + class RTCDataChannel(AsyncIOEventEmitter): """ The :class:`RTCDataChannel` interface represents a network channel which @@ -17,7 +53,9 @@ class RTCDataChannel(AsyncIOEventEmitter): :param: parameters: An :class:`RTCDataChannelParameters`. """ - def __init__(self, transport, parameters, send_open=True): + def __init__( + self, transport, parameters: RTCDataChannelParameters, send_open: bool = True + ) -> None: super().__init__() self.__bufferedAmount = 0 self.__bufferedAmountLowThreshold = 0 @@ -43,21 +81,21 @@ def __init__(self, transport, parameters, send_open=True): self.__transport._data_channel_add_negotiated(self) @property - def bufferedAmount(self): + def bufferedAmount(self) -> int: """ The number of bytes of data currently queued to be sent over the data channel. """ return self.__bufferedAmount @property - def bufferedAmountLowThreshold(self): + def bufferedAmountLowThreshold(self) -> int: """ The number of bytes of buffered outgoing data that is considered "low". """ return self.__bufferedAmountLowThreshold @bufferedAmountLowThreshold.setter - def bufferedAmountLowThreshold(self, value): + def bufferedAmountLowThreshold(self, value: int) -> None: if value < 0 or value > 4294967295: raise ValueError( "bufferedAmountLowThreshold must be in range 0 - 4294967295" @@ -65,21 +103,21 @@ def bufferedAmountLowThreshold(self, value): self.__bufferedAmountLowThreshold = value @property - def negotiated(self): + def negotiated(self) -> bool: """ Whether data channel was negotiated out-of-band. """ return self.__parameters.negotiated @property - def id(self): + def id(self) -> Optional[int]: """ An ID number which uniquely identifies the data channel. """ return self.__id @property - def label(self): + def label(self) -> str: """ A name describing the data channel. @@ -88,35 +126,35 @@ def label(self): return self.__parameters.label @property - def ordered(self): + def ordered(self) -> bool: """ Indicates whether or not the data channel guarantees in-order delivery of messages. """ return self.__parameters.ordered @property - def maxPacketLifeTime(self): + def maxPacketLifeTime(self) -> Optional[int]: """ The maximum time in milliseconds during which transmissions are attempted. """ return self.__parameters.maxPacketLifeTime @property - def maxRetransmits(self): + def maxRetransmits(self) -> Optional[int]: """ "The maximum number of retransmissions that are attempted. """ return self.__parameters.maxRetransmits @property - def protocol(self): + def protocol(self) -> str: """ The name of the subprotocol in use. """ return self.__parameters.protocol @property - def readyState(self): + def readyState(self) -> str: """ A string indicating the current state of the underlying data transport. """ @@ -129,13 +167,13 @@ def transport(self): """ return self.__transport - def close(self): + def close(self) -> None: """ Close the data channel. """ self.transport._data_channel_close(self) - def send(self, data): + def send(self, data: Union[bytes, str]) -> None: """ Send `data` across the data channel to the remote peer. """ @@ -147,7 +185,7 @@ def send(self, data): self.transport._data_channel_send(self, data) - def _addBufferedAmount(self, amount): + def _addBufferedAmount(self, amount: int) -> None: crosses_threshold = ( self.__bufferedAmount > self.bufferedAmountLowThreshold and self.__bufferedAmount + amount <= self.bufferedAmountLowThreshold @@ -156,10 +194,10 @@ def _addBufferedAmount(self, amount): if crosses_threshold: self.emit("bufferedamountlow") - def _setId(self, id): + def _setId(self, id: int) -> None: self.__id = id - def _setReadyState(self, state): + def _setReadyState(self, state: str) -> None: if state != self.__readyState: self.__log_debug("- %s -> %s", self.__readyState, state) self.__readyState = state @@ -173,40 +211,5 @@ def _setReadyState(self, state): # to facilitate garbage collection. self.remove_all_listeners() - def __log_debug(self, msg, *args): + def __log_debug(self, msg: str, *args) -> None: logger.debug(str(self.id) + " " + msg, *args) - - -@attr.s -class RTCDataChannelParameters: - """ - The :class:`RTCDataChannelParameters` dictionary describes the - configuration of an :class:`RTCDataChannel`. - """ - - label = attr.ib(default="") - "A name describing the data channel." - - maxPacketLifeTime = attr.ib(default=None) - "The maximum time in milliseconds during which transmissions are attempted." - - maxRetransmits = attr.ib(default=None) - "The maximum number of retransmissions that are attempted." - - ordered = attr.ib(default=True) - "Whether the data channel guarantees in-order delivery of messages." - - protocol = attr.ib(default="") - "The name of the subprotocol in use." - - negotiated = attr.ib(default=False) - """ - Whether data channel will be negotiated out of-band, where both sides - create data channel with an agreed-upon ID.""" - - id = attr.ib(default=None) - """ - An numeric ID for the channel; permitted values are 0-65534. - If you don't include this option, the user agent will select an ID for you. - Must be set when negotiating out-of-band. - """ diff --git a/aiortc/rtcdtlstransport.py b/aiortc/rtcdtlstransport.py index 840829dd6..05182ab09 100644 --- a/aiortc/rtcdtlstransport.py +++ b/aiortc/rtcdtlstransport.py @@ -117,7 +117,7 @@ def text(charp) -> str: return errors -def get_srtp_key_salt(src, idx): +def get_srtp_key_salt(src, idx: int) -> bytes: key_start = idx * SRTP_KEY_LEN salt_start = 2 * SRTP_KEY_LEN + idx * SRTP_SALT_LEN return ( @@ -397,7 +397,7 @@ def transport(self): """ return self._transport - def getLocalParameters(self): + def getLocalParameters(self) -> RTCDtlsParameters: """ Get the local parameters of the DTLS transport. @@ -407,7 +407,7 @@ def getLocalParameters(self): fingerprints=self.__local_certificate.getFingerprints() ) - async def start(self, remoteParameters): + async def start(self, remoteParameters: RTCDtlsParameters) -> None: """ Start DTLS transport negotiation with the parameters of the remote DTLS transport. @@ -497,7 +497,7 @@ async def start(self, remoteParameters): self._set_state(State.CONNECTED) self._task = asyncio.ensure_future(self.__run()) - async def stop(self): + async def stop(self) -> None: """ Stop and close the DTLS transport. """ @@ -513,7 +513,7 @@ async def stop(self): pass self.__log_debug("- DTLS shutdown complete") - async def __run(self): + async def __run(self) -> None: try: while True: await self._recv_next() @@ -546,7 +546,7 @@ def _get_stats(self): ) return report - async def _handle_rtcp_data(self, data): + async def _handle_rtcp_data(self, data: bytes) -> None: try: packets = RtcpPacket.parse(data) except ValueError as exc: @@ -558,7 +558,7 @@ async def _handle_rtcp_data(self, data): for recipient in self._rtp_router.route_rtcp(packet): await recipient._handle_rtcp_packet(packet) - async def _handle_rtp_data(self, data, arrival_time_ms): + async def _handle_rtp_data(self, data: bytes, arrival_time_ms: int) -> None: try: packet = RtpPacket.parse(data, self._rtp_header_extensions_map) except ValueError as exc: @@ -570,7 +570,7 @@ async def _handle_rtp_data(self, data, arrival_time_ms): if receiver is not None: await receiver._handle_rtp_packet(packet, arrival_time_ms=arrival_time_ms) - async def _recv_next(self): + async def _recv_next(self) -> None: # get timeout timeout = None if not self.encrypted: @@ -659,7 +659,7 @@ async def _send_rtp(self, data): self.__tx_bytes += len(data) self.__tx_packets += 1 - def _set_state(self, state): + def _set_state(self, state: State) -> None: if state != self._state: self.__log_debug("- %s -> %s", self._state, state) self._state = state @@ -675,7 +675,7 @@ def _unregister_rtp_receiver(self, receiver): def _unregister_rtp_sender(self, sender): self._rtp_router.unregister_sender(sender) - async def _write_ssl(self): + async def _write_ssl(self) -> None: """ Flush outgoing data which OpenSSL put in our BIO to the transport. """ @@ -688,8 +688,8 @@ async def _write_ssl(self): self.__tx_bytes += result self.__tx_packets += 1 - def __log_debug(self, msg, *args): + def __log_debug(self, msg: str, *args) -> None: logger.debug(self._role + " " + msg, *args) - def __log_warning(self, msg, *args): + def __log_warning(self, msg: str, *args) -> None: logger.warning(self._role + " " + msg, *args) diff --git a/aiortc/rtcicetransport.py b/aiortc/rtcicetransport.py index baff19b67..59dd32f6e 100644 --- a/aiortc/rtcicetransport.py +++ b/aiortc/rtcicetransport.py @@ -1,7 +1,7 @@ import asyncio import logging import re -from typing import List, Optional +from typing import Any, Dict, List, Optional import attr from aioice import Candidate, Connection @@ -83,8 +83,8 @@ def candidate_to_aioice(x: RTCIceCandidate) -> Candidate: ) -def connection_kwargs(servers): - kwargs = {} +def connection_kwargs(servers: List[RTCIceServer]) -> Dict[str, Any]: + kwargs = {} # type: Dict[str, Any] for server in servers: if isinstance(server.urls, list): @@ -128,7 +128,7 @@ def connection_kwargs(servers): return kwargs -def parse_stun_turn_uri(uri): +def parse_stun_turn_uri(uri: str) -> Dict[str, Any]: if uri.startswith("stun"): match = STUN_REGEX.fullmatch(uri) elif uri.startswith("turn"): @@ -140,21 +140,21 @@ def parse_stun_turn_uri(uri): raise ValueError("malformed uri") # set port - match = match.groupdict() - if match["port"]: - match["port"] = int(match["port"]) - elif match["scheme"] in ["stuns", "turns"]: - match["port"] = 5349 + parsed = match.groupdict() + if parsed["port"]: + parsed["port"] = int(parsed["port"]) + elif parsed["scheme"] in ["stuns", "turns"]: + parsed["port"] = 5349 else: - match["port"] = 3478 + parsed["port"] = 3478 # set transport - if match["scheme"] == "turn" and not match["transport"]: - match["transport"] = "udp" - elif match["scheme"] == "turns" and not match["transport"]: - match["transport"] = "tcp" + if parsed["scheme"] == "turn" and not parsed["transport"]: + parsed["transport"] = "udp" + elif parsed["scheme"] == "turns" and not parsed["transport"]: + parsed["transport"] = "tcp" - return match + return parsed class RTCIceGatherer(AsyncIOEventEmitter): @@ -165,7 +165,7 @@ class RTCIceGatherer(AsyncIOEventEmitter): exchanged in signaling. """ - def __init__(self, iceServers=None): + def __init__(self, iceServers: Optional[List[RTCIceServer]] = None) -> None: super().__init__() if iceServers is None: @@ -230,22 +230,22 @@ class RTCIceTransport(AsyncIOEventEmitter): :param: gatherer: An :class:`RTCIceGatherer`. """ - def __init__(self, gatherer): + def __init__(self, gatherer: RTCIceGatherer) -> None: super().__init__() - self.__start = None + self.__start = None # type: Optional[asyncio.Event] self.__iceGatherer = gatherer self.__state = "new" self._connection = gatherer._connection @property - def iceGatherer(self): + def iceGatherer(self) -> RTCIceGatherer: """ The ICE gatherer passed in the constructor. """ return self.__iceGatherer @property - def role(self): + def role(self) -> str: """ The current role of the ICE transport. @@ -257,13 +257,13 @@ def role(self): return "controlled" @property - def state(self): + def state(self) -> str: """ The current state of the ICE transport. """ return self.__state - def addRemoteCandidate(self, candidate): + def addRemoteCandidate(self, candidate: Optional[RTCIceCandidate]) -> None: """ Add a remote candidate. """ @@ -281,7 +281,7 @@ def getRemoteCandidates(self): """ return [candidate_from_aioice(x) for x in self._connection.remote_candidates] - async def start(self, remoteParameters): + async def start(self, remoteParameters: RTCIceParameters) -> None: """ Initiate connectivity checks. @@ -293,7 +293,8 @@ async def start(self, remoteParameters): # handle the case where start is already in progress if self.__start is not None: - return await self.__start.wait() + await self.__start.wait() + return self.__start = asyncio.Event() self.__setState("checking") @@ -307,7 +308,7 @@ async def start(self, remoteParameters): self.__setState("completed") self.__start.set() - async def stop(self): + async def stop(self) -> None: """ Irreversibly stop the :class:`RTCIceTransport`. """ @@ -315,7 +316,7 @@ async def stop(self): self.__setState("closed") await self._connection.close() - async def _recv(self): + async def _recv(self) -> bytes: try: return await self._connection.recv() except ConnectionError: @@ -323,7 +324,7 @@ async def _recv(self): self.__setState("failed") raise - async def _send(self, data): + async def _send(self, data: bytes) -> None: try: await self._connection.send(data) except ConnectionError: @@ -331,10 +332,10 @@ async def _send(self, data): self.__setState("failed") raise - def __log_debug(self, msg, *args): + def __log_debug(self, msg: str, *args) -> None: logger.debug(self.role + " " + msg, *args) - def __setState(self, state): + def __setState(self, state: str) -> None: if state != self.__state: self.__log_debug("- %s -> %s", self.__state, state) self.__state = state diff --git a/aiortc/rtcpeerconnection.py b/aiortc/rtcpeerconnection.py index cd89f7197..b4e4097b2 100644 --- a/aiortc/rtcpeerconnection.py +++ b/aiortc/rtcpeerconnection.py @@ -2,7 +2,7 @@ import copy import uuid from collections import OrderedDict -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Union from pyee import AsyncIOEventEmitter @@ -12,14 +12,17 @@ from .exceptions import InternalError, InvalidAccessError, InvalidStateError from .rtcconfiguration import RTCConfiguration from .rtcdatachannel import RTCDataChannel, RTCDataChannelParameters -from .rtcdtlstransport import RTCCertificate, RTCDtlsTransport -from .rtcicetransport import RTCIceGatherer, RTCIceTransport +from .rtcdtlstransport import RTCCertificate, RTCDtlsParameters, RTCDtlsTransport +from .rtcicetransport import RTCIceGatherer, RTCIceParameters, RTCIceTransport from .rtcrtpparameters import ( + RTCRtpCodecCapability, RTCRtpCodecParameters, RTCRtpDecodingParameters, + RTCRtpHeaderExtensionParameters, RTCRtpParameters, RTCRtpReceiveParameters, RTCRtpRtxParameters, + RTCRtpSendParameters, ) from .rtcrtpreceiver import RemoteStreamTrack, RTCRtpReceiver from .rtcrtpsender import RTCRtpSender @@ -34,7 +37,7 @@ def filter_preferred_codecs( - codecs: List[RTCRtpCodecParameters], preferred: List[RTCRtpCodecParameters] + codecs: List[RTCRtpCodecParameters], preferred: List[RTCRtpCodecCapability] ) -> List[RTCRtpCodecParameters]: if not preferred: return codecs @@ -63,9 +66,12 @@ def filter_preferred_codecs( return filtered -def find_common_codecs(local_codecs, remote_codecs): +def find_common_codecs( + local_codecs: List[RTCRtpCodecParameters], + remote_codecs: List[RTCRtpCodecParameters], +) -> List[RTCRtpCodecParameters]: common = [] - common_base = {} + common_base = {} # type: Dict[int, RTCRtpCodecParameters] for c in remote_codecs: # for RTX, check we accepted the base codec if is_rtx(c): @@ -102,7 +108,10 @@ def find_common_codecs(local_codecs, remote_codecs): return common -def find_common_header_extensions(local_extensions, remote_extensions): +def find_common_header_extensions( + local_extensions: List[RTCRtpHeaderExtensionParameters], + remote_extensions: List[RTCRtpHeaderExtensionParameters], +) -> List[RTCRtpHeaderExtensionParameters]: common = [] for rx in remote_extensions: for lx in local_extensions: @@ -135,7 +144,9 @@ def add_transport_description( media.dtls.role = "client" -def add_remote_candidates(iceTransport, media): +def add_remote_candidates( + iceTransport: RTCIceTransport, media: sdp.MediaDescription +) -> None: for candidate in media.ice_candidates: iceTransport.addRemoteCandidate(candidate) if media.ice_candidates_complete: @@ -261,13 +272,17 @@ def __init__(self, configuration: Optional[RTCConfiguration] = None) -> None: self.__configuration = configuration or RTCConfiguration() self.__iceTransports = set() # type: Set[RTCIceTransport] self.__initialOfferer = None - self.__remoteDtls = {} # type: Dict[RTCRtpTransceiver, RTCDtlsTransport] - self.__remoteIce = {} # type: Dict[RTCRtpTransceiver, RTCIceTransport] + self.__remoteDtls = ( + {} + ) # type: Dict[Union[RTCRtpTransceiver, RTCSctpTransport], RTCDtlsParameters] + self.__remoteIce = ( + {} + ) # type: Dict[Union[RTCRtpTransceiver, RTCSctpTransport], RTCIceParameters] self.__seenMids = set() # type: Set[str] - self.__sctp = None + self.__sctp = None # type: Optional[RTCSctpTransport] self.__sctp_mline_index = None # type: Optional[int] self._sctpLegacySdp = True - self.__sctpRemotePort = None + self.__sctpRemotePort = None # type: Optional[int] self.__sctpRemoteCaps = None self.__stream_id = str(uuid.uuid4()) self.__transceivers = [] # type: List[RTCRtpTransceiver] @@ -698,7 +713,9 @@ async def setLocalDescription(self, sessionDescription): else: self.__pendingLocalDescription = description - async def setRemoteDescription(self, sessionDescription): + async def setRemoteDescription( + self, sessionDescription: RTCSessionDescription + ) -> None: """ Changes the remote description associated with the connection. @@ -844,7 +861,7 @@ async def setRemoteDescription(self, sessionDescription): else: self.__pendingRemoteDescription = description - async def __connect(self): + async def __connect(self) -> None: for transceiver in self.__transceivers: dtlsTransport = transceiver._transport iceTransport = dtlsTransport.transport @@ -890,7 +907,7 @@ def __assertTrackHasNoSender(self, track): if sender.track == track: raise InvalidAccessError("Track already has a sender") - def __createDtlsTransport(self): + def __createDtlsTransport(self) -> RTCDtlsTransport: # create ICE transport iceGatherer = RTCIceGatherer(iceServers=self.__configuration.iceServers) iceGatherer.on("statechange", self.__updateIceGatheringState) @@ -904,7 +921,7 @@ def __createDtlsTransport(self): return RTCDtlsTransport(iceTransport, self.__certificates) - def __createSctpTransport(self): + def __createSctpTransport(self) -> None: self.__sctp = RTCSctpTransport(self.__createDtlsTransport()) self.__sctp._bundled = False self.__sctp.mid = None @@ -913,7 +930,9 @@ def __createSctpTransport(self): def on_datachannel(channel): self.emit("datachannel", channel) - def __createTransceiver(self, direction, kind, sender_track=None): + def __createTransceiver( + self, direction: str, kind: str, sender_track=None + ) -> RTCRtpTransceiver: dtlsTransport = self.__createDtlsTransport() transceiver = RTCRtpTransceiver( direction=direction, @@ -939,8 +958,8 @@ def __getTransceiverByMLineIndex(self, index: int) -> Optional[RTCRtpTransceiver def __localDescription(self) -> Optional[sdp.SessionDescription]: return self.__pendingLocalDescription or self.__currentLocalDescription - def __localRtp(self, transceiver: RTCRtpTransceiver) -> RTCRtpParameters: - rtp = RTCRtpParameters( + def __localRtp(self, transceiver: RTCRtpTransceiver) -> RTCRtpSendParameters: + rtp = RTCRtpSendParameters( codecs=transceiver._codecs, headerExtensions=transceiver._headerExtensions, muxId=transceiver.mid, @@ -980,11 +999,11 @@ def __remoteRtp(self, transceiver: RTCRtpTransceiver) -> RTCRtpReceiveParameters receiveParameters.encodings = list(encodings.values()) return receiveParameters - def __setSignalingState(self, state): + def __setSignalingState(self, state: str) -> None: self.__signalingState = state self.emit("signalingstatechange") - def __updateIceConnectionState(self): + def __updateIceConnectionState(self) -> None: # compute new state states = set(map(lambda x: x.state, self.__iceTransports)) if self.__isClosed: @@ -1003,7 +1022,7 @@ def __updateIceConnectionState(self): self.__iceConnectionState = state self.emit("iceconnectionstatechange") - def __updateIceGatheringState(self): + def __updateIceGatheringState(self) -> None: # compute new state states = set(map(lambda x: x.iceGatherer.state, self.__iceTransports)) if states == set(["completed"]): @@ -1018,7 +1037,9 @@ def __updateIceGatheringState(self): self.__iceGatheringState = state self.emit("icegatheringstatechange") - def __validate_description(self, description, is_local): + def __validate_description( + self, description: sdp.SessionDescription, is_local: bool + ) -> None: # check description is compatible with signaling state if is_local: if description.type == "offer": diff --git a/aiortc/rtcrtpreceiver.py b/aiortc/rtcrtpreceiver.py index c02ecf4af..9616317a5 100644 --- a/aiortc/rtcrtpreceiver.py +++ b/aiortc/rtcrtpreceiver.py @@ -5,7 +5,7 @@ import random import threading import time -from typing import Optional, Set +from typing import Dict, Optional, Set import attr @@ -15,7 +15,12 @@ from .jitterbuffer import JitterBuffer from .mediastreams import MediaStreamError, MediaStreamTrack from .rate import RemoteBitrateEstimator -from .rtcrtpparameters import RTCRtpReceiveParameters +from .rtcdtlstransport import RTCDtlsTransport +from .rtcrtpparameters import ( + RTCRtpCapabilities, + RTCRtpCodecParameters, + RTCRtpReceiveParameters, +) from .rtp import ( RTCP_PSFB_APP, RTCP_PSFB_PLI, @@ -231,14 +236,14 @@ class RTCRtpReceiver: :param: transport: An :class:`RTCDtlsTransport`. """ - def __init__(self, kind, transport): + def __init__(self, kind: str, transport) -> None: if transport.state == "closed": raise InvalidStateError - self.__active_ssrc = {} - self.__codecs = {} + self.__active_ssrc = {} # type: Dict[int, datetime.datetime] + self.__codecs = {} # type: Dict[int, RTCRtpCodecParameters] self.__decoder_queue = queue.Queue() - self.__decoder_thread = None + self.__decoder_thread = None # type: Optional[threading.Thread] self.__kind = kind if kind == "audio": self.__jitter_buffer = JitterBuffer(capacity=16, prefetch=4) @@ -250,18 +255,18 @@ def __init__(self, kind, transport): self.__remote_bitrate_estimator = RemoteBitrateEstimator() self._track = None self.__rtcp_exited = asyncio.Event() - self.__rtcp_task = None - self.__rtx_ssrc = {} + self.__rtcp_task = None # type: Optional[asyncio.Future[None]] + self.__rtx_ssrc = {} # type: Dict[int, int] self.__started = False self.__stats = RTCStatsReport() self.__timestamp_mapper = TimestampMapper() self.__transport = transport # RTCP - self.__lsr = {} - self.__lsr_time = {} - self.__remote_streams = {} - self.__rtcp_ssrc = None + self.__lsr = {} # type: Dict[int, int] + self.__lsr_time = {} # type: Dict[int, float] + self.__remote_streams = {} # type: Dict[int, StreamStatistics] + self.__rtcp_ssrc = None # type: Optional[int] @property def transport(self): @@ -272,7 +277,7 @@ def transport(self): return self.__transport @classmethod - def getCapabilities(self, kind): + def getCapabilities(self, kind) -> Optional[RTCRtpCapabilities]: """ Returns the most optimistic view of the system's capabilities for receiving media of the given `kind`. @@ -323,7 +328,7 @@ def getSynchronizationSources(self): ) return sources - async def receive(self, parameters: RTCRtpReceiveParameters): + async def receive(self, parameters: RTCRtpReceiveParameters) -> None: """ Attempt to set the parameters controlling the receiving of media. @@ -352,7 +357,7 @@ async def receive(self, parameters: RTCRtpReceiveParameters): self.__rtcp_task = asyncio.ensure_future(self._run_rtcp()) self.__started = True - def setTransport(self, transport): + def setTransport(self, transport: RTCDtlsTransport) -> None: self.__transport = transport async def stop(self): @@ -550,7 +555,7 @@ async def _send_rtcp_pli(self, media_ssrc): ) await self._send_rtcp(packet) - def _set_rtcp_ssrc(self, ssrc): + def _set_rtcp_ssrc(self, ssrc: int) -> None: self.__rtcp_ssrc = ssrc def __stop_decoder(self): diff --git a/aiortc/rtcrtpsender.py b/aiortc/rtcrtpsender.py index b9a597cc9..e5f02bc5d 100644 --- a/aiortc/rtcrtpsender.py +++ b/aiortc/rtcrtpsender.py @@ -3,7 +3,7 @@ import random import time import uuid -from typing import List +from typing import Dict, List, Optional from . import clock, rtp from .codecs import get_capabilities, get_encoder, is_rtx @@ -51,7 +51,7 @@ class RTCRtpSender: :param: transport: An :class:`RTCDtlsTransport`. """ - def __init__(self, trackOrKind, transport): + def __init__(self, trackOrKind, transport) -> None: if transport.state == "closed": raise InvalidStateError @@ -61,7 +61,7 @@ def __init__(self, trackOrKind, transport): else: self.__kind = trackOrKind self.replaceTrack(None) - self.__cname = None + self.__cname = None # type: Optional[str] self._ssrc = random32() self._rtx_ssrc = random32() # FIXME: how should this be initialised? @@ -69,22 +69,22 @@ def __init__(self, trackOrKind, transport): self.__encoder = None self.__force_keyframe = False self.__loop = asyncio.get_event_loop() - self.__mid = None + self.__mid = None # type: Optional[str] self.__rtp_exited = asyncio.Event() self.__rtp_header_extensions_map = rtp.HeaderExtensionsMap() - self.__rtp_task = None - self.__rtp_history = {} + self.__rtp_task = None # type: Optional[asyncio.Future[None]] + self.__rtp_history = {} # type: Dict[int, RtpPacket] self.__rtcp_exited = asyncio.Event() - self.__rtcp_task = None - self.__rtx_payload_type = None + self.__rtcp_task = None # type: Optional[asyncio.Future[None]] + self.__rtx_payload_type = None # type: Optional[int] self.__rtx_sequence_number = random16() self.__started = False self.__stats = RTCStatsReport() self.__transport = transport # stats - self.__lsr = None - self.__lsr_time = None + self.__lsr = None # type: Optional[int] + self.__lsr_time = None # type: Optional[float] self.__ntp_timestamp = 0 self.__rtp_timestamp = 0 self.__octet_count = 0 @@ -147,14 +147,14 @@ async def getStats(self): return self.__stats - def replaceTrack(self, track): + def replaceTrack(self, track) -> None: self.__track = track if track is not None: self._track_id = track.id else: self._track_id = str(uuid.uuid4()) - def setTransport(self, transport): + def setTransport(self, transport) -> None: self.__transport = transport async def send(self, parameters: RTCRtpSendParameters) -> None: diff --git a/aiortc/rtcrtptransceiver.py b/aiortc/rtcrtptransceiver.py index 523e0c832..9a809a5be 100644 --- a/aiortc/rtcrtptransceiver.py +++ b/aiortc/rtcrtptransceiver.py @@ -1,8 +1,16 @@ import logging -from typing import Optional - -from aiortc.codecs import get_capabilities -from aiortc.sdp import DIRECTIONS +from typing import List, Optional + +from .codecs import get_capabilities +from .rtcdtlstransport import RTCDtlsTransport +from .rtcrtpparameters import ( + RTCRtpCodecCapability, + RTCRtpCodecParameters, + RTCRtpHeaderExtensionParameters, +) +from .rtcrtpreceiver import RTCRtpReceiver +from .rtcrtpsender import RTCRtpSender +from .sdp import DIRECTIONS logger = logging.getLogger("rtp") @@ -14,7 +22,13 @@ class RTCRtpTransceiver: shared state. """ - def __init__(self, kind, receiver, sender, direction="sendrecv"): + def __init__( + self, + kind: str, + receiver: RTCRtpReceiver, + sender: RTCRtpSender, + direction: str = "sendrecv", + ): self.__direction = direction self.__kind = kind self.__mid = None # type: Optional[str] @@ -23,12 +37,18 @@ def __init__(self, kind, receiver, sender, direction="sendrecv"): self.__sender = sender self.__stopped = False - self._currentDirection = None - self._offerDirection = None - self._preferred_codecs = [] + self._currentDirection = None # type: Optional[str] + self._offerDirection = None # type: Optional[str] + self._preferred_codecs = [] # type: List[RTCRtpCodecCapability] + self._transport = None # type: RTCDtlsTransport + + # FIXME: this is only used by RTCPeerConnection + self._bundled = False + self._codecs = [] # type: List[RTCRtpCodecParameters] + self._headerExtensions = [] # type: List[RTCRtpHeaderExtensionParameters] @property - def currentDirection(self): + def currentDirection(self) -> Optional[str]: """ The currently negotiated direction of the transceiver. @@ -37,7 +57,7 @@ def currentDirection(self): return self._currentDirection @property - def direction(self): + def direction(self) -> str: """ The preferred direction of the transceiver, which will be used in :meth:`RTCPeerConnection.createOffer` and :meth:`RTCPeerConnection.createAnswer`. @@ -47,20 +67,20 @@ def direction(self): return self.__direction @direction.setter - def direction(self, direction): + def direction(self, direction: str) -> None: assert direction in DIRECTIONS self.__direction = direction @property - def kind(self): + def kind(self) -> str: return self.__kind @property - def mid(self): + def mid(self) -> Optional[str]: return self.__mid @property - def receiver(self): + def receiver(self) -> RTCRtpReceiver: """ The :class:`RTCRtpReceiver` that handles receiving and decoding incoming media. @@ -68,7 +88,7 @@ def receiver(self): return self.__receiver @property - def sender(self): + def sender(self) -> RTCRtpSender: """ The :class:`RTCRtpSender` responsible for encoding and sending data to the remote peer. @@ -76,10 +96,10 @@ def sender(self): return self.__sender @property - def stopped(self): + def stopped(self) -> bool: return self.__stopped - def setCodecPreferences(self, codecs): + def setCodecPreferences(self, codecs: List[RTCRtpCodecCapability]) -> None: """ Override the default codec preferences. @@ -93,7 +113,7 @@ def setCodecPreferences(self, codecs): self._preferred_codecs = [] capabilities = get_capabilities(self.kind).codecs - unique = [] + unique = [] # type: List[RTCRtpCodecCapability] for codec in reversed(codecs): if codec not in capabilities: raise ValueError("Codec is not in capabilities") diff --git a/aiortc/rtcsctptransport.py b/aiortc/rtcsctptransport.py index c34a5a75c..b8ec48889 100644 --- a/aiortc/rtcsctptransport.py +++ b/aiortc/rtcsctptransport.py @@ -77,7 +77,7 @@ WEBRTC_BINARY_EMPTY = 57 -def chunk_type(chunk): +def chunk_type(chunk) -> str: return chunk.__class__.__name__ @@ -91,7 +91,7 @@ def decode_params(body: bytes) -> List[Tuple[int, bytes]]: return params -def encode_params(params): +def encode_params(params: List[Tuple[int, bytes]]) -> bytes: body = b"" padding = b"" for param_type, param_value in params: @@ -107,26 +107,26 @@ def padl(l: int) -> int: return 4 - m if m else 0 -def tsn_minus_one(a): +def tsn_minus_one(a: int) -> int: return (a - 1) % SCTP_TSN_MODULO -def tsn_plus_one(a): +def tsn_plus_one(a: int) -> int: return (a + 1) % SCTP_TSN_MODULO class Chunk: - def __init__(self, flags=0, body=b""): + def __init__(self, flags: int = 0, body: bytes = b"") -> None: self.flags = flags self.body = body - def __bytes__(self): + def __bytes__(self) -> bytes: body = self.body data = pack("!BBH", self.type, self.flags, len(body) + 4) + body data += b"\x00" * padl(len(body)) return data - def __repr__(self): + def __repr__(self) -> str: return "%s(flags=%d)" % (chunk_type(self), self.flags) @@ -139,7 +139,7 @@ def __init__(self, flags: int = 0, body: Optional[bytes] = None) -> None: self.params = [] @property - def body(self): + def body(self) -> bytes: return encode_params(self.params) @@ -158,7 +158,7 @@ class CookieEchoChunk(Chunk): class DataChunk: type = 0 - def __init__(self, flags=0, body=None): + def __init__(self, flags: int = 0, body: Optional[bytes] = None) -> None: self.flags = flags if body: (self.tsn, self.stream_id, self.stream_seq, self.protocol) = unpack_from( @@ -207,9 +207,9 @@ class ErrorChunk(BaseParamsChunk): class ForwardTsnChunk(Chunk): type = 192 - def __init__(self, flags=0, body=None): + def __init__(self, flags: int = 0, body: Optional[bytes] = None) -> None: self.flags = flags - self.streams = [] + self.streams = [] # type: List[Tuple[int, int]] if body: self.cumulative_tsn = unpack_from("!L", body, 0)[0] pos = 4 @@ -220,13 +220,13 @@ def __init__(self, flags=0, body=None): self.cumulative_tsn = 0 @property - def body(self): + def body(self) -> bytes: body = pack("!L", self.cumulative_tsn) for stream_id, stream_seq in self.streams: body += pack("!HH", stream_id, stream_seq) return body - def __repr__(self): + def __repr__(self) -> str: return "ForwardTsnChunk(cumulative_tsn=%d, streams=%s)" % ( self.cumulative_tsn, self.streams, @@ -242,7 +242,7 @@ class HeartbeatAckChunk(BaseParamsChunk): class BaseInitChunk(Chunk): - def __init__(self, flags=0, body=None): + def __init__(self, flags: int = 0, body: Optional[bytes] = None) -> None: self.flags = flags if body: ( @@ -262,7 +262,7 @@ def __init__(self, flags=0, body=None): self.params = [] @property - def body(self): + def body(self) -> bytes: body = pack( "!LLHHL", self.initiate_tag, @@ -312,7 +312,7 @@ def __init__(self, flags=0, body=None): self.cumulative_tsn = 0 self.advertised_rwnd = 0 - def __bytes__(self): + def __bytes__(self) -> bytes: length = 16 + 4 * (len(self.gaps) + len(self.duplicates)) data = pack( "!BBHLLHH", @@ -330,7 +330,7 @@ def __bytes__(self): data += pack("!L", tsn) return data - def __repr__(self): + def __repr__(self) -> str: return "SackChunk(flags=%d, advertised_rwnd=%d, cumulative_tsn=%d, gaps=%s)" % ( self.flags, self.advertised_rwnd, @@ -350,7 +350,7 @@ def __init__(self, flags=0, body=None): self.cumulative_tsn = 0 @property - def body(self): + def body(self) -> bytes: return pack("!L", self.cumulative_tsn) def __repr__(self): @@ -412,7 +412,9 @@ def parse_packet(data: bytes) -> Tuple[int, int, int, List[Any]]: return source_port, destination_port, verification_tag, chunks -def serialize_packet(source_port, destination_port, verification_tag, chunk): +def serialize_packet( + source_port: int, destination_port: int, verification_tag: int, chunk: Chunk +) -> bytes: header = pack("!HHL", source_port, destination_port, verification_tag) data = bytes(chunk) checksum = crc32(header + b"\x00\x00\x00\x00" + data) @@ -646,7 +648,11 @@ def __init__(self, transport, port=5000): # timers self._rto = SCTP_RTO_INITIAL + self._t1_chunk = None # type: Optional[Chunk] + self._t1_failures = 0 self._t1_handle = None + self._t2_chunk = None # type: Optional[Chunk] + self._t2_failures = 0 self._t2_handle = None self._t3_handle = None @@ -655,6 +661,10 @@ def __init__(self, transport, port=5000): self._data_channel_queue = deque() self._data_channels = {} + # FIXME: this is only used by RTCPeerConnection + self._bundled = False + self.mid = None # type: Optional[str] + @property def is_server(self) -> bool: return self.transport.transport.role != "controlling" @@ -689,10 +699,10 @@ def getCapabilities(cls) -> RTCSctpCapabilities: """ return RTCSctpCapabilities(maxMessageSize=65536) - def setTransport(self, transport): + def setTransport(self, transport) -> None: self.__transport = transport - async def start(self, remoteCaps, remotePort): + async def start(self, remoteCaps: RTCSctpCapabilities, remotePort: int) -> None: """ Start the transport. """ @@ -717,7 +727,7 @@ async def start(self, remoteCaps, remotePort): if not self.is_server: await self._init() - async def stop(self): + async def stop(self) -> None: """ Stop the transport. """ @@ -726,7 +736,7 @@ async def stop(self): self.__transport._unregister_data_receiver(self) self._set_state(self.State.CLOSED) - async def _abort(self): + async def _abort(self) -> None: """ Abort the association. """ @@ -736,7 +746,7 @@ async def _abort(self): except ConnectionError: pass - async def _init(self): + async def _init(self) -> None: """ Initialize the association. """ @@ -753,10 +763,10 @@ async def _init(self): self._t1_start(chunk) self._set_state(self.State.COOKIE_WAIT) - def _flight_size_decrease(self, chunk): + def _flight_size_decrease(self, chunk: Chunk) -> None: self._flight_size = max(0, self._flight_size - chunk._book_size) - def _flight_size_increase(self, chunk): + def _flight_size_increase(self, chunk: Chunk) -> None: self._flight_size += chunk._book_size def _get_extensions(self, params): @@ -769,7 +779,7 @@ def _get_extensions(self, params): elif k == SCTP_SUPPORTED_CHUNK_EXT: self._remote_extensions = list(v) - def _set_extensions(self, params): + def _set_extensions(self, params: List[Tuple[int, bytes]]) -> None: """ Sets what extensions are supported by the local party. """ @@ -824,7 +834,7 @@ async def _handle_data(self, data): if self._sack_needed: await self._send_sack() - def _maybe_abandon(self, chunk): + def _maybe_abandon(self, chunk: Chunk) -> bool: """ Determine if a chunk needs to be marked as abandoned. @@ -1243,13 +1253,13 @@ async def _receive_reconfig_param(self, param): async def _send( self, - stream_id, - pp_id, - user_data, - expiry=None, - max_retransmits=None, - ordered=True, - ): + stream_id: int, + pp_id: int, + user_data: bytes, + expiry: Optional[float] = None, + max_retransmits: Optional[int] = None, + ordered: bool = True, + ) -> None: """ Send data ULP -> stream. """ @@ -1297,7 +1307,7 @@ async def _send( if not self._t3_handle: await self._transmit() - async def _send_chunk(self, chunk): + async def _send_chunk(self, chunk: Chunk) -> None: """ Transmit a chunk (no bundling for now). """ @@ -1347,7 +1357,7 @@ async def _send_sack(self): self._sack_duplicates.clear() self._sack_needed = False - def _set_state(self, state): + def _set_state(self, state) -> None: """ Transition the SCTP association to a new state. """ @@ -1377,14 +1387,14 @@ def _set_state(self, state): # timers - def _t1_cancel(self): + def _t1_cancel(self) -> None: if self._t1_handle is not None: self.__log_debug("- T1(%s) cancel", chunk_type(self._t1_chunk)) self._t1_handle.cancel() self._t1_handle = None self._t1_chunk = None - def _t1_expired(self): + def _t1_expired(self) -> None: self._t1_failures += 1 self._t1_handle = None self.__log_debug( @@ -1396,21 +1406,21 @@ def _t1_expired(self): asyncio.ensure_future(self._send_chunk(self._t1_chunk)) self._t1_handle = self._loop.call_later(self._rto, self._t1_expired) - def _t1_start(self, chunk): + def _t1_start(self, chunk: Chunk) -> None: assert self._t1_handle is None self._t1_chunk = chunk self._t1_failures = 0 self.__log_debug("- T1(%s) start", chunk_type(self._t1_chunk)) self._t1_handle = self._loop.call_later(self._rto, self._t1_expired) - def _t2_cancel(self): + def _t2_cancel(self) -> None: if self._t2_handle is not None: self.__log_debug("- T2(%s) cancel", chunk_type(self._t2_chunk)) self._t2_handle.cancel() self._t2_handle = None self._t2_chunk = None - def _t2_expired(self): + def _t2_expired(self) -> None: self._t2_failures += 1 self._t2_handle = None self.__log_debug( @@ -1422,14 +1432,14 @@ def _t2_expired(self): asyncio.ensure_future(self._send_chunk(self._t2_chunk)) self._t2_handle = self._loop.call_later(self._rto, self._t2_expired) - def _t2_start(self, chunk): + def _t2_start(self, chunk) -> None: assert self._t2_handle is None self._t2_chunk = chunk self._t2_failures = 0 self.__log_debug("- T2(%s) start", chunk_type(self._t2_chunk)) self._t2_handle = self._loop.call_later(self._rto, self._t2_expired) - def _t3_expired(self): + def _t3_expired(self) -> None: self._t3_handle = None self.__log_debug("x T3 expired") @@ -1449,25 +1459,25 @@ def _t3_expired(self): asyncio.ensure_future(self._transmit()) - def _t3_restart(self): + def _t3_restart(self) -> None: self.__log_debug("- T3 restart") if self._t3_handle is not None: self._t3_handle.cancel() self._t3_handle = None self._t3_handle = self._loop.call_later(self._rto, self._t3_expired) - def _t3_start(self): + def _t3_start(self) -> None: assert self._t3_handle is None self.__log_debug("- T3 start") self._t3_handle = self._loop.call_later(self._rto, self._t3_expired) - def _t3_cancel(self): + def _t3_cancel(self) -> None: if self._t3_handle is not None: self.__log_debug("- T3 cancel") self._t3_handle.cancel() self._t3_handle = None - async def _transmit(self): + async def _transmit(self) -> None: """ Transmit outbound data. """ @@ -1535,7 +1545,7 @@ async def _transmit_reconfig(self): await self._send_reconfig_param(param) - def _update_advanced_peer_ack_point(self): + def _update_advanced_peer_ack_point(self) -> None: """ Try to advance "Advanced.Peer.Ack.Point" according to RFC 3758. """ @@ -1581,11 +1591,11 @@ def _data_channel_close(self, channel, transmit=True): if len(self._reconfig_queue) == 1: asyncio.ensure_future(self._transmit_reconfig()) - def _data_channel_closed(self, stream_id): + def _data_channel_closed(self, stream_id: int) -> None: channel = self._data_channels.pop(stream_id) channel._setReadyState("closed") - async def _data_channel_flush(self): + async def _data_channel_flush(self) -> None: """ Try to flush buffered data to the SCTP layer. @@ -1627,7 +1637,7 @@ async def _data_channel_flush(self): ) channel._addBufferedAmount(-len(user_data)) - def _data_channel_add_negotiated(self, channel): + def _data_channel_add_negotiated(self, channel: RTCDataChannel) -> None: if channel.id in self._data_channels: raise ValueError("Data channel with ID %d already registered" % channel.id) @@ -1636,7 +1646,7 @@ def _data_channel_add_negotiated(self, channel): if self._association_state == self.State.ESTABLISHED: channel._setReadyState("open") - def _data_channel_open(self, channel): + def _data_channel_open(self, channel: RTCDataChannel) -> None: if channel.id is not None: if channel.id in self._data_channels: raise ValueError( diff --git a/aiortc/rtp.py b/aiortc/rtp.py index a4a3a9de7..f57c75423 100644 --- a/aiortc/rtp.py +++ b/aiortc/rtp.py @@ -206,7 +206,7 @@ def unpack_remb_fci(data): return (bitrate, ssrcs) -def is_rtcp(msg): +def is_rtcp(msg: bytes) -> bool: return len(msg) >= 2 and msg[1] >= 192 and msg[1] <= 208 @@ -314,14 +314,14 @@ class RtcpReceiverInfo: lsr = attr.ib() dlsr = attr.ib() - def __bytes__(self): + def __bytes__(self) -> bytes: data = pack("!LB", self.ssrc, self.fraction_lost) data += pack_packets_lost(self.packets_lost) data += pack("!LLLL", self.highest_sequence, self.jitter, self.lsr, self.dlsr) return data @classmethod - def parse(cls, data): + def parse(cls, data: bytes): ssrc, fraction_lost = unpack("!LB", data[0:5]) packets_lost = unpack_packets_lost(data[5:8]) highest_sequence, jitter, lsr, dlsr = unpack("!LLLL", data[8:]) @@ -338,12 +338,12 @@ def parse(cls, data): @attr.s class RtcpSenderInfo: - ntp_timestamp = attr.ib() - rtp_timestamp = attr.ib() - packet_count = attr.ib() - octet_count = attr.ib() + ntp_timestamp = attr.ib() # type: int + rtp_timestamp = attr.ib() # type: int + packet_count = attr.ib() # type: int + octet_count = attr.ib() # type: int - def __bytes__(self): + def __bytes__(self) -> bytes: return pack( "!QLLL", self.ntp_timestamp, @@ -353,7 +353,7 @@ def __bytes__(self): ) @classmethod - def parse(cls, data): + def parse(cls, data: bytes): ntp_timestamp, rtp_timestamp, packet_count, octet_count = unpack("!QLLL", data) return cls( ntp_timestamp=ntp_timestamp, @@ -369,63 +369,16 @@ class RtcpSourceInfo: items = attr.ib() -class RtcpPacket: - @classmethod - def parse(cls, data): - pos = 0 - packets = [] - - while pos < len(data): - if len(data) < pos + RTCP_HEADER_LENGTH: - raise ValueError( - "RTCP packet length is less than %d bytes" % RTCP_HEADER_LENGTH - ) - - v_p_count, packet_type, length = unpack("!BBH", data[pos : pos + 4]) - version = v_p_count >> 6 - padding = (v_p_count >> 5) & 1 - count = v_p_count & 0x1F - if version != 2: - raise ValueError("RTCP packet has invalid version") - pos += 4 - - end = pos + length * 4 - if len(data) < end: - raise ValueError("RTCP packet is truncated") - payload = data[pos:end] - pos = end - - if padding: - if not payload or not payload[-1] or payload[-1] > len(payload): - raise ValueError("RTCP packet padding length is invalid") - payload = payload[0 : -payload[-1]] - - if packet_type == RTCP_BYE: - packets.append(RtcpByePacket.parse(payload, count)) - elif packet_type == RTCP_SDES: - packets.append(RtcpSdesPacket.parse(payload, count)) - elif packet_type == RTCP_SR: - packets.append(RtcpSrPacket.parse(payload, count)) - elif packet_type == RTCP_RR: - packets.append(RtcpRrPacket.parse(payload, count)) - elif packet_type == RTCP_RTPFB: - packets.append(RtcpRtpfbPacket.parse(payload, count)) - elif packet_type == RTCP_PSFB: - packets.append(RtcpPsfbPacket.parse(payload, count)) - - return packets - - @attr.s class RtcpByePacket: sources = attr.ib() - def __bytes__(self): + def __bytes__(self) -> bytes: payload = b"".join([pack("!L", ssrc) for ssrc in self.sources]) return pack_rtcp_packet(RTCP_BYE, len(self.sources), payload) @classmethod - def parse(cls, data, count): + def parse(cls, data: bytes, count: int): if len(data) < 4 * count: raise ValueError("RTCP bye length is invalid") if count > 0: @@ -446,12 +399,12 @@ class RtcpPsfbPacket: media_ssrc = attr.ib() fci = attr.ib(default=b"") - def __bytes__(self): + def __bytes__(self) -> bytes: payload = pack("!LL", self.ssrc, self.media_ssrc) + self.fci return pack_rtcp_packet(RTCP_PSFB, self.fmt, payload) @classmethod - def parse(cls, data, fmt): + def parse(cls, data: bytes, fmt: int): if len(data) < 8: raise ValueError("RTCP payload-specific feedback length is invalid") @@ -462,17 +415,17 @@ def parse(cls, data, fmt): @attr.s class RtcpRrPacket: - ssrc = attr.ib() - reports = attr.ib(default=attr.Factory(list)) + ssrc = attr.ib() # type: int + reports = attr.ib(default=attr.Factory(list)) # type: List[RtcpReceiverInfo] - def __bytes__(self): + def __bytes__(self) -> bytes: payload = pack("!L", self.ssrc) for report in self.reports: payload += bytes(report) return pack_rtcp_packet(RTCP_RR, len(self.reports), payload) @classmethod - def parse(cls, data, count): + def parse(cls, data: bytes, count: int): if len(data) != 4 + 24 * count: raise ValueError("RTCP receiver report length is invalid") @@ -496,9 +449,9 @@ class RtcpRtpfbPacket: media_ssrc = attr.ib() # generick NACK - lost = attr.ib(default=attr.Factory(list)) + lost = attr.ib(default=attr.Factory(list)) # type: List[int] - def __bytes__(self): + def __bytes__(self) -> bytes: payload = pack("!LL", self.ssrc, self.media_ssrc) if self.lost: pid = self.lost[0] @@ -515,7 +468,7 @@ def __bytes__(self): return pack_rtcp_packet(RTCP_RTPFB, self.fmt, payload) @classmethod - def parse(cls, data, fmt): + def parse(cls, data: bytes, fmt: int): if len(data) < 8 or len(data) % 4: raise ValueError("RTCP RTP feedback length is invalid") @@ -532,9 +485,9 @@ def parse(cls, data, fmt): @attr.s class RtcpSdesPacket: - chunks = attr.ib(default=attr.Factory(list)) + chunks = attr.ib(default=attr.Factory(list)) # type: List[RtcpSourceInfo] - def __bytes__(self): + def __bytes__(self) -> bytes: payload = b"" for chunk in self.chunks: payload += pack("!L", chunk.ssrc) @@ -546,7 +499,7 @@ def __bytes__(self): return pack_rtcp_packet(RTCP_SDES, len(self.chunks), payload) @classmethod - def parse(cls, data, count): + def parse(cls, data: bytes, count: int): pos = 0 chunks = [] for r in range(count): @@ -574,11 +527,11 @@ def parse(cls, data, count): @attr.s class RtcpSrPacket: - ssrc = attr.ib() - sender_info = attr.ib() - reports = attr.ib(default=attr.Factory(list)) + ssrc = attr.ib() # type: int + sender_info = attr.ib() # type: RtcpSenderInfo + reports = attr.ib(default=attr.Factory(list)) # type: List[RtcpReceiverInfo] - def __bytes__(self): + def __bytes__(self) -> bytes: payload = pack("!L", self.ssrc) payload += bytes(self.sender_info) for report in self.reports: @@ -586,7 +539,7 @@ def __bytes__(self): return pack_rtcp_packet(RTCP_SR, len(self.reports), payload) @classmethod - def parse(cls, data, count): + def parse(cls, data: bytes, count: int): if len(data) != 24 + 24 * count: raise ValueError("RTCP sender report length is invalid") @@ -600,6 +553,63 @@ def parse(cls, data, count): return RtcpSrPacket(ssrc=ssrc, sender_info=sender_info, reports=reports) +AnyRtcpPacket = Union[ + RtcpByePacket, + RtcpPsfbPacket, + RtcpRrPacket, + RtcpRtpfbPacket, + RtcpSdesPacket, + RtcpSrPacket, +] + + +class RtcpPacket: + @classmethod + def parse(cls, data: bytes) -> List[AnyRtcpPacket]: + pos = 0 + packets = [] + + while pos < len(data): + if len(data) < pos + RTCP_HEADER_LENGTH: + raise ValueError( + "RTCP packet length is less than %d bytes" % RTCP_HEADER_LENGTH + ) + + v_p_count, packet_type, length = unpack("!BBH", data[pos : pos + 4]) + version = v_p_count >> 6 + padding = (v_p_count >> 5) & 1 + count = v_p_count & 0x1F + if version != 2: + raise ValueError("RTCP packet has invalid version") + pos += 4 + + end = pos + length * 4 + if len(data) < end: + raise ValueError("RTCP packet is truncated") + payload = data[pos:end] + pos = end + + if padding: + if not payload or not payload[-1] or payload[-1] > len(payload): + raise ValueError("RTCP packet padding length is invalid") + payload = payload[0 : -payload[-1]] + + if packet_type == RTCP_BYE: + packets.append(RtcpByePacket.parse(payload, count)) + elif packet_type == RTCP_SDES: + packets.append(RtcpSdesPacket.parse(payload, count)) + elif packet_type == RTCP_SR: + packets.append(RtcpSrPacket.parse(payload, count)) + elif packet_type == RTCP_RR: + packets.append(RtcpRrPacket.parse(payload, count)) + elif packet_type == RTCP_RTPFB: + packets.append(RtcpRtpfbPacket.parse(payload, count)) + elif packet_type == RTCP_PSFB: + packets.append(RtcpPsfbPacket.parse(payload, count)) + + return packets + + class RtpPacket: def __init__( self, @@ -746,6 +756,3 @@ def wrap_rtx( rtx.csrc = packet.csrc rtx.extensions = packet.extensions return rtx - - -AnyRtcpPacket = Union[RtcpByePacket, RtcpSdesPacket, RtcpSrPacket] diff --git a/aiortc/sdp.py b/aiortc/sdp.py index 4ff6c4be3..1ba381af5 100644 --- a/aiortc/sdp.py +++ b/aiortc/sdp.py @@ -133,21 +133,21 @@ def parse_attr(line: str) -> Tuple[str, Optional[str]]: return line[2:], None -def parse_group(dest, value, type=str): - bits = value.split() - if bits: - dest.append(GroupDescription(semantic=bits[0], items=list(map(type, bits[1:])))) - - @attr.s class GroupDescription: - semantic = attr.ib() - items = attr.ib() + semantic = attr.ib() # type: str + items = attr.ib() # List[Union[int, str]] def __str__(self) -> str: return "%s %s" % (self.semantic, " ".join(map(str, self.items))) +def parse_group(dest: List[GroupDescription], value: str, type=str) -> None: + bits = value.split() + if bits: + dest.append(GroupDescription(semantic=bits[0], items=list(map(type, bits[1:])))) + + @attr.s class SsrcDescription: ssrc = attr.ib() @@ -195,7 +195,7 @@ def __init__(self, kind: str, port: int, profile: str, fmt: List[Any]) -> None: self.ice = RTCIceParameters() self.ice_candidates = [] # type: List[RTCIceCandidate] self.ice_candidates_complete = False - self.ice_options = None + self.ice_options = None # type: Optional[str] def __str__(self) -> str: lines = [] @@ -285,19 +285,19 @@ def __init__(self) -> None: self.origin = None # type: Optional[str] self.name = "-" self.time = "0 0" - self.host = None + self.host = None # type: Optional[str] self.group = [] # type: List[GroupDescription] self.msid_semantic = [] # type: List[GroupDescription] self.media = [] # type: List[MediaDescription] self.type = None # type: str @classmethod - def parse(cls, sdp): - current_media = None + def parse(cls, sdp: str): + current_media = None # type: Optional[MediaDescription] dtls_fingerprints = [] ice_options = None - def find_codec(pt): + def find_codec(pt: int) -> RTCRtpCodecParameters: for codec in current_media.rtp.codecs: if codec.payloadType == pt: return codec @@ -339,14 +339,15 @@ def find_codec(pt): # check payload types are valid kind = m.group(1) fmt = m.group(4).split() + fmt_int = None # type: Optional[List[int]] if kind in ["audio", "video"]: - fmt = [int(x) for x in fmt] - for pt in fmt: + fmt_int = [int(x) for x in fmt] + for pt in fmt_int: assert pt >= 0 and pt < 256 assert pt not in rtp.FORBIDDEN_PAYLOAD_TYPES current_media = MediaDescription( - kind=kind, port=int(m.group(2)), profile=m.group(3), fmt=fmt + kind=kind, port=int(m.group(2)), profile=m.group(3), fmt=fmt_int or fmt ) current_media.dtls = RTCDtlsParameters( fingerprints=dtls_fingerprints[:], role=None @@ -402,7 +403,6 @@ def find_codec(pt): current_media.direction = attr elif attr == "rtpmap": format_id, format_desc = value.split(" ", 1) - format_id = int(format_id) bits = format_desc.split("/") if current_media.kind == "audio": if len(bits) > 2: @@ -426,8 +426,8 @@ def find_codec(pt): elif attr == "ssrc-group": parse_group(current_media.ssrc_group, value, type=int) elif attr == "ssrc": - ssrc, ssrc_desc = value.split(" ", 1) - ssrc = int(ssrc) + ssrc_str, ssrc_desc = value.split(" ", 1) + ssrc = int(ssrc_str) ssrc_attr, ssrc_value = ssrc_desc.split(":") try: diff --git a/stubs/audioop.pyi b/stubs/audioop.pyi index d75923cb0..d39a91260 100644 --- a/stubs/audioop.pyi +++ b/stubs/audioop.pyi @@ -3,7 +3,9 @@ from typing import Any, Optional, Tuple class error(Exception): ... def add(fragment1: bytes, fragment2: bytes, width: int) -> bytes: ... -def adpcm2lin(adpcmfragment: bytes, width: int, state: Optional[Tuple]) -> Tuple[bytes, Tuple]: ... +def adpcm2lin( + adpcmfragment: bytes, width: int, state: Optional[Tuple] +) -> Tuple[bytes, Tuple]: ... def alaw2lin(fragment: bytes, width: int) -> bytes: ... def avg(fragment: bytes, width: int) -> int: ... def avgpp(fragment: bytes, width: int) -> int: ... @@ -14,7 +16,9 @@ def findfactor(fragment: bytes, reference: bytes) -> float: ... def findfit(fragment: bytes, reference: bytes) -> Tuple[int, float]: ... def findmax(fragment: bytes, length: int) -> int: ... def getsample(fragment: bytes, width: int, index: int) -> int: ... -def lin2adpcm(fragment: bytes, width: int, state: Optional[Tuple]) -> Tuple[bytes, Tuple]: ... +def lin2adpcm( + fragment: bytes, width: int, state: Optional[Tuple] +) -> Tuple[bytes, Tuple]: ... def lin2alaw(fragment: bytes, width: int) -> bytes: ... def lin2lin(fragment: bytes, width: int, newwidth: int) -> bytes: ... def lin2ulaw(fragment: bytes, width: int) -> bytes: ... diff --git a/stubs/pyee.pyi b/stubs/pyee.pyi index 3e9be65a2..0da4ab7a3 100644 --- a/stubs/pyee.pyi +++ b/stubs/pyee.pyi @@ -1,2 +1,6 @@ +from typing import Callable + class AsyncIOEventEmitter: - def emit(self, name: str) -> None: ... + def emit(self, name: str, *args) -> None: ... + def on(self, name: str, cb: Callable[..., None] = ...) -> Callable[..., None]: ... + def remove_all_listeners(self) -> None: ... diff --git a/tests/test_rtcrtpreceiver.py b/tests/test_rtcrtpreceiver.py index 40ddccd95..80305ffc9 100644 --- a/tests/test_rtcrtpreceiver.py +++ b/tests/test_rtcrtpreceiver.py @@ -284,8 +284,8 @@ def test_capabilities(self): ) # bogus - capabilities = RTCRtpReceiver.getCapabilities("bogus") - self.assertIsNone(capabilities) + with self.assertRaises(ValueError): + RTCRtpReceiver.getCapabilities("bogus") def test_connection_error(self): """ diff --git a/tests/test_rtcrtpsender.py b/tests/test_rtcrtpsender.py index b6b9497cb..814a20883 100644 --- a/tests/test_rtcrtpsender.py +++ b/tests/test_rtcrtpsender.py @@ -115,8 +115,8 @@ def test_capabilities(self): ) # bogus - capabilities = RTCRtpSender.getCapabilities("bogus") - self.assertIsNone(capabilities) + with self.assertRaises(ValueError): + RTCRtpSender.getCapabilities("bogus") def test_construct(self): sender = RTCRtpSender("audio", self.local_transport) diff --git a/tests/utils.py b/tests/utils.py index bca71743e..d84ede45e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -92,7 +92,7 @@ def dummy_dtls_transport_pair(): return (dtls_a, dtls_b) -def load(name): +def load(name: str) -> bytes: path = os.path.join(os.path.dirname(__file__), name) with open(path, "rb") as fp: return fp.read()