diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/MediaSourceDesc.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/MediaSourceDesc.kt index e09c7f2ee7..7256df7659 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/MediaSourceDesc.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/MediaSourceDesc.kt @@ -137,13 +137,11 @@ class MediaSourceDesc fun getRtpLayerByQualityIdx(idx: Int): RtpLayerDesc? = layersByIndex[idx] @Synchronized - fun findRtpLayerDesc(videoRtpPacket: VideoRtpPacket): RtpLayerDesc? { + fun findRtpLayerDescs(videoRtpPacket: VideoRtpPacket): Collection { if (ArrayUtils.isNullOrEmpty(rtpEncodings)) { - return null + return emptyList() } - val encodingId = videoRtpPacket.getEncodingId() - val desc = layersById[encodingId] - return desc + return videoRtpPacket.getEncodingIds().mapNotNull { layersById[it] } } @Synchronized @@ -188,10 +186,14 @@ class MediaSourceDesc */ fun Array.copy() = Array(this.size) { i -> this[i].copy() } -fun Array.findRtpLayerDesc(packet: VideoRtpPacket): RtpLayerDesc? { +fun Array.findRtpLayerDescs(packet: VideoRtpPacket): Collection { + return this.flatMap { it.findRtpLayerDescs(packet) } +} + +fun Array.findRtpEncodingId(packet: VideoRtpPacket): Int? { for (source in this) { - source.findRtpLayerDesc(packet)?.let { - return it + source.findRtpEncodingDesc(packet.ssrc)?.let { + return it.eid } } return null diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpEncodingDesc.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpEncodingDesc.kt index 8cfee2f807..570c5dde53 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpEncodingDesc.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpEncodingDesc.kt @@ -73,9 +73,17 @@ constructor( validateLayerEids(initialLayers) } + private var nominalHeight = initialLayers.getNominalHeight() + internal var layers = initialLayers set(newLayers) { validateLayerEids(newLayers) + /* Check if the new layer set is a single spatial layer that doesn't specify a height - if so, we + * want to apply the nominal height to them. + */ + val useNominalHeight = nominalHeight != RtpLayerDesc.NO_HEIGHT && + newLayers.all { it.sid == 0 } && + newLayers.all { it.height == RtpLayerDesc.NO_HEIGHT } /* Copy the rate statistics objects from the old layers to the new layers * with matching layer IDs. */ @@ -89,6 +97,15 @@ constructor( oldLayerMap[newLayer.layerId]?.let { newLayer.inheritFrom(it) } + if (useNominalHeight) { + newLayer.height = nominalHeight + } + } + if (!useNominalHeight) { + val newNominalHeight = newLayers.getNominalHeight() + if (newNominalHeight != RtpLayerDesc.NO_HEIGHT) { + nominalHeight = newNominalHeight + } } field = newLayers } @@ -157,6 +174,7 @@ constructor( addNumber("rtx_ssrc", getSecondarySsrc(SsrcAssociationType.RTX)) addNumber("fec_ssrc", getSecondarySsrc(SsrcAssociationType.FEC)) addNumber("eid", eid) + addNumber("nominal_height", nominalHeight) for (layer in layers) { addBlock(layer.getNodeStats()) } @@ -167,6 +185,23 @@ constructor( } } -fun VideoRtpPacket.getEncodingId(): Long { - return RtpEncodingDesc.calcEncodingId(ssrc, this.layerId) +fun VideoRtpPacket.getEncodingIds(): Collection { + return this.layerIds.map { RtpEncodingDesc.calcEncodingId(ssrc, it) } +} + +/** + * Get the "nominal" height of a set of layers - if they all indicate the same spatial layer and same height. + */ +private fun Array.getNominalHeight(): Int { + if (isEmpty()) { + return RtpLayerDesc.NO_HEIGHT + } + val firstHeight = first().height + if (!(all { it.sid == 0 } || all { it.sid == -1 })) { + return RtpLayerDesc.NO_HEIGHT + } + if (any { it.height != firstHeight }) { + return RtpLayerDesc.NO_HEIGHT + } + return firstHeight } diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpLayerDesc.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpLayerDesc.kt index 2f7019b06d..9cd30d9383 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpLayerDesc.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpLayerDesc.kt @@ -20,106 +20,60 @@ import org.jitsi.nlj.transform.node.incoming.BitrateCalculator import org.jitsi.nlj.util.Bandwidth import org.jitsi.nlj.util.BitrateTracker import org.jitsi.nlj.util.DataSize -import org.jitsi.nlj.util.sum +import org.jitsi.utils.OrderedJsonObject /** * Keeps track of its subjective quality index, * its last stable bitrate and other useful things for adaptivity/routing. * - * Note: this class and [getBitrate] are only open to allow to be overridden for testing. We found that mocking has - * severe overhead and is not suitable for performance tests. - * * @author George Politis */ -open class RtpLayerDesc -@JvmOverloads +abstract class RtpLayerDesc constructor( /** * The index of this instance's encoding in the source encoding array. */ val eid: Int, /** - * The temporal layer ID of this instance, or negative for unknown. + * The temporal layer ID of this instance. */ val tid: Int, /** - * The spatial layer ID of this instance, or negative for unknown. + * The spatial layer ID of this instance. */ val sid: Int, /** * The max height of the bitstream that this instance represents. The actual - * height may be less due to bad network or system load. + * height may be less due to bad network or system load. [NO_HEIGHT] for unknown. * * XXX we should be able to sniff the actual height from the RTP packets. */ - val height: Int, + var height: Int, /** * The max frame rate (in fps) of the bitstream that this instance * represents. The actual frame rate may be less due to bad network or - * system load. + * system load. [NO_FRAME_RATE] for unknown. */ val frameRate: Double, - /** - * The [RtpLayerDesc]s on which this layer definitely depends. - */ - private val dependencyLayers: Array = emptyArray(), - /** - * The [RtpLayerDesc]s on which this layer possibly depends. - * (The intended use case is K-SVC mode.) - */ - private val softDependencyLayers: Array = emptyArray() ) { - init { - require(tid < 8) { "Invalid temporal ID $tid" } - require(sid < 8) { "Invalid spatial ID $sid" } - } - - /** - * Clone an existing layer desc, inheriting its statistics, - * modifying only specific values. - */ - fun copy( - eid: Int = this.eid, - tid: Int = this.tid, - sid: Int = this.sid, - height: Int = this.height, - frameRate: Double = this.frameRate, - dependencyLayers: Array = this.dependencyLayers, - softDependencyLayers: Array = this.softDependencyLayers - ) = RtpLayerDesc(eid, tid, sid, height, frameRate, dependencyLayers, softDependencyLayers).also { - it.inheritFrom(this) - } - - /** - * Whether softDependencyLayers are to be used. - */ - var useSoftDependencies = true + abstract fun copy(height: Int = this.height): RtpLayerDesc /** * The [BitrateTracker] instance used to calculate the receiving bitrate of this RTP layer. */ - private var bitrateTracker = BitrateCalculator.createBitrateTracker() + protected var bitrateTracker = BitrateCalculator.createBitrateTracker() /** * @return the "id" of this layer within this encoding. This is a server-side id and should * not be confused with any encoding id defined in the client (such as the * rid). */ - val layerId = getIndex(0, sid, tid) + abstract val layerId: Int /** * A local index of this track. */ - val index = getIndex(eid, sid, tid) - - /** - * {@inheritDoc} - */ - override fun toString(): String { - return "subjective_quality=" + index + - ",temporal_id=" + tid + - ",spatial_id=" + sid - } + abstract val index: Int /** * Inherit a [BitrateTracker] object @@ -131,9 +85,8 @@ constructor( /** * Inherit another layer description's [BitrateTracker] object. */ - internal fun inheritFrom(other: RtpLayerDesc) { + internal open fun inheritFrom(other: RtpLayerDesc) { inheritStatistics(other.bitrateTracker) - useSoftDependencies = other.useSoftDependencies } /** @@ -152,12 +105,10 @@ constructor( /** * Gets the cumulative bitrate (in bps) of this [RtpLayerDesc] and its dependencies. * - * This is left open for use in testing. - * * @param nowMs * @return the cumulative bitrate (in bps) of this [RtpLayerDesc] and its dependencies. */ - open fun getBitrate(nowMs: Long): Bandwidth = calcBitrate(nowMs).values.sum() + abstract fun getBitrate(nowMs: Long): Bandwidth /** * Expose [getBitrate] as a [Double] in order to make it accessible from java (since [Bandwidth] is an inline @@ -165,67 +116,27 @@ constructor( */ fun getBitrateBps(nowMs: Long): Double = getBitrate(nowMs).bps - /** - * Recursively adds the bitrate (in bps) of this [RTPLayerDesc] and - * its dependencies in the map passed in as an argument. - * - * This is necessary to ensure we don't double-count layers in cases - * of multiple dependencies. - * - * @param nowMs - */ - private fun calcBitrate(nowMs: Long, rates: MutableMap = HashMap()): MutableMap { - if (rates.containsKey(index)) { - return rates - } - rates[index] = bitrateTracker.getRate(nowMs) - - dependencyLayers.forEach { it.calcBitrate(nowMs, rates) } - - if (useSoftDependencies) { - softDependencyLayers.forEach { it.calcBitrate(nowMs, rates) } - } - - return rates - } - - /** - * Returns true if this layer, alone, has a zero bitrate. - */ - private fun layerHasZeroBitrate(nowMs: Long) = bitrateTracker.getAccumulatedSize(nowMs).bits == 0L - /** * Recursively checks this layer and its dependencies to see if the bitrate is zero. * Note that unlike [calcBitrate] this does not avoid double-visiting layers; the overhead * of the hash table is usually more than the cost of any double-visits. - * - * This is left open for use in testing. */ - open fun hasZeroBitrate(nowMs: Long): Boolean { - if (!layerHasZeroBitrate(nowMs)) { - return false - } - if (dependencyLayers.any { !it.layerHasZeroBitrate(nowMs) }) { - return false - } - if (useSoftDependencies && softDependencyLayers.any { !it.layerHasZeroBitrate(nowMs) }) { - return false - } - return true - } + abstract fun hasZeroBitrate(nowMs: Long): Boolean /** * Extracts a [NodeStatsBlock] from an [RtpLayerDesc]. */ - fun getNodeStats() = NodeStatsBlock(indexString(index)).apply { + open fun getNodeStats() = NodeStatsBlock(indexString()).apply { addNumber("frameRate", frameRate) addNumber("height", height) addNumber("index", index) addNumber("bitrate_bps", getBitrate(System.currentTimeMillis()).bps) - addNumber("tid", tid) - addNumber("sid", sid) } + fun debugState(): OrderedJsonObject = getNodeStats().toJson().apply { put("indexString", indexString()) } + + abstract fun indexString(): String + companion object { /** * The index value that is used to represent that forwarding is suspended. @@ -270,14 +181,14 @@ constructor( fun getEidFromIndex(index: Int) = index shr 6 /** - * Get an spatial ID from a layer index. If the index is [SUSPENDED_INDEX], + * Get a spatial ID from a layer index. If the index is [SUSPENDED_INDEX], * the value is unspecified. */ @JvmStatic fun getSidFromIndex(index: Int) = (index and 0x38) shr 3 /** - * Get an temporal ID from a layer index. If the index is [SUSPENDED_INDEX], + * Get a temporal ID from a layer index. If the index is [SUSPENDED_INDEX], * the value is unspecified. */ @JvmStatic diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpReceiverImpl.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpReceiverImpl.kt index edd15e2be6..8f749fef38 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpReceiverImpl.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpReceiverImpl.kt @@ -143,7 +143,7 @@ class RtpReceiverImpl @JvmOverloads constructor( private val videoBitrateCalculator = VideoBitrateCalculator(parentLogger) private val audioBitrateCalculator = BitrateCalculator("Audio bitrate calculator") - private val videoParser = VideoParser(streamInformationStore, logger) + private val videoParser = VideoParser(streamInformationStore, logger, diagnosticContext) override fun isReceivingAudio() = audioBitrateCalculator.active override fun isReceivingVideo() = videoBitrateCalculator.active diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/ParsedVideoPacket.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/ParsedVideoPacket.kt index bc9de8c01a..d4928295b8 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/ParsedVideoPacket.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/ParsedVideoPacket.kt @@ -25,8 +25,8 @@ abstract class ParsedVideoPacket( buffer: ByteArray, offset: Int, length: Int, - encodingIndex: Int? -) : VideoRtpPacket(buffer, offset, length, encodingIndex) { + encodingId: Int +) : VideoRtpPacket(buffer, offset, length, encodingId) { abstract val isKeyframe: Boolean abstract val isStartOfFrame: Boolean diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/RtpExtensions.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/RtpExtensions.kt index 1e869642ec..90799e11a8 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/RtpExtensions.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/RtpExtensions.kt @@ -96,7 +96,14 @@ enum class RtpExtensionType(val uri: String) { /** * The URN which identifies the RTP Header Extension for Video Orientation. */ - VIDEO_ORIENTATION("urn:3gpp:video-orientation"); + VIDEO_ORIENTATION("urn:3gpp:video-orientation"), + + /** + * The URN which identifies the AV1 Dependency Descriptor RTP Header Extension + */ + AV1_DEPENDENCY_DESCRIPTOR( + "https://aomediacodec.github.io/av1-rtp-spec/#dependency-descriptor-rtp-header-extension" + ); companion object { private val uriMap = RtpExtensionType.values().associateBy(RtpExtensionType::uri) diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/VideoRtpPacket.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/VideoRtpPacket.kt index dc4835e3c1..19dc0716b5 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/VideoRtpPacket.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/VideoRtpPacket.kt @@ -15,6 +15,7 @@ */ package org.jitsi.nlj.rtp +import org.jitsi.nlj.RtpLayerDesc import org.jitsi.rtp.rtp.RtpPacket /** @@ -22,35 +23,22 @@ import org.jitsi.rtp.rtp.RtpPacket * parsed (i.e. we don't know information gained from * parsing codec-specific data). */ -open class VideoRtpPacket protected constructor( +open class VideoRtpPacket @JvmOverloads constructor( buffer: ByteArray, offset: Int, length: Int, - qualityIndex: Int? + /** The encoding ID of this packet. */ + var encodingId: Int = RtpLayerDesc.SUSPENDED_ENCODING_ID ) : RtpPacket(buffer, offset, length) { - constructor( - buffer: ByteArray, - offset: Int, - length: Int - ) : this( - buffer, - offset, - length, - qualityIndex = null - ) - - /** The index of this packet relative to its source's RtpLayers. */ - var qualityIndex: Int = qualityIndex ?: -1 - - open val layerId = 0 + open val layerIds: Collection = listOf(0) override fun clone(): VideoRtpPacket { return VideoRtpPacket( cloneBuffer(BYTES_TO_LEAVE_AT_START_OF_PACKET), BYTES_TO_LEAVE_AT_START_OF_PACKET, length, - qualityIndex = qualityIndex + encodingId = encodingId ) } } diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/VideoCodecParser.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/VideoCodecParser.kt index 2ffe3e977a..28fad3f251 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/VideoCodecParser.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/VideoCodecParser.kt @@ -19,8 +19,7 @@ package org.jitsi.nlj.rtp.codec import org.jitsi.nlj.MediaSourceDesc import org.jitsi.nlj.PacketInfo import org.jitsi.nlj.RtpEncodingDesc -import org.jitsi.nlj.RtpLayerDesc -import org.jitsi.nlj.findRtpLayerDesc +import org.jitsi.nlj.findRtpLayerDescs import org.jitsi.nlj.rtp.VideoRtpPacket /** @@ -52,5 +51,5 @@ abstract class VideoCodecParser( return null } - protected fun findRtpLayerDesc(packet: VideoRtpPacket): RtpLayerDesc? = sources.findRtpLayerDesc(packet) + protected fun findRtpLayerDescs(packet: VideoRtpPacket) = sources.findRtpLayerDescs(packet) } diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDPacket.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDPacket.kt new file mode 100644 index 0000000000..476beca682 --- /dev/null +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDPacket.kt @@ -0,0 +1,226 @@ +/* + * Copyright @ 2018 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jitsi.nlj.rtp.codec.av1 + +import org.jitsi.nlj.RtpEncodingDesc +import org.jitsi.nlj.RtpLayerDesc +import org.jitsi.nlj.rtp.ParsedVideoPacket +import org.jitsi.rtp.rtp.RtpPacket +import org.jitsi.rtp.rtp.header_extensions.Av1DependencyDescriptorHeaderExtension +import org.jitsi.rtp.rtp.header_extensions.Av1DependencyDescriptorReader +import org.jitsi.rtp.rtp.header_extensions.Av1DependencyDescriptorStatelessSubset +import org.jitsi.rtp.rtp.header_extensions.Av1DependencyException +import org.jitsi.rtp.rtp.header_extensions.Av1TemplateDependencyStructure +import org.jitsi.rtp.rtp.header_extensions.FrameInfo +import org.jitsi.utils.logging2.Logger + +/** A video packet carrying an AV1 Dependency Descriptor. Note that this may or may not be an actual AV1 packet; + * other video codecs can also carry the AV1 DD. + */ +class Av1DDPacket : ParsedVideoPacket { + var descriptor: Av1DependencyDescriptorHeaderExtension? + val statelessDescriptor: Av1DependencyDescriptorStatelessSubset + val frameInfo: FrameInfo? + val av1DDHeaderExtensionId: Int + + private constructor( + buffer: ByteArray, + offset: Int, + length: Int, + av1DDHeaderExtensionId: Int, + encodingId: Int, + descriptor: Av1DependencyDescriptorHeaderExtension?, + statelessDescriptor: Av1DependencyDescriptorStatelessSubset, + frameInfo: FrameInfo? + ) : super(buffer, offset, length, encodingId) { + this.descriptor = descriptor + this.statelessDescriptor = statelessDescriptor + this.frameInfo = frameInfo + this.av1DDHeaderExtensionId = av1DDHeaderExtensionId + } + + constructor( + packet: RtpPacket, + av1DDHeaderExtensionId: Int, + templateDependencyStructure: Av1TemplateDependencyStructure?, + logger: Logger + ) : super(packet.buffer, packet.offset, packet.length, RtpLayerDesc.SUSPENDED_ENCODING_ID) { + this.av1DDHeaderExtensionId = av1DDHeaderExtensionId + val ddExt = packet.getHeaderExtension(av1DDHeaderExtensionId) + requireNotNull(ddExt) { + "Packet did not have Dependency Descriptor" + } + val parser = Av1DependencyDescriptorReader(ddExt) + descriptor = try { + parser.parse(templateDependencyStructure) + } catch (e: Av1DependencyException) { + logger.warn( + "Could not parse AV1 Dependency Descriptor for ssrc ${packet.ssrc} seq ${packet.sequenceNumber}: " + + e.message + ) + null + } + statelessDescriptor = descriptor ?: parser.parseStateless() + frameInfo = try { + descriptor?.frameInfo + } catch (e: Av1DependencyException) { + logger.warn( + "Could not extract frame info from AV1 Dependency Descriptor for " + + "ssrc ${packet.ssrc} seq ${packet.sequenceNumber}: ${e.message}" + ) + null + } + } + + /* "template_dependency_structure_present_flag MUST be set to 1 for the first packet of a coded video sequence, + * and MUST be set to 0 otherwise" + */ + override val isKeyframe: Boolean + get() = statelessDescriptor.newTemplateDependencyStructure != null + + override val isStartOfFrame: Boolean + get() = statelessDescriptor.startOfFrame + + override val isEndOfFrame: Boolean + get() = statelessDescriptor.endOfFrame + + override val layerIds: Collection + get() = frameInfo?.dtisPresent + ?: run { super.layerIds } + + val frameNumber + get() = statelessDescriptor.frameNumber + + val activeDecodeTargets + get() = descriptor?.activeDecodeTargetsBitmask + + override fun toString(): String = buildString { + append(super.toString()) + append(", DTIs=${frameInfo?.dtisPresent}") + activeDecodeTargets?.let { append(", ActiveTargets=$it") } + } + + override fun clone(): Av1DDPacket { + val descriptor = descriptor?.clone() + val statelessDescriptor = descriptor ?: statelessDescriptor.clone() + return Av1DDPacket( + cloneBuffer(BYTES_TO_LEAVE_AT_START_OF_PACKET), + BYTES_TO_LEAVE_AT_START_OF_PACKET, + length, + av1DDHeaderExtensionId = av1DDHeaderExtensionId, + encodingId = encodingId, + descriptor = descriptor, + statelessDescriptor = statelessDescriptor, + frameInfo = frameInfo + ) + } + + fun getScalabilityStructure(eid: Int = 0, baseFrameRate: Double = 30.0): RtpEncodingDesc? { + val descriptor = this.descriptor + requireNotNull(descriptor) { + "Can't get scalability structure from packet without a descriptor" + } + return descriptor.getScalabilityStructure(ssrc, eid, baseFrameRate) + } + + /** Re-encode the current descriptor to the header extension. For use after modifying it. */ + fun reencodeDdExt() { + val descriptor = this.descriptor + requireNotNull(descriptor) { + "Can't re-encode extension from a packet without a descriptor" + } + + var ext = getHeaderExtension(av1DDHeaderExtensionId) + if (ext == null || ext.dataLengthBytes != descriptor.encodedLength) { + removeHeaderExtension(av1DDHeaderExtensionId) + ext = addHeaderExtension(av1DDHeaderExtensionId, descriptor.encodedLength) + } + descriptor.write(ext) + } +} + +fun Av1DependencyDescriptorHeaderExtension.getScalabilityStructure( + ssrc: Long, + eid: Int = 0, + baseFrameRate: Double = 30.0 +): RtpEncodingDesc? { + val activeDecodeTargetsBitmask = this.activeDecodeTargetsBitmask + ?: // Can't get scalability structure from dependency descriptor that doesn't specify decode targets + return null + + val layerCounts = Array(structure.maxSpatialId + 1) { + IntArray(structure.maxTemporalId + 1) + } + + // Figure out the frame rates per spatial/temporal layer. + structure.templateInfo.forEach { t -> + if (!t.hasInterPictureDependency()) { + // This is a template that doesn't reference any previous frames, so is probably a key frame or + // part of the same temporal picture with one, i.e. not part of the regular structure. + return@forEach + } + layerCounts[t.spatialId][t.temporalId]++ + } + + // Sum up counts per spatial layer + layerCounts.forEach { a -> + var total = 0 + for (i in a.indices) { + val entry = a[i] + a[i] += total + total += entry + } + } + + val maxFrameGroup = layerCounts.maxOf { it.maxOrNull()!! } + + val layers = ArrayList() + + structure.decodeTargetInfo.forEachIndexed { i, dt -> + if (!activeDecodeTargetsBitmask.containsDecodeTarget(i)) { + return@forEachIndexed + } + val height = structure.maxRenderResolutions.getOrNull(dt.spatialId)?.height ?: -1 + + // Calculate the fraction of this spatial layer's framerate this DT comprises. + val frameRate = baseFrameRate * layerCounts[dt.spatialId][dt.temporalId] / maxFrameGroup + + layers.add(Av1DDRtpLayerDesc(eid, i, dt.temporalId, dt.spatialId, height, frameRate)) + } + return RtpEncodingDesc(ssrc, layers.toArray(arrayOf()), eid) +} + +/** Check whether an activeDecodeTargetsBitmask contains a specific decode target. */ +fun Int.containsDecodeTarget(dt: Int) = ((1 shl dt) and this) != 0 + +/** + * Returns the delta between two AV1 templateID values, taking into account + * rollover. This will return the 'positive' delta between the two + * picture IDs in the form of the number you'd add to b to get a. e.g.: + * getTl0PicIdxDelta(1, 10) -> 55 (10 + 55 = 1) + * getTl0PicIdxDelta(1, 58) -> 7 (58 + 7 = 1) + */ +fun getTemplateIdDelta(a: Int, b: Int): Int = (a - b + 64) % 64 + +/** + * Apply a delta to a given templateID and return the result (taking + * rollover into account) + * @param start the starting templateID + * @param delta the delta to be applied + * @return the templateID resulting from doing "start + delta" + */ +fun applyTemplateIdDelta(start: Int, delta: Int): Int = (start + delta) % 64 diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDParser.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDParser.kt new file mode 100644 index 0000000000..db169f1a76 --- /dev/null +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDParser.kt @@ -0,0 +1,179 @@ +/* + * Copyright @ 2018 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jitsi.nlj.rtp.codec.av1 + +import org.jitsi.nlj.MediaSourceDesc +import org.jitsi.nlj.PacketInfo +import org.jitsi.nlj.rtp.codec.VideoCodecParser +import org.jitsi.nlj.util.Rfc3711IndexTracker +import org.jitsi.nlj.util.TreeCache +import org.jitsi.rtp.rtp.RtpPacket +import org.jitsi.rtp.rtp.header_extensions.Av1TemplateDependencyStructure +import org.jitsi.utils.LRUCache +import org.jitsi.utils.logging.DiagnosticContext +import org.jitsi.utils.logging.TimeSeriesLogger +import org.jitsi.utils.logging2.Logger +import org.jitsi.utils.logging2.createChildLogger + +/** + * Some [Av1DDPacket] fields are not able to be determined by looking at a single packet with an AV1 DD + * (for example the template dependency structure is only carried in keyframes). This class updates the layer + * descriptions with information from frames, and also diagnoses packet format variants that the Jitsi videobridge + * won't be able to route. + */ +class Av1DDParser( + sources: Array, + parentLogger: Logger, + private val diagnosticContext: DiagnosticContext +) : VideoCodecParser(sources) { + private val logger = createChildLogger(parentLogger) + + /** History of AV1 templates. */ + private val ddStateHistory = LRUCache(STATE_HISTORY_SIZE, true) + + fun createFrom(packet: RtpPacket, av1DdExtId: Int): Av1DDPacket { + val history = ddStateHistory.getOrPut(packet.ssrc) { + TemplateHistory(TEMPLATE_HISTORY_SIZE) + } + + val priorEntry = history.get(packet.sequenceNumber) + + val priorStructure = priorEntry?.value?.structure?.clone() + + val av1Packet = Av1DDPacket(packet, av1DdExtId, priorStructure, logger) + + val newStructure = av1Packet.descriptor?.newTemplateDependencyStructure + if (newStructure != null) { + val structureChanged = newStructure.templateIdOffset != priorStructure?.templateIdOffset + history.insert(packet.sequenceNumber, Av1DdInfo(newStructure.clone(), structureChanged)) + logger.debug { + "Inserting new structure with templates ${newStructure.templateIdOffset} .. " + + "${(newStructure.templateIdOffset + newStructure.templateCount - 1) % 64} " + + "for RTP packet ssrc ${packet.ssrc} seq ${packet.sequenceNumber}. " + + "Changed from previous: $structureChanged." + } + } + + if (timeSeriesLogger.isTraceEnabled) { + val point = diagnosticContext + .makeTimeSeriesPoint("av1_parser") + .addField("rtp.ssrc", packet.ssrc) + .addField("rtp.seq", packet.sequenceNumber) + .addField("rtp.timestamp", packet.timestamp) + .addField("av1_parser.key", priorEntry?.key) + .addField("av1.startOfFrame", av1Packet.statelessDescriptor.startOfFrame) + .addField("av1.endOfFrame", av1Packet.statelessDescriptor.endOfFrame) + .addField("av1.templateId", av1Packet.statelessDescriptor.frameDependencyTemplateId) + .addField("av1.frameNum", av1Packet.statelessDescriptor.frameNumber) + .addField("av1.frameInfo", av1Packet.frameInfo?.toString()) + .addField("av1.structure", newStructure != null) + .addField("av1.activeTargets", av1Packet.descriptor?.activeDecodeTargetsBitmask) + val packetStructure = av1Packet.descriptor?.structure + if (packetStructure != null) { + point.addField("av1.structureIdOffset", packetStructure.templateIdOffset) + .addField("av1.templateCount", packetStructure.templateCount) + .addField("av1.structureId", System.identityHashCode(packetStructure)) + } + if (newStructure != null) { + point.addField("av1.newStructureIdOffset", newStructure.templateIdOffset) + .addField("av1.newTemplateCount", newStructure.templateCount) + .addField("av1.newStructureId", System.identityHashCode(newStructure)) + } + timeSeriesLogger.trace(point) + } + + return av1Packet + } + + override fun parse(packetInfo: PacketInfo) { + val av1Packet = packetInfo.packetAs() + val history = ddStateHistory[av1Packet.ssrc] + + if (history == null) { + /** Probably getting spammed with SSRCs? */ + logger.warn("History for ${av1Packet.ssrc} disappeared between createFrom and parse!") + return + } + + val activeDecodeTargets = av1Packet.activeDecodeTargets + + if (activeDecodeTargets != null) { + val changed = history.updateDecodeTargets(av1Packet.sequenceNumber, activeDecodeTargets) + + if (changed) { + packetInfo.layeringChanged = true + logger.debug { + "Decode targets for ${av1Packet.ssrc} changed in seq ${av1Packet.sequenceNumber}: " + + "now 0x${Integer.toHexString(activeDecodeTargets)}. Updating layering." + } + + findSourceDescAndRtpEncodingDesc(av1Packet)?.let { (src, enc) -> + av1Packet.getScalabilityStructure(eid = enc.eid)?.let { + src.setEncodingLayers(it.layers, av1Packet.ssrc) + } + for (otherEnc in src.rtpEncodings) { + if (!ddStateHistory.keys.contains(otherEnc.primarySSRC)) { + src.setEncodingLayers(emptyArray(), otherEnc.primarySSRC) + } + } + } + } + } + } + + companion object { + const val STATE_HISTORY_SIZE = 500 + const val TEMPLATE_HISTORY_SIZE = 500 + + private val timeSeriesLogger = TimeSeriesLogger.getTimeSeriesLogger(Av1DDParser::class.java) + } +} + +class TemplateHistory(minHistory: Int) { + private val indexTracker = Rfc3711IndexTracker() + private val history = TreeCache(minHistory) + private var latestDecodeTargets = -1 + private var latestDecodeTargetIndex = -1 + + fun get(seqNo: Int): Map.Entry? { + val index = indexTracker.update(seqNo) + return history.getEntryBefore(index) + } + + fun insert(seqNo: Int, value: Av1DdInfo) { + val index = indexTracker.update(seqNo) + return history.insert(index, value) + } + + /** Update the current decode targets. + * Return true if the decode target set or the template structure has changed. */ + fun updateDecodeTargets(seqNo: Int, decodeTargets: Int): Boolean { + val index = indexTracker.update(seqNo) + if (index < latestDecodeTargetIndex) { + return false + } + val changed = decodeTargets != latestDecodeTargets || history.get(index)?.changed == true + latestDecodeTargetIndex = index + latestDecodeTargets = decodeTargets + return changed + } +} + +data class Av1DdInfo( + val structure: Av1TemplateDependencyStructure, + val changed: Boolean +) diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDRtpLayerDesc.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDRtpLayerDesc.kt new file mode 100644 index 0000000000..38952a2415 --- /dev/null +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDRtpLayerDesc.kt @@ -0,0 +1,120 @@ +/* + * Copyright @ 2018 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.nlj.rtp.codec.av1 + +import org.jitsi.nlj.RtpLayerDesc +import org.jitsi.nlj.stats.NodeStatsBlock + +/** + * An RtpLayerDesc of the type needed to describe AV1 DD scalability. + */ + +class Av1DDRtpLayerDesc( + /** + * The index of this instance's encoding in the source encoding array. + */ + eid: Int, + /** + * The decoding target of this instance, or negative for unknown. + */ + val dt: Int, + /** + * The temporal layer ID of this instance,. + */ + tid: Int, + /** + * The spatial layer ID of this instance. + */ + sid: Int, + /** + * The max height of the bitstream that this instance represents. The actual + * height may be less due to bad network or system load. + */ + + height: Int, + /** + * The max frame rate (in fps) of the bitstream that this instance + * represents. The actual frame rate may be less due to bad network or + * system load. + */ + frameRate: Double, +) : RtpLayerDesc(eid, tid, sid, height, frameRate) { + override fun copy(height: Int): RtpLayerDesc = Av1DDRtpLayerDesc(eid, dt, tid, sid, height, frameRate) + + override val layerId = dt + override val index = getIndex(eid, dt) + + override fun getBitrate(nowMs: Long) = bitrateTracker.getRate(nowMs) + + override fun hasZeroBitrate(nowMs: Long) = bitrateTracker.getAccumulatedSize(nowMs).bits == 0L + + /** + * Extracts a [NodeStatsBlock] from an [RtpLayerDesc]. + */ + override fun getNodeStats() = super.getNodeStats().apply { + addNumber("dt", dt) + } + + override fun indexString(): String = indexString(index) + + /** + * {@inheritDoc} + */ + override fun toString(): String { + return "subjective_quality=" + index + + ",DT=" + dt + } + + companion object { + /** + * The index value that is used to represent that forwarding is suspended. + */ + const val SUSPENDED_INDEX = -1 + + const val SUSPENDED_DT = -1 + + /** + * Calculate the "index" of a layer based on its encoding and decode target. + * This is a server-side id and should not be confused with any encoding id defined + * in the client (such as the rid) or the encodingId. This is used by the videobridge's + * adaptive source projection for filtering. + */ + @JvmStatic + fun getIndex(eid: Int, dt: Int): Int { + val e = if (eid < 0) 0 else eid + val d = if (dt < 0) 0 else dt + + return (e shl 6) or d + } + + /** + * Get a decode target ID from a layer index. If the index is [SUSPENDED_INDEX], + * the value is unspecified. + */ + @JvmStatic + fun getDtFromIndex(index: Int) = if (index == SUSPENDED_INDEX) SUSPENDED_DT else index and 0x3f + + /** + * Get a string description of a layer index. + */ + @JvmStatic + fun indexString(index: Int): String = if (index == SUSPENDED_INDEX) { + "SUSP" + } else { + "E${getEidFromIndex(index)}DT${getDtFromIndex(index)}" + } + } +} diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp8/Vp8Packet.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp8/Vp8Packet.kt index 93ea7b741a..2798125d7e 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp8/Vp8Packet.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp8/Vp8Packet.kt @@ -38,11 +38,11 @@ class Vp8Packet private constructor( length: Int, isKeyframe: Boolean?, isStartOfFrame: Boolean?, - encodingIndex: Int?, + encodingId: Int, height: Int?, pictureId: Int?, TL0PICIDX: Int? -) : ParsedVideoPacket(buffer, offset, length, encodingIndex) { +) : ParsedVideoPacket(buffer, offset, length, encodingId) { constructor( buffer: ByteArray, @@ -52,7 +52,7 @@ class Vp8Packet private constructor( buffer, offset, length, isKeyframe = null, isStartOfFrame = null, - encodingIndex = null, + encodingId = RtpLayerDesc.SUSPENDED_ENCODING_ID, height = null, pictureId = null, TL0PICIDX = null @@ -116,8 +116,12 @@ class Vp8Packet private constructor( val temporalLayerIndex: Int = Vp8Utils.getTemporalLayerIdOfFrame(this) - override val layerId: Int - get() = if (hasTemporalLayerIndex) RtpLayerDesc.getIndex(0, 0, temporalLayerIndex) else super.layerId + override val layerIds: Collection + get() = if (hasTemporalLayerIndex) { + listOf(RtpLayerDesc.getIndex(0, 0, temporalLayerIndex)) + } else { + super.layerIds + } /** * This is currently used as an overall spatial index, not an in-band spatial quality index a la vp9. That is, @@ -154,7 +158,7 @@ class Vp8Packet private constructor( length, isKeyframe = isKeyframe, isStartOfFrame = isStartOfFrame, - encodingIndex = qualityIndex, + encodingId = encodingId, height = height, pictureId = pictureId, TL0PICIDX = TL0PICIDX diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Packet.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Packet.kt index bacff0b341..ef26c01f89 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Packet.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Packet.kt @@ -19,6 +19,7 @@ package org.jitsi.nlj.rtp.codec.vp9 import org.jitsi.nlj.RtpEncodingDesc import org.jitsi.nlj.RtpLayerDesc import org.jitsi.nlj.rtp.ParsedVideoPacket +import org.jitsi.nlj.rtp.codec.vpx.VpxRtpLayerDesc import org.jitsi.rtp.extensions.bytearray.hashCodeOfSegment import org.jitsi.utils.logging2.createLogger import org.jitsi.utils.logging2.cwarn @@ -39,10 +40,10 @@ class Vp9Packet private constructor( isKeyframe: Boolean?, isStartOfFrame: Boolean?, isEndOfFrame: Boolean?, - encodingIndex: Int?, + encodingId: Int, pictureId: Int?, TL0PICIDX: Int? -) : ParsedVideoPacket(buffer, offset, length, encodingIndex) { +) : ParsedVideoPacket(buffer, offset, length, encodingId) { constructor( buffer: ByteArray, @@ -53,7 +54,7 @@ class Vp9Packet private constructor( isKeyframe = null, isStartOfFrame = null, isEndOfFrame = null, - encodingIndex = null, + encodingId = RtpLayerDesc.SUSPENDED_ENCODING_ID, pictureId = null, TL0PICIDX = null ) @@ -67,8 +68,12 @@ class Vp9Packet private constructor( override val isEndOfFrame: Boolean = isEndOfFrame ?: DePacketizer.VP9PayloadDescriptor.isEndOfFrame(buffer, payloadOffset, payloadLength) - override val layerId: Int - get() = if (hasLayerIndices) RtpLayerDesc.getIndex(0, spatialLayerIndex, temporalLayerIndex) else super.layerId + override val layerIds: Collection + get() = if (hasLayerIndices) { + listOf(RtpLayerDesc.getIndex(0, spatialLayerIndex, temporalLayerIndex)) + } else { + super.layerIds + } /** End of VP9 picture is the marker bit. Note frame/picture distinction. */ /* TODO: not sure this should be the override from [ParsedVideoPacket] */ @@ -194,7 +199,7 @@ class Vp9Packet private constructor( isKeyframe = isKeyframe, isStartOfFrame = isStartOfFrame, isEndOfFrame = isEndOfFrame, - encodingIndex = qualityIndex, + encodingId = encodingId, pictureId = pictureId, TL0PICIDX = TL0PICIDX ) @@ -301,12 +306,12 @@ class Vp9Packet private constructor( tlCounts[t] += tlCounts[t - 1] } - val layers = ArrayList() + val layers = ArrayList() for (s in 0 until numSpatial) { for (t in 0 until numTemporal) { - val dependencies = ArrayList() - val softDependencies = ArrayList() + val dependencies = ArrayList() + val softDependencies = ArrayList() if (s > 0) { /* Because of K-SVC, spatial layer dependencies are soft */ layers.find { it.sid == s - 1 && it.tid == t }?.let { softDependencies.add(it) } @@ -314,7 +319,7 @@ class Vp9Packet private constructor( if (t > 0) { layers.find { it.sid == s && it.tid == t - 1 }?.let { dependencies.add(it) } } - val layerDesc = RtpLayerDesc( + val layerDesc = VpxRtpLayerDesc( eid = eid, tid = t, sid = s, @@ -324,8 +329,8 @@ class Vp9Packet private constructor( } else { RtpLayerDesc.NO_FRAME_RATE }, - dependencyLayers = dependencies.toArray(arrayOf()), - softDependencyLayers = softDependencies.toArray(arrayOf()) + dependencyLayers = dependencies.toArray(arrayOf()), + softDependencyLayers = softDependencies.toArray(arrayOf()) ) layers.add(layerDesc) } diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Parser.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Parser.kt index 8d0f50db3a..04518ddbf6 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Parser.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Parser.kt @@ -18,8 +18,8 @@ package org.jitsi.nlj.rtp.codec.vp9 import org.jitsi.nlj.MediaSourceDesc import org.jitsi.nlj.PacketInfo -import org.jitsi.nlj.findRtpLayerDesc import org.jitsi.nlj.rtp.codec.VideoCodecParser +import org.jitsi.nlj.rtp.codec.vpx.VpxRtpLayerDesc import org.jitsi.nlj.util.StateChangeLogger import org.jitsi.rtp.extensions.toHex import org.jitsi.utils.logging2.Logger @@ -75,7 +75,11 @@ class Vp9Parser( * when calculating layers' bitrates. These values are small enough this is probably * fine, but revisit this if it turns out to be a problem. */ - findRtpLayerDesc(vp9Packet)?.useSoftDependencies = vp9Packet.usesInterLayerDependency + findRtpLayerDescs(vp9Packet).forEach { + if (it is VpxRtpLayerDesc) { + it.useSoftDependencies = vp9Packet.usesInterLayerDependency + } + } } pictureIdState.setState(vp9Packet.hasPictureId, vp9Packet) { diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vpx/VpxRtpLayerDesc.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vpx/VpxRtpLayerDesc.kt new file mode 100644 index 0000000000..9c42b14e04 --- /dev/null +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vpx/VpxRtpLayerDesc.kt @@ -0,0 +1,190 @@ +/* + * Copyright @ 2018 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jitsi.nlj.rtp.codec.vpx + +import org.jitsi.nlj.RtpLayerDesc +import org.jitsi.nlj.stats.NodeStatsBlock +import org.jitsi.nlj.util.Bandwidth +import org.jitsi.nlj.util.BitrateTracker +import org.jitsi.nlj.util.sum + +/** + * An RtpLayerDesc of the type needed to describe VP8 and VP9 scalability. + */ +class VpxRtpLayerDesc +@JvmOverloads +constructor( + /** + * The index of this instance's encoding in the source encoding array. + */ + eid: Int, + /** + * The temporal layer ID of this instance, or negative for unknown. + */ + tid: Int, + /** + * The spatial layer ID of this instance, or negative for unknown. + */ + sid: Int, + /** + * The max height of the bitstream that this instance represents. The actual + * height may be less due to bad network or system load. [RtpLayerDesc.NO_HEIGHT] for unknown. + * + * XXX we should be able to sniff the actual height from the RTP + * packets. + */ + height: Int, + /** + * The max frame rate (in fps) of the bitstream that this instance + * represents. The actual frame rate may be less due to bad network or + * system load. [RtpLayerDesc.NO_FRAME_RATE] for unknown. + */ + frameRate: Double, + /** + * The [RtpLayerDesc]s on which this layer definitely depends. + */ + val dependencyLayers: Array = emptyArray(), + /** + * The [RtpLayerDesc]s on which this layer possibly depends. + * (The intended use case is K-SVC mode.) + */ + val softDependencyLayers: Array = emptyArray() +) : RtpLayerDesc(eid, tid, sid, height, frameRate) { + init { + require(tid < 8) { "Invalid temporal ID $tid" } + require(sid < 8) { "Invalid spatial ID $sid" } + } + + /** + * Clone an existing layer desc, inheriting its statistics, + * modifying only specific values. + */ + override fun copy(height: Int) = VpxRtpLayerDesc( + eid = this.eid, + tid = this.tid, + sid = this.sid, + height = height, + frameRate = this.frameRate, + dependencyLayers = this.dependencyLayers, + softDependencyLayers = this.softDependencyLayers + ).also { + it.inheritFrom(this) + } + + /** + * Whether softDependencyLayers are to be used. + */ + var useSoftDependencies = true + + /** + * @return the "id" of this layer within this encoding. This is a server-side id and should + * not be confused with any encoding id defined in the client (such as the + * rid). + */ + override val layerId = getIndex(0, sid, tid) + + /** + * A local index of this track. + */ + override val index = getIndex(eid, sid, tid) + + /** + * Inherit another layer description's [BitrateTracker] object. + */ + override fun inheritFrom(other: RtpLayerDesc) { + super.inheritFrom(other) + if (other is VpxRtpLayerDesc) { + useSoftDependencies = other.useSoftDependencies + } + } + + /** + * {@inheritDoc} + */ + override fun toString(): String { + return "subjective_quality=$index,temporal_id=$tid,spatial_id=$sid,height=$height" + } + + /** + * Gets the cumulative bitrate (in bps) of this [RtpLayerDesc] and its dependencies. + * + * This is left open for use in testing. + * + * @param nowMs + * @return the cumulative bitrate (in bps) of this [RtpLayerDesc] and its dependencies. + */ + override fun getBitrate(nowMs: Long): Bandwidth = calcBitrate(nowMs).values.sum() + + /** + * Recursively adds the bitrate (in bps) of this [RtpLayerDesc] and + * its dependencies in the map passed in as an argument. + * + * This is necessary to ensure we don't double-count layers in cases + * of multiple dependencies. + * + * @param nowMs + */ + private fun calcBitrate(nowMs: Long, rates: MutableMap = HashMap()): MutableMap { + if (rates.containsKey(index)) { + return rates + } + rates[index] = bitrateTracker.getRate(nowMs) + + dependencyLayers.forEach { it.calcBitrate(nowMs, rates) } + + if (useSoftDependencies) { + softDependencyLayers.forEach { it.calcBitrate(nowMs, rates) } + } + + return rates + } + + /** + * Returns true if this layer, alone, has a zero bitrate. + */ + private fun layerHasZeroBitrate(nowMs: Long) = bitrateTracker.getAccumulatedSize(nowMs).bits == 0L + + /** + * Recursively checks this layer and its dependencies to see if the bitrate is zero. + * Note that unlike [calcBitrate] this does not avoid double-visiting layers; the overhead + * of the hash table is usually more than the cost of any double-visits. + * + * This is left open for use in testing. + */ + override fun hasZeroBitrate(nowMs: Long): Boolean { + if (!layerHasZeroBitrate(nowMs)) { + return false + } + if (dependencyLayers.any { !it.layerHasZeroBitrate(nowMs) }) { + return false + } + if (useSoftDependencies && softDependencyLayers.any { !it.layerHasZeroBitrate(nowMs) }) { + return false + } + return true + } + + /** + * Extracts a [NodeStatsBlock] from an [RtpLayerDesc]. + */ + override fun getNodeStats() = super.getNodeStats().apply { + addNumber("tid", tid) + addNumber("sid", sid) + } + + override fun indexString(): String = indexString(index) +} diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/BitrateCalculator.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/BitrateCalculator.kt index 62105d6994..2daf4cc1da 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/BitrateCalculator.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/BitrateCalculator.kt @@ -22,7 +22,7 @@ import org.jitsi.nlj.Event import org.jitsi.nlj.MediaSourceDesc import org.jitsi.nlj.PacketInfo import org.jitsi.nlj.SetMediaSourcesEvent -import org.jitsi.nlj.findRtpLayerDesc +import org.jitsi.nlj.findRtpLayerDescs import org.jitsi.nlj.rtp.VideoRtpPacket import org.jitsi.nlj.stats.NodeStatsBlock import org.jitsi.nlj.transform.node.ObserverNode @@ -55,8 +55,8 @@ class VideoBitrateCalculator( super.observe(packetInfo) val videoRtpPacket: VideoRtpPacket = packetInfo.packet as VideoRtpPacket - mediaSourceDescs.findRtpLayerDesc(videoRtpPacket)?.let { - val now = clock.millis() + val now = clock.millis() + mediaSourceDescs.findRtpLayerDescs(videoRtpPacket).forEach { if (it.updateBitrate(videoRtpPacket.length.bytes, now)) { /* When a layer is started when it was previously inactive, * we want to recalculate bandwidth allocation. diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/VideoParser.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/VideoParser.kt index a31326f595..685c70639a 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/VideoParser.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/VideoParser.kt @@ -21,7 +21,9 @@ import org.jitsi.nlj.PacketInfo import org.jitsi.nlj.SetMediaSourcesEvent import org.jitsi.nlj.format.Vp8PayloadType import org.jitsi.nlj.format.Vp9PayloadType +import org.jitsi.nlj.rtp.RtpExtensionType import org.jitsi.nlj.rtp.codec.VideoCodecParser +import org.jitsi.nlj.rtp.codec.av1.Av1DDParser import org.jitsi.nlj.rtp.codec.vp8.Vp8Packet import org.jitsi.nlj.rtp.codec.vp8.Vp8Parser import org.jitsi.nlj.rtp.codec.vp9.Vp9Packet @@ -32,6 +34,7 @@ import org.jitsi.nlj.util.ReadOnlyStreamInformationStore import org.jitsi.rtp.extensions.bytearray.toHex import org.jitsi.rtp.rtp.RtpPacket import org.jitsi.utils.OrderedJsonObject +import org.jitsi.utils.logging.DiagnosticContext import org.jitsi.utils.logging2.Logger import org.jitsi.utils.logging2.cdebug import org.jitsi.utils.logging2.createChildLogger @@ -41,7 +44,8 @@ import org.jitsi.utils.logging2.createChildLogger */ class VideoParser( private val streamInformationStore: ReadOnlyStreamInformationStore, - parentLogger: Logger + parentLogger: Logger, + private val diagnosticContext: DiagnosticContext ) : TransformerNode("Video parser") { private val logger = createChildLogger(parentLogger) private val stats = Stats() @@ -49,18 +53,27 @@ class VideoParser( private var sources: Array = arrayOf() private var signaledSources: Array = sources + private var av1DDExtId: Int? = null + private var videoCodecParser: VideoCodecParser? = null + init { + streamInformationStore.onRtpExtensionMapping(RtpExtensionType.AV1_DEPENDENCY_DESCRIPTOR) { + av1DDExtId = it + } + } + override fun transform(packetInfo: PacketInfo): PacketInfo? { val packet = packetInfo.packetAs() + val av1DDExtId = this.av1DDExtId // So null checks work val payloadType = streamInformationStore.rtpPayloadTypes[packet.payloadType.toByte()] ?: run { logger.error("Unrecognized video payload type ${packet.payloadType}, cannot parse video information") stats.numPacketsDroppedUnknownPt++ return null } val parsedPacket = try { - when (payloadType) { - is Vp8PayloadType -> { + when { + payloadType is Vp8PayloadType -> { val vp8Packet = packetInfo.packet.toOtherType(::Vp8Packet) packetInfo.packet = vp8Packet packetInfo.resetPayloadVerification() @@ -75,7 +88,7 @@ class VideoParser( } vp8Packet } - is Vp9PayloadType -> { + payloadType is Vp9PayloadType -> { val vp9Packet = packetInfo.packet.toOtherType(::Vp9Packet) packetInfo.packet = vp9Packet packetInfo.resetPayloadVerification() @@ -90,6 +103,22 @@ class VideoParser( } vp9Packet } + av1DDExtId != null && packet.getHeaderExtension(av1DDExtId) != null -> { + if (videoCodecParser !is Av1DDParser) { + logger.cdebug { + "Creating new Av1DDParser, current videoCodecParser is ${videoCodecParser?.javaClass}" + } + resetSources() + packetInfo.layeringChanged = true + videoCodecParser = Av1DDParser(sources, logger, diagnosticContext) + } + + val av1DDPacket = (videoCodecParser as Av1DDParser).createFrom(packet, av1DDExtId) + packetInfo.packet = av1DDPacket + packetInfo.resetPayloadVerification() + + av1DDPacket + } else -> { if (videoCodecParser != null) { logger.cdebug { diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/VideoQualityLayerLookup.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/VideoQualityLayerLookup.kt index 6b2935d41b..399dc79fdc 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/VideoQualityLayerLookup.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/VideoQualityLayerLookup.kt @@ -19,7 +19,7 @@ import org.jitsi.nlj.Event import org.jitsi.nlj.MediaSourceDesc import org.jitsi.nlj.PacketInfo import org.jitsi.nlj.SetMediaSourcesEvent -import org.jitsi.nlj.findRtpLayerDesc +import org.jitsi.nlj.findRtpEncodingId import org.jitsi.nlj.rtp.VideoRtpPacket import org.jitsi.nlj.stats.NodeStatsBlock import org.jitsi.nlj.transform.node.TransformerNode @@ -37,10 +37,10 @@ class VideoQualityLayerLookup( private var sources: Array = arrayOf() private val numPacketsDroppedNoEncoding = AtomicInteger() - /* TODO: combine this with VideoBitrateCalculator? They both do findRtpLayerDesc. */ override fun transform(packetInfo: PacketInfo): PacketInfo? { val videoPacket = packetInfo.packetAs() - val encodingDesc = sources.findRtpLayerDesc(videoPacket) ?: run { + val encodingId = sources.findRtpEncodingId(videoPacket) + if (encodingId == null) { logger.warn( "Unable to find encoding matching packet! packet=$videoPacket; " + "sources=${sources.joinToString(separator = "\n")}" @@ -48,7 +48,7 @@ class VideoQualityLayerLookup( numPacketsDroppedNoEncoding.incrementAndGet() return null } - videoPacket.qualityIndex = encodingDesc.index + videoPacket.encodingId = encodingId return packetInfo } diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/outgoing/HeaderExtStripper.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/outgoing/HeaderExtStripper.kt index 009bf1288e..b697d9a763 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/outgoing/HeaderExtStripper.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/outgoing/HeaderExtStripper.kt @@ -17,30 +17,41 @@ package org.jitsi.nlj.transform.node.outgoing import org.jitsi.nlj.PacketInfo import org.jitsi.nlj.rtp.RtpExtensionType +import org.jitsi.nlj.rtp.codec.av1.Av1DDPacket import org.jitsi.nlj.transform.node.ModifierNode import org.jitsi.nlj.util.ReadOnlyStreamInformationStore import org.jitsi.rtp.rtp.RtpPacket /** - * Strip all hop-by-hop header extensions. Currently this leaves only ssrc-audio-level and video-orientation. + * Strip all hop-by-hop header extensions. Currently this leaves ssrc-audio-level and video-orientation, + * plus the AV1 dependency descriptor if the packet is an Av1DDPacket. */ class HeaderExtStripper( streamInformationStore: ReadOnlyStreamInformationStore ) : ModifierNode("Strip header extensions") { private var retainedExts: Set = emptySet() + private var retainedExtsWithAv1DD: Set = emptySet() init { retainedExtTypes.forEach { rtpExtensionType -> streamInformationStore.onRtpExtensionMapping(rtpExtensionType) { - it?.let { retainedExts = retainedExts.plus(it) } + it?.let { + retainedExts = retainedExts.plus(it) + retainedExtsWithAv1DD = retainedExtsWithAv1DD.plus(it) + } } } + streamInformationStore.onRtpExtensionMapping(RtpExtensionType.AV1_DEPENDENCY_DESCRIPTOR) { + it?.let { retainedExtsWithAv1DD = retainedExtsWithAv1DD.plus(it) } + } } override fun modify(packetInfo: PacketInfo): PacketInfo { val rtpPacket = packetInfo.packetAs() - rtpPacket.removeHeaderExtensionsExcept(retainedExts) + val retained = if (rtpPacket is Av1DDPacket) retainedExtsWithAv1DD else retainedExts + + rtpPacket.removeHeaderExtensionsExcept(retained) return packetInfo } diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/util/Rfc3711IndexTracker.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/util/Rfc3711IndexTracker.kt index 65928583cf..fbe7fa679f 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/util/Rfc3711IndexTracker.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/util/Rfc3711IndexTracker.kt @@ -16,6 +16,7 @@ package org.jitsi.nlj.util +import org.jitsi.rtp.util.RtpUtils import org.jitsi.rtp.util.isNewerThan import org.jitsi.rtp.util.rolledOverTo @@ -78,4 +79,16 @@ class Rfc3711IndexTracker { fun interpret(seqNum: Int): Int { return getIndex(seqNum, false) } + + /** Force this sequence number to be interpreted as the new highest, regardless + * of its rollover state. + */ + fun resetAt(seq: Int) { + val delta = RtpUtils.getSequenceNumberDelta(seq, highestSeqNumReceived) + if (delta < 0) { + roc++ + highestSeqNumReceived = seq + } + getIndex(seq, true) + } } diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/util/TreeCache.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/util/TreeCache.kt new file mode 100644 index 0000000000..7288eb420a --- /dev/null +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/util/TreeCache.kt @@ -0,0 +1,62 @@ +/* + * Copyright @ 2019 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jitsi.nlj.util + +import java.util.* + +/** + * Implements a cache based on integer values, optimized for sparse values, such that you can find the most + * recent cached value before a specific value. + * + * The intended use case is AV1 Dependency Descriptor history. + */ +open class TreeCache( + private val minSize: Int +) { + private val map = TreeMap() + + private var highestIndex = -1 + + fun insert(index: Int, value: T) { + map[index] = value + + updateState(index) + } + + fun getEntryBefore(index: Int): Map.Entry? { + updateState(index) + return map.floorEntry(index) + } + + fun get(index: Int): T? = map[index] + + private fun updateState(index: Int) { + if (highestIndex < index) { + highestIndex = index + } + + /* Keep at most one entry older than highestIndex - minSize. */ + val headMap = map.headMap(highestIndex - minSize) + if (headMap.size > 1) { + val last = headMap.keys.last() + headMap.keys.removeIf { it < last } + } + } + + val size + get() = map.size +} diff --git a/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/MediaSourceDescTest.kt b/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/MediaSourceDescTest.kt index b1526978bc..a5eb419ff3 100644 --- a/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/MediaSourceDescTest.kt +++ b/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/MediaSourceDescTest.kt @@ -17,7 +17,10 @@ package org.jitsi.nlj import io.kotest.core.spec.style.ShouldSpec +import io.kotest.matchers.should import io.kotest.matchers.shouldBe +import io.kotest.matchers.types.beInstanceOf +import org.jitsi.nlj.rtp.codec.vpx.VpxRtpLayerDesc import org.jitsi.nlj.util.Bandwidth import org.jitsi.nlj.util.BitrateTracker import org.jitsi.nlj.util.bps @@ -56,6 +59,9 @@ class MediaSourceDescTest : ShouldSpec() { e.layers.size shouldBe 3 for (j in e.layers.indices) { val l = e.layers[j] + l should beInstanceOf() + l as VpxRtpLayerDesc + l.eid shouldBe i l.tid shouldBe j l.sid shouldBe -1 @@ -124,13 +130,12 @@ private fun idx(spatialIdx: Int, temporalIdx: Int, temporalLen: Int) = spatialId * @return an array that holds the layer descriptions. */ private fun createRTPLayerDescs(spatialLen: Int, temporalLen: Int, encodingIdx: Int, height: Int): Array { - val rtpLayers = arrayOfNulls(spatialLen * temporalLen) + val rtpLayers = arrayOfNulls(spatialLen * temporalLen) for (spatialIdx in 0 until spatialLen) { var frameRate = 30.toDouble() / (1 shl temporalLen - 1) for (temporalIdx in 0 until temporalLen) { val idx: Int = idx(spatialIdx, temporalIdx, temporalLen) - var dependencies: Array - dependencies = if (spatialIdx > 0 && temporalIdx > 0) { + val dependencies: Array = if (spatialIdx > 0 && temporalIdx > 0) { // this layer depends on spatialIdx-1 and temporalIdx-1. arrayOf( rtpLayers[ @@ -176,7 +181,7 @@ private fun createRTPLayerDescs(spatialLen: Int, temporalLen: Int, encodingIdx: } val temporalId = if (temporalLen > 1) temporalIdx else -1 val spatialId = if (spatialLen > 1) spatialIdx else -1 - rtpLayers[idx] = RtpLayerDesc( + rtpLayers[idx] = VpxRtpLayerDesc( encodingIdx, temporalId, spatialId, @@ -226,7 +231,7 @@ private fun createSource( name: String, videoType: VideoType, ): MediaSourceDesc { - var height = 720 + var height = 180 val encodings = Array(primarySsrcs.size) { encodingIdx -> val primarySsrc: Long = primarySsrcs[encodingIdx] diff --git a/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/RtpLayerDescTest.kt b/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/RtpLayerDescTest.kt index 31fbf55972..7333842171 100644 --- a/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/RtpLayerDescTest.kt +++ b/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/RtpLayerDescTest.kt @@ -17,21 +17,22 @@ package org.jitsi.nlj import io.kotest.core.spec.style.FunSpec import io.kotest.matchers.shouldBe +import org.jitsi.nlj.rtp.codec.vpx.VpxRtpLayerDesc class RtpLayerDescTest : FunSpec({ test("VP8 layer ids") { // mostly for documenting the encoding -> index mapping. val vp8Layers = arrayOf( - RtpLayerDesc(0, 0, 0, 180, 7.5), - RtpLayerDesc(0, 1, 0, 180, 15.0), - RtpLayerDesc(0, 2, 0, 180, 30.0), - RtpLayerDesc(1, 0, 0, 360, 7.5), - RtpLayerDesc(1, 1, 0, 360, 15.0), - RtpLayerDesc(1, 2, 0, 360, 30.0), - RtpLayerDesc(2, 0, 0, 720, 7.5), - RtpLayerDesc(2, 1, 0, 720, 15.0), - RtpLayerDesc(2, 2, 0, 720, 30.0) + VpxRtpLayerDesc(0, 0, 0, 180, 7.5), + VpxRtpLayerDesc(0, 1, 0, 180, 15.0), + VpxRtpLayerDesc(0, 2, 0, 180, 30.0), + VpxRtpLayerDesc(1, 0, 0, 360, 7.5), + VpxRtpLayerDesc(1, 1, 0, 360, 15.0), + VpxRtpLayerDesc(1, 2, 0, 360, 30.0), + VpxRtpLayerDesc(2, 0, 0, 720, 7.5), + VpxRtpLayerDesc(2, 1, 0, 720, 15.0), + VpxRtpLayerDesc(2, 2, 0, 720, 30.0) ) vp8Layers[0].index shouldBe 0 @@ -58,15 +59,15 @@ class RtpLayerDescTest : FunSpec({ test("VP9 layer ids") { // same here, mostly for documenting the encoding -> index mapping. val vp9Layers = arrayOf( - RtpLayerDesc(0, 0, 0, 180, 7.5), - RtpLayerDesc(0, 1, 0, 180, 15.0), - RtpLayerDesc(0, 2, 0, 180, 30.0), - RtpLayerDesc(0, 0, 1, 360, 7.5), - RtpLayerDesc(0, 1, 1, 360, 15.0), - RtpLayerDesc(0, 2, 1, 360, 30.0), - RtpLayerDesc(0, 0, 2, 720, 7.5), - RtpLayerDesc(0, 1, 2, 720, 15.0), - RtpLayerDesc(0, 2, 2, 720, 30.0) + VpxRtpLayerDesc(0, 0, 0, 180, 7.5), + VpxRtpLayerDesc(0, 1, 0, 180, 15.0), + VpxRtpLayerDesc(0, 2, 0, 180, 30.0), + VpxRtpLayerDesc(0, 0, 1, 360, 7.5), + VpxRtpLayerDesc(0, 1, 1, 360, 15.0), + VpxRtpLayerDesc(0, 2, 1, 360, 30.0), + VpxRtpLayerDesc(0, 0, 2, 720, 7.5), + VpxRtpLayerDesc(0, 1, 2, 720, 15.0), + VpxRtpLayerDesc(0, 2, 2, 720, 30.0) ) vp9Layers[0].index shouldBe 0 diff --git a/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDPacketTest.kt b/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDPacketTest.kt new file mode 100644 index 0000000000..f123be9038 --- /dev/null +++ b/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDPacketTest.kt @@ -0,0 +1,211 @@ +/* + * Copyright @ 2018 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.nlj.rtp.codec.av1 + +import io.kotest.assertions.withClue +import io.kotest.core.spec.style.ShouldSpec +import io.kotest.matchers.should +import io.kotest.matchers.shouldBe +import io.kotest.matchers.shouldNotBe +import io.kotest.matchers.types.beInstanceOf +import org.jitsi.nlj.RtpEncodingDesc +import org.jitsi.rtp.rtp.RtpPacket +import org.jitsi.utils.logging2.LoggerImpl +import javax.xml.bind.DatatypeConverter + +class Av1DDPacketTest : ShouldSpec() { + val logger = LoggerImpl(javaClass.name) + + private data class SampleAv1DDPacket( + val description: String, + val data: ByteArray, + val structureSource: SampleAv1DDPacket?, + // ... + val scalabilityStructure: RtpEncodingDesc? = null + + ) { + constructor( + description: String, + hexData: String, + structureSource: SampleAv1DDPacket?, + // ... + scalabilityStructure: RtpEncodingDesc? = null + ) : this( + description = description, + data = DatatypeConverter.parseHexBinary(hexData), + structureSource = structureSource, + // ... + scalabilityStructure = scalabilityStructure + ) + } + + /* Packets captured from AV1 test calls */ + + private val nonScalableKeyframe = + SampleAv1DDPacket( + "non-scalable keyframe", + // RTP header + "902939b91ff6a695114ed316" + + // Header extension header + "bede0006" + + // Other header extensions + "3248c482" + + "510002" + + // AV1 DD + "bc80000180003a410180ef808680" + + // Padding + "000000" + + // AV1 Media payload, truncated. + "680b0800000004477e1a00d004301061", + null, + RtpEncodingDesc( + 0x114ed316L, + arrayOf( + Av1DDRtpLayerDesc(0, 0, 0, 0, 270, 30.0) + ) + ) + ) + + private val scalableKeyframe = + SampleAv1DDPacket( + "scalable keyframe", + // RTP header + "906519aed780a8f32ab2873c" + + // Header extension header (2-byte) + "1000001a" + + // Other header extensions + "030302228e" + + "050219f9" + + // AV1 DD + "0b5a8003ca80081485214eaaaaa8000600004000100002aa80a80006" + + "00004000100002a000a800060000400016d549241b5524906d549231" + + "57e001974ca864330e222396eca8655304224390eca87753013f00b3" + + "027f016704ff02cf" + + // Padding + "000000" + + // AV1 Media payload, truncated. + "aabbdc101a58014000b4028001680500", + null, + RtpEncodingDesc( + 0x2ab2873cL, + arrayOf( + Av1DDRtpLayerDesc(0, 0, 0, 0, 180, 7.5), + Av1DDRtpLayerDesc(0, 1, 1, 0, 180, 15.0), + Av1DDRtpLayerDesc(0, 2, 2, 0, 180, 30.0), + Av1DDRtpLayerDesc(0, 3, 1, 0, 360, 7.5), + Av1DDRtpLayerDesc(0, 4, 1, 1, 360, 15.0), + Av1DDRtpLayerDesc(0, 5, 1, 2, 360, 30.0), + Av1DDRtpLayerDesc(0, 6, 2, 0, 720, 7.5), + Av1DDRtpLayerDesc(0, 7, 2, 1, 720, 15.0), + Av1DDRtpLayerDesc(0, 8, 2, 2, 720, 30.0) + ) + ) + ) + + private val testPackets = arrayOf( + nonScalableKeyframe, + SampleAv1DDPacket( + "non-scalable following packet", + // RTP header + "90a939ba1ff6a695114ed316" + + // Header extension header + "bede0003" + + // Other header extensions + "3248d02a" + + "510003" + + // AV1 DD + "b2400001" + + // Padding + "00" + + // AV1 Media payload, truncated. + "9057f7c3f51ba803b0c397e938589750", + nonScalableKeyframe + ), + scalableKeyframe, + SampleAv1DDPacket( + "scalable following packet changing available DTs", + // RTP header + "90e519c2d780b1bd2ab2873c" + + // Header extension header + "bede0004" + + // Other header extensions + "3202f824" + + "511a19" + + // AV1 DD + "b4c303cd401c" + + // Padding + "000000" + + // AV1 Media payload, truncated. + "edbbdd501a87000000027e016704ff02", + scalableKeyframe, + RtpEncodingDesc( + 0x2ab2873cL, + arrayOf( + Av1DDRtpLayerDesc(0, 0, 0, 0, 180, 7.5), + Av1DDRtpLayerDesc(0, 1, 0, 1, 180, 15.0), + Av1DDRtpLayerDesc(0, 2, 0, 2, 180, 30.0), + ) + ) + ) + ) + + init { + context("AV1 packets") { + should("be parsed correctly") { + for (t in testPackets) { + withClue(t.description) { + val structure = if (t.structureSource != null) { + val sourceR = + RtpPacket(t.structureSource.data, 0, t.structureSource.data.size) + val sourceP = Av1DDPacket(sourceR, AV1_DD_HEADER_EXTENSION_ID, null, logger) + sourceP.descriptor?.structure + } else { + null + } + + val r = RtpPacket(t.data, 0, t.data.size) + val p = Av1DDPacket(r, AV1_DD_HEADER_EXTENSION_ID, structure, logger) + + if (t.scalabilityStructure != null) { + val tss = t.scalabilityStructure + val ss = p.getScalabilityStructure() + ss shouldNotBe null + ss!!.primarySSRC shouldBe tss.primarySSRC + ss.layers.size shouldBe tss.layers.size + for ((index, layer) in ss.layers.withIndex()) { + val tLayer = tss.layers[index] + layer.layerId shouldBe tLayer.layerId + layer.index shouldBe tLayer.index + layer should beInstanceOf() + layer as Av1DDRtpLayerDesc + tLayer as Av1DDRtpLayerDesc + layer.dt shouldBe tLayer.dt + layer.height shouldBe tLayer.height + layer.frameRate shouldBe tLayer.frameRate + } + } else { + p.getScalabilityStructure() shouldBe null + } + } + } + } + } + } + + companion object { + const val AV1_DD_HEADER_EXTENSION_ID = 11 + } +} diff --git a/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9PacketTest.kt b/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9PacketTest.kt index 3aee52fc19..93b08517bf 100644 --- a/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9PacketTest.kt +++ b/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9PacketTest.kt @@ -1,11 +1,28 @@ +/* + * Copyright @ 2018 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.jitsi.nlj.rtp.codec.vp9 import io.kotest.assertions.withClue import io.kotest.core.spec.style.ShouldSpec +import io.kotest.matchers.should import io.kotest.matchers.shouldBe import io.kotest.matchers.shouldNotBe +import io.kotest.matchers.types.beInstanceOf import org.jitsi.nlj.RtpEncodingDesc -import org.jitsi.nlj.RtpLayerDesc +import org.jitsi.nlj.rtp.codec.vpx.VpxRtpLayerDesc import org.jitsi_modified.impl.neomedia.codec.video.vp9.DePacketizer import javax.xml.bind.DatatypeConverter @@ -136,16 +153,16 @@ class Vp9PacketTest : ShouldSpec() { scalabilityStructure = RtpEncodingDesc( 0x6098017bL, arrayOf( - RtpLayerDesc(0, 0, 0, 180, 7.5), + VpxRtpLayerDesc(0, 0, 0, 180, 7.5), /* TODO: dependencies */ - RtpLayerDesc(0, 1, 0, 180, 15.0), - RtpLayerDesc(0, 2, 0, 180, 30.0), - RtpLayerDesc(0, 0, 1, 360, 7.5), - RtpLayerDesc(0, 1, 1, 360, 15.0), - RtpLayerDesc(0, 2, 1, 360, 30.0), - RtpLayerDesc(0, 0, 2, 720, 7.5), - RtpLayerDesc(0, 1, 2, 720, 15.0), - RtpLayerDesc(0, 2, 2, 720, 30.0) + VpxRtpLayerDesc(0, 1, 0, 180, 15.0), + VpxRtpLayerDesc(0, 2, 0, 180, 30.0), + VpxRtpLayerDesc(0, 0, 1, 360, 7.5), + VpxRtpLayerDesc(0, 1, 1, 360, 15.0), + VpxRtpLayerDesc(0, 2, 1, 360, 30.0), + VpxRtpLayerDesc(0, 0, 2, 720, 7.5), + VpxRtpLayerDesc(0, 1, 2, 720, 15.0), + VpxRtpLayerDesc(0, 2, 2, 720, 30.0) ) ) ), @@ -480,7 +497,7 @@ class Vp9PacketTest : ShouldSpec() { scalabilityStructure = RtpEncodingDesc( 0x184b0cc4L, arrayOf( - RtpLayerDesc(0, 0, 0, 1158, 30.0) + VpxRtpLayerDesc(0, 0, 0, 1158, 30.0) ) ) ), @@ -600,7 +617,7 @@ class Vp9PacketTest : ShouldSpec() { scalabilityStructure = RtpEncodingDesc( 0x6538459eL, arrayOf( - RtpLayerDesc(0, 0, 0, 720, 30.0) + VpxRtpLayerDesc(0, 0, 0, 720, 30.0) ) ) ), @@ -663,9 +680,9 @@ class Vp9PacketTest : ShouldSpec() { scalabilityStructure = RtpEncodingDesc( 0xa4d04528L, arrayOf( - RtpLayerDesc(0, 0, 0, 720, 7.5), - RtpLayerDesc(0, 1, 0, 720, 15.0), - RtpLayerDesc(0, 2, 0, 720, 30.0) + VpxRtpLayerDesc(0, 0, 0, 720, 7.5), + VpxRtpLayerDesc(0, 1, 0, 720, 15.0), + VpxRtpLayerDesc(0, 2, 0, 720, 30.0) ) ) ), @@ -773,6 +790,9 @@ class Vp9PacketTest : ShouldSpec() { val tLayer = tss.layers[index] layer.layerId shouldBe tLayer.layerId layer.index shouldBe tLayer.index + layer should beInstanceOf() + layer as VpxRtpLayerDesc + tLayer as VpxRtpLayerDesc layer.sid shouldBe tLayer.sid layer.tid shouldBe tLayer.tid layer.height shouldBe tLayer.height diff --git a/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/util/TreeCacheTest.kt b/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/util/TreeCacheTest.kt new file mode 100644 index 0000000000..95e3113515 --- /dev/null +++ b/jitsi-media-transform/src/test/kotlin/org/jitsi/nlj/util/TreeCacheTest.kt @@ -0,0 +1,100 @@ +/* + * Copyright @ 2018 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.nlj.util + +import io.kotest.core.spec.IsolationMode +import io.kotest.core.spec.style.ShouldSpec +import io.kotest.matchers.shouldBe +import java.util.AbstractMap.SimpleImmutableEntry + +class TreeCacheTest : ShouldSpec() { + override fun isolationMode() = IsolationMode.InstancePerLeaf + + data class Dummy(val data: String) + + private val treeCache = TreeCache(16) + + /** Shorthand for a Map.Entry mapping [key] to a Dummy containing [dummyVal] */ + private fun ed(key: Int, dummyVal: String) = SimpleImmutableEntry(key, Dummy(dummyVal)) + + init { + context("Reading from an empty TreeCache") { + should("return null") { + treeCache.getEntryBefore(10) shouldBe null + } + should("have size 0") { + treeCache.size shouldBe 0 + } + } + context("An entry in a TreeCache") { + treeCache.insert(5, Dummy("A")) + should("be found looking up values after it") { + treeCache.getEntryBefore(10) shouldBe ed(5, "A") + } + should("be found looking up the same value") { + treeCache.getEntryBefore(5) shouldBe ed(5, "A") + } + should("not be found looking up values before it") { + treeCache.getEntryBefore(3) shouldBe null + } + should("not be expired even if values long after it are looked up") { + treeCache.getEntryBefore(10000) shouldBe ed(5, "A") + } + should("cause the tree to have size 1") { + treeCache.size shouldBe 1 + } + } + context("Multiple values in a TreeCache") { + treeCache.insert(5, Dummy("A")) + treeCache.insert(10, Dummy("B")) + should("Be looked up properly") { + treeCache.getEntryBefore(13) shouldBe ed(10, "B") + treeCache.size shouldBe 2 + } + should("Persist within the cache window") { + treeCache.getEntryBefore(8) shouldBe ed(5, "A") + treeCache.size shouldBe 2 + } + should("Not expire an older one if it is the only value outside the cache window") { + treeCache.getEntryBefore(25) shouldBe ed(10, "B") + treeCache.getEntryBefore(8) shouldBe ed(5, "A") + treeCache.size shouldBe 2 + } + should("Expire older ones when newer ones are outside the cache window") { + treeCache.getEntryBefore(30) shouldBe ed(10, "B") + treeCache.getEntryBefore(8) shouldBe null + treeCache.size shouldBe 1 + } + should("Expire only older ones when later values are inserted") { + treeCache.insert(40, Dummy("C")) + treeCache.getEntryBefore(13) shouldBe ed(10, "B") + treeCache.getEntryBefore(8) shouldBe null + treeCache.size shouldBe 2 + } + should("Persist values within the window while expiring values outside it") { + treeCache.insert(15, Dummy("C")) + treeCache.getEntryBefore(8) shouldBe ed(5, "A") + treeCache.getEntryBefore(25) shouldBe ed(15, "C") + treeCache.getEntryBefore(13) shouldBe ed(10, "B") + treeCache.getEntryBefore(8) shouldBe ed(5, "A") + treeCache.size shouldBe 3 + treeCache.insert(30, Dummy("D")) + treeCache.getEntryBefore(8) shouldBe null + treeCache.size shouldBe 3 + } + } + } +} diff --git a/jvb/src/main/java/org/jitsi/videobridge/cc/AdaptiveSourceProjection.java b/jvb/src/main/java/org/jitsi/videobridge/cc/AdaptiveSourceProjection.java index a0f545f364..c59ec62c9f 100644 --- a/jvb/src/main/java/org/jitsi/videobridge/cc/AdaptiveSourceProjection.java +++ b/jvb/src/main/java/org/jitsi/videobridge/cc/AdaptiveSourceProjection.java @@ -17,19 +17,21 @@ import org.jetbrains.annotations.*; import org.jitsi.nlj.*; -import org.jitsi.nlj.format.*; import org.jitsi.nlj.rtp.*; +import org.jitsi.nlj.rtp.codec.av1.*; import org.jitsi.nlj.rtp.codec.vp8.*; +import org.jitsi.nlj.rtp.codec.vp9.*; import org.jitsi.rtp.rtcp.*; -import org.jitsi.utils.collections.*; import org.jitsi.utils.logging.*; import org.jitsi.utils.logging2.Logger; +import org.jitsi.videobridge.cc.av1.*; import org.jitsi.videobridge.cc.vp8.*; import org.jitsi.videobridge.cc.vp9.*; import org.json.simple.*; import java.lang.*; import java.util.*; +import java.util.stream.*; /** * Filters the packets coming from a specific {@link MediaSourceDesc} @@ -88,38 +90,26 @@ public class AdaptiveSourceProjection */ private AdaptiveSourceProjectionContext context; - /** - * The payload type that was used to determine the {@link #context} type. - */ - private int contextPayloadType = -1; - /** * The target quality index for this source projection. */ private int targetIndex = RtpLayerDesc.SUSPENDED_INDEX; - private final Map payloadTypes; - /** * Ctor. * * @param source the {@link MediaSourceDesc} that owns the packets * that this instance filters. - * - * @param payloadTypes a reference to a map of payload types. This map - * should be updated as the payload types change. */ public AdaptiveSourceProjection( @NotNull DiagnosticContext diagnosticContext, @NotNull MediaSourceDesc source, Runnable keyframeRequester, - Map payloadTypes, Logger parentLogger ) { targetSsrc = source.getPrimarySSRC(); this.diagnosticContext = diagnosticContext; - this.payloadTypes = payloadTypes; this.parentLogger = parentLogger; this.logger = parentLogger.createChildLogger(AdaptiveSourceProjection.class.getName(), Map.of("targetSsrc", Long.toString(targetSsrc), @@ -160,7 +150,8 @@ public boolean accept(@NotNull PacketInfo packetInfo) // suspended so that it can raise the needsKeyframe flag and also allow // it to compute a sequence number delta when the target becomes > -1. - if (videoRtpPacket.getQualityIndex() < 0) + int encodingId = videoRtpPacket.getEncodingId(); + if (encodingId == RtpLayerDesc.SUSPENDED_ENCODING_ID) { logger.warn( "Dropping an RTP packet, because egress was unable to find " + @@ -169,8 +160,7 @@ public boolean accept(@NotNull PacketInfo packetInfo) } int targetIndexCopy = targetIndex; - boolean accept = contextCopy.accept( - packetInfo, videoRtpPacket.getQualityIndex(), targetIndexCopy); + boolean accept = contextCopy.accept(packetInfo, encodingId, targetIndexCopy); // We check if the context needs a keyframe regardless of whether or not // the packet was accepted. @@ -190,8 +180,8 @@ public boolean accept(@NotNull PacketInfo packetInfo) /** * Gets or creates the adaptive source projection context that corresponds to - * the payload type of the RTP packet that is specified as a parameter. If - * the payload type is different from {@link #contextPayloadType}, then a + * the parsed class of the RTP packet that is specified as a parameter. If + * the payload type is different from the appropriate one for the context, then a * new adaptive source projection context is created that is appropriate for * the new payload type. * @@ -204,26 +194,9 @@ public boolean accept(@NotNull PacketInfo packetInfo) */ private AdaptiveSourceProjectionContext getContext(@NotNull VideoRtpPacket rtpPacket) { - PayloadType payloadTypeObject; int payloadType = rtpPacket.getPayloadType(); - if (context == null || contextPayloadType != payloadType) - { - payloadTypeObject = payloadTypes.get((byte)payloadType); - if (payloadTypeObject == null) - { - logger.error("No payload type object signalled for payload type " + payloadType + " yet, " + - "cannot create source projection context"); - return null; - } - } - else - { - // No need to call the expensive getDynamicRTPPayloadTypes. - payloadTypeObject = context.getPayloadType(); - } - - if (payloadTypeObject instanceof Vp8PayloadType) + if (rtpPacket instanceof Vp8Packet) { // Context switch between VP8 simulcast and VP8 non-simulcast (sort // of pretend that they're different codecs). @@ -233,8 +206,7 @@ private AdaptiveSourceProjectionContext getContext(@NotNull VideoRtpPacket rtpPa // then simulcast is disabled. /* Check whether this stream is projectable by the VP8AdaptiveSourceProjectionContext. */ - boolean projectable = rtpPacket instanceof Vp8Packet && - ((Vp8Packet)rtpPacket).getHasTemporalLayerIndex() && + boolean projectable = ((Vp8Packet)rtpPacket).getHasTemporalLayerIndex() && ((Vp8Packet)rtpPacket).getHasPictureId(); if (projectable @@ -244,39 +216,33 @@ private AdaptiveSourceProjectionContext getContext(@NotNull VideoRtpPacket rtpPa RtpState rtpState = getRtpState(); logger.debug(() -> "adaptive source projection " + (context == null ? "creating new" : "changing to") + - " VP8 context for payload type " - + payloadType + - ", source packet ssrc " + rtpPacket.getSsrc()); + " VP8 context for source packet ssrc " + rtpPacket.getSsrc()); context = new VP8AdaptiveSourceProjectionContext( - diagnosticContext, payloadTypeObject, rtpState, parentLogger); - contextPayloadType = payloadType; + diagnosticContext, rtpState, parentLogger); } else if (!projectable - && !(context instanceof GenericAdaptiveSourceProjectionContext)) + && (!(context instanceof GenericAdaptiveSourceProjectionContext) || + ((GenericAdaptiveSourceProjectionContext)context).getPayloadType() != payloadType)) { RtpState rtpState = getRtpState(); // context switch logger.debug(() -> { - boolean hasTemporalLayer = rtpPacket instanceof Vp8Packet && - ((Vp8Packet)rtpPacket).getHasTemporalLayerIndex(); - boolean hasPictureId = rtpPacket instanceof Vp8Packet && - ((Vp8Packet)rtpPacket).getHasPictureId(); + boolean hasTemporalLayer = ((Vp8Packet)rtpPacket).getHasTemporalLayerIndex(); + boolean hasPictureId = ((Vp8Packet)rtpPacket).getHasPictureId(); return "adaptive source projection " + (context == null ? "creating new" : "changing to") + - " generic context for non-scalable VP8 payload type " - + payloadType + + " generic context for non-scalable VP8 payload " + " (packet is " + rtpPacket.getClass().getSimpleName() + ", ssrc " + rtpPacket.getSsrc() + ", hasTL=" + hasTemporalLayer + ", hasPID=" + hasPictureId + ")"; }); - context = new GenericAdaptiveSourceProjectionContext(payloadTypeObject, rtpState, parentLogger); - contextPayloadType = payloadType; + context = new GenericAdaptiveSourceProjectionContext(payloadType, rtpState, parentLogger); } // no context switch return context; } - else if (payloadTypeObject instanceof Vp9PayloadType) + else if (rtpPacket instanceof Vp9Packet) { if (!(context instanceof Vp9AdaptiveSourceProjectionContext)) { @@ -284,28 +250,39 @@ else if (payloadTypeObject instanceof Vp9PayloadType) RtpState rtpState = getRtpState(); logger.debug(() -> "adaptive source projection " + (context == null ? "creating new" : "changing to") + - " VP9 context for payload type " - + payloadType + - ", source packet ssrc " + rtpPacket.getSsrc()); + " VP9 context for source packet ssrc " + rtpPacket.getSsrc()); context = new Vp9AdaptiveSourceProjectionContext( - diagnosticContext, payloadTypeObject, rtpState, parentLogger); - contextPayloadType = payloadType; + diagnosticContext, rtpState, parentLogger); } return context; } - else if (context == null || contextPayloadType != payloadType) + else if (rtpPacket instanceof Av1DDPacket) { - RtpState rtpState = getRtpState(); - logger.debug(() -> "adaptive source projection " + - (context == null ? "creating new" : "changing to") + - " generic context for payload type " + payloadType); - context = new GenericAdaptiveSourceProjectionContext(payloadTypeObject, rtpState, parentLogger); - contextPayloadType = payloadType; + if (!(context instanceof Av1DDAdaptiveSourceProjectionContext)) + { + // context switch + RtpState rtpState = getRtpState(); + logger.debug(() -> "adaptive source projection " + + (context == null ? "creating new" : "changing to") + + " AV1 DD context for source packet ssrc " + rtpPacket.getSsrc()); + context = new Av1DDAdaptiveSourceProjectionContext( + diagnosticContext, rtpState, parentLogger); + } + return context; } else { + if (!(context instanceof GenericAdaptiveSourceProjectionContext) || + ((GenericAdaptiveSourceProjectionContext)context).getPayloadType() != payloadType) + { + RtpState rtpState = getRtpState(); + logger.debug(() -> "adaptive source projection " + + (context == null ? "creating new" : "changing to") + + " generic context for payload type " + rtpPacket.getPayloadType()); + context = new GenericAdaptiveSourceProjectionContext(payloadType, rtpState, parentLogger); + } return context; } } @@ -384,7 +361,6 @@ public JSONObject getDebugState() debugState.put( "context", contextCopy == null ? null : contextCopy.getDebugState()); - debugState.put("contextPayloadType", contextPayloadType); debugState.put("targetIndex", targetIndex); return debugState; diff --git a/jvb/src/main/java/org/jitsi/videobridge/cc/AdaptiveSourceProjectionContext.java b/jvb/src/main/java/org/jitsi/videobridge/cc/AdaptiveSourceProjectionContext.java index 72629e9f1a..af56947d67 100644 --- a/jvb/src/main/java/org/jitsi/videobridge/cc/AdaptiveSourceProjectionContext.java +++ b/jvb/src/main/java/org/jitsi/videobridge/cc/AdaptiveSourceProjectionContext.java @@ -16,10 +16,11 @@ package org.jitsi.videobridge.cc; import org.jitsi.nlj.*; -import org.jitsi.nlj.format.*; import org.jitsi.rtp.rtcp.*; import org.json.simple.*; +import java.util.*; + /** * Implementations of this interface are responsible for projecting a specific * video source of a specific payload type. @@ -40,11 +41,11 @@ public interface AdaptiveSourceProjectionContext * Determines whether an RTP packet should be accepted or not. * * @param packetInfo the RTP packet to determine whether to accept or not. - * @param incomingIndex the quality index of the incoming RTP packet. + * @param incomingEncoding The encoding of the incoming packet. * @param targetIndex the target quality index * @return true if the packet should be accepted, false otherwise. */ - boolean accept(PacketInfo packetInfo, int incomingIndex, int targetIndex); + boolean accept(PacketInfo packetInfo, int incomingEncoding, int targetIndex); /** * @return true if this stream context needs a keyframe in order to either @@ -81,12 +82,6 @@ void rewriteRtp(PacketInfo packetInfo) */ RtpState getRtpState(); - /** - * @return the {@link PayloadType} of the RTP packets that this context - * processes. - */ - PayloadType getPayloadType(); - /** * Gets a JSON representation of the parts of this object's state that * are deemed useful for debugging. diff --git a/jvb/src/main/java/org/jitsi/videobridge/cc/GenericAdaptiveSourceProjectionContext.java b/jvb/src/main/java/org/jitsi/videobridge/cc/GenericAdaptiveSourceProjectionContext.java index 6a6e6d9987..f1561bab66 100644 --- a/jvb/src/main/java/org/jitsi/videobridge/cc/GenericAdaptiveSourceProjectionContext.java +++ b/jvb/src/main/java/org/jitsi/videobridge/cc/GenericAdaptiveSourceProjectionContext.java @@ -17,13 +17,14 @@ import org.jetbrains.annotations.*; import org.jitsi.nlj.*; -import org.jitsi.nlj.format.*; import org.jitsi.nlj.rtp.*; import org.jitsi.rtp.rtcp.*; import org.jitsi.rtp.util.*; import org.jitsi.utils.logging2.*; import org.json.simple.*; +import java.util.*; + /** * A generic implementation of an adaptive source projection context that can be * used with non-SVC codecs or when simulcast is not enabled/used or when @@ -69,9 +70,9 @@ class GenericAdaptiveSourceProjectionContext private boolean needsKeyframe = true; /** - * Useful to determine whether a packet is a "keyframe". + * Useful to determine when we need to change generic projection contexts. */ - private final PayloadType payloadType; + private final int payloadType; /** * The maximum sequence number that we have sent. @@ -109,7 +110,7 @@ class GenericAdaptiveSourceProjectionContext * @param rtpState the RTP state (i.e. seqnum, timestamp to start with, etc). */ GenericAdaptiveSourceProjectionContext( - @NotNull PayloadType payloadType, + int payloadType, @NotNull RtpState rtpState, @NotNull Logger parentLogger) { @@ -131,13 +132,13 @@ class GenericAdaptiveSourceProjectionContext * thread) accessing this method at a time. * * @param packetInfo the RTP packet to determine whether to accept or not. - * @param incomingIndex the quality index of the + * @param incomingEncoding The encoding index of the packet * @param targetIndex the target quality index * @return true if the packet should be accepted, false otherwise. */ @Override public synchronized boolean - accept(@NotNull PacketInfo packetInfo, int incomingIndex, int targetIndex) + accept(@NotNull PacketInfo packetInfo, int incomingEncoding, int targetIndex) { VideoRtpPacket rtpPacket = packetInfo.packetAs(); if (targetIndex == RtpLayerDesc.SUSPENDED_INDEX) @@ -327,8 +328,7 @@ public RtpState getRtpState() ssrc, maxDestinationSequenceNumber, maxDestinationTimestamp); } - @Override - public PayloadType getPayloadType() + public int getPayloadType() { return payloadType; } diff --git a/jvb/src/main/java/org/jitsi/videobridge/cc/vp8/VP8AdaptiveSourceProjectionContext.java b/jvb/src/main/java/org/jitsi/videobridge/cc/vp8/VP8AdaptiveSourceProjectionContext.java index da073e321c..97be97d72a 100644 --- a/jvb/src/main/java/org/jitsi/videobridge/cc/vp8/VP8AdaptiveSourceProjectionContext.java +++ b/jvb/src/main/java/org/jitsi/videobridge/cc/vp8/VP8AdaptiveSourceProjectionContext.java @@ -18,11 +18,9 @@ import org.jetbrains.annotations.*; import org.jitsi.nlj.*; import org.jitsi.nlj.codec.vpx.*; -import org.jitsi.nlj.format.*; import org.jitsi.nlj.rtp.codec.vp8.*; import org.jitsi.rtp.rtcp.*; import org.jitsi.rtp.util.*; -import org.jitsi.utils.*; import org.jitsi.utils.logging.*; import org.jitsi.utils.logging2.Logger; import org.jitsi.videobridge.cc.*; @@ -67,30 +65,19 @@ public class VP8AdaptiveSourceProjectionContext */ private VP8FrameProjection lastVP8FrameProjection; - /** - * The VP8 media format. No essential functionality relies on this field, - * it's only used as a cache of the {@link PayloadType} instance for VP8 in - * case we have to do a context switch (see {@link AdaptiveSourceProjection}), - * in order to avoid having to resolve the format. - */ - private final PayloadType payloadType; - /** * Ctor. * - * @param payloadType the VP8 media format. * @param rtpState the RTP state to begin with. */ public VP8AdaptiveSourceProjectionContext( @NotNull DiagnosticContext diagnosticContext, - @NotNull PayloadType payloadType, @NotNull RtpState rtpState, @NotNull Logger parentLogger) { this.diagnosticContext = diagnosticContext; this.logger = parentLogger.createChildLogger( VP8AdaptiveSourceProjectionContext.class.getName()); - this.payloadType = payloadType; this.vp8QualityFilter = new VP8QualityFilter(parentLogger); lastVP8FrameProjection = new VP8FrameProjection(diagnosticContext, @@ -265,13 +252,13 @@ private boolean frameIsNewSsrc(VP8Frame frame) * Determines whether a packet should be accepted or not. * * @param packetInfo the RTP packet to determine whether to project or not. - * @param incomingIndex the quality index of the incoming RTP packet + * @param incomingEncoding the encoding of the incoming RTP packet * @param targetIndex the target quality index we want to achieve * @return true if the packet should be accepted, false otherwise. */ @Override public synchronized boolean accept( - @NotNull PacketInfo packetInfo, int incomingIndex, int targetIndex) + @NotNull PacketInfo packetInfo, int incomingEncoding, int targetIndex) { if (!(packetInfo.getPacket() instanceof Vp8Packet)) { @@ -309,7 +296,7 @@ public synchronized boolean accept( Instant receivedTime = packetInfo.getReceivedTime(); boolean accepted = vp8QualityFilter - .acceptFrame(frame, incomingIndex, targetIndex, receivedTime); + .acceptFrame(frame, incomingEncoding, targetIndex, receivedTime); if (accepted) { @@ -672,12 +659,6 @@ public RtpState getRtpState() lastVP8FrameProjection.getTimestamp()); } - @Override - public PayloadType getPayloadType() - { - return payloadType; - } - /** * Rewrites the RTP packet that is specified as an argument. * @@ -744,7 +725,6 @@ public synchronized JSONObject getDebugState() debugState.put( "vp8FrameMaps", mapSizes); debugState.put("vp8QualityFilter", vp8QualityFilter.getDebugState()); - debugState.put("payloadType", payloadType.toString()); return debugState; } diff --git a/jvb/src/main/java/org/jitsi/videobridge/cc/vp8/VP8QualityFilter.java b/jvb/src/main/java/org/jitsi/videobridge/cc/vp8/VP8QualityFilter.java index e3b635d515..9461b37a91 100644 --- a/jvb/src/main/java/org/jitsi/videobridge/cc/vp8/VP8QualityFilter.java +++ b/jvb/src/main/java/org/jitsi/videobridge/cc/vp8/VP8QualityFilter.java @@ -106,7 +106,7 @@ boolean needsKeyframe() * method at a time. * * @param frame the VP8 frame. - * @param incomingIndex the quality index of the incoming RTP packet + * @param incomingEncoding the encoding index of the incoming RTP packet * @param externalTargetIndex the target quality index that the user of this * instance wants to achieve. * @param receivedTime the current time @@ -114,7 +114,7 @@ boolean needsKeyframe() */ synchronized boolean acceptFrame( @NotNull VP8Frame frame, - int incomingIndex, + int incomingEncoding, int externalTargetIndex, Instant receivedTime) { // We make local copies of the externalTemporalLayerIdTarget and the @@ -154,16 +154,15 @@ synchronized boolean acceptFrame( temporalLayerIdOfFrame = 0; } - int encodingId = RtpLayerDesc.getEidFromIndex(incomingIndex); if (frame.isKeyframe()) { logger.debug(() -> "Quality filter got keyframe for stream " + frame.getSsrc()); - return acceptKeyframe(encodingId, receivedTime); + return acceptKeyframe(incomingEncoding, receivedTime); } else if (currentEncodingId > SUSPENDED_ENCODING_ID) { - if (isOutOfSwitchingPhase(receivedTime) && isPossibleToSwitch(encodingId)) + if (isOutOfSwitchingPhase(receivedTime) && isPossibleToSwitch(incomingEncoding)) { // XXX(george) i've noticed some "rogue" base layer keyframes // that trigger this. what happens is the client sends a base @@ -176,7 +175,7 @@ else if (currentEncodingId > SUSPENDED_ENCODING_ID) needsKeyframe = true; } - if (encodingId != currentEncodingId) + if (incomingEncoding != currentEncodingId) { // for non-keyframes, we can't route anything but the current encoding return false; @@ -206,13 +205,12 @@ else if (currentEncodingId < externalEncodingIdTarget) else { // In this branch we're not processing a keyframe and the - // currentSpatialLayerId is in suspended state, which means we need - // a keyframe to start streaming again. Reaching this point also - // means that we want to forward something (because both - // externalEncodingIdTarget and externalTemporalLayerIdTarget - // are greater than 0) so we set the request keyframe flag. + // currentEncodingId is in suspended state, which means we need + // a keyframe to start streaming again. + + // We should have already requested a keyframe, either above or when the + // internal target encoding was first moved off SUSPENDED_ENCODING. - // assert needsKeyframe == true; return false; } } @@ -222,7 +220,7 @@ else if (currentEncodingId < externalEncodingIdTarget) * or not. * * @param receivedTime the time the latest frame was received - * @return true if we're in layer switching phase, false otherwise. + * @return false if we're in layer switching phase, true otherwise. */ private synchronized boolean isOutOfSwitchingPhase(@Nullable Instant receivedTime) { diff --git a/jvb/src/main/java/org/jitsi/videobridge/xmpp/MediaSourceFactory.java b/jvb/src/main/java/org/jitsi/videobridge/xmpp/MediaSourceFactory.java index 943b66f5da..c8582b9002 100644 --- a/jvb/src/main/java/org/jitsi/videobridge/xmpp/MediaSourceFactory.java +++ b/jvb/src/main/java/org/jitsi/videobridge/xmpp/MediaSourceFactory.java @@ -17,6 +17,7 @@ import org.jitsi.nlj.*; import org.jitsi.nlj.rtp.*; +import org.jitsi.nlj.rtp.codec.vpx.*; import org.jitsi.utils.logging2.*; import org.jitsi.xmpp.extensions.colibri.*; import org.jitsi.xmpp.extensions.jingle.*; @@ -79,7 +80,7 @@ Map getSecondarySsrcTypeMap() return secondarySsrcTypeMap; } - private static final RtpLayerDesc[] noDependencies = new RtpLayerDesc[0]; + private static final VpxRtpLayerDesc[] noDependencies = new VpxRtpLayerDesc[0]; /* * Creates layers for an encoding. @@ -92,8 +93,8 @@ Map getSecondarySsrcTypeMap() private static RtpLayerDesc[] createRTPLayerDescs( int spatialLen, int temporalLen, int encodingIdx, int height) { - RtpLayerDesc[] rtpLayers - = new RtpLayerDesc[spatialLen * temporalLen]; + VpxRtpLayerDesc[] rtpLayers + = new VpxRtpLayerDesc[spatialLen * temporalLen]; for (int spatialIdx = 0; spatialIdx < spatialLen; spatialIdx++) { @@ -104,11 +105,11 @@ private static RtpLayerDesc[] createRTPLayerDescs( int idx = idx(spatialIdx, temporalIdx, temporalLen); - RtpLayerDesc[] dependencies; + VpxRtpLayerDesc[] dependencies; if (spatialIdx > 0 && temporalIdx > 0) { // this layer depends on spatialIdx-1 and temporalIdx-1. - dependencies = new RtpLayerDesc[]{ + dependencies = new VpxRtpLayerDesc[]{ rtpLayers[ idx(spatialIdx, temporalIdx - 1, temporalLen)], @@ -120,7 +121,7 @@ private static RtpLayerDesc[] createRTPLayerDescs( else if (spatialIdx > 0) { // this layer depends on spatialIdx-1. - dependencies = new RtpLayerDesc[] + dependencies = new VpxRtpLayerDesc[] {rtpLayers[ idx(spatialIdx - 1, temporalIdx, temporalLen)]}; @@ -128,7 +129,7 @@ else if (spatialIdx > 0) else if (temporalIdx > 0) { // this layer depends on temporalIdx-1. - dependencies = new RtpLayerDesc[] + dependencies = new VpxRtpLayerDesc[] {rtpLayers[ idx(spatialIdx, temporalIdx - 1, temporalLen)]}; @@ -143,13 +144,11 @@ else if (temporalIdx > 0) int spatialId = spatialLen > 1 ? spatialIdx : -1; rtpLayers[idx] - = new RtpLayerDesc(encodingIdx, + = new VpxRtpLayerDesc(encodingIdx, temporalId, spatialId, height, frameRate, dependencies); frameRate *= 2; } - - } return rtpLayers; } diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/SsrcCache.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/SsrcCache.kt index 5189d7f391..1754ab6aa0 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/SsrcCache.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/SsrcCache.kt @@ -20,6 +20,9 @@ import org.jitsi.nlj.MediaSourceDesc import org.jitsi.nlj.VideoType import org.jitsi.nlj.codec.vpx.VpxUtils import org.jitsi.nlj.rtp.SsrcAssociationType +import org.jitsi.nlj.rtp.codec.av1.Av1DDPacket +import org.jitsi.nlj.rtp.codec.av1.applyTemplateIdDelta +import org.jitsi.nlj.rtp.codec.av1.getTemplateIdDelta import org.jitsi.nlj.rtp.codec.vp8.Vp8Packet import org.jitsi.nlj.rtp.codec.vp9.Vp9Packet import org.jitsi.rtp.rtcp.RtcpPacket @@ -660,10 +663,65 @@ private class Vp9CodecDeltas(val tl0IndexDelta: Int) : CodecDeltas { override fun toString() = "[VP9 TL0Idx]$tl0IndexDelta" } +private class Av1DDCodecState : CodecState { + val lastFrameNum: Int + val lastTemplateIdx: Int + constructor(lastFrameNum: Int, lastTemplateIdx: Int) { + this.lastFrameNum = lastFrameNum + this.lastTemplateIdx = lastTemplateIdx + } + + constructor(packet: Av1DDPacket) { + val descriptor = packet.descriptor + requireNotNull(descriptor) { "AV1 Packet being routed must have non-null descriptor" } + this.lastFrameNum = packet.frameNumber + this.lastTemplateIdx = descriptor.structure.templateIdOffset + descriptor.structure.templateCount + } + + override fun getDeltas(otherState: CodecState?): CodecDeltas? { + if (otherState !is Av1DDCodecState) { + return null + } + val frameNumDelta = RtpUtils.getSequenceNumberDelta(lastFrameNum, otherState.lastFrameNum) + val templateIdDelta = getTemplateIdDelta(lastTemplateIdx, otherState.lastTemplateIdx) + return Av1DDCodecDeltas(frameNumDelta, templateIdDelta) + } + + override fun getDeltas(packet: RtpPacket): CodecDeltas? { + if (packet !is Av1DDPacket) { + return null + } + val descriptor = packet.descriptor ?: return null + val frameNumDelta = RtpUtils.getSequenceNumberDelta(lastFrameNum, packet.frameNumber - 1) + val packetLastTemplateIdx = descriptor.structure.templateIdOffset + descriptor.structure.templateCount + val templateIdDelta = getTemplateIdDelta(lastTemplateIdx, packetLastTemplateIdx - 1) + return Av1DDCodecDeltas(frameNumDelta, templateIdDelta) + } +} + +private class Av1DDCodecDeltas(val frameNumDelta: Int, val templateIdDelta: Int) : CodecDeltas { + override fun rewritePacket(packet: RtpPacket) { + require(packet is Av1DDPacket) + val descriptor = packet.descriptor + requireNotNull(descriptor) + + descriptor.frameNumber = RtpUtils.applySequenceNumberDelta(descriptor.frameNumber, frameNumDelta) + descriptor.frameDependencyTemplateId = + applyTemplateIdDelta(descriptor.frameDependencyTemplateId, templateIdDelta) + descriptor.structure.templateIdOffset = + applyTemplateIdDelta(descriptor.structure.templateIdOffset, templateIdDelta) + + packet.reencodeDdExt() + } + + override fun toString() = "[AV1DD FrameNum]$frameNumDelta [Av1DD templateId]$templateIdDelta" +} + private fun RtpPacket.getCodecState(): CodecState? { return when (this) { is Vp8Packet -> Vp8CodecState(this) is Vp9Packet -> Vp9CodecState(this) + is Av1DDPacket -> Av1DDCodecState(this) else -> null } } diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/BandwidthAllocation.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/BandwidthAllocation.kt index ebca3fd157..16681fac83 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/BandwidthAllocation.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/BandwidthAllocation.kt @@ -60,6 +60,13 @@ class BandwidthAllocation @JvmOverloads constructor( put("oversending", oversending) put("has_suspended_sources", hasSuspendedSources) put("suspended_sources", suspendedSources) + val allocations = JSONObject().apply { + allocations.forEach { + val name = it.mediaSource?.sourceName ?: it.endpointId + put(name, it.debugState) + } + } + put("allocations", allocations) } } @@ -93,4 +100,10 @@ data class SingleAllocation( override fun toString(): String = "[id=$endpointId target=${targetLayer?.height}/${targetLayer?.frameRate} " + "ideal=${idealLayer?.height}/${idealLayer?.frameRate}]" + + val debugState: JSONObject + get() = JSONObject().apply { + put("target", targetLayer?.debugState()) + put("ideal", idealLayer?.debugState()) + } } diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/BitrateController.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/BitrateController.kt index 851468a01c..de2e61400d 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/BitrateController.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/BitrateController.kt @@ -163,8 +163,6 @@ class BitrateController @JvmOverloads constructor( } fun addPayloadType(payloadType: PayloadType) { - packetHandler.addPayloadType(payloadType) - if (payloadType.encoding == PayloadTypeEncoding.RTX) { supportsRtx = true } diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/PacketHandler.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/PacketHandler.kt index ed0d2c040c..3cff98225c 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/PacketHandler.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/PacketHandler.kt @@ -19,7 +19,6 @@ import org.jitsi.nlj.MediaSourceDesc import org.jitsi.nlj.PacketInfo import org.jitsi.nlj.PacketInfo.Companion.ENABLE_PAYLOAD_VERIFICATION import org.jitsi.nlj.RtpLayerDesc -import org.jitsi.nlj.format.PayloadType import org.jitsi.nlj.rtp.VideoRtpPacket import org.jitsi.rtp.rtcp.RtcpSrPacket import org.jitsi.utils.event.EventEmitter @@ -59,7 +58,6 @@ internal class PacketHandler( private var firstMedia: Instant? = null private val numDroppedPacketsUnknownSsrc = AtomicInteger(0) - private val payloadTypes: MutableMap = ConcurrentHashMap() /** * The [AdaptiveSourceProjection]s that this instance is managing, keyed @@ -158,7 +156,6 @@ internal class PacketHandler( { eventEmitter.fireEvent { keyframeNeeded(endpointID, source.primarySSRC) } }, - payloadTypes, logger ) logger.debug { "new source projection for $source" } @@ -173,10 +170,6 @@ internal class PacketHandler( fun timeSinceFirstMedia(): Duration = firstMedia?.let { Duration.between(it, clock.instant()) } ?: Duration.ZERO - fun addPayloadType(payloadType: PayloadType) { - payloadTypes[payloadType.pt] = payloadType - } - val debugState: JSONObject get() = JSONObject().apply { this["numDroppedPacketsUnknownSsrc"] = numDroppedPacketsUnknownSsrc.toInt() diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/SingleSourceAllocation.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/SingleSourceAllocation.kt index 509b42616d..82d139f860 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/SingleSourceAllocation.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/SingleSourceAllocation.kt @@ -18,7 +18,6 @@ package org.jitsi.videobridge.cc.allocation import org.jitsi.nlj.MediaSourceDesc import org.jitsi.nlj.RtpLayerDesc -import org.jitsi.nlj.RtpLayerDesc.Companion.indexString import org.jitsi.nlj.VideoType import org.jitsi.utils.logging.DiagnosticContext import org.jitsi.utils.logging.TimeSeriesLogger @@ -63,7 +62,7 @@ internal class SingleSourceAllocation( .addField("remote_endpoint_id", endpointId) for ((l, bitrate) in layers.layers) { ratesTimeSeriesPoint.addField( - "${indexString(l.index)}_${l.height}p_${l.frameRate}fps_bps", + "${l.indexString()}_${l.height}p_${l.frameRate}fps_bps", bitrate ) } diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDAdaptiveSourceProjectionContext.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDAdaptiveSourceProjectionContext.kt new file mode 100644 index 0000000000..78159c4f0c --- /dev/null +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDAdaptiveSourceProjectionContext.kt @@ -0,0 +1,659 @@ +/* + * Copyright @ 2019-present 8x8, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.videobridge.cc.av1 + +import org.jitsi.nlj.PacketInfo +import org.jitsi.nlj.RtpLayerDesc.Companion.getEidFromIndex +import org.jitsi.nlj.rtp.codec.av1.Av1DDPacket +import org.jitsi.nlj.rtp.codec.av1.Av1DDRtpLayerDesc +import org.jitsi.nlj.rtp.codec.av1.Av1DDRtpLayerDesc.Companion.getDtFromIndex +import org.jitsi.nlj.rtp.codec.av1.getTemplateIdDelta +import org.jitsi.rtp.rtcp.RtcpSrPacket +import org.jitsi.rtp.rtp.header_extensions.DTI +import org.jitsi.rtp.rtp.header_extensions.toShortString +import org.jitsi.rtp.util.RtpUtils +import org.jitsi.rtp.util.isNewerThan +import org.jitsi.utils.logging.DiagnosticContext +import org.jitsi.utils.logging.TimeSeriesLogger +import org.jitsi.utils.logging2.Logger +import org.jitsi.utils.logging2.createChildLogger +import org.jitsi.videobridge.cc.AdaptiveSourceProjectionContext +import org.jitsi.videobridge.cc.RewriteException +import org.jitsi.videobridge.cc.RtpState +import org.json.simple.JSONArray +import org.json.simple.JSONObject +import java.time.Duration +import java.time.Instant + +class Av1DDAdaptiveSourceProjectionContext( + private val diagnosticContext: DiagnosticContext, + rtpState: RtpState, + parentLogger: Logger +) : AdaptiveSourceProjectionContext { + private val logger: Logger = createChildLogger(parentLogger) + + /** + * A map that stores the per-encoding AV1 frame maps. + */ + private val av1FrameMaps = HashMap() + + /** + * The [Av1DDQualityFilter] instance that does quality filtering on the + * incoming pictures, to choose encodings and layers to forward. + */ + private val av1QualityFilter = Av1DDQualityFilter(av1FrameMaps, logger) + + private var lastAv1FrameProjection = Av1DDFrameProjection( + diagnosticContext, + rtpState.ssrc, + rtpState.maxSequenceNumber, + rtpState.maxTimestamp + ) + + /** + * The frame number index that started the latest stream resumption. + * We can't send frames with frame number less than this, because we don't have + * space in the projected sequence number/frame number counts. + */ + private var lastFrameNumberIndexResumption = -1 + + override fun accept(packetInfo: PacketInfo, incomingEncoding: Int, targetIndex: Int): Boolean { + val packet = packetInfo.packet + + if (packet !is Av1DDPacket) { + logger.warn("Packet is not AV1 DD Packet") + return false + } + + /* If insertPacketInMap returns null, this is a very old picture, more than Av1FrameMap.PICTURE_MAP_SIZE old, + or something is wrong with the stream. */ + val result = insertPacketInMap(packet) ?: return false + + val frame = result.frame + + if (result.isNewFrame) { + if (packet.isKeyframe && frameIsNewSsrc(frame)) { + /* If we're not currently projecting this SSRC, check if we've + * already decided to drop a subsequent required frame of this SSRC for the DT. + * If we have, we can't turn on the encoding starting from this + * packet, so treat this frame as though it weren't a keyframe. + * Note that this may mean re-analyzing the packets with a now-available template dependency structure. + */ + if (haveSubsequentNonAcceptedChain(frame, incomingEncoding, targetIndex)) { + frame.isKeyframe = false + } + } + val receivedTime = packetInfo.receivedTime + val acceptResult = av1QualityFilter + .acceptFrame(frame, incomingEncoding, targetIndex, receivedTime) + frame.isAccepted = acceptResult.accept && frame.index >= lastFrameNumberIndexResumption + if (frame.isAccepted) { + val projection: Av1DDFrameProjection + try { + projection = createProjection( + frame = frame, + initialPacket = packet, + isResumption = acceptResult.isResumption, + isReset = result.isReset, + mark = acceptResult.mark, + newDt = acceptResult.newDt, + receivedTime = receivedTime + ) + } catch (e: Exception) { + logger.warn("Failed to create frame projection", e) + /* Make sure we don't have an accepted frame without a projection in the map. */ + frame.isAccepted = false + return false + } + frame.projection = projection + if (projection.earliestProjectedSeqNum isNewerThan lastAv1FrameProjection.latestProjectedSeqNum) { + lastAv1FrameProjection = projection + } + } + } + + val accept = frame.isAccepted && frame.projection?.accept(packet) == true + + if (timeSeriesLogger.isTraceEnabled) { + val pt = diagnosticContext.makeTimeSeriesPoint("rtp_av1") + .addField("ssrc", packet.ssrc) + .addField("timestamp", packet.timestamp) + .addField("seq", packet.sequenceNumber) + .addField("frameNumber", packet.frameNumber) + .addField("templateId", packet.statelessDescriptor.frameDependencyTemplateId) + .addField("hasStructure", packet.descriptor?.newTemplateDependencyStructure != null) + .addField("spatialLayer", packet.frameInfo?.spatialId) + .addField("temporalLayer", packet.frameInfo?.temporalId) + .addField("dti", packet.frameInfo?.dti?.toShortString()) + .addField("hasInterPictureDependency", packet.frameInfo?.hasInterPictureDependency()) + // TODO add more relevant fields from AV1 DD for debugging + .addField("targetIndex", Av1DDRtpLayerDesc.indexString(targetIndex)) + .addField("new_frame", result.isNewFrame) + .addField("accept", accept) + av1QualityFilter.addDiagnosticContext(pt) + timeSeriesLogger.trace(pt) + } + + return accept + } + + /** Look up an Av1DDFrame for a packet. */ + private fun lookupAv1Frame(av1Packet: Av1DDPacket): Av1DDFrame? = av1FrameMaps[av1Packet.ssrc]?.findFrame(av1Packet) + + /** + * Insert a packet in the appropriate [Av1DDFrameMap]. + */ + private fun insertPacketInMap(av1Packet: Av1DDPacket) = + av1FrameMaps.getOrPut(av1Packet.ssrc) { Av1DDFrameMap(logger) }.insertPacket(av1Packet) + + private fun haveSubsequentNonAcceptedChain(frame: Av1DDFrame, incomingEncoding: Int, targetIndex: Int): Boolean { + val map = av1FrameMaps[frame.ssrc] ?: return false + val structure = frame.structure ?: return false + val dtsToCheck = if (incomingEncoding == getEidFromIndex(targetIndex)) { + setOf(getDtFromIndex(targetIndex)) + } else { + frame.frameInfo?.dtisPresent ?: emptySet() + } + val chainsToCheck = dtsToCheck.map { structure.decodeTargetInfo[it].protectedBy }.toSet() + return map.nextFrameWith(frame) { + if (it.isAccepted) return@nextFrameWith false + if (it.frameInfo == null) { + it.updateParse(structure, logger) + } + return@nextFrameWith chainsToCheck.any { chainIdx -> + it.partOfActiveChain(chainIdx) + } + } != null + } + + private fun Av1DDFrame.partOfActiveChain(chainIdx: Int): Boolean { + val structure = structure ?: return false + val frameInfo = frameInfo ?: return false + for (i in structure.decodeTargetInfo.indices) { + if (structure.decodeTargetInfo[i].protectedBy != chainIdx) { + continue + } + if (frameInfo.dti[i] == DTI.NOT_PRESENT || frameInfo.dti[i] == DTI.DISCARDABLE) { + return false + } + } + return true + } + + /** + * Calculate the projected sequence number gap between two frames (of the same encoding), + * allowing collapsing for unaccepted frames. + */ + private fun seqGap(frame1: Av1DDFrame, frame2: Av1DDFrame): Int { + var seqGap = RtpUtils.getSequenceNumberDelta( + frame2.earliestKnownSequenceNumber, + frame1.latestKnownSequenceNumber + ) + if (!frame1.isAccepted && !frame2.isAccepted && frame2.isImmediatelyAfter(frame1)) { + /* If neither frame is being projected, and they have consecutive + frame numbers, we don't need to leave any gap. */ + seqGap = 0 + } else { + /* If the earlier frame wasn't projected, and we haven't seen its + * final packet, we know it has to consume at least one more sequence number. */ + if (!frame1.isAccepted && !frame1.seenEndOfFrame && seqGap > 1) { + seqGap-- + } + /* Similarly, if the later frame wasn't projected and we haven't seen + * its first packet. */ + if (!frame2.isAccepted && !frame2.seenStartOfFrame && seqGap > 1) { + seqGap-- + } + /* If the frame wasn't accepted, it has to have consumed at least one sequence number, + * which we can collapse out. */ + if (!frame1.isAccepted && seqGap > 0) { + seqGap-- + } + } + return seqGap + } + + private fun frameIsNewSsrc(frame: Av1DDFrame): Boolean = lastAv1FrameProjection.av1Frame?.matchesSSRC(frame) != true + + /** + * Find the previous frame before the given one. + */ + @Synchronized + private fun prevFrame(frame: Av1DDFrame) = av1FrameMaps[frame.ssrc]?.prevFrame(frame) + + /** + * Find the next frame after the given one. + */ + @Synchronized + private fun nextFrame(frame: Av1DDFrame) = av1FrameMaps[frame.ssrc]?.nextFrame(frame) + + /** + * Find the previous accepted frame before the given one. + */ + private fun findPrevAcceptedFrame(frame: Av1DDFrame) = av1FrameMaps[frame.ssrc]?.findPrevAcceptedFrame(frame) + + /** + * Find the next accepted frame after the given one. + */ + private fun findNextAcceptedFrame(frame: Av1DDFrame) = av1FrameMaps[frame.ssrc]?.findNextAcceptedFrame(frame) + + /** + * Create a projection for this frame. + */ + private fun createProjection( + frame: Av1DDFrame, + initialPacket: Av1DDPacket, + mark: Boolean, + isResumption: Boolean, + isReset: Boolean, + newDt: Int?, + receivedTime: Instant? + ): Av1DDFrameProjection { + if (frameIsNewSsrc(frame)) { + return createEncodingSwitchProjection(frame, initialPacket, mark, newDt, receivedTime) + } else if (isResumption) { + return createResumptionProjection(frame, initialPacket, mark, newDt, receivedTime) + } else if (isReset) { + return createResetProjection(frame, initialPacket, mark, newDt, receivedTime) + } + + return createInEncodingProjection(frame, initialPacket, mark, newDt, receivedTime) + } + + /** + * Create an projection for the first frame after an encoding switch. + */ + private fun createEncodingSwitchProjection( + frame: Av1DDFrame, + initialPacket: Av1DDPacket, + mark: Boolean, + newDt: Int?, + receivedTime: Instant? + ): Av1DDFrameProjection { + // We can only switch on packets that carry a scalability structure, which is the first packet of a keyframe + assert(frame.isKeyframe) + assert(initialPacket.isStartOfFrame) + + var projectedSeqGap = 1 + + if (lastAv1FrameProjection.av1Frame?.seenEndOfFrame == false) { + /* Leave a gap to signal to the decoder that the previously routed + frame was incomplete. */ + projectedSeqGap++ + + /* Make sure subsequent packets of the previous projection won't + overlap the new one. (This means the gap, above, will never be + filled in.) + */ + lastAv1FrameProjection.close() + } + + val projectedSeq = + RtpUtils.applySequenceNumberDelta(lastAv1FrameProjection.latestProjectedSeqNum, projectedSeqGap) + + // this is a simulcast switch. The typical incremental value = + // 90kHz / 30 = 90,000Hz / 30 = 3000 per frame or per 33ms + val tsDelta = lastAv1FrameProjection.created?.let { created -> + receivedTime?.let { + 3000 * Duration.between(created, receivedTime).dividedBy(33).seconds.coerceAtLeast(1L) + } + } ?: 3000 + val projectedTs = RtpUtils.applyTimestampDelta(lastAv1FrameProjection.timestamp, tsDelta) + + val frameNumber: Int + val templateIdDelta: Int + if (lastAv1FrameProjection.av1Frame != null) { + frameNumber = RtpUtils.applySequenceNumberDelta( + lastAv1FrameProjection.frameNumber, + 1 + ) + val nextTemplateId = lastAv1FrameProjection.getNextTemplateId() + templateIdDelta = if (nextTemplateId != null) { + val structure = frame.structure + check(structure != null) + getTemplateIdDelta(nextTemplateId, structure.templateIdOffset) + } else { + 0 + } + } else { + frameNumber = frame.frameNumber + templateIdDelta = 0 + } + + return Av1DDFrameProjection( + diagnosticContext = diagnosticContext, + av1Frame = frame, + ssrc = lastAv1FrameProjection.ssrc, + timestamp = projectedTs, + sequenceNumberDelta = RtpUtils.getSequenceNumberDelta(projectedSeq, initialPacket.sequenceNumber), + frameNumber = frameNumber, + templateIdDelta = templateIdDelta, + dti = newDt?.let { frame.structure?.getDtBitmaskForDt(it) }, + mark = mark, + created = receivedTime + ) + } + + /** + * Create a projection for the first frame after a resumption, i.e. when a source is turned back on. + */ + private fun createResumptionProjection( + frame: Av1DDFrame, + initialPacket: Av1DDPacket, + mark: Boolean, + newDt: Int?, + receivedTime: Instant? + ): Av1DDFrameProjection { + lastFrameNumberIndexResumption = frame.index + + /* These must be non-null because we don't execute this function unless + frameIsNewSsrc has returned false. + */ + val lastFrame = prevFrame(frame)!! + val lastProjectedFrame = lastAv1FrameProjection.av1Frame!! + + /* Project timestamps linearly. */ + val tsDelta = RtpUtils.getTimestampDiff( + lastAv1FrameProjection.timestamp, + lastProjectedFrame.timestamp + ) + val projectedTs = RtpUtils.applyTimestampDelta(frame.timestamp, tsDelta) + + /* Increment frameNumber by 1 from the last projected frame. */ + val projectedFrameNumber = RtpUtils.applySequenceNumberDelta(lastAv1FrameProjection.frameNumber, 1) + + /** If this packet has a template structure, rewrite it to follow after the pre-resumption structure. + * (We could check to see if the structure is unchanged, but that's an unnecessary optimization.) + */ + val templateIdDelta = if (frame.structure != null) { + val nextTemplateId = lastAv1FrameProjection.getNextTemplateId() + + if (nextTemplateId != null) { + val structure = frame.structure + check(structure != null) + getTemplateIdDelta(nextTemplateId, structure.templateIdOffset) + } else { + 0 + } + } else { + 0 + } + + /* Increment sequence numbers based on the last projected frame, but leave a gap + * for packet reordering in case this isn't the first packet of the keyframe. + */ + val seqGap = RtpUtils.getSequenceNumberDelta(initialPacket.sequenceNumber, lastFrame.latestKnownSequenceNumber) + val newSeq = RtpUtils.applySequenceNumberDelta(lastAv1FrameProjection.latestProjectedSeqNum, seqGap) + val seqDelta = RtpUtils.getSequenceNumberDelta(newSeq, initialPacket.sequenceNumber) + + return Av1DDFrameProjection( + diagnosticContext = diagnosticContext, + av1Frame = frame, + ssrc = lastAv1FrameProjection.ssrc, + timestamp = projectedTs, + sequenceNumberDelta = seqDelta, + frameNumber = projectedFrameNumber, + templateIdDelta = templateIdDelta, + dti = newDt?.let { frame.structure?.getDtBitmaskForDt(it) }, + mark = mark, + created = receivedTime + ) + } + + /** + * Create a projection for the first frame after a frame reset, i.e. after a large gap in sequence numbers. + */ + private fun createResetProjection( + frame: Av1DDFrame, + initialPacket: Av1DDPacket, + mark: Boolean, + newDt: Int?, + receivedTime: Instant? + ): Av1DDFrameProjection { + /* This must be non-null because we don't execute this function unless + frameIsNewSsrc has returned false. + */ + val lastFrame = lastAv1FrameProjection.av1Frame!! + + /* Apply the latest projected frame's projections out, linearly. */ + val seqDelta = RtpUtils.getSequenceNumberDelta( + lastAv1FrameProjection.latestProjectedSeqNum, + lastFrame.latestKnownSequenceNumber + ) + val tsDelta = RtpUtils.getTimestampDiff( + lastAv1FrameProjection.timestamp, + lastFrame.timestamp + ) + val frameNumberDelta = RtpUtils.applySequenceNumberDelta( + lastAv1FrameProjection.frameNumber, + lastFrame.frameNumber + ) + + val projectedTs = RtpUtils.applyTimestampDelta(frame.timestamp, tsDelta) + val projectedFrameNumber = RtpUtils.applySequenceNumberDelta(frame.frameNumber, frameNumberDelta) + + /** If this packet has a template structure, rewrite it to follow after the pre-reset structure. + * (We could check to see if the structure is unchanged, but that's an unnecessary optimization.) + */ + val templateIdDelta = if (frame.structure != null) { + val nextTemplateId = lastAv1FrameProjection.getNextTemplateId() + + if (nextTemplateId != null) { + val structure = frame.structure + check(structure != null) + getTemplateIdDelta(nextTemplateId, structure.templateIdOffset) + } else { + 0 + } + } else { + 0 + } + return Av1DDFrameProjection( + diagnosticContext = diagnosticContext, + av1Frame = frame, + ssrc = lastAv1FrameProjection.ssrc, + timestamp = projectedTs, + sequenceNumberDelta = seqDelta, + frameNumber = projectedFrameNumber, + templateIdDelta = templateIdDelta, + dti = newDt?.let { frame.structure?.getDtBitmaskForDt(it) }, + mark = mark, + created = receivedTime + ) + } + + /** + * Create a frame projection for the normal case, i.e. as part of the same encoding as the + * previously-projected frame. + */ + private fun createInEncodingProjection( + frame: Av1DDFrame, + initialPacket: Av1DDPacket, + mark: Boolean, + newDt: Int?, + receivedTime: Instant? + ): Av1DDFrameProjection { + val prevFrame = findPrevAcceptedFrame(frame) + if (prevFrame != null) { + return createInEncodingProjection(frame, prevFrame, initialPacket, mark, newDt, receivedTime) + } + + /* prev frame has rolled off beginning of frame map, try next frame */ + val nextFrame = findNextAcceptedFrame(frame) + if (nextFrame != null) { + return createInEncodingProjection(frame, nextFrame, initialPacket, mark, newDt, receivedTime) + } + + /* Neither previous or next is found. Very big frame? Use previous projected. + (This must be valid because we don't execute this function unless + frameIsNewSsrc has returned false.) + */ + return createInEncodingProjection( + frame, + lastAv1FrameProjection.av1Frame!!, + initialPacket, + mark, + newDt, + receivedTime + ) + } + + /** + * Create a frame projection for the normal case, i.e. as part of the same encoding as the + * previously-projected frame, based on a specific chosen previously-projected frame. + */ + private fun createInEncodingProjection( + frame: Av1DDFrame, + refFrame: Av1DDFrame, + initialPacket: Av1DDPacket, + mark: Boolean, + newDt: Int?, + receivedTime: Instant? + ): Av1DDFrameProjection { + val tsGap = RtpUtils.getTimestampDiff(frame.timestamp, refFrame.timestamp) + val frameNumGap = RtpUtils.getSequenceNumberDelta(frame.frameNumber, refFrame.frameNumber) + var seqGap = 0 + + var f1 = refFrame + var f2: Av1DDFrame? + val refSeq: Int + if (frameNumGap > 0) { + /* refFrame is earlier than frame in decode order. */ + do { + f2 = nextFrame(f1) + checkNotNull(f2) { + "No next frame found after frame with frame number ${f1.frameNumber}, " + + "even though refFrame ${refFrame.frameNumber} is before " + + "frame ${frame.frameNumber}!" + } + seqGap += seqGap(f1, f2) + f1 = f2 + } while (f2 !== frame) + /* refFrame is a projected frame, so it has a projection. */ + refSeq = refFrame.projection!!.latestProjectedSeqNum + } else { + /* refFrame is later than frame in decode order. */ + do { + f2 = prevFrame(f1) + checkNotNull(f2) { + "No previous frame found before frame with frame number ${f1.frameNumber}, " + + "even though refFrame ${refFrame.frameNumber} is after " + + "frame ${frame.frameNumber}!" + } + seqGap += -seqGap(f2, f1) + f1 = f2 + } while (f2 !== frame) + refSeq = refFrame.projection!!.earliestProjectedSeqNum + } + + val projectedSeq = RtpUtils.applySequenceNumberDelta(refSeq, seqGap) + val projectedTs = RtpUtils.applyTimestampDelta(refFrame.projection!!.timestamp, tsGap) + val projectedFrameNumber = RtpUtils.applySequenceNumberDelta(refFrame.projection!!.frameNumber, frameNumGap) + + return Av1DDFrameProjection( + diagnosticContext = diagnosticContext, + av1Frame = frame, + ssrc = lastAv1FrameProjection.ssrc, + timestamp = projectedTs, + sequenceNumberDelta = RtpUtils.getSequenceNumberDelta(projectedSeq, initialPacket.sequenceNumber), + frameNumber = projectedFrameNumber, + templateIdDelta = lastAv1FrameProjection.templateIdDelta, + dti = newDt?.let { frame.structure?.getDtBitmaskForDt(it) }, + mark = mark, + created = receivedTime + ) + } + + override fun needsKeyframe(): Boolean { + if (av1QualityFilter.needsKeyframe) { + return true + } + + return lastAv1FrameProjection.av1Frame == null + } + + override fun rewriteRtp(packetInfo: PacketInfo) { + if (packetInfo.packet !is Av1DDPacket) { + logger.info("Got a non-AV1 DD packet.") + throw RewriteException("Non-AV1 DD packet in AV1 DD source projection") + } + val av1Packet = packetInfo.packetAs() + + val av1Frame = lookupAv1Frame(av1Packet) + ?: throw RewriteException("Frame not in tracker (aged off?)") + + val av1Projection = av1Frame.projection + ?: throw RewriteException("Frame does not have projection?") + /* Shouldn't happen for an accepted packet whose frame is still known? */ + + logger.trace { "Rewriting packet with structure ${System.identityHashCode(av1Packet.descriptor?.structure)}" } + av1Projection.rewriteRtp(av1Packet) + } + + override fun rewriteRtcp(rtcpSrPacket: RtcpSrPacket): Boolean { + val lastAv1FrameProjectionCopy: Av1DDFrameProjection = lastAv1FrameProjection + if (rtcpSrPacket.senderSsrc != lastAv1FrameProjectionCopy.av1Frame?.ssrc) { + return false + } + + rtcpSrPacket.senderSsrc = lastAv1FrameProjectionCopy.ssrc + + val srcTs = rtcpSrPacket.senderInfo.rtpTimestamp + val delta = RtpUtils.getTimestampDiff( + lastAv1FrameProjectionCopy.timestamp, + lastAv1FrameProjectionCopy.av1Frame.timestamp + ) + + val dstTs = RtpUtils.applyTimestampDelta(srcTs, delta) + + if (srcTs != dstTs) { + rtcpSrPacket.senderInfo.rtpTimestamp = dstTs + } + + return true + } + + override fun getRtpState() = RtpState( + lastAv1FrameProjection.ssrc, + lastAv1FrameProjection.latestProjectedSeqNum, + lastAv1FrameProjection.timestamp + ) + + override fun getDebugState(): JSONObject { + val debugState = JSONObject() + debugState["class"] = Av1DDAdaptiveSourceProjectionContext::class.java.simpleName + + val mapSizes = JSONArray() + for ((key, value) in av1FrameMaps.entries) { + val sizeInfo = JSONObject() + sizeInfo["ssrc"] = key + sizeInfo["size"] = value.size() + mapSizes.add(sizeInfo) + } + debugState["av1FrameMaps"] = mapSizes + debugState["av1QualityFilter"] = av1QualityFilter.debugState + + return debugState + } + + companion object { + /** + * The time series logger for this class. + */ + private val timeSeriesLogger = + TimeSeriesLogger.getTimeSeriesLogger(Av1DDAdaptiveSourceProjectionContext::class.java) + } +} diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDFrame.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDFrame.kt new file mode 100644 index 0000000000..fcf82cf027 --- /dev/null +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDFrame.kt @@ -0,0 +1,309 @@ +/* + * Copyright @ 2019 8x8, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.videobridge.cc.av1 + +import org.jitsi.nlj.rtp.codec.av1.Av1DDPacket +import org.jitsi.rtp.rtp.RtpPacket +import org.jitsi.rtp.rtp.header_extensions.Av1DependencyDescriptorReader +import org.jitsi.rtp.rtp.header_extensions.Av1DependencyException +import org.jitsi.rtp.rtp.header_extensions.Av1TemplateDependencyStructure +import org.jitsi.rtp.rtp.header_extensions.FrameInfo +import org.jitsi.rtp.util.RtpUtils.Companion.applySequenceNumberDelta +import org.jitsi.rtp.util.isNewerThan +import org.jitsi.rtp.util.isOlderThan +import org.jitsi.utils.logging2.Logger + +class Av1DDFrame internal constructor( + /** + * The RTP SSRC of the incoming frame that this instance refers to + * (RFC3550). + */ + val ssrc: Long, + + /** + * The RTP timestamp of the incoming frame that this instance refers to + * (RFC3550). + */ + val timestamp: Long, + + /** + * The earliest RTP sequence number seen of the incoming frame that this instance + * refers to (RFC3550). + */ + earliestKnownSequenceNumber: Int, + + /** + * The latest RTP sequence number seen of the incoming frame that this instance + * refers to (RFC3550). + */ + latestKnownSequenceNumber: Int, + + /** + * A boolean that indicates whether we've seen the first packet of the frame. + * If so, its sequence is earliestKnownSequenceNumber. + */ + seenStartOfFrame: Boolean, + + /** + * A boolean that indicates whether we've seen the last packet of the frame. + * If so, its sequence is latestKnownSequenceNumber. + */ + seenEndOfFrame: Boolean, + + /** + * A boolean that indicates whether we've seen a packet with the marker bit set. + */ + seenMarker: Boolean, + + /** + * AV1 FrameInfo for the frame + */ + frameInfo: FrameInfo?, + + /** + * The AV1 DD Frame Number of this frame. + */ + val frameNumber: Int, + + /** + * The FrameID index (FrameID plus cycles) of this frame. + */ + val index: Int, + + /** + * The template ID of this frame + */ + val templateId: Int, + + /** + * The AV1 Template Dependency Structure in effect for this frame, if known + */ + structure: Av1TemplateDependencyStructure?, + + /** + * A new activeDecodeTargets specified for this frame, if any. + * TODO: is this always specified in all packets of the frame? + */ + val activeDecodeTargets: Int?, + + /** + * A boolean that indicates whether the incoming AV1 frame that this + * instance refers to is a keyframe. + */ + var isKeyframe: Boolean, + + /** + * The raw dependency descriptor included in the packet. Stored if it could not be parsed initially. + */ + val rawDependencyDescriptor: RtpPacket.HeaderExtension? +) { + /** + * AV1 FrameInfo for the frame + */ + var frameInfo = frameInfo + private set + + /** + * The AV1 Template Dependency Structure in effect for this frame, if known + */ + var structure = structure + private set + + /** + * The earliest RTP sequence number seen of the incoming frame that this instance + * refers to (RFC3550). + */ + var earliestKnownSequenceNumber = earliestKnownSequenceNumber + private set + + /** + * The latest RTP sequence number seen of the incoming frame that this instance + * refers to (RFC3550). + */ + var latestKnownSequenceNumber: Int = latestKnownSequenceNumber + private set + + /** + * A boolean that indicates whether we've seen the first packet of the frame. + * If so, its sequence is earliestKnownSequenceNumber. + */ + var seenStartOfFrame: Boolean = seenStartOfFrame + private set + + /** + * A boolean that indicates whether we've seen the last packet of the frame. + * If so, its sequence is latestKnownSequenceNumber. + */ + var seenEndOfFrame: Boolean = seenEndOfFrame + private set + + /** + * A boolean that indicates whether we've seen a packet with the marker bit set. + */ + var seenMarker: Boolean = seenMarker + private set + + /** + * A record of how this frame was projected, or null if not. + */ + var projection: Av1DDFrameProjection? = null + + /** + * A boolean that records whether this frame was accepted, i.e. should be forwarded to the receiver + * given the decoding target currently being forwarded. + */ + var isAccepted = false + + // Validate that the index matches the pictureId + init { + assert((index and 0xffff) == frameNumber) + } + + constructor(packet: Av1DDPacket, index: Int) : this( + ssrc = packet.ssrc, + timestamp = packet.timestamp, + earliestKnownSequenceNumber = packet.sequenceNumber, + latestKnownSequenceNumber = packet.sequenceNumber, + seenStartOfFrame = packet.isStartOfFrame, + seenEndOfFrame = packet.isEndOfFrame, + seenMarker = packet.isMarked, + frameInfo = packet.frameInfo, + frameNumber = packet.statelessDescriptor.frameNumber, + index = index, + templateId = packet.statelessDescriptor.frameDependencyTemplateId, + structure = packet.descriptor?.structure, + activeDecodeTargets = packet.activeDecodeTargets, + isKeyframe = packet.isKeyframe, + rawDependencyDescriptor = if (packet.frameInfo == null) { + packet.getHeaderExtension(packet.av1DDHeaderExtensionId)?.clone() + } else { + null + } + ) + + /** + * Remember another packet of this frame. + * Note: this assumes every packet is received only once, i.e. a filter + * like [org.jitsi.nlj.transform.node.incoming.PaddingTermination] is in use. + * @param packet The packet to remember. This should be a packet which + * has tested true with [matchesFrame]. + */ + fun addPacket(packet: Av1DDPacket) { + require(matchesFrame(packet)) { "Non-matching packet added to frame" } + val seq = packet.sequenceNumber + if (seq isOlderThan earliestKnownSequenceNumber) { + earliestKnownSequenceNumber = seq + } + if (seq isNewerThan latestKnownSequenceNumber) { + latestKnownSequenceNumber = seq + } + if (packet.isStartOfFrame) { + seenStartOfFrame = true + } + if (packet.isEndOfFrame) { + seenEndOfFrame = true + } + if (packet.isMarked) { + seenMarker = true + } + + if (structure == null && packet.descriptor?.structure != null) { + structure = packet.descriptor?.structure + } + + if (frameInfo == null && packet.frameInfo != null) { + frameInfo = packet.frameInfo + } + } + + fun updateParse(templateDependencyStructure: Av1TemplateDependencyStructure, logger: Logger) { + if (rawDependencyDescriptor == null) { + return + } + val parser = Av1DependencyDescriptorReader(rawDependencyDescriptor) + val descriptor = try { + parser.parse(templateDependencyStructure) + } catch (e: Av1DependencyException) { + logger.warn("Could not parse updated AV1 Dependency Descriptor: ${e.message}") + return + } + structure = descriptor.structure + frameInfo = try { + descriptor.frameInfo + } catch (e: Av1DependencyException) { + logger.warn("Could not extract frame info from updated AV1 Dependency Descriptor: ${e.message}") + null + } + } + + /** + * Small utility method that checks whether the [Av1DDFrame] that is + * specified as a parameter belongs to the same RTP stream as the frame that + * this instance refers to. + * + * @param av1Frame the [Av1DDFrame] to check whether it belongs to the + * same RTP stream as the frame that this instance refers to. + * @return true if the [Av1DDFrame] that is specified as a parameter + * belongs to the same RTP stream as the frame that this instance refers to, + * false otherwise. + */ + fun matchesSSRC(av1Frame: Av1DDFrame): Boolean { + return ssrc == av1Frame.ssrc + } + + /** + * Checks whether the specified RTP packet is part of this frame. + * + * @param pkt the RTP packet to check whether it's part of this frame. + * @return true if the specified RTP packet is part of this frame, false + * otherwise. + */ + fun matchesFrame(pkt: Av1DDPacket): Boolean { + return ssrc == pkt.ssrc && timestamp == pkt.timestamp && + frameNumber == pkt.frameNumber + } + + fun validateConsistency(pkt: Av1DDPacket) { + if (frameInfo == null) { + return + } + if (frameInfo == pkt.frameInfo) { + return + } + + throw RuntimeException( + buildString { + with(pkt) { + append("Packet ssrc $ssrc, seq $sequenceNumber, frame number $frameNumber, timestamp $timestamp ") + append("packet template ${statelessDescriptor.frameDependencyTemplateId} ") + append("frame info $frameInfo ") + } + append("is not consistent with frame ${this@Av1DDFrame}") + } + ) + } + fun isImmediatelyAfter(otherFrame: Av1DDFrame): Boolean { + return frameNumber == + applySequenceNumberDelta(otherFrame.frameNumber, 1) + } + + override fun toString() = buildString { + append("$ssrc, ") + append("seq $earliestKnownSequenceNumber-$latestKnownSequenceNumber ") + append("frame number $frameNumber, timestamp $timestamp: ") + append("frame template $templateId info $frameInfo") + } +} diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDFrameMap.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDFrameMap.kt new file mode 100644 index 0000000000..e3722e1e2f --- /dev/null +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDFrameMap.kt @@ -0,0 +1,254 @@ +package org.jitsi.videobridge.cc.av1 + +import org.jitsi.nlj.rtp.codec.av1.Av1DDPacket +import org.jitsi.nlj.util.ArrayCache +import org.jitsi.nlj.util.Rfc3711IndexTracker +import org.jitsi.rtp.util.RtpUtils +import org.jitsi.rtp.util.isNewerThan +import org.jitsi.utils.logging2.Logger +import org.jitsi.utils.logging2.createChildLogger + +/** + * A history of recent frames on a Av1 stream. + */ +class Av1DDFrameMap( + parentLogger: Logger +) { + /** Cache mapping frame IDs to frames. */ + private val frameHistory = FrameHistory(FRAME_MAP_SIZE) + private val logger: Logger = createChildLogger(parentLogger) + + /** Find a frame in the frame map, based on a packet. */ + @Synchronized + fun findFrame(packet: Av1DDPacket): Av1DDFrame? { + return frameHistory[packet.frameNumber] + } + + /** Get the current size of the map. */ + fun size(): Int { + return frameHistory.numCached + } + + /** Check whether this is a large jump from previous state, so the map should be reset. */ + private fun isLargeJump(packet: Av1DDPacket): Boolean { + val latestFrame: Av1DDFrame = frameHistory.latestFrame ?: return false + val picDelta = RtpUtils.getSequenceNumberDelta(packet.frameNumber, latestFrame.frameNumber) + if (picDelta > FRAME_MAP_SIZE) { + return true + } + val tsDelta: Long = RtpUtils.getTimestampDiff(packet.timestamp, latestFrame.timestamp) + if (picDelta < 0) { + /* if picDelta is negative but timestamp or sequence delta is positive, we've cycled. */ + if (tsDelta > 0) { + return true + } + if (packet.sequenceNumber isNewerThan latestFrame.latestKnownSequenceNumber) { + return true + } + } + + /* If tsDelta is more than twice the frame map size at 1 fps, we've cycled. */ + return tsDelta > FRAME_MAP_SIZE * 90000 * 2 + } + + /** Insert a packet into the frame map. Return a frameInsertionResult + * describing what happened. + * @param packet The packet to insert. + * @return What happened. null if insertion failed. + */ + @Synchronized + fun insertPacket(packet: Av1DDPacket): PacketInsertionResult? { + val frameNumber = packet.frameNumber + + if (isLargeJump(packet)) { + frameHistory.indexTracker.resetAt(frameNumber) + val frame = frameHistory.insert(packet) ?: return null + + return PacketInsertionResult(frame, true, isReset = true) + } + val frame = frameHistory[frameNumber] + if (frame != null) { + if (!frame.matchesFrame(packet)) { + check(frame.frameNumber == frameNumber) { + "frame map returned frame with frame number ${frame.frameNumber} " + "when asked for frame with frame ID $frameNumber" + } + logger.warn( + "Cannot insert packet in frame map: " + + with(frame) { + "frame with ssrc $ssrc, timestamp $timestamp, " + + "and sequence number range $earliestKnownSequenceNumber-$latestKnownSequenceNumber, " + } + + with(packet) { + "and packet $sequenceNumber with ssrc $ssrc, timestamp $timestamp, " + + "and sequence number $sequenceNumber" + } + + " both have frame ID $frameNumber" + ) + return null + } + try { + frame.validateConsistency(packet) + } catch (e: Exception) { + logger.warn(e) + } + + frame.addPacket(packet) + return PacketInsertionResult(frame, isNewFrame = false) + } + + val newframe = frameHistory.insert(packet) ?: return null + + return PacketInsertionResult(newframe, true) + } + + /** Insert a frame. Only used for unit testing. */ + @Synchronized + internal fun insertFrame(frame: Av1DDFrame) { + frameHistory.insert(frame) + } + + @Synchronized + fun getIndex(frameIndex: Int) = frameHistory.getIndex(frameIndex) + + @Synchronized + fun nextFrame(frame: Av1DDFrame): Av1DDFrame? { + return frameHistory.findAfter(frame) { true } + } + + @Synchronized + fun nextFrameWith(frame: Av1DDFrame, pred: (Av1DDFrame) -> Boolean): Av1DDFrame? { + return frameHistory.findAfter(frame, pred) + } + + @Synchronized + fun prevFrame(frame: Av1DDFrame): Av1DDFrame? { + return frameHistory.findBefore(frame) { true } + } + + @Synchronized + fun prevFrameWith(frame: Av1DDFrame, pred: (Av1DDFrame) -> Boolean): Av1DDFrame? { + return frameHistory.findBefore(frame, pred) + } + + fun findPrevAcceptedFrame(frame: Av1DDFrame): Av1DDFrame? { + return prevFrameWith(frame) { it.isAccepted } + } + + fun findNextAcceptedFrame(frame: Av1DDFrame): Av1DDFrame? { + return nextFrameWith(frame) { it.isAccepted } + } + + companion object { + const val FRAME_MAP_SIZE = 500 // Matches PacketCache default size. + } +} + +internal class FrameHistory(size: Int) : + ArrayCache( + size, + cloneItem = { k -> k }, + synchronize = false + ) { + var numCached = 0 + var firstIndex = -1 + var indexTracker = Rfc3711IndexTracker() + + /** + * Gets a frame with a given frame number from the cache. + */ + operator fun get(frameNumber: Int): Av1DDFrame? { + val index = indexTracker.interpret(frameNumber) + return getIndex(index) + } + + /** + * Gets a frame with a given frame number index from the cache. + */ + fun getIndex(index: Int): Av1DDFrame? { + if (index <= lastIndex - size) { + /* We don't want to remember old frames even if they're still + tracked; their neighboring frames may have been evicted, + so findBefore / findAfter will return bogus data. */ + return null + } + val c = getContainer(index) ?: return null + return c.item + } + + /** Get the latest frame in the tracker. */ + val latestFrame: Av1DDFrame? + get() = getIndex(lastIndex) + + fun insert(frame: Av1DDFrame): Boolean { + val ret = super.insertItem(frame, frame.index) + if (ret) { + numCached++ + if (firstIndex == -1 || frame.index < firstIndex) { + firstIndex = frame.index + } + } + return ret + } + + fun insert(packet: Av1DDPacket): Av1DDFrame? { + val index = indexTracker.update(packet.frameNumber) + val frame = Av1DDFrame(packet, index) + return if (insert(frame)) frame else null + } + + /** + * Called when an item in the cache is replaced/discarded. + */ + override fun discardItem(item: Av1DDFrame) { + numCached-- + } + + fun findBefore(frame: Av1DDFrame, pred: (Av1DDFrame) -> Boolean): Av1DDFrame? { + val lastIndex = lastIndex + if (lastIndex == -1) { + return null + } + val index = frame.index + val searchStartIndex = Integer.min(index - 1, lastIndex) + val searchEndIndex = Integer.max(lastIndex - size, firstIndex - 1) + return doFind(pred, searchStartIndex, searchEndIndex, -1) + } + + fun findAfter(frame: Av1DDFrame, pred: (Av1DDFrame) -> Boolean): Av1DDFrame? { + val lastIndex = lastIndex + if (lastIndex == -1) { + return null + } + val index = frame.index + if (index >= lastIndex) { + return null + } + val searchStartIndex = Integer.max(index + 1, Integer.max(lastIndex - size + 1, firstIndex)) + return doFind(pred, searchStartIndex, lastIndex + 1, 1) + } + + private fun doFind(pred: (Av1DDFrame) -> Boolean, startIndex: Int, endIndex: Int, increment: Int): Av1DDFrame? { + var index = startIndex + while (index != endIndex) { + val frame = getIndex(index) + if (frame != null && pred(frame)) { + return frame + } + index += increment + } + return null + } +} + +/** + * The result of calling [insertPacket] + */ +class PacketInsertionResult( + /** The frame corresponding to the packet that was inserted. */ + val frame: Av1DDFrame, + /** Whether inserting the packet created a new frame. */ + val isNewFrame: Boolean, + /** Whether inserting the packet caused a reset */ + val isReset: Boolean = false +) diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDFrameProjection.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDFrameProjection.kt new file mode 100644 index 0000000000..88e34cb2b7 --- /dev/null +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDFrameProjection.kt @@ -0,0 +1,240 @@ +/* + * Copyright @ 2019 8x8, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.videobridge.cc.av1 + +import org.jitsi.nlj.rtp.codec.av1.Av1DDPacket +import org.jitsi.nlj.rtp.codec.av1.applyTemplateIdDelta +import org.jitsi.rtp.rtp.header_extensions.toShortString +import org.jitsi.rtp.util.RtpUtils.Companion.applySequenceNumberDelta +import org.jitsi.rtp.util.isOlderThan +import org.jitsi.utils.logging.DiagnosticContext +import org.jitsi.utils.logging.TimeSeriesLogger +import java.time.Instant + +/** + * Represents an AV1 DD frame projection. It puts together all the necessary bits + * and pieces that are useful when projecting an accepted AV1 DD frame. A + * projection is responsible for rewriting a AV1 DD packet. Instances of this class + * are thread-safe. + */ +class Av1DDFrameProjection internal constructor( + /** + * The diagnostic context for this instance. + */ + private val diagnosticContext: DiagnosticContext, + /** + * The projected [Av1DDFrame]. + */ + val av1Frame: Av1DDFrame?, + /** + * The RTP SSRC of the projection (RFC7667, RFC3550). + */ + val ssrc: Long, + /** + * The RTP timestamp of the projection (RFC7667, RFC3550). + */ + val timestamp: Long, + /** + * The sequence number delta for packets of this frame. + */ + private val sequenceNumberDelta: Int, + /** + * The AV1 frame number of the projection. + */ + val frameNumber: Int, + /** + * The template ID delta for this frame. This applies both to the frame's own template ID, and to + * the template ID offset in the dependency structure, if present. + */ + val templateIdDelta: Int, + /** + * The decode target indication to set on this frame, if any. + */ + val dti: Int?, + /** + * Whether to add a marker bit to the last packet of this frame. + * (Note this will not clear already-existing marker bits.) + */ + val mark: Boolean, + /** + * A timestamp of when this instance was created. It's used to calculate + * RTP timestamps when we switch encodings. + */ + val created: Instant? +) { + + /** + * -1 if this projection is still "open" for new, later packets. + * Projections can be closed when we switch away from their encodings. + */ + private var closedSeq = -1 + + /** + * Ctor. + * + * @param ssrc the SSRC of the destination AV1 frame. + * @param timestamp The RTP timestamp of the projected frame that this + * instance refers to (RFC3550). + * @param sequenceNumberDelta The starting RTP sequence number of the + * projected frame that this instance refers to (RFC3550). + */ + internal constructor( + diagnosticContext: DiagnosticContext, + ssrc: Long, + sequenceNumberDelta: Int, + timestamp: Long + ) : this( + diagnosticContext = diagnosticContext, + av1Frame = null, + ssrc = ssrc, + timestamp = timestamp, + sequenceNumberDelta = sequenceNumberDelta, + frameNumber = 0, + templateIdDelta = 0, + dti = null, + mark = false, + created = null + ) + + fun rewriteSeqNo(seq: Int): Int = applySequenceNumberDelta(seq, sequenceNumberDelta) + + fun rewriteTemplateId(id: Int): Int = applyTemplateIdDelta(id, templateIdDelta) + + /** + * Rewrites an RTP packet. + * + * @param pkt the RTP packet to rewrite. + */ + fun rewriteRtp(pkt: Av1DDPacket) { + val sequenceNumber = rewriteSeqNo(pkt.sequenceNumber) + val templateId = rewriteTemplateId(pkt.statelessDescriptor.frameDependencyTemplateId) + if (timeSeriesLogger.isTraceEnabled) { + timeSeriesLogger.trace( + diagnosticContext + .makeTimeSeriesPoint("rtp_av1_rewrite") + .addField("orig.rtp.ssrc", pkt.ssrc) + .addField("orig.rtp.timestamp", pkt.timestamp) + .addField("orig.rtp.seq", pkt.sequenceNumber) + .addField("orig.av1.framenum", pkt.frameNumber) + .addField("orig.av1.dti", pkt.frameInfo?.dti?.toShortString() ?: "-") + .addField("orig.av1.templateid", pkt.statelessDescriptor.frameDependencyTemplateId) + .addField("proj.rtp.ssrc", ssrc) + .addField("proj.rtp.timestamp", timestamp) + .addField("proj.rtp.seq", sequenceNumber) + .addField("proj.av1.framenum", frameNumber) + .addField("proj.av1.dti", dti ?: -1) + .addField("proj.av1.templateid", templateId) + .addField("proj.rtp.mark", mark) + ) + } + + // update ssrc, sequence number, timestamp, frameNumber, and templateID + pkt.ssrc = ssrc + pkt.timestamp = timestamp + pkt.sequenceNumber = sequenceNumber + if (mark && pkt.isEndOfFrame) pkt.isMarked = true + + val descriptor = pkt.descriptor + if (descriptor != null && ( + frameNumber != pkt.frameNumber || templateId != descriptor.frameDependencyTemplateId || + dti != null + ) + ) { + descriptor.frameNumber = frameNumber + descriptor.frameDependencyTemplateId = templateId + val structure = descriptor.structure + check( + descriptor.newTemplateDependencyStructure == null || + descriptor.newTemplateDependencyStructure === descriptor.structure + ) + + structure.templateIdOffset = rewriteTemplateId(structure.templateIdOffset) + if (dti != null && ( + descriptor.newTemplateDependencyStructure == null || + dti != (1 shl structure.decodeTargetCount) - 1 + ) + ) { + descriptor.activeDecodeTargetsBitmask = dti + } + + pkt.descriptor = descriptor + pkt.reencodeDdExt() + } + } + + /** + * Determines whether a packet can be forwarded as part of this + * [Av1DDFrameProjection] instance. The check is based on the sequence + * of the incoming packet and whether or not the [Av1DDFrameProjection] + * has been "closed" or not. + * + * @param rtpPacket the [Av1DDPacket] that will be examined. + * @return true if the packet can be forwarded as part of this + * [Av1DDFrameProjection], false otherwise. + */ + fun accept(rtpPacket: Av1DDPacket): Boolean { + if (av1Frame?.matchesFrame(rtpPacket) != true) { + // The packet does not belong to this AV1 picture. + return false + } + synchronized(av1Frame) { + return if (closedSeq < 0) { + true + } else { + rtpPacket.sequenceNumber isOlderThan closedSeq + } + } + } + + val earliestProjectedSeqNum: Int + get() { + if (av1Frame == null) { + return sequenceNumberDelta + } + synchronized(av1Frame) { return rewriteSeqNo(av1Frame.earliestKnownSequenceNumber) } + } + + val latestProjectedSeqNum: Int + get() { + if (av1Frame == null) { + return sequenceNumberDelta + } + synchronized(av1Frame) { return rewriteSeqNo(av1Frame.latestKnownSequenceNumber) } + } + + /** + * Prevents the max sequence number of this frame to grow any further. + */ + fun close() { + if (av1Frame != null) { + synchronized(av1Frame) { closedSeq = av1Frame.latestKnownSequenceNumber } + } + } + + /** + * Get the next template ID that would come after the template IDs in this projection's structure + */ + fun getNextTemplateId(): Int? { + return av1Frame?.structure?.let { rewriteTemplateId(it.templateIdOffset + it.templateCount) } + } + + companion object { + /** + * The time series logger for this class. + */ + private val timeSeriesLogger = TimeSeriesLogger.getTimeSeriesLogger(Av1DDFrameProjection::class.java) + } +} diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDQualityFilter.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDQualityFilter.kt new file mode 100644 index 0000000000..331a80faa9 --- /dev/null +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDQualityFilter.kt @@ -0,0 +1,460 @@ +/* + * Copyright @ 2019 - present 8x8, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.videobridge.cc.av1 + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings +import org.jitsi.nlj.RtpLayerDesc.Companion.SUSPENDED_ENCODING_ID +import org.jitsi.nlj.RtpLayerDesc.Companion.SUSPENDED_INDEX +import org.jitsi.nlj.RtpLayerDesc.Companion.getEidFromIndex +import org.jitsi.nlj.RtpLayerDesc.Companion.indexString +import org.jitsi.nlj.rtp.codec.av1.Av1DDRtpLayerDesc +import org.jitsi.nlj.rtp.codec.av1.Av1DDRtpLayerDesc.Companion.SUSPENDED_DT +import org.jitsi.nlj.rtp.codec.av1.Av1DDRtpLayerDesc.Companion.getDtFromIndex +import org.jitsi.nlj.rtp.codec.av1.Av1DDRtpLayerDesc.Companion.getIndex +import org.jitsi.nlj.rtp.codec.av1.containsDecodeTarget +import org.jitsi.rtp.rtp.header_extensions.DTI +import org.jitsi.utils.logging.DiagnosticContext +import org.jitsi.utils.logging2.Logger +import org.jitsi.utils.logging2.createChildLogger +import org.json.simple.JSONObject +import java.time.Duration +import java.time.Instant + +/** + * This class is responsible for dropping AV1 simulcast/svc packets based on + * their quality, i.e. packets that correspond to qualities that are above a + * given quality target. Instances of this class are thread-safe. + */ +internal class Av1DDQualityFilter( + val av1FrameMap: Map, + parentLogger: Logger +) { + /** + * The [Logger] to be used by this instance to print debug + * information. + */ + private val logger: Logger = createChildLogger(parentLogger) + + /** + * Holds the arrival time of the most recent keyframe group. + * Reading/writing of this field is synchronized on this instance. + */ + private var mostRecentKeyframeGroupArrivalTime: Instant? = null + + /** + * A boolean flag that indicates whether a keyframe is needed, due to an + * encoding or (in some cases) a decode target switch. + */ + var needsKeyframe = false + private set + + /** + * The encoding ID that this instance tries to achieve. Upon + * receipt of a packet, we check whether encoding in the externalTargetIndex + * (that's specified as an argument to the + * [acceptFrame] method) is set to something different, + * in which case we set [needsKeyframe] equal to true and + * update. + */ + private var internalTargetEncoding = SUSPENDED_ENCODING_ID + + /** + * The decode target that this instance tries to achieve. + */ + private var internalTargetDt = SUSPENDED_DT + + /** + * The layer index that we're currently forwarding. [SUSPENDED_INDEX] + * indicates that we're not forwarding anything. Reading/writing of this + * field is synchronized on this instance. + */ + private var currentIndex = SUSPENDED_INDEX + + /** + * Determines whether to accept or drop an AV1 frame. + * + * Note that, at the time of this writing, there's no practical need for a + * synchronized keyword because there's only one thread accessing this + * method at a time. + * + * @param frame the AV1 frame. + * @param incomingEncoding The encoding ID of the incoming packet + * @param externalTargetIndex the target quality index that the user of this + * instance wants to achieve. + * @param receivedTime the current time (as an Instant) + * @return true to accept the AV1 frame, otherwise false. + */ + @Synchronized + fun acceptFrame( + frame: Av1DDFrame, + incomingEncoding: Int, + externalTargetIndex: Int, + receivedTime: Instant? + ): AcceptResult { + val prevIndex = currentIndex + val accept = doAcceptFrame(frame, incomingEncoding, externalTargetIndex, receivedTime) + val currentDt = getDtFromIndex(currentIndex) + val mark = currentDt != SUSPENDED_DT && + (frame.frameInfo?.spatialId == frame.structure?.decodeTargetInfo?.get(currentDt)?.spatialId) + val isResumption = (prevIndex == SUSPENDED_INDEX && currentIndex != SUSPENDED_INDEX) + if (isResumption) { + check(accept) { + // Every code path that can turn off SUSPENDED_INDEX also accepts + "isResumption=$isResumption but accept=$accept for frame ${frame.frameNumber}, " + + "frameInfo=${frame.frameInfo}" + } + } + val dtChanged = (prevIndex != currentIndex) + if (dtChanged && currentDt != SUSPENDED_DT) { + check(accept) { + // Every code path that changes DT also accepts + "dtChanged=$dtChanged but accept=$accept for frame ${frame.frameNumber}, frameInfo=${frame.frameInfo}" + } + } + val newDt = if (dtChanged || frame.activeDecodeTargets != null) currentDt else null + return AcceptResult(accept = accept, isResumption = isResumption, mark = mark, newDt = newDt) + } + + private fun doAcceptFrame( + frame: Av1DDFrame, + incomingEncoding: Int, + externalTargetIndex: Int, + receivedTime: Instant? + ): Boolean { + val externalTargetEncoding = getEidFromIndex(externalTargetIndex) + val currentEncoding = getEidFromIndex(currentIndex) + + if (externalTargetEncoding != internalTargetEncoding) { + // The externalEncodingIdTarget has changed since accept last + // ran; perhaps we should request a keyframe. + internalTargetEncoding = externalTargetEncoding + if (externalTargetEncoding != SUSPENDED_ENCODING_ID && + externalTargetEncoding != currentEncoding + ) { + needsKeyframe = true + } + } + if (externalTargetEncoding == SUSPENDED_ENCODING_ID) { + // We stop forwarding immediately. We will need a keyframe in order + // to resume. + currentIndex = SUSPENDED_INDEX + return false + } + return if (frame.isKeyframe) { + logger.debug { + "Quality filter got keyframe for stream ${frame.ssrc}" + } + acceptKeyframe(frame, incomingEncoding, externalTargetIndex, receivedTime) + } else if (currentEncoding != SUSPENDED_ENCODING_ID) { + if (isOutOfSwitchingPhase(receivedTime) && isPossibleToSwitch(incomingEncoding)) { + // XXX(george) i've noticed some "rogue" base layer keyframes + // that trigger this. what happens is the client sends a base + // layer key frame, the bridge switches to that layer because + // for all it knows it may be the only keyframe sent by the + // client engine. then the bridge notices that packets from the + // higher quality streams are flowing and execution ends-up + // here. it is a mystery why the engine is "leaking" base layer + // key frames + needsKeyframe = true + } + if (incomingEncoding != currentEncoding) { + // for non-keyframes, we can't route anything but the current encoding + return false + } + + /** Logic to forward a non-keyframe: + * If the frame does not have FrameInfo, reject and set needsKeyframe (we couldn't decode its templates). + * If we're trying to switch DTs in the current encoding, check the template structure to ensure that + * there is at least one template which is SWITCH for the target DT and not NOT_PRESENT for the current DT. + * If there is not, request a keyframe. + * If the current frame is SWITCH for the target DT, and we've forwarded all the frames on which (by + * its fdiffs) it depends, forward it, and change the current DT to the target DT. + * In normal circumstances (when the target index == the current index), or when we're trying to switch + * *up* encodings, forward all frames whose DT for the current DT is not NOT_PRESENT. + * If we're trying to switch *down* encodings, only forward frames which are REQUIRED or SWITCH for the + * current DT. + */ + val frameInfo = frame.frameInfo ?: run { + needsKeyframe = true + return@doAcceptFrame false + } + var currentDt = getDtFromIndex(currentIndex) + val externalTargetDt = if (currentEncoding == externalTargetEncoding) { + getDtFromIndex(externalTargetIndex) + } else { + currentDt + } + + if ( + frame.activeDecodeTargets != null && + !frame.activeDecodeTargets.containsDecodeTarget(externalTargetDt) + ) { + /* This shouldn't happen, because we should have set layeringChanged for this packet. */ + logger.warn { + "External target DT $externalTargetDt not present in current decode targets 0x" + + Integer.toHexString(frame.activeDecodeTargets) + " for frame $frame." + } + return false + } + + if (currentDt != externalTargetDt) { + val frameMap = av1FrameMap[frame.ssrc] + if (frameInfo.dti.getOrNull(externalTargetDt) == null) { + logger.warn { "Target DT $externalTargetDt not present for frame $frame [frameInfo $frameInfo]" } + } + if (frameInfo.dti[externalTargetDt] == DTI.SWITCH && + frameMap != null && + frameInfo.fdiff.all { + frameMap.getIndex(frame.index - it)?.isAccepted == true + } + ) { + logger.debug { "Switching to DT $externalTargetDt from $currentDt" } + currentDt = externalTargetDt + currentIndex = externalTargetIndex + } else { + if (frame.structure?.canSwitchWithoutKeyframe(currentDt, externalTargetDt) != true) { + logger.debug { + "Want to switch to DT $externalTargetDt from $currentDt, requesting keyframe" + } + needsKeyframe = true + } + } + } + + val currentFrameDti = frameInfo.dti[currentDt] + if (currentEncoding > externalTargetEncoding) { + (currentFrameDti == DTI.SWITCH || currentFrameDti == DTI.REQUIRED) + } else { + (currentFrameDti != DTI.NOT_PRESENT) + } + } else { + // In this branch we're not processing a keyframe and the + // currentEncoding is in suspended state, which means we need + // a keyframe to start streaming again. + + // We should have already requested a keyframe, either above or when the + // internal target encoding was first moved off SUSPENDED_ENCODING. + false + } + } + + /** + * Returns a boolean that indicates whether we are in layer switching phase + * or not. + * + * @param receivedTime the time the latest frame was received + * @return false if we're in layer switching phase, true otherwise. + */ + @Synchronized + private fun isOutOfSwitchingPhase(receivedTime: Instant?): Boolean { + if (receivedTime == null) { + return false + } + if (mostRecentKeyframeGroupArrivalTime == null) { + return true + } + val delta = Duration.between(mostRecentKeyframeGroupArrivalTime, receivedTime) + return delta > MIN_KEY_FRAME_WAIT + } + + /** + * @return true if it looks like we can re-scale (see implementation of + * method for specific details). + */ + @Synchronized + private fun isPossibleToSwitch(incomingEncoding: Int): Boolean { + val currentEncoding = getEidFromIndex(currentIndex) + + if (incomingEncoding == SUSPENDED_ENCODING_ID) { + // We failed to resolve the spatial/quality layer of the packet. + return false + } + return when { + incomingEncoding > currentEncoding && currentEncoding < internalTargetEncoding -> + // It looks like upscaling is possible + true + incomingEncoding < currentEncoding && currentEncoding > internalTargetEncoding -> + // It looks like downscaling is possible. + true + else -> + false + } + } + + /** + * Determines whether to accept or drop an AV1 keyframe. This method updates + * the encoding id. + * + * Note that, at the time of this writing, there's no practical need for a + * synchronized keyword because there's only one thread accessing this + * method at a time. + * + * @param receivedTime the time the frame was received + * @return true to accept the AV1 keyframe, otherwise false. + */ + @Synchronized + private fun acceptKeyframe( + frame: Av1DDFrame, + incomingEncoding: Int, + externalTargetIndex: Int, + receivedTime: Instant? + ): Boolean { + // This branch writes the {@link #currentSpatialLayerId} and it + // determines whether or not we should switch to another simulcast + // stream. + if (incomingEncoding < 0) { + // something went terribly wrong, normally we should be able to + // extract the layer id from a keyframe. + logger.error("unable to get layer id from keyframe") + return false + } + val frameInfo = frame.frameInfo ?: run { + // something went terribly wrong, normally we should be able to + // extract the frame info from a keyframe. + logger.error("unable to get frame info from keyframe") + return@acceptKeyframe false + } + logger.debug { + "Received a keyframe of encoding: $incomingEncoding" + } + + val currentEncoding = getEidFromIndex(currentIndex) + val externalTargetEncoding = getEidFromIndex(externalTargetIndex) + + val indexIfSwitched = when { + incomingEncoding == externalTargetEncoding -> externalTargetIndex + incomingEncoding == internalTargetEncoding && internalTargetDt != -1 -> + getIndex(currentEncoding, internalTargetDt) + else -> frameInfo.dtisPresent.maxOrNull()!! + } + val dtIfSwitched = getDtFromIndex(indexIfSwitched) + val dtiIfSwitched = frameInfo.dti[dtIfSwitched] + val acceptIfSwitched = dtiIfSwitched != DTI.NOT_PRESENT + + // The keyframe request has been fulfilled at this point, regardless of + // whether we'll be able to achieve the internalEncodingIdTarget. + needsKeyframe = false + return if (isOutOfSwitchingPhase(receivedTime)) { + // During the switching phase we always project the first + // keyframe because it may very well be the only one that we + // receive (i.e. the endpoint is sending low quality only). Then + // we try to approach the target. + mostRecentKeyframeGroupArrivalTime = receivedTime + logger.debug { + "First keyframe in this kf group " + + "currentEncodingId: $incomingEncoding. " + + "Target is $internalTargetEncoding" + } + if (incomingEncoding <= internalTargetEncoding) { + // If the target is 180p and the first keyframe of a group of + // keyframes is a 720p keyframe we don't project it. If we + // receive a 720p keyframe, we know that there MUST be a 180p + // keyframe shortly after. + if (acceptIfSwitched) { + currentIndex = indexIfSwitched + } + acceptIfSwitched + } else { + false + } + } else { + // We're within the 300ms window since the reception of the + // first key frame of a key frame group, let's check whether an + // upscale/downscale is possible. + when { + currentEncoding <= incomingEncoding && + incomingEncoding <= internalTargetEncoding -> { + // upscale or current quality case + if (acceptIfSwitched) { + currentIndex = indexIfSwitched + logger.debug { + "Upscaling to encoding $incomingEncoding. " + + "The target is $internalTargetEncoding" + } + } + acceptIfSwitched + } + incomingEncoding <= internalTargetEncoding && + internalTargetEncoding < currentEncoding -> { + // downscale case + if (acceptIfSwitched) { + currentIndex = indexIfSwitched + logger.debug { + "Downscaling to encoding $incomingEncoding. " + + "The target is $internalTargetEncoding" + } + } + acceptIfSwitched + } + else -> { + false + } + } + } + } + + /** + * Adds internal state to a diagnostic context time series point. + */ + @SuppressFBWarnings( + value = ["IS2_INCONSISTENT_SYNC"], + justification = "We intentionally avoid synchronizing while reading fields only used in debug output." + ) + internal fun addDiagnosticContext(pt: DiagnosticContext.TimeSeriesPoint) { + pt.addField("qf.currentIndex", Av1DDRtpLayerDesc.indexString(currentIndex)) + .addField("qf.internalTargetEncoding", internalTargetEncoding) + .addField("qf.needsKeyframe", needsKeyframe) + .addField( + "qf.mostRecentKeyframeGroupArrivalTimeMs", + mostRecentKeyframeGroupArrivalTime?.toEpochMilli() ?: -1 + ) + /* TODO any other fields necessary */ + } + + /** + * Gets a JSON representation of the parts of this object's state that + * are deemed useful for debugging. + */ + @get:SuppressFBWarnings( + value = ["IS2_INCONSISTENT_SYNC"], + justification = "We intentionally avoid synchronizing while reading fields only used in debug output." + ) + val debugState: JSONObject + get() { + val debugState = JSONObject() + debugState["mostRecentKeyframeGroupArrivalTimeMs"] = + mostRecentKeyframeGroupArrivalTime?.toEpochMilli() ?: -1 + debugState["needsKeyframe"] = needsKeyframe + debugState["internalTargetEncoding"] = internalTargetEncoding + debugState["currentIndex"] = Av1DDRtpLayerDesc.indexString(currentIndex) + return debugState + } + + data class AcceptResult( + val accept: Boolean, + val isResumption: Boolean, + val mark: Boolean, + val newDt: Int? + ) + + companion object { + /** + * The default maximum frequency at which the media engine + * generates key frame. + */ + private val MIN_KEY_FRAME_WAIT = Duration.ofMillis(300) + } +} diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/vp9/Vp9AdaptiveSourceProjectionContext.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/vp9/Vp9AdaptiveSourceProjectionContext.kt index 1fe7d7b1e8..1417f82c34 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/vp9/Vp9AdaptiveSourceProjectionContext.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/vp9/Vp9AdaptiveSourceProjectionContext.kt @@ -21,7 +21,6 @@ import org.jitsi.nlj.codec.vpx.VpxUtils.Companion.applyExtendedPictureIdDelta import org.jitsi.nlj.codec.vpx.VpxUtils.Companion.applyTl0PicIdxDelta import org.jitsi.nlj.codec.vpx.VpxUtils.Companion.getExtendedPictureIdDelta import org.jitsi.nlj.codec.vpx.VpxUtils.Companion.getTl0PicIdxDelta -import org.jitsi.nlj.format.PayloadType import org.jitsi.nlj.rtp.codec.vp9.Vp9Packet import org.jitsi.rtp.rtcp.RtcpSrPacket import org.jitsi.rtp.util.RtpUtils.Companion.applySequenceNumberDelta @@ -33,7 +32,6 @@ import org.jitsi.utils.logging.DiagnosticContext import org.jitsi.utils.logging.TimeSeriesLogger import org.jitsi.utils.logging2.Logger import org.jitsi.utils.logging2.createChildLogger -import org.jitsi.utils.times import org.jitsi.videobridge.cc.AdaptiveSourceProjectionContext import org.jitsi.videobridge.cc.RewriteException import org.jitsi.videobridge.cc.RtpState @@ -49,7 +47,6 @@ import java.time.Instant */ class Vp9AdaptiveSourceProjectionContext( private val diagnosticContext: DiagnosticContext, - private val payloadType: PayloadType, rtpState: RtpState, parentLogger: Logger ) : AdaptiveSourceProjectionContext { @@ -81,7 +78,7 @@ class Vp9AdaptiveSourceProjectionContext( private var lastPicIdIndexResumption = -1 @Synchronized - override fun accept(packetInfo: PacketInfo, incomingIndex: Int, targetIndex: Int): Boolean { + override fun accept(packetInfo: PacketInfo, incomingEncoding: Int, targetIndex: Int): Boolean { val packet = packetInfo.packet if (packet !is Vp9Packet) { logger.warn("Packet is not Vp9 packet") @@ -108,7 +105,7 @@ class Vp9AdaptiveSourceProjectionContext( } val receivedTime = packetInfo.receivedTime val acceptResult = vp9QualityFilter - .acceptFrame(frame, incomingIndex, targetIndex, receivedTime) + .acceptFrame(frame, incomingEncoding, targetIndex, receivedTime) frame.isAccepted = acceptResult.accept && frame.index >= lastPicIdIndexResumption if (frame.isAccepted) { val projection: Vp9FrameProjection @@ -142,7 +139,9 @@ class Vp9AdaptiveSourceProjectionContext( .addField("timestamp", packet.timestamp) .addField("seq", packet.sequenceNumber) .addField("pictureId", packet.pictureId) - .addField("index", indexString(incomingIndex)) + .addField("encoding", incomingEncoding) + .addField("spatialLayer", packet.spatialLayerIndex) + .addField("temporalLayer", packet.temporalLayerIndex) .addField("isInterPicturePredicted", packet.isInterPicturePredicted) .addField("usesInterLayerDependency", packet.usesInterLayerDependency) .addField("isUpperLevelReference", packet.isUpperLevelReference) @@ -584,10 +583,6 @@ class Vp9AdaptiveSourceProjectionContext( lastVp9FrameProjection.timestamp ) - override fun getPayloadType(): PayloadType { - return payloadType - } - @Synchronized override fun getDebugState(): JSONObject { val debugState = JSONObject() @@ -603,8 +598,6 @@ class Vp9AdaptiveSourceProjectionContext( debugState["vp9FrameMaps"] = mapSizes debugState["vp9QualityFilter"] = vp9QualityFilter.debugState - debugState["payloadType"] = payloadType.toString() - return debugState } diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/vp9/Vp9PictureMap.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/vp9/Vp9PictureMap.kt index 2913590852..b7a65b63af 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/vp9/Vp9PictureMap.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/vp9/Vp9PictureMap.kt @@ -165,7 +165,7 @@ constructor(size: Int) : ArrayCache( ) { var numCached = 0 var firstIndex = -1 - var indexTracker = PictureIdIndexTracker() + val indexTracker = PictureIdIndexTracker() /** * Gets a picture with a given VP9 picture ID from the cache. diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/vp9/Vp9QualityFilter.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/vp9/Vp9QualityFilter.kt index c17e94665b..863b4c8786 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/vp9/Vp9QualityFilter.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/vp9/Vp9QualityFilter.kt @@ -90,7 +90,7 @@ internal class Vp9QualityFilter(parentLogger: Logger) { * method at a time. * * @param frame the VP9 frame. - * @param incomingIndex the quality index of the incoming RTP packet + * @param incomingEncoding the encoding of the incoming RTP packet * @param externalTargetIndex the target quality index that the user of this * instance wants to achieve. * @param receivedTime the current time (as an Instant) @@ -99,19 +99,19 @@ internal class Vp9QualityFilter(parentLogger: Logger) { @Synchronized fun acceptFrame( frame: Vp9Frame, - incomingIndex: Int, + incomingEncoding: Int, externalTargetIndex: Int, receivedTime: Instant? ): AcceptResult { val prevIndex = currentIndex - val accept = doAcceptFrame(frame, incomingIndex, externalTargetIndex, receivedTime) + val accept = doAcceptFrame(frame, incomingEncoding, externalTargetIndex, receivedTime) val mark = if (frame.isInterPicturePredicted) { - getSidFromIndex(incomingIndex) == getSidFromIndex(currentIndex) + frame.spatialLayer.coerceAtLeast(0) == getSidFromIndex(currentIndex) } else { /* This is wrong if the stream isn't actually currently encoding the target index's spatial layer */ /* However, in that case the final (lower) spatial layer should have the marker bit set by the encoder, so I think this shouldn't be a problem? */ - getSidFromIndex(incomingIndex) == getSidFromIndex(externalTargetIndex) + frame.spatialLayer.coerceAtLeast(0) == getSidFromIndex(externalTargetIndex) } val isResumption = (prevIndex == SUSPENDED_INDEX && currentIndex != SUSPENDED_INDEX) if (isResumption) assert(accept) // Every code path that can turn off SUSPENDED_INDEX also accepts @@ -120,7 +120,7 @@ internal class Vp9QualityFilter(parentLogger: Logger) { private fun doAcceptFrame( frame: Vp9Frame, - incomingIndex: Int, + incomingEncoding: Int, externalTargetIndex: Int, receivedTime: Instant? ): Boolean { @@ -145,12 +145,11 @@ internal class Vp9QualityFilter(parentLogger: Logger) { } // If temporal scalability is not enabled, pretend that this is the base temporal layer. val temporalLayerIdOfFrame = frame.temporalLayer.coerceAtLeast(0) - val incomingEncoding = getEidFromIndex(incomingIndex) return if (frame.isKeyframe) { logger.debug { "Quality filter got keyframe for stream ${frame.ssrc}" } - val accept = acceptKeyframe(incomingIndex, receivedTime) + val accept = acceptKeyframe(frame, incomingEncoding, receivedTime) if (accept) { // Keyframes reset layer forwarding, whether or not they're an encoding switch for (i in layers.indices) { @@ -189,7 +188,7 @@ internal class Vp9QualityFilter(parentLogger: Logger) { * return accept */ - val spatialLayerOfFrame = getSidFromIndex(incomingIndex) + val spatialLayerOfFrame = frame.spatialLayer.coerceAtLeast(0) var externalTargetSpatialId = getSidFromIndex(externalTargetIndex) var currentSpatialLayer = getSidFromIndex(currentIndex) @@ -211,7 +210,7 @@ internal class Vp9QualityFilter(parentLogger: Logger) { if (wantToSwitch) { if (canForwardLayer) { logger.debug { "Switching to spatial layer $externalTargetSpatialId from $currentSpatialLayer" } - currentIndex = incomingIndex + currentIndex = RtpLayerDesc.getIndex(incomingEncoding, frame.spatialLayer, frame.temporalLayer) currentSpatialLayer = spatialLayerOfFrame } else { if (internalTargetSpatialId != externalTargetSpatialId) { @@ -273,11 +272,11 @@ internal class Vp9QualityFilter(parentLogger: Logger) { } else { // In this branch we're not processing a keyframe and the // currentEncoding is in suspended state, which means we need - // a keyframe to start streaming again. Reaching this point also - // means that we want to forward something (because both - // externalEncodingTarget is not suspended) so we set the request keyframe flag. + // a keyframe to start streaming again. + + // We should have already requested a keyframe, either above or when the + // internal target encoding was first moved off SUSPENDED_ENCODING. - // assert needsKeyframe == true; false } } @@ -287,7 +286,7 @@ internal class Vp9QualityFilter(parentLogger: Logger) { * or not. * * @param receivedTime the time the latest frame was received - * @return true if we're in layer switching phase, false otherwise. + * @return false if we're in layer switching phase, true otherwise. */ @Synchronized private fun isOutOfSwitchingPhase(receivedTime: Instant?): Boolean { @@ -337,24 +336,25 @@ internal class Vp9QualityFilter(parentLogger: Logger) { * @return true to accept the VP9 keyframe, otherwise false. */ @Synchronized - private fun acceptKeyframe(incomingIndex: Int, receivedTime: Instant?): Boolean { - val encodingIdOfKeyframe = getEidFromIndex(incomingIndex) + private fun acceptKeyframe(frame: Vp9Frame, incomingEncoding: Int, receivedTime: Instant?): Boolean { // This branch writes the {@link #currentSpatialLayerId} and it // determines whether or not we should switch to another simulcast // stream. - if (encodingIdOfKeyframe < 0) { + if (incomingEncoding < 0) { // something went terribly wrong, normally we should be able to // extract the layer id from a keyframe. - logger.error("unable to get layer id from keyframe") + logger.error("invalid encoding id for keyframe") return false } // Keyframes have to be sid 0, tid 0, unless something screwy is going on. - if (getSidFromIndex(incomingIndex) != 0 || getTidFromIndex(incomingIndex) != 0) { - logger.warn("Surprising index ${indexString(incomingIndex)} on keyframe") + // The layers can also be -1 if the layers aren't known + if (frame.spatialLayer > 0 || frame.temporalLayer > 0) { + logger.warn("Surprising layers S${frame.spatialLayer}T${frame.temporalLayer} on keyframe") } logger.debug { - "Received a keyframe of encoding: $encodingIdOfKeyframe" + "Received a keyframe of encoding: $incomingEncoding" } + val incomingIndex = RtpLayerDesc.getIndex(incomingEncoding, frame.spatialLayer, frame.temporalLayer) // The keyframe request has been fulfilled at this point, regardless of // whether we'll be able to achieve the internalEncodingIdTarget. @@ -367,16 +367,16 @@ internal class Vp9QualityFilter(parentLogger: Logger) { mostRecentKeyframeGroupArrivalTime = receivedTime logger.debug { "First keyframe in this kf group " + - "currentEncodingId: $encodingIdOfKeyframe. " + + "currentEncodingId: $incomingEncoding. " + "Target is $internalTargetEncoding" } - if (encodingIdOfKeyframe <= internalTargetEncoding) { + if (incomingEncoding <= internalTargetEncoding) { val currentEncoding = getEidFromIndex(currentIndex) // If the target is 180p and the first keyframe of a group of // keyframes is a 720p keyframe we don't project it. If we // receive a 720p keyframe, we know that there MUST be a 180p // keyframe shortly after. - if (currentEncoding != encodingIdOfKeyframe) { + if (currentEncoding != incomingEncoding) { currentIndex = incomingIndex } true @@ -389,24 +389,24 @@ internal class Vp9QualityFilter(parentLogger: Logger) { // upscale/downscale is possible. val currentEncoding = getEidFromIndex(currentIndex) when { - currentEncoding <= encodingIdOfKeyframe && - encodingIdOfKeyframe <= internalTargetEncoding -> { + currentEncoding <= incomingEncoding && + incomingEncoding <= internalTargetEncoding -> { // upscale or current quality case - if (currentEncoding != encodingIdOfKeyframe) { + if (currentEncoding != incomingEncoding) { currentIndex = incomingIndex } logger.debug { - "Upscaling to encoding $encodingIdOfKeyframe. " + + "Upscaling to encoding $incomingEncoding. " + "The target is $internalTargetEncoding" } true } - encodingIdOfKeyframe <= internalTargetEncoding && + incomingEncoding <= internalTargetEncoding && internalTargetEncoding < currentEncoding -> { // downscale case currentIndex = incomingIndex logger.debug { - "Downscaling to encoding $encodingIdOfKeyframe. " + + "Downscaling to encoding $incomingEncoding. " + "The target is $internalTargetEncoding" } true diff --git a/jvb/src/test/java/org/jitsi/videobridge/cc/vp8/VP8AdaptiveSourceProjectionTest.java b/jvb/src/test/java/org/jitsi/videobridge/cc/vp8/VP8AdaptiveSourceProjectionTest.java index 4257733130..56e9782943 100644 --- a/jvb/src/test/java/org/jitsi/videobridge/cc/vp8/VP8AdaptiveSourceProjectionTest.java +++ b/jvb/src/test/java/org/jitsi/videobridge/cc/vp8/VP8AdaptiveSourceProjectionTest.java @@ -39,8 +39,6 @@ public class VP8AdaptiveSourceProjectionTest { private final Logger logger = new LoggerImpl(getClass().getName()); - private final PayloadType payloadType = new Vp8PayloadType((byte)96, - new ConcurrentHashMap<>(), new CopyOnWriteArraySet<>()); @Test public void singlePacketProjectionTest() throws RewriteException @@ -52,8 +50,7 @@ public void singlePacketProjectionTest() throws RewriteException new RtpState(1, 10000, 1000000); VP8AdaptiveSourceProjectionContext context = - new VP8AdaptiveSourceProjectionContext(diagnosticContext, payloadType, - initialState, logger); + new VP8AdaptiveSourceProjectionContext(diagnosticContext, initialState, logger); Vp8PacketGenerator generator = new Vp8PacketGenerator(1); @@ -62,7 +59,7 @@ public void singlePacketProjectionTest() throws RewriteException int targetIndex = RtpLayerDesc.getIndex(0, 0, 0); - assertTrue(context.accept(packetInfo, packet.getTemporalLayerIndex(), targetIndex)); + assertTrue(context.accept(packetInfo, 0, targetIndex)); context.rewriteRtp(packetInfo); @@ -84,7 +81,7 @@ private void runInOrderTest(Vp8PacketGenerator generator, int targetTid) int targetIndex = RtpLayerDesc.getIndex(0, 0, targetTid); VP8AdaptiveSourceProjectionContext context = - new VP8AdaptiveSourceProjectionContext(diagnosticContext, payloadType, + new VP8AdaptiveSourceProjectionContext(diagnosticContext, initialState, logger); int expectedSeq = 10001; @@ -97,7 +94,7 @@ private void runInOrderTest(Vp8PacketGenerator generator, int targetTid) PacketInfo packetInfo = generator.nextPacket(); Vp8Packet packet = packetInfo.packetAs(); - boolean accepted = context.accept(packetInfo, packet.getTemporalLayerIndex(), targetIndex); + boolean accepted = context.accept(packetInfo, 0, targetIndex); if (packet.isStartOfFrame() && packet.getTemporalLayerIndex() == 0) { @@ -176,7 +173,6 @@ private void doRunOutOfOrderTest(Vp8PacketGenerator generator, int targetTid, VP8AdaptiveSourceProjectionContext context = new VP8AdaptiveSourceProjectionContext(diagnosticContext, - payloadType, initialState, logger); int latestSeq = buffer.get(0).packetAs().getSequenceNumber(); @@ -198,7 +194,7 @@ private void doRunOutOfOrderTest(Vp8PacketGenerator generator, int targetTid, { latestSeq = origSeq; } - boolean accepted = context.accept(packetInfo, packet.getTemporalLayerIndex(), targetIndex); + boolean accepted = context.accept(packetInfo, 0, targetIndex); int oldestValidSeq = RtpUtils.applySequenceNumberDelta(latestSeq, -((VP8FrameMap.FRAME_MAP_SIZE - 1) * generator.packetsPerFrame)); @@ -385,8 +381,7 @@ public void slightlyDelayedKeyframeTest() throws RewriteException new RtpState(1, 10000, 1000000); VP8AdaptiveSourceProjectionContext context = - new VP8AdaptiveSourceProjectionContext(diagnosticContext, payloadType, - initialState, logger); + new VP8AdaptiveSourceProjectionContext(diagnosticContext, initialState, logger); PacketInfo firstPacketInfo = generator.nextPacket(); Vp8Packet firstPacket = firstPacketInfo.packetAs(); @@ -398,10 +393,10 @@ public void slightlyDelayedKeyframeTest() throws RewriteException PacketInfo packetInfo = generator.nextPacket(); Vp8Packet packet = packetInfo.packetAs(); - assertFalse(context.accept(packetInfo, RtpLayerDesc.getIndex(0, 0, packet.getTemporalLayerIndex()), targetIndex)); + assertFalse(context.accept(packetInfo, 0, targetIndex)); } - assertTrue(context.accept(firstPacketInfo, RtpLayerDesc.getIndex(0, 0, firstPacket.getTemporalLayerIndex()), targetIndex)); + assertTrue(context.accept(firstPacketInfo, 0, targetIndex)); context.rewriteRtp(firstPacketInfo); for (int i = 0; i < 9996; i++) @@ -409,7 +404,7 @@ public void slightlyDelayedKeyframeTest() throws RewriteException PacketInfo packetInfo = generator.nextPacket(); Vp8Packet packet = packetInfo.packetAs(); - assertTrue(context.accept(packetInfo, RtpLayerDesc.getIndex(0, 0, packet.getTemporalLayerIndex()), targetIndex)); + assertTrue(context.accept(packetInfo, 0, targetIndex)); context.rewriteRtp(packetInfo); } } @@ -426,8 +421,7 @@ public void veryDelayedKeyframeTest() throws RewriteException new RtpState(1, 10000, 1000000); VP8AdaptiveSourceProjectionContext context = - new VP8AdaptiveSourceProjectionContext(diagnosticContext, payloadType, - initialState, logger); + new VP8AdaptiveSourceProjectionContext(diagnosticContext, initialState, logger); PacketInfo firstPacketInfo = generator.nextPacket(); Vp8Packet firstPacket = firstPacketInfo.packetAs(); @@ -439,17 +433,17 @@ public void veryDelayedKeyframeTest() throws RewriteException PacketInfo packetInfo = generator.nextPacket(); Vp8Packet packet = packetInfo.packetAs(); - assertFalse(context.accept(packetInfo, RtpLayerDesc.getIndex(0, 0, packet.getTemporalLayerIndex()), targetIndex)); + assertFalse(context.accept(packetInfo, 0, targetIndex)); } - assertFalse(context.accept(firstPacketInfo, RtpLayerDesc.getIndex(0, 0, firstPacket.getTemporalLayerIndex()), targetIndex)); + assertFalse(context.accept(firstPacketInfo, 0, targetIndex)); for (int i = 0; i < 10; i++) { PacketInfo packetInfo = generator.nextPacket(); Vp8Packet packet = packetInfo.packetAs(); - assertFalse(context.accept(packetInfo, RtpLayerDesc.getIndex(0, 0, packet.getTemporalLayerIndex()), targetIndex)); + assertFalse(context.accept(packetInfo, 0, targetIndex)); } generator.requestKeyframe(); @@ -459,7 +453,7 @@ public void veryDelayedKeyframeTest() throws RewriteException PacketInfo packetInfo = generator.nextPacket(); Vp8Packet packet = packetInfo.packetAs(); - assertTrue(context.accept(packetInfo, RtpLayerDesc.getIndex(0, 0, packet.getTemporalLayerIndex()), targetIndex)); + assertTrue(context.accept(packetInfo, 0, targetIndex)); context.rewriteRtp(packetInfo); } } @@ -476,8 +470,7 @@ public void delayedPartialKeyframeTest() throws RewriteException new RtpState(1, 10000, 1000000); VP8AdaptiveSourceProjectionContext context = - new VP8AdaptiveSourceProjectionContext(diagnosticContext, payloadType, - initialState, logger); + new VP8AdaptiveSourceProjectionContext(diagnosticContext, initialState, logger); PacketInfo firstPacketInfo = generator.nextPacket(); Vp8Packet firstPacket = firstPacketInfo.packetAs(); @@ -489,17 +482,17 @@ public void delayedPartialKeyframeTest() throws RewriteException PacketInfo packetInfo = generator.nextPacket(); Vp8Packet packet = packetInfo.packetAs(); - assertFalse(context.accept(packetInfo, RtpLayerDesc.getIndex(0, 0, packet.getTemporalLayerIndex()), targetIndex)); + assertFalse(context.accept(packetInfo, 0, targetIndex)); } - assertFalse(context.accept(firstPacketInfo, firstPacket.getTemporalLayerIndex(), 2)); + assertFalse(context.accept(firstPacketInfo, 0, 2)); for (int i = 0; i < 30; i++) { PacketInfo packetInfo = generator.nextPacket(); Vp8Packet packet = packetInfo.packetAs(); - assertFalse(context.accept(packetInfo, RtpLayerDesc.getIndex(0, 0, packet.getTemporalLayerIndex()), targetIndex)); + assertFalse(context.accept(packetInfo, 0, targetIndex)); } generator.requestKeyframe(); @@ -509,7 +502,7 @@ public void delayedPartialKeyframeTest() throws RewriteException PacketInfo packetInfo = generator.nextPacket(); Vp8Packet packet = packetInfo.packetAs(); - assertTrue(context.accept(packetInfo, RtpLayerDesc.getIndex(0, 0, packet.getTemporalLayerIndex()), targetIndex)); + assertTrue(context.accept(packetInfo, 0, targetIndex)); context.rewriteRtp(packetInfo); } } @@ -528,8 +521,7 @@ public void twoStreamsNoSwitchingTest() throws RewriteException new RtpState(1, 10000, 1000000); VP8AdaptiveSourceProjectionContext context = - new VP8AdaptiveSourceProjectionContext(diagnosticContext, payloadType, - initialState, logger); + new VP8AdaptiveSourceProjectionContext(diagnosticContext, initialState, logger); int targetIndex = RtpLayerDesc.getIndex(1, 0, 2); @@ -540,12 +532,12 @@ public void twoStreamsNoSwitchingTest() throws RewriteException PacketInfo packetInfo1 = generator1.nextPacket(); Vp8Packet packet1 = packetInfo1.packetAs(); - assertTrue(context.accept(packetInfo1, RtpLayerDesc.getIndex(1, 0, packet1.getTemporalLayerIndex()), targetIndex)); + assertTrue(context.accept(packetInfo1, 1, targetIndex)); PacketInfo packetInfo2 = generator2.nextPacket(); Vp8Packet packet2 = packetInfo2.packetAs(); - assertFalse(context.accept(packetInfo2, RtpLayerDesc.getIndex(0, 0, packet2.getTemporalLayerIndex()), targetIndex)); + assertFalse(context.accept(packetInfo2, 0, targetIndex)); context.rewriteRtp(packetInfo1); @@ -574,8 +566,7 @@ public void twoStreamsSwitchingTest() throws RewriteException new RtpState(1, 10000, 1000000); VP8AdaptiveSourceProjectionContext context = - new VP8AdaptiveSourceProjectionContext(diagnosticContext, payloadType, - initialState, logger); + new VP8AdaptiveSourceProjectionContext(diagnosticContext, initialState, logger); int expectedSeq = 10001; long expectedTs = 1003000; @@ -596,7 +587,7 @@ public void twoStreamsSwitchingTest() throws RewriteException expectedTl0PicIdx = VpxUtils.applyTl0PicIdxDelta(expectedTl0PicIdx, 1); } - assertTrue(context.accept(packetInfo1, RtpLayerDesc.getIndex(0, 0, packet1.getTemporalLayerIndex()), targetIndex)); + assertTrue(context.accept(packetInfo1, 0, targetIndex)); context.rewriteRtp(packetInfo1); @@ -608,7 +599,7 @@ public void twoStreamsSwitchingTest() throws RewriteException PacketInfo packetInfo2 = generator2.nextPacket(); Vp8Packet packet2 = packetInfo2.packetAs(); - assertFalse(context.accept(packetInfo2, RtpLayerDesc.getIndex(1, 0, packet2.getTemporalLayerIndex()), targetIndex)); + assertFalse(context.accept(packetInfo2, 1, targetIndex)); assertFalse(context.rewriteRtcp(srPacket2)); assertEquals(expectedSeq, packet1.getSequenceNumber()); @@ -637,7 +628,7 @@ public void twoStreamsSwitchingTest() throws RewriteException expectedTl0PicIdx = VpxUtils.applyTl0PicIdxDelta(expectedTl0PicIdx, 1); } - assertTrue(context.accept(packetInfo1, RtpLayerDesc.getIndex(0, 0, packet1.getTemporalLayerIndex()), targetIndex)); + assertTrue(context.accept(packetInfo1, 0, targetIndex)); context.rewriteRtp(packetInfo1); @@ -649,7 +640,7 @@ public void twoStreamsSwitchingTest() throws RewriteException PacketInfo packetInfo2 = generator2.nextPacket(); Vp8Packet packet2 = packetInfo2.packetAs(); - assertFalse(context.accept(packetInfo2, RtpLayerDesc.getIndex(1, 0, packet2.getTemporalLayerIndex()), targetIndex)); + assertFalse(context.accept(packetInfo2, 1, targetIndex)); assertFalse(context.rewriteRtcp(srPacket2)); assertEquals(expectedSeq, packet1.getSequenceNumber()); @@ -682,7 +673,7 @@ public void twoStreamsSwitchingTest() throws RewriteException } /* We will cut off the layer 0 keyframe after 1 packet, once we see the layer 1 keyframe. */ - assertEquals(i == 0, context.accept(packetInfo1, RtpLayerDesc.getIndex(0, 0, packet1.getTemporalLayerIndex()), targetIndex)); + assertEquals(i == 0, context.accept(packetInfo1, 0, targetIndex)); assertEquals(i == 0, context.rewriteRtcp(srPacket1)); if (i == 0) @@ -701,7 +692,7 @@ public void twoStreamsSwitchingTest() throws RewriteException expectedTl0PicIdx = VpxUtils.applyTl0PicIdxDelta(expectedTl0PicIdx, 1); } - assertTrue(context.accept(packetInfo2, RtpLayerDesc.getIndex(1, 0, packet2.getTemporalLayerIndex()), targetIndex)); + assertTrue(context.accept(packetInfo2, 1, targetIndex)); context.rewriteRtp(packetInfo2); @@ -745,8 +736,7 @@ public void temporalLayerSwitchingTest() throws RewriteException new RtpState(1, 10000, 1000000); VP8AdaptiveSourceProjectionContext context = - new VP8AdaptiveSourceProjectionContext(diagnosticContext, payloadType, - initialState, logger); + new VP8AdaptiveSourceProjectionContext(diagnosticContext, initialState, logger); int targetTid = 0; int decodableTid = 0; @@ -763,7 +753,7 @@ public void temporalLayerSwitchingTest() throws RewriteException PacketInfo packetInfo = generator.nextPacket(); Vp8Packet packet = packetInfo.packetAs(); - boolean accepted = context.accept(packetInfo, RtpLayerDesc.getIndex(0, 0, packet.getTemporalLayerIndex()), targetIndex); + boolean accepted = context.accept(packetInfo, 0, targetIndex); if (packet.isStartOfFrame() && packet.getTemporalLayerIndex() == 0) { @@ -830,9 +820,7 @@ private void runLargeDropoutTest(Vp8PacketGenerator generator, new RtpState(1, 10000, 1000000); VP8AdaptiveSourceProjectionContext context = - new VP8AdaptiveSourceProjectionContext(diagnosticContext, - payloadType, - initialState, logger); + new VP8AdaptiveSourceProjectionContext(diagnosticContext, initialState, logger); int expectedSeq = 10001; long expectedTs = 1003000; @@ -845,7 +833,7 @@ private void runLargeDropoutTest(Vp8PacketGenerator generator, Vp8Packet packet = packetInfo.packetAs(); boolean accepted = - context.accept(packetInfo, RtpLayerDesc.getIndex(0, 0, packet.getTemporalLayerIndex()), targetIndex); + context.accept(packetInfo, 0, targetIndex); if (packet.isStartOfFrame() && packet.getTemporalLayerIndex() == 0) { @@ -898,7 +886,7 @@ private void runLargeDropoutTest(Vp8PacketGenerator generator, } while (packet.getTemporalLayerIndex() > targetIndex); - assertTrue(context.accept(packetInfo, RtpLayerDesc.getIndex(0, 0, packet.getTemporalLayerIndex()), targetIndex)); + assertTrue(context.accept(packetInfo, 0, targetIndex)); context.rewriteRtp(packetInfo); /* Allow any values after a gap. */ @@ -921,7 +909,7 @@ private void runLargeDropoutTest(Vp8PacketGenerator generator, packet = packetInfo.packetAs(); boolean accepted = context - .accept(packetInfo, packet.getTemporalLayerIndex(), targetIndex); + .accept(packetInfo, 0, targetIndex); if (packet.isStartOfFrame() && packet.getTemporalLayerIndex() == 0) diff --git a/jvb/src/test/kotlin/org/jitsi/videobridge/cc/allocation/BitrateControllerTest.kt b/jvb/src/test/kotlin/org/jitsi/videobridge/cc/allocation/BitrateControllerTest.kt index bd9435c513..e50c0f7197 100644 --- a/jvb/src/test/kotlin/org/jitsi/videobridge/cc/allocation/BitrateControllerTest.kt +++ b/jvb/src/test/kotlin/org/jitsi/videobridge/cc/allocation/BitrateControllerTest.kt @@ -1532,9 +1532,16 @@ class MockRtpLayerDesc( var bitrate: Bandwidth, sid: Int = -1 ) : RtpLayerDesc(eid, tid, sid, height, frameRate) { + override fun copy(height: Int): RtpLayerDesc { + TODO("Not yet implemented") + } + + override val layerId = getIndex(0, sid, tid) + override val index = getIndex(eid, sid, tid) override fun getBitrate(nowMs: Long): Bandwidth = bitrate override fun hasZeroBitrate(nowMs: Long): Boolean = bitrate == 0.bps + override fun indexString() = indexString(index) } typealias History = MutableList> diff --git a/jvb/src/test/kotlin/org/jitsi/videobridge/cc/av1/Av1DDAdaptiveSourceProjectionTest.kt b/jvb/src/test/kotlin/org/jitsi/videobridge/cc/av1/Av1DDAdaptiveSourceProjectionTest.kt new file mode 100644 index 0000000000..e3bbead0ef --- /dev/null +++ b/jvb/src/test/kotlin/org/jitsi/videobridge/cc/av1/Av1DDAdaptiveSourceProjectionTest.kt @@ -0,0 +1,1571 @@ +/* + * Copyright @ 2019 - Present, 8x8 Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.videobridge.cc.av1 + +import jakarta.xml.bind.DatatypeConverter.parseHexBinary +import org.jitsi.nlj.PacketInfo +import org.jitsi.nlj.RtpLayerDesc +import org.jitsi.nlj.rtp.codec.av1.Av1DDPacket +import org.jitsi.nlj.rtp.codec.av1.Av1DDRtpLayerDesc.Companion.getIndex +import org.jitsi.nlj.util.Rfc3711IndexTracker +import org.jitsi.rtp.rtcp.RtcpSrPacket +import org.jitsi.rtp.rtcp.RtcpSrPacketBuilder +import org.jitsi.rtp.rtp.RtpPacket +import org.jitsi.rtp.rtp.header_extensions.Av1DependencyDescriptorHeaderExtension +import org.jitsi.rtp.rtp.header_extensions.Av1DependencyDescriptorReader +import org.jitsi.rtp.rtp.header_extensions.Av1TemplateDependencyStructure +import org.jitsi.rtp.rtp.header_extensions.FrameInfo +import org.jitsi.rtp.util.RtpUtils +import org.jitsi.rtp.util.isNewerThan +import org.jitsi.rtp.util.isOlderThan +import org.jitsi.utils.logging.DiagnosticContext +import org.jitsi.utils.logging2.Logger +import org.jitsi.utils.logging2.LoggerImpl +import org.jitsi.videobridge.cc.RtpState +import org.junit.Assert +import org.junit.Test +import java.time.Duration +import java.time.Instant +import java.util.* +import javax.xml.bind.DatatypeConverter + +class Av1DDAdaptiveSourceProjectionTest { + private val logger: Logger = LoggerImpl(javaClass.name) + + @Test + fun singlePacketProjectionTest() { + val diagnosticContext = DiagnosticContext() + diagnosticContext["test"] = "singlePacketProjectionTest" + val initialState = RtpState(1, 10000, 1000000) + val context = Av1DDAdaptiveSourceProjectionContext( + diagnosticContext, + initialState, + logger + ) + val generator = ScalableAv1PacketGenerator(1) + val packetInfo = generator.nextPacket() + val packet = packetInfo.packetAs() + val targetIndex = getIndex(eid = 0, dt = 0) + Assert.assertTrue(context.accept(packetInfo, 0, targetIndex)) + context.rewriteRtp(packetInfo) + Assert.assertEquals(10001, packet.sequenceNumber) + Assert.assertEquals(1003000, packet.timestamp) + Assert.assertEquals(0, packet.frameNumber) + Assert.assertEquals(0, packet.frameInfo?.spatialId) + Assert.assertEquals(0, packet.frameInfo?.temporalId) + } + + private fun runInOrderTest(generator: Av1PacketGenerator, targetIndex: Int, expectAccept: (FrameInfo) -> Boolean) { + val diagnosticContext = DiagnosticContext() + diagnosticContext["test"] = Thread.currentThread().stackTrace[2].methodName + val initialState = RtpState(1, 10000, 1000000) + val context = Av1DDAdaptiveSourceProjectionContext( + diagnosticContext, + initialState, + logger + ) + var expectedSeq = 10001 + var expectedTs: Long = 1003000 + var expectedFrameNumber = 0 + for (i in 0..99999) { + val packetInfo = generator.nextPacket() + val packet = packetInfo.packetAs() + val frameInfo = packet.frameInfo!! + + val accepted = context.accept(packetInfo, 0, targetIndex) + val endOfFrame = packet.isEndOfFrame + val endOfPicture = packet.isMarked // Save this before rewriteRtp + if (expectAccept(frameInfo)) { + Assert.assertTrue(accepted) + context.rewriteRtp(packetInfo) + Assert.assertEquals(expectedSeq, packet.sequenceNumber) + Assert.assertEquals(expectedTs, packet.timestamp) + Assert.assertEquals(expectedFrameNumber, packet.frameNumber) + expectedSeq = RtpUtils.applySequenceNumberDelta(expectedSeq, 1) + } else { + Assert.assertFalse(accepted) + } + if (endOfFrame) { + expectedFrameNumber = RtpUtils.applySequenceNumberDelta(expectedFrameNumber, 1) + } + if (endOfPicture) { + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + } + } + } + + @Test + fun simpleNonScalableTest() { + val generator = NonScalableAv1PacketGenerator(1) + runInOrderTest(generator, getIndex(eid = 0, dt = 0)) { + true + } + } + + @Test + fun simpleTemporalProjectionTest() { + val generator = TemporallyScaledPacketGenerator(1) + runInOrderTest(generator, getIndex(eid = 0, dt = 2)) { + true + } + } + + @Test + fun filteredTemporalProjectionTest() { + val generator = TemporallyScaledPacketGenerator(1) + runInOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.temporalId == 0 + } + } + + @Test + fun largerFrameTemporalProjectionTest() { + val generator = TemporallyScaledPacketGenerator(3) + runInOrderTest(generator, getIndex(eid = 0, dt = 2)) { + true + } + } + + @Test + fun largerFrameTemporalFilteredTest() { + val generator = TemporallyScaledPacketGenerator(3) + runInOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.temporalId == 0 + } + } + + @Test + fun hugeFrameTest() { + val generator = TemporallyScaledPacketGenerator(200) + runInOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.temporalId == 0 + } + } + + @Test + fun simpleSvcTest() { + val generator = ScalableAv1PacketGenerator(1) + runInOrderTest(generator, getIndex(eid = 0, dt = 3 * 2 + 2)) { + true + } + } + + @Test + fun filteredSvcTest() { + val generator = ScalableAv1PacketGenerator(1) + runInOrderTest(generator, getIndex(eid = 0, dt = 2)) { + it.spatialId == 0 + } + } + + @Test + fun temporalFilteredSvcTest() { + val generator = ScalableAv1PacketGenerator(1) + runInOrderTest(generator, getIndex(eid = 0, dt = 3 * 2)) { + it.temporalId == 0 + } + } + + @Test + fun spatialAndTemporalFilteredSvcTest() { + val generator = ScalableAv1PacketGenerator(1) + runInOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.spatialId == 0 && it.temporalId == 0 + } + } + + @Test + fun largerSvcTest() { + val generator = ScalableAv1PacketGenerator(3) + runInOrderTest(generator, getIndex(eid = 0, dt = 3 * 2 + 2)) { + true + } + } + + @Test + fun largerFilteredSvcTest() { + val generator = ScalableAv1PacketGenerator(3) + runInOrderTest(generator, getIndex(eid = 0, dt = 2)) { + it.spatialId == 0 + } + } + + @Test + fun largerTemporalFilteredSvcTest() { + val generator = ScalableAv1PacketGenerator(3) + runInOrderTest(generator, getIndex(eid = 0, dt = 3 * 2)) { + it.temporalId == 0 + } + } + + @Test + fun largerSpatialAndTemporalFilteredSvcTest() { + val generator = ScalableAv1PacketGenerator(3) + runInOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.spatialId == 0 && it.temporalId == 0 + } + } + + @Test + fun simpleKSvcTest() { + val generator = KeyScalableAv1PacketGenerator(1) + runInOrderTest(generator, getIndex(eid = 0, dt = 3 * 2 + 2)) { + it.spatialId == 2 || !it.hasInterPictureDependency() + } + } + + @Test + fun filteredKSvcTest() { + val generator = KeyScalableAv1PacketGenerator(1) + runInOrderTest(generator, getIndex(eid = 0, dt = 2)) { + it.spatialId == 0 + } + } + + @Test + fun temporalFilteredKSvcTest() { + val generator = KeyScalableAv1PacketGenerator(1) + runInOrderTest(generator, getIndex(eid = 0, dt = 3 * 2)) { + it.temporalId == 0 && (it.spatialId == 2 || !it.hasInterPictureDependency()) + } + } + + @Test + fun spatialAndTemporalFilteredKSvcTest() { + val generator = KeyScalableAv1PacketGenerator(1) + runInOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.spatialId == 0 && it.temporalId == 0 + } + } + + @Test + fun largerKSvcTest() { + val generator = KeyScalableAv1PacketGenerator(3) + runInOrderTest(generator, getIndex(eid = 0, dt = 3 * 2 + 2)) { + it.spatialId == 2 || !it.hasInterPictureDependency() + } + } + + @Test + fun largerFilteredKSvcTest() { + val generator = KeyScalableAv1PacketGenerator(3) + runInOrderTest(generator, getIndex(eid = 0, dt = 2)) { + it.spatialId == 0 + } + } + + @Test + fun largerTemporalFilteredKSvcTest() { + val generator = KeyScalableAv1PacketGenerator(3) + runInOrderTest(generator, getIndex(eid = 0, dt = 3 * 2)) { + it.temporalId == 0 && (it.spatialId == 2 || !it.hasInterPictureDependency()) + } + } + + @Test + fun largerSpatialAndTemporalFilteredKSvcTest() { + val generator = KeyScalableAv1PacketGenerator(3) + runInOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.spatialId == 0 && it.temporalId == 0 + } + } + + @Test + fun simpleSingleEncodingSimulcastTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(1) + runInOrderTest(generator, getIndex(eid = 0, dt = 3 * 2 + 2)) { + it.spatialId == 2 + } + } + + @Test + fun filteredSingleEncodingSimulcastTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(1) + runInOrderTest(generator, getIndex(eid = 0, dt = 2)) { + it.spatialId == 0 + } + } + + @Test + fun temporalFilteredSingleEncodingSimulcastTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(1) + runInOrderTest(generator, getIndex(eid = 0, dt = 3 * 2)) { + it.temporalId == 0 && it.spatialId == 2 + } + } + + @Test + fun spatialAndTemporalFilteredSingleEncodingSimulcastTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(1) + runInOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.spatialId == 0 && it.temporalId == 0 + } + } + + @Test + fun largerSingleEncodingSimulcastTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(3) + runInOrderTest(generator, getIndex(eid = 0, dt = 3 * 2 + 2)) { + it.spatialId == 2 + } + } + + @Test + fun largerFilteredSingleEncodingSimulcastTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(3) + runInOrderTest(generator, getIndex(eid = 0, dt = 2)) { + it.spatialId == 0 + } + } + + @Test + fun largerTemporalFilteredSingleEncodingSimulcastTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(3) + runInOrderTest(generator, getIndex(eid = 0, dt = 3 * 2)) { + it.temporalId == 0 && it.spatialId == 2 + } + } + + @Test + fun largerSpatialAndTemporalFilteredSingleEncodingSimulcastTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(3) + runInOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.spatialId == 0 && it.temporalId == 0 + } + } + + private class ProjectedPacket constructor( + val packet: Av1DDPacket, + val origSeq: Int, + val extOrigSeq: Int, + val extFrameNum: Int, + ) + + /** Run an out-of-order test on a single stream, randomized order except for the first + * [initialOrderedCount] packets. */ + private fun doRunOutOfOrderTest( + generator: Av1PacketGenerator, + targetIndex: Int, + initialOrderedCount: Int, + seed: Long, + expectAccept: (FrameInfo) -> Boolean + ) { + val diagnosticContext = DiagnosticContext() + diagnosticContext["test"] = Thread.currentThread().stackTrace[2].methodName + val initialState = RtpState(1, 10000, 1000000) + val expectedInitialTs: Long = RtpUtils.applyTimestampDelta(initialState.maxTimestamp, 3000) + val expectedTsOffset: Long = RtpUtils.getTimestampDiff(expectedInitialTs, generator.ts) + val reorderSize = 64 + val buffer = ArrayList(reorderSize) + for (i in 0 until reorderSize) { + buffer.add(generator.nextPacket()) + } + val random = Random(seed) + var orderedCount = initialOrderedCount - 1 + val context = Av1DDAdaptiveSourceProjectionContext( + diagnosticContext, + initialState, + logger + ) + var latestSeq = buffer[0].packetAs().sequenceNumber + val projectedPackets = TreeMap() + val origSeqIdxTracker = Rfc3711IndexTracker() + val newSeqIdxTracker = Rfc3711IndexTracker() + val frameNumsDropped = HashSet() + val frameNumsIndexTracker = Rfc3711IndexTracker() + for (i in 0..99999) { + val packetInfo = buffer[0] + val packet = packetInfo.packetAs() + val origSeq = packet.sequenceNumber + val origTs = packet.timestamp + if (latestSeq isOlderThan origSeq) { + latestSeq = origSeq + } + val frameInfo = packet.frameInfo!! + + val accepted = context.accept(packetInfo, 0, targetIndex) + val oldestValidSeq: Int = + RtpUtils.applySequenceNumberDelta( + latestSeq, + -((Av1DDFrameMap.FRAME_MAP_SIZE - 1) * generator.packetsPerFrame) + ) + if (origSeq isOlderThan oldestValidSeq && !accepted) { + /* This is fine; packets that are too old get ignored. */ + /* Note we don't want assertFalse(accepted) here because slightly-too-old packets + * that are part of an existing accepted frame will be accepted. + */ + val extFrameNum = frameNumsIndexTracker.update(packet.frameNumber) + frameNumsDropped.add(extFrameNum) + } else if (expectAccept(frameInfo) + ) { + Assert.assertTrue(accepted) + + context.rewriteRtp(packetInfo) + Assert.assertEquals(RtpUtils.applyTimestampDelta(origTs, expectedTsOffset), packet.timestamp) + val newSeq = packet.sequenceNumber + val extNewSeq = newSeqIdxTracker.update(newSeq) + val extOrigSeq = origSeqIdxTracker.update(origSeq) + Assert.assertFalse(projectedPackets.containsKey(extNewSeq)) + val extFrameNum = frameNumsIndexTracker.update(packet.frameNumber) + projectedPackets[extNewSeq] = ProjectedPacket(packet, origSeq, extOrigSeq, extFrameNum) + } else { + Assert.assertFalse(accepted) + } + if (orderedCount > 0) { + buffer.removeAt(0) + buffer.add(generator.nextPacket()) + orderedCount-- + } else { + buffer[0] = generator.nextPacket() + buffer.shuffle(random) + } + } + val frameNumsSeen = HashSet() + + /* Add packets that weren't added yet, or that were dropped for being too old, to frameNumsSeen. */ + frameNumsSeen.addAll(frameNumsDropped) + buffer.forEach { + frameNumsSeen.add(frameNumsIndexTracker.update(it.packetAs().frameNumber)) + } + + val iter = projectedPackets.keys.iterator() + var prevPacket = projectedPackets[iter.next()]!! + frameNumsSeen.add(prevPacket.extFrameNum) + while (iter.hasNext()) { + val packet = projectedPackets[iter.next()] + Assert.assertTrue(packet!!.origSeq isNewerThan prevPacket.origSeq) + frameNumsSeen.add(packet.extFrameNum) + Assert.assertTrue( + RtpUtils.getSequenceNumberDelta( + prevPacket.packet.frameNumber, + packet.packet.frameNumber + ) <= 0 + ) + if (packet.packet.isStartOfFrame) { + Assert.assertTrue( + RtpUtils.getSequenceNumberDelta( + prevPacket.packet.frameNumber, + packet.packet.frameNumber + ) < 0 + ) + if (prevPacket.packet.sequenceNumber == + RtpUtils.applySequenceNumberDelta(packet.packet.sequenceNumber, -1) + ) { + Assert.assertTrue(prevPacket.packet.isEndOfFrame) + } + } else { + if (prevPacket.packet.sequenceNumber == + RtpUtils.applySequenceNumberDelta(packet.packet.sequenceNumber, -1) + ) { + Assert.assertEquals(prevPacket.packet.frameNumber, packet.packet.frameNumber) + Assert.assertEquals(prevPacket.packet.timestamp, packet.packet.timestamp) + } + } + packet.packet.frameInfo?.fdiff?.forEach { + Assert.assertTrue(frameNumsSeen.contains(packet.extFrameNum - it)) + } + prevPacket = packet + } + + /* Overall, we should not have expanded sequence numbers. */ + val firstPacket = projectedPackets.firstEntry().value + val lastPacket = projectedPackets.lastEntry().value + val origDelta = lastPacket!!.extOrigSeq - firstPacket!!.extOrigSeq + val projDelta = projectedPackets.lastKey() - projectedPackets.firstKey() + Assert.assertTrue(projDelta <= origDelta) + } + + /** Run multiple instances of out-of-order test on a single stream, with different + * random seeds. */ + private fun runOutOfOrderTest( + generator: Av1PacketGenerator, + targetIndex: Int, + initialOrderedCount: Int = 1, + expectAccept: (FrameInfo) -> Boolean + ) { + /* Seeds that have triggered problems in the past for this or VP8/VP9, plus a random one. */ + val seeds = longArrayOf( + 1576267371838L, + 1578347926155L, + 1579620018479L, + 5786714086792432950L, + 5929140296748347521L, + -8226056792707023108L, + System.currentTimeMillis() + ) + for (seed in seeds) { + try { + doRunOutOfOrderTest(generator, targetIndex, initialOrderedCount, seed, expectAccept) + } catch (e: Throwable) { + logger.error( + "Exception thrown in randomized test, seed = $seed", + e + ) + throw e + } + generator.reset() + } + } + + @Test + fun simpleOutOfOrderNonScalableTest() { + val generator = NonScalableAv1PacketGenerator(1) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 0)) { + true + } + } + + @Test + fun simpleOutOfOrderTemporalProjectionTest() { + val generator = TemporallyScaledPacketGenerator(1) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 2)) { + true + } + } + + @Test + fun filteredOutOfOrderTemporalProjectionTest() { + val generator = TemporallyScaledPacketGenerator(1) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.temporalId == 0 + } + } + + @Test + fun largerFrameOutOfOrderTemporalProjectionTest() { + val generator = TemporallyScaledPacketGenerator(3) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 2)) { + true + } + } + + @Test + fun largerFrameOutOfOrderTemporalFilteredTest() { + val generator = TemporallyScaledPacketGenerator(3) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.temporalId == 0 + } + } + + @Test + fun hugeFrameOutOfOrderTest() { + val generator = TemporallyScaledPacketGenerator(200) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.temporalId == 0 + } + } + + @Test + fun simpleSvcOutOfOrderTest() { + val generator = ScalableAv1PacketGenerator(1) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 3 * 2 + 2)) { + true + } + } + + @Test + fun filteredSvcOutOfOrderTest() { + val generator = ScalableAv1PacketGenerator(1) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 2)) { + it.spatialId == 0 + } + } + + @Test + fun temporalFilteredOutOfOrderSvcOutOfOrderTest() { + val generator = ScalableAv1PacketGenerator(1) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 3 * 2)) { + it.temporalId == 0 + } + } + + @Test + fun spatialAndTemporalFilteredSvcOutOfOrderTest() { + val generator = ScalableAv1PacketGenerator(1) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.spatialId == 0 && it.temporalId == 0 + } + } + + @Test + fun largerSvcOutOfOrderTest() { + val generator = ScalableAv1PacketGenerator(3) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 3 * 2 + 2)) { + true + } + } + + @Test + fun largerFilteredSvcOutOfOrderTest() { + val generator = ScalableAv1PacketGenerator(3) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 2)) { + it.spatialId == 0 + } + } + + @Test + fun largerTemporalFilteredSvcOutOfOrderTest() { + val generator = ScalableAv1PacketGenerator(3) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 3 * 2)) { + it.temporalId == 0 + } + } + + @Test + fun largerSpatialAndTemporalFilteredSvcOutOfOrderTest() { + val generator = ScalableAv1PacketGenerator(3) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.spatialId == 0 && it.temporalId == 0 + } + } + + @Test + fun simpleKSvcOutOfOrderTest() { + val generator = KeyScalableAv1PacketGenerator(1) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 3 * 2 + 2)) { + it.spatialId == 2 || !it.hasInterPictureDependency() + } + } + + @Test + fun filteredKSvcOutOfOrderTest() { + val generator = KeyScalableAv1PacketGenerator(1) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 2)) { + it.spatialId == 0 + } + } + + @Test + fun temporalFilteredKSvcOutOfOrderTest() { + val generator = KeyScalableAv1PacketGenerator(1) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 3 * 2)) { + it.temporalId == 0 && (it.spatialId == 2 || !it.hasInterPictureDependency()) + } + } + + @Test + fun spatialAndTemporalFilteredKSvcOutOfOrderTest() { + val generator = KeyScalableAv1PacketGenerator(1) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.spatialId == 0 && it.temporalId == 0 + } + } + + @Test + fun largerKSvcOutOfOrderTest() { + val generator = KeyScalableAv1PacketGenerator(3) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 3 * 2 + 2)) { + it.spatialId == 2 || !it.hasInterPictureDependency() + } + } + + @Test + fun largerFilteredKSvcOutOfOrderTest() { + val generator = KeyScalableAv1PacketGenerator(3) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 2)) { + it.spatialId == 0 + } + } + + @Test + fun largerTemporalFilteredKSvcOutOfOrderTest() { + val generator = KeyScalableAv1PacketGenerator(3) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 3 * 2)) { + it.temporalId == 0 && (it.spatialId == 2 || !it.hasInterPictureDependency()) + } + } + + @Test + fun largerSpatialAndTemporalFilteredKSvcOutOfOrderTest() { + val generator = KeyScalableAv1PacketGenerator(3) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 0)) { + it.spatialId == 0 && it.temporalId == 0 + } + } + + @Test + fun simpleSingleEncodingSimulcastOutOfOrderTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(1) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 3 * 2 + 2), 3) { + it.spatialId == 2 + } + } + + @Test + fun filteredSingleEncodingSimulcastOutOfOrderTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(1) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 2), 3) { + it.spatialId == 0 + } + } + + @Test + fun temporalFilteredSingleEncodingSimulcastOutOfOrderTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(1) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 3 * 2), 3) { + it.temporalId == 0 && it.spatialId == 2 + } + } + + @Test + fun spatialAndTemporalFilteredSingleEncodingSimulcastOutOfOrderTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(1) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 0), 3) { + it.spatialId == 0 && it.temporalId == 0 + } + } + + @Test + fun largerSingleEncodingSimulcastOutOfOrderTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(3) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 3 * 2 + 2), 7) { + it.spatialId == 2 + } + } + + @Test + fun largerFilteredSingleEncodingSimulcastOutOfOrderTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(3) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 2), 7) { + it.spatialId == 0 + } + } + + @Test + fun largerTemporalFilteredSingleEncodingSimulcastOutOfOrderTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(3) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 3 * 2), 7) { + it.temporalId == 0 && it.spatialId == 2 + } + } + + @Test + fun largerSpatialAndTemporalFilteredSingleEncodingSimulcastOutOfOrderTest() { + val generator = SingleEncodingSimulcastAv1PacketGenerator(3) + runOutOfOrderTest(generator, getIndex(eid = 0, dt = 0), 7) { + it.spatialId == 0 && it.temporalId == 0 + } + } + + @Test + fun slightlyDelayedKeyframeTest() { + val generator = TemporallyScaledPacketGenerator(1) + val diagnosticContext = DiagnosticContext() + diagnosticContext["test"] = "slightlyDelayedKeyframeTest" + val initialState = RtpState(1, 10000, 1000000) + val context = Av1DDAdaptiveSourceProjectionContext( + diagnosticContext, + initialState, + logger + ) + val firstPacketInfo = generator.nextPacket() + val targetIndex = getIndex(eid = 0, dt = 2) + for (i in 0..2) { + val packetInfo = generator.nextPacket() + + Assert.assertFalse(context.accept(packetInfo, 0, targetIndex)) + } + Assert.assertTrue(context.accept(firstPacketInfo, 0, targetIndex)) + context.rewriteRtp(firstPacketInfo) + for (i in 0..9995) { + val packetInfo = generator.nextPacket() + Assert.assertTrue(context.accept(packetInfo, 0, targetIndex)) + context.rewriteRtp(packetInfo) + } + } + + @Test + fun veryDelayedKeyframeTest() { + val generator = TemporallyScaledPacketGenerator(1) + val diagnosticContext = DiagnosticContext() + diagnosticContext["test"] = "veryDelayedKeyframeTest" + val initialState = RtpState(1, 10000, 1000000) + val context = Av1DDAdaptiveSourceProjectionContext( + diagnosticContext, + initialState, + logger + ) + val firstPacketInfo = generator.nextPacket() + val targetIndex = getIndex(eid = 0, dt = 2) + for (i in 0..3) { + val packetInfo = generator.nextPacket(missedStructure = true) + Assert.assertFalse(context.accept(packetInfo, 0, targetIndex)) + } + Assert.assertFalse(context.accept(firstPacketInfo, 0, targetIndex)) + for (i in 0..9) { + val packetInfo = generator.nextPacket() + Assert.assertFalse(context.accept(packetInfo, 0, targetIndex)) + } + generator.requestKeyframe() + for (i in 0..9995) { + val packetInfo = generator.nextPacket() + Assert.assertTrue(context.accept(packetInfo, 0, targetIndex)) + context.rewriteRtp(packetInfo) + } + } + + @Test + fun twoStreamsNoSwitchingTest() { + val generator1 = TemporallyScaledPacketGenerator(3) + val generator2 = TemporallyScaledPacketGenerator(3) + generator2.ssrc = 0xdeadbeefL + val diagnosticContext = DiagnosticContext() + diagnosticContext["test"] = "twoStreamsNoSwitchingTest" + val initialState = RtpState(1, 10000, 1000000) + val context = Av1DDAdaptiveSourceProjectionContext(diagnosticContext, initialState, logger) + val targetIndex = getIndex(eid = 1, dt = 2) + var expectedSeq = 10001 + var expectedTs: Long = 1003000 + for (i in 0..9999) { + val packetInfo1 = generator1.nextPacket() + val packet1 = packetInfo1.packetAs() + + Assert.assertTrue(context.accept(packetInfo1, 1, targetIndex)) + val packetInfo2 = generator2.nextPacket() + Assert.assertFalse(context.accept(packetInfo2, 0, targetIndex)) + context.rewriteRtp(packetInfo1) + Assert.assertEquals(expectedSeq, packet1.sequenceNumber) + Assert.assertEquals(expectedTs, packet1.timestamp) + expectedSeq = RtpUtils.applySequenceNumberDelta(expectedSeq, 1) + if (packet1.isEndOfFrame) { + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + } + } + } + + @Test + fun twoStreamsSwitchingTest() { + val generator1 = TemporallyScaledPacketGenerator(3) + val generator2 = TemporallyScaledPacketGenerator(3) + generator2.ssrc = 0xdeadbeefL + val diagnosticContext = DiagnosticContext() + diagnosticContext["test"] = "twoStreamsSwitchingTest" + val initialState = RtpState(1, 10000, 1000000) + val context = Av1DDAdaptiveSourceProjectionContext(diagnosticContext, initialState, logger) + var expectedSeq = 10001 + var expectedTs: Long = 1003000 + var expectedFrameNumber = 0 + var expectedTemplateOffset = 0 + var targetIndex = getIndex(eid = 0, dt = 2) + + /* Start by wanting encoding 0 */ + for (i in 0..899) { + val srPacket1 = generator1.srPacket + val packetInfo1 = generator1.nextPacket() + val packet1 = packetInfo1.packetAs() + if (i == 0) { + expectedTemplateOffset = packet1.descriptor!!.structure.templateIdOffset + } + Assert.assertTrue(context.accept(packetInfo1, 0, targetIndex)) + context.rewriteRtp(packetInfo1) + Assert.assertTrue(context.rewriteRtcp(srPacket1)) + Assert.assertEquals(packet1.ssrc, srPacket1.senderSsrc) + Assert.assertEquals(packet1.timestamp, srPacket1.senderInfo.rtpTimestamp) + val srPacket2 = generator2.srPacket + val packetInfo2 = generator2.nextPacket() + Assert.assertFalse(context.accept(packetInfo2, 1, targetIndex)) + Assert.assertFalse(context.rewriteRtcp(srPacket2)) + Assert.assertEquals(expectedSeq, packet1.sequenceNumber) + Assert.assertEquals(expectedTs, packet1.timestamp) + Assert.assertEquals(expectedFrameNumber, packet1.frameNumber) + Assert.assertEquals(expectedTemplateOffset, packet1.descriptor?.structure?.templateIdOffset) + expectedSeq = RtpUtils.applySequenceNumberDelta(expectedSeq, 1) + if (packet1.isEndOfFrame) { + expectedFrameNumber = RtpUtils.applySequenceNumberDelta(expectedFrameNumber, 1) + } + if (packet1.isMarked) { + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + } + } + + /* Switch to wanting encoding 1, but don't send a keyframe. We should stay at the first encoding. */ + targetIndex = getIndex(eid = 1, dt = 2) + for (i in 0..89) { + val srPacket1 = generator1.srPacket + val packetInfo1 = generator1.nextPacket() + val packet1 = packetInfo1.packetAs() + Assert.assertTrue(context.accept(packetInfo1, 0, targetIndex)) + context.rewriteRtp(packetInfo1) + Assert.assertTrue(context.rewriteRtcp(srPacket1)) + Assert.assertEquals(packet1.ssrc, srPacket1.senderSsrc) + Assert.assertEquals(packet1.timestamp, srPacket1.senderInfo.rtpTimestamp) + val srPacket2 = generator2.srPacket + val packetInfo2 = generator2.nextPacket() + Assert.assertFalse(context.accept(packetInfo2, 1, targetIndex)) + Assert.assertFalse(context.rewriteRtcp(srPacket2)) + Assert.assertEquals(expectedSeq, packet1.sequenceNumber) + Assert.assertEquals(expectedTs, packet1.timestamp) + Assert.assertEquals(expectedFrameNumber, packet1.frameNumber) + Assert.assertEquals(expectedTemplateOffset, packet1.descriptor?.structure?.templateIdOffset) + expectedSeq = RtpUtils.applySequenceNumberDelta(expectedSeq, 1) + if (packet1.isEndOfFrame) { + expectedFrameNumber = RtpUtils.applySequenceNumberDelta(expectedFrameNumber, 1) + } + if (packet1.isMarked) { + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + } + } + generator1.requestKeyframe() + generator2.requestKeyframe() + + /* After a keyframe we should accept spatial layer 1 */ + for (i in 0..8999) { + val srPacket1 = generator1.srPacket + val packetInfo1 = generator1.nextPacket() + val packet1 = packetInfo1.packetAs() + + /* We will cut off the layer 0 keyframe after 1 packet, once we see the layer 1 keyframe. */ + Assert.assertEquals(i == 0, context.accept(packetInfo1, 0, targetIndex)) + Assert.assertEquals(i == 0, context.rewriteRtcp(srPacket1)) + if (i == 0) { + context.rewriteRtp(packetInfo1) + Assert.assertEquals(packet1.ssrc, srPacket1.senderSsrc) + Assert.assertEquals(packet1.timestamp, srPacket1.senderInfo.rtpTimestamp) + expectedTemplateOffset += packet1.descriptor!!.structure.templateCount + } + val srPacket2 = generator2.srPacket + val packetInfo2 = generator2.nextPacket() + val packet2 = packetInfo2.packetAs() + Assert.assertTrue(context.accept(packetInfo2, 1, targetIndex)) + val expectedTemplateId = (packet2.descriptor!!.frameDependencyTemplateId + expectedTemplateOffset) % 64 + context.rewriteRtp(packetInfo2) + Assert.assertTrue(context.rewriteRtcp(srPacket2)) + Assert.assertEquals(packet2.ssrc, srPacket2.senderSsrc) + Assert.assertEquals(packet2.timestamp, srPacket2.senderInfo.rtpTimestamp) + if (i == 0) { + /* We leave a 1-packet gap for the layer 0 keyframe. */ + expectedSeq += 2 + /* ts will advance by an extra 3000 samples for the extra frame. */ + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + /* frame number will advance by 1 for the extra keyframe. */ + expectedFrameNumber = RtpUtils.applySequenceNumberDelta(expectedFrameNumber, 1) + } + Assert.assertEquals(expectedSeq, packet2.sequenceNumber) + Assert.assertEquals(expectedTs, packet2.timestamp) + Assert.assertEquals(expectedFrameNumber, packet2.frameNumber) + Assert.assertEquals(expectedTemplateId, packet2.descriptor?.frameDependencyTemplateId) + if (packet2.descriptor?.newTemplateDependencyStructure != null) { + Assert.assertEquals( + expectedTemplateOffset, + packet2.descriptor?.newTemplateDependencyStructure?.templateIdOffset + ) + } + expectedSeq = RtpUtils.applySequenceNumberDelta(expectedSeq, 1) + if (packet2.isEndOfFrame) { + expectedFrameNumber = RtpUtils.applySequenceNumberDelta(expectedFrameNumber, 1) + } + if (packet2.isMarked) { + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + } + } + } + + @Test + fun temporalLayerSwitchingTest() { + val generator = TemporallyScaledPacketGenerator(3) + val diagnosticContext = DiagnosticContext() + diagnosticContext["test"] = "temporalLayerSwitchingTest" + val initialState = RtpState(1, 10000, 1000000) + val context = Av1DDAdaptiveSourceProjectionContext( + diagnosticContext, + initialState, + logger + ) + var targetTid = 0 + var decodableTid = 0 + var targetIndex = getIndex(0, targetTid) + var expectedSeq = 10001 + var expectedTs: Long = 1003000 + var expectedFrameNumber = 0 + for (i in 0..9999) { + val packetInfo = generator.nextPacket() + val packet = packetInfo.packetAs() + val accepted = context.accept(packetInfo, 0, targetIndex) + if (accepted) { + if (decodableTid < packet.frameInfo!!.temporalId) { + decodableTid = packet.frameInfo!!.temporalId + } + } else { + if (decodableTid > packet.frameInfo!!.temporalId - 1) { + decodableTid = packet.frameInfo!!.temporalId - 1 + } + } + if (packet.frameInfo!!.temporalId <= decodableTid) { + Assert.assertTrue(accepted) + context.rewriteRtp(packetInfo) + Assert.assertEquals(expectedSeq, packet.sequenceNumber) + Assert.assertEquals(expectedTs, packet.timestamp) + Assert.assertEquals(expectedFrameNumber, packet.frameNumber) + expectedSeq = RtpUtils.applySequenceNumberDelta(expectedSeq, 1) + } else { + Assert.assertFalse(accepted) + } + if (packet.isEndOfFrame) { + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + expectedFrameNumber = RtpUtils.applySequenceNumberDelta(expectedFrameNumber, 1) + if (i % 97 == 0) { // Prime number so it's out of sync with packet cycles. + targetTid = (targetTid + 2) % 3 + targetIndex = getIndex(0, targetTid) + } + } + } + } + + private fun runLargeDropoutTest( + generator: Av1PacketGenerator, + targetIndex: Int, + expectAccept: (FrameInfo) -> Boolean + ) { + val diagnosticContext = DiagnosticContext() + diagnosticContext["test"] = Thread.currentThread().stackTrace[2].methodName + val initialState = RtpState(1, 10000, 1000000) + val context = Av1DDAdaptiveSourceProjectionContext( + diagnosticContext, + initialState, + logger + ) + var expectedSeq = 10001 + var expectedTs: Long = 1003000 + var expectedFrameNumber = 0 + for (i in 0..999) { + val packetInfo = generator.nextPacket() + val packet = packetInfo.packetAs() + + val accepted = context.accept(packetInfo, 0, targetIndex) + val frameInfo = packet.frameInfo!! + val endOfPicture = packet.isMarked + if (expectAccept(frameInfo)) { + Assert.assertTrue(accepted) + context.rewriteRtp(packetInfo) + Assert.assertEquals(expectedSeq, packet.sequenceNumber) + Assert.assertEquals(expectedTs, packet.timestamp) + Assert.assertEquals(expectedFrameNumber, packet.frameNumber) + expectedSeq = RtpUtils.applySequenceNumberDelta(expectedSeq, 1) + } else { + Assert.assertFalse(accepted) + } + if (packet.isEndOfFrame) { + expectedFrameNumber = RtpUtils.applySequenceNumberDelta(expectedFrameNumber, 1) + } + if (endOfPicture) { + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + } + } + for (gap in 64..65536 step { it * 2 }) { + for (i in 0 until gap) { + generator.nextPacket() + } + var packetInfo: PacketInfo + var packet: Av1DDPacket + var frameInfo: FrameInfo + do { + packetInfo = generator.nextPacket() + packet = packetInfo.packetAs() + frameInfo = packet.frameInfo!! + } while (!expectAccept(frameInfo)) + var endOfPicture = packet.isMarked + Assert.assertTrue(context.accept(packetInfo, 0, targetIndex)) + context.rewriteRtp(packetInfo) + + /* Allow any values after a gap. */ + expectedSeq = RtpUtils.applySequenceNumberDelta(packet.sequenceNumber, 1) + expectedTs = packet.timestamp + expectedFrameNumber = packet.frameNumber + if (packet.isEndOfFrame) { + expectedFrameNumber = RtpUtils.applySequenceNumberDelta(expectedFrameNumber, 1) + } + if (endOfPicture) { + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + } + for (i in 0..999) { + packetInfo = generator.nextPacket() + packet = packetInfo.packetAs() + val accepted = context.accept(packetInfo, 0, targetIndex) + endOfPicture = packet.isMarked + frameInfo = packet.frameInfo!! + if (expectAccept(frameInfo)) { + Assert.assertTrue(accepted) + context.rewriteRtp(packetInfo) + Assert.assertEquals(expectedSeq, packet.sequenceNumber) + Assert.assertEquals(expectedTs, packet.timestamp) + Assert.assertEquals(expectedFrameNumber, packet.frameNumber) + expectedSeq = RtpUtils.applySequenceNumberDelta(expectedSeq, 1) + } else { + Assert.assertFalse(accepted) + } + if (packet.isEndOfFrame) { + expectedFrameNumber = RtpUtils.applySequenceNumberDelta(expectedFrameNumber, 1) + } + if (endOfPicture) { + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + } + } + } + } + + @Test + fun largeDropoutTest() { + val generator = TemporallyScaledPacketGenerator(1) + runLargeDropoutTest(generator, getIndex(eid = 0, dt = 2)) { + true + } + } + + @Test + fun filteredDropoutTest() { + val generator = TemporallyScaledPacketGenerator(1) + runLargeDropoutTest(generator, getIndex(eid = 0, dt = 0)) { + it.temporalId == 0 + } + } + + @Test + fun largeFrameDropoutTest() { + val generator = TemporallyScaledPacketGenerator(3) + runLargeDropoutTest(generator, getIndex(eid = 0, dt = 2)) { + true + } + } + + @Test + fun largeFrameFilteredDropoutTest() { + val generator = TemporallyScaledPacketGenerator(3) + runLargeDropoutTest(generator, getIndex(eid = 0, dt = 0)) { + it.temporalId == 0 + } + } + + private fun runSourceSuspensionTest( + generator: Av1PacketGenerator, + targetIndex: Int, + expectAccept: (FrameInfo) -> Boolean + ) { + val diagnosticContext = DiagnosticContext() + diagnosticContext["test"] = Thread.currentThread().stackTrace[2].methodName + val initialState = RtpState(1, 10000, 1000000) + val context = Av1DDAdaptiveSourceProjectionContext( + diagnosticContext, + initialState, + logger + ) + var expectedSeq = 10001 + var expectedTs: Long = 1003000 + var expectedFrameNumber = 0 + + var packetInfo: PacketInfo + var packet: Av1DDPacket + var frameInfo: FrameInfo + + var lastPacketAccepted = false + var lastFrameAccepted = -1 + + for (i in 0..999) { + packetInfo = generator.nextPacket() + packet = packetInfo.packetAs() + frameInfo = packet.frameInfo!! + val accepted = context.accept(packetInfo, 0, targetIndex) + val endOfPicture = packet.isMarked + if (expectAccept(frameInfo)) { + Assert.assertTrue(accepted) + context.rewriteRtp(packetInfo) + Assert.assertEquals(expectedSeq, packet.sequenceNumber) + Assert.assertEquals(expectedTs, packet.timestamp) + Assert.assertEquals(expectedFrameNumber, packet.frameNumber) + expectedSeq = RtpUtils.applySequenceNumberDelta(expectedSeq, 1) + lastPacketAccepted = true + lastFrameAccepted = packet.frameNumber + } else { + Assert.assertFalse(accepted) + lastPacketAccepted = false + } + if (packet.isEndOfFrame) { + expectedFrameNumber = RtpUtils.applySequenceNumberDelta(expectedFrameNumber, 1) + } + if (endOfPicture) { + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + } + } + for (suspended in 64..65536 step { it * 2 }) { + /* If the last frame was accepted, finish the current frame if this generator is creating multi-packet + frames. */ + if (lastPacketAccepted) { + while (generator.packetOfFrame != 0) { + packetInfo = generator.nextPacket() + packet = packetInfo.packetAs() + + val accepted = context.accept(packetInfo, 0, targetIndex) + val endOfPicture = packet.isMarked + Assert.assertTrue(accepted) + context.rewriteRtp(packetInfo) + expectedSeq = RtpUtils.applySequenceNumberDelta(expectedSeq, 1) + if (packet.isEndOfFrame) { + expectedFrameNumber = RtpUtils.applySequenceNumberDelta(expectedFrameNumber, 1) + } + if (endOfPicture) { + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + } + } + } + /* Turn the source off for a time. */ + for (i in 0 until suspended) { + packetInfo = generator.nextPacket() + packet = packetInfo.packetAs() + + val accepted = context.accept(packetInfo, 0, RtpLayerDesc.SUSPENDED_INDEX) + Assert.assertFalse(accepted) + val endOfPicture = packet.isMarked + if (endOfPicture) { + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + } + } + + /* Switch back to wanting [targetIndex], but don't send a keyframe for a while. + * Should still be dropped. */ + for (i in 0 until 30) { + packetInfo = generator.nextPacket() + packet = packetInfo.packetAs() + + val accepted = context.accept(packetInfo, 0, targetIndex) + val endOfPicture = packet.isMarked + Assert.assertFalse(accepted) + if (endOfPicture) { + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + } + } + + /* Request a keyframe. Will be sent as of the next frame. */ + generator.requestKeyframe() + /* If this generator is creating multi-packet frames, finish the previous frame. */ + while (generator.packetOfFrame != 0) { + packetInfo = generator.nextPacket() + packet = packetInfo.packetAs() + val accepted = context.accept(packetInfo, 0, targetIndex) + val endOfPicture = packet.isMarked + Assert.assertFalse(accepted) + if (endOfPicture) { + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + } + } + expectedFrameNumber = RtpUtils.applySequenceNumberDelta(lastFrameAccepted, 1) + + for (i in 0..999) { + packetInfo = generator.nextPacket() + packet = packetInfo.packetAs() + frameInfo = packet.frameInfo!! + val accepted = context.accept(packetInfo, 0, targetIndex) + val endOfPicture = packet.isMarked + if (expectAccept(frameInfo)) { + Assert.assertTrue(accepted) + context.rewriteRtp(packetInfo) + Assert.assertEquals(expectedSeq, packet.sequenceNumber) + Assert.assertEquals(expectedTs, packet.timestamp) + Assert.assertEquals(expectedFrameNumber, packet.frameNumber) + expectedSeq = RtpUtils.applySequenceNumberDelta(expectedSeq, 1) + lastPacketAccepted = true + lastFrameAccepted = packet.frameNumber + } else { + Assert.assertFalse(accepted) + lastPacketAccepted = false + } + if (packet.isEndOfFrame) { + expectedFrameNumber = RtpUtils.applySequenceNumberDelta(expectedFrameNumber, 1) + } + if (endOfPicture) { + expectedTs = RtpUtils.applyTimestampDelta(expectedTs, 3000) + } + } + } + } + + @Test + fun sourceSuspensionTest() { + val generator = TemporallyScaledPacketGenerator(1) + runSourceSuspensionTest(generator, getIndex(eid = 0, dt = 2)) { + true + } + } + + @Test + fun filteredSourceSuspensionTest() { + val generator = TemporallyScaledPacketGenerator(1) + runSourceSuspensionTest(generator, getIndex(eid = 0, dt = 0)) { + it.temporalId == 0 + } + } + + @Test + fun largeFrameSourceSuspensionTest() { + val generator = TemporallyScaledPacketGenerator(3) + runSourceSuspensionTest(generator, getIndex(eid = 0, dt = 2)) { + true + } + } + + @Test + fun largeFrameFilteredSourceSuspensionTest() { + val generator = TemporallyScaledPacketGenerator(3) + runSourceSuspensionTest(generator, getIndex(eid = 0, dt = 0)) { + it.temporalId == 0 + } + } +} + +private open class Av1PacketGenerator( + val packetsPerFrame: Int, + val keyframeTemplates: Array, + val normalTemplates: Array, + // Equivalent to number of layers + val framesPerTimestamp: Int, + templateDdHex: String, + val allKeyframesGetStructure: Boolean = false +) { + private val logger: Logger = LoggerImpl(javaClass.name) + + var packetOfFrame = 0 + private set + private var frameOfPicture = 0 + + private var seq = 0 + var ts: Long = 0L + private set + var ssrc: Long = 0xcafebabeL + private var frameNumber = 0 + private var keyframePicture = false + private var keyframeRequested = false + private var pictureCount = 0 + private var receivedTime = baseReceivedTime + private var templateIdx = 0 + private var packetCount = 0 + private var octetCount = 0 + + private val structure: Av1TemplateDependencyStructure + + init { + val dd = parseHexBinary(templateDdHex) + structure = Av1DependencyDescriptorReader(dd, 0, dd.size).parse(null).structure + } + + fun reset() { + val useRandom = true // switch off to ease debugging + val seed = System.currentTimeMillis() + val random = Random(seed) + seq = if (useRandom) random.nextInt() % 0x10000 else 0 + ts = if (useRandom) random.nextLong() % 0x100000000L else 0 + frameNumber = 0 + packetOfFrame = 0 + frameOfPicture = 0 + keyframePicture = true + keyframeRequested = false + ssrc = 0xcafebabeL + pictureCount = 0 + receivedTime = baseReceivedTime + templateIdx = 0 + packetCount = 0 + octetCount = 0 + } + + fun nextPacket(missedStructure: Boolean = false): PacketInfo { + val startOfFrame = packetOfFrame == 0 + val endOfFrame = packetOfFrame == packetsPerFrame - 1 + val startOfPicture = startOfFrame && frameOfPicture == 0 + val endOfPicture = endOfFrame && frameOfPicture == framesPerTimestamp - 1 + + val templateId = ( + (if (keyframePicture) keyframeTemplates[templateIdx] else normalTemplates[templateIdx]) + + structure.templateIdOffset + ) % 64 + + val buffer = packetTemplate.clone() + val rtpPacket = RtpPacket(buffer, 0, buffer.size) + rtpPacket.ssrc = ssrc + rtpPacket.sequenceNumber = seq + rtpPacket.timestamp = ts + rtpPacket.isMarked = endOfPicture + + val dd = Av1DependencyDescriptorHeaderExtension( + startOfFrame = startOfFrame, + endOfFrame = endOfFrame, + frameDependencyTemplateId = templateId, + frameNumber = frameNumber, + newTemplateDependencyStructure = + if (keyframePicture && startOfFrame && (startOfPicture || allKeyframesGetStructure)) { + structure + } else { + null + }, + activeDecodeTargetsBitmask = null, + customDtis = null, + customFdiffs = null, + customChains = null, + structure = structure + ) + + val ext = rtpPacket.addHeaderExtension(AV1_DD_HEADER_EXTENSION_ID, dd.encodedLength) + dd.write(ext) + rtpPacket.encodeHeaderExtensions() + + val av1Packet = Av1DDPacket( + rtpPacket, + AV1_DD_HEADER_EXTENSION_ID, + if (missedStructure) null else structure, + logger + ) + + val info = PacketInfo(av1Packet) + info.receivedTime = receivedTime + + seq = RtpUtils.applySequenceNumberDelta(seq, 1) + packetCount++ + octetCount += av1Packet.length + + if (endOfFrame) { + packetOfFrame = 0 + if (endOfPicture) { + frameOfPicture = 0 + } else { + frameOfPicture++ + } + templateIdx++ + if (keyframeRequested) { + keyframePicture = true + templateIdx = 0 + } else if (keyframePicture) { + if (templateIdx >= keyframeTemplates.size) { + keyframePicture = false + } + } + frameNumber = RtpUtils.applySequenceNumberDelta(frameNumber, 1) + } else { + packetOfFrame++ + } + + if (endOfPicture) { + ts = RtpUtils.applyTimestampDelta(ts, 3000) + + keyframeRequested = false + if (templateIdx >= normalTemplates.size) { + templateIdx = 0 + } + pictureCount++ + receivedTime = baseReceivedTime + Duration.ofMillis(pictureCount * 100L / 3) + } + + return info + } + + fun requestKeyframe() { + if (packetOfFrame == 0) { + keyframePicture = true + templateIdx = 0 + } else { + keyframeRequested = true + } + } + + val srPacket: RtcpSrPacket + get() { + val srPacketBuilder = RtcpSrPacketBuilder() + srPacketBuilder.rtcpHeader.senderSsrc = ssrc + val siBuilder = srPacketBuilder.senderInfo + siBuilder.setNtpFromJavaTime(receivedTime.toEpochMilli()) + siBuilder.rtpTimestamp = ts + siBuilder.sendersOctetCount = packetCount.toLong() + siBuilder.sendersOctetCount = octetCount.toLong() + return srPacketBuilder.build() + } + + init { + reset() + } + + companion object { + val baseReceivedTime: Instant = Instant.ofEpochMilli(1577836800000L) // 2020-01-01 00:00:00 UTC + + const val AV1_DD_HEADER_EXTENSION_ID = 11 + + private val packetTemplate = DatatypeConverter.parseHexBinary( + // RTP Header + "80" + // V, P, X, CC + "29" + // M, PT + "0000" + // Seq + "00000000" + // TS + "cafebabe" + // SSRC + // Header extension will be added dynamically + // Dummy payload data + "0000000000000000000000" + ) + } +} + +private class NonScalableAv1PacketGenerator( + packetsPerFrame: Int +) : + Av1PacketGenerator( + packetsPerFrame, + arrayOf(0), + arrayOf(1), + 1, + "80000180003a410180ef808680" + ) + +private class TemporallyScaledPacketGenerator(packetsPerFrame: Int) : Av1PacketGenerator( + packetsPerFrame, + arrayOf(0), + arrayOf(1, 3, 2, 4), + 1, + "800001800214eaa860414d141020842701df010d" +) + +private class ScalableAv1PacketGenerator( + packetsPerFrame: Int +) : + Av1PacketGenerator( + packetsPerFrame, + arrayOf(1, 6, 11), + arrayOf(0, 5, 10, 3, 8, 13, 2, 7, 12, 4, 9, 14), + 3, + "d0013481e81485214eafffaaaa863cf0430c10c302afc0aaa0063c00430010c002a000a800060000" + + "40001d954926e082b04a0941b820ac1282503157f974000ca864330e222222eca8655304224230ec" + + "a87753013f00b3027f016704ff02cf" + ) + +private class KeyScalableAv1PacketGenerator( + packetsPerFrame: Int +) : + Av1PacketGenerator( + packetsPerFrame, + arrayOf(0, 5, 10), + arrayOf(1, 6, 11, 3, 8, 13, 2, 7, 12, 4, 9, 14), + 3, + "8f008581e81485214eaaaaa8000600004000100002aa80a8000600004000100002a000a80006000040" + + "0016d549241b5524906d54923157e001974ca864330e222396eca8655304224390eca87753013f00b3027f016704ff02cf" + ) + +private class SingleEncodingSimulcastAv1PacketGenerator( + packetsPerFrame: Int +) : + Av1PacketGenerator( + packetsPerFrame, + arrayOf(1, 6, 11), + arrayOf(0, 5, 10, 3, 8, 13, 2, 7, 12, 4, 9, 14), + 3, + "c1000180081485214ea000a8000600004000100002a000a8000600004000100002a000a8000600004" + + "0001d954926caa493655248c55fe5d00032a190cc38e58803b2a1954c10e10843b2a1dd4c01dc010803bc0218077c0434", + allKeyframesGetStructure = true + ) + +private infix fun IntRange.step(next: (Int) -> Int) = + generateSequence(first, next).takeWhile { if (first < last) it <= last else it >= last } diff --git a/jvb/src/test/kotlin/org/jitsi/videobridge/cc/av1/Av1DDQualityFilterTest.kt b/jvb/src/test/kotlin/org/jitsi/videobridge/cc/av1/Av1DDQualityFilterTest.kt new file mode 100644 index 0000000000..0ab04c8ec5 --- /dev/null +++ b/jvb/src/test/kotlin/org/jitsi/videobridge/cc/av1/Av1DDQualityFilterTest.kt @@ -0,0 +1,896 @@ +/* + * Copyright @ 2019 - present 8x8, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jitsi.videobridge.cc.av1 + +import io.kotest.core.spec.style.ShouldSpec +import io.kotest.matchers.shouldBe +import jakarta.xml.bind.DatatypeConverter +import org.jitsi.nlj.rtp.codec.av1.Av1DDRtpLayerDesc +import org.jitsi.rtp.rtp.header_extensions.Av1DependencyDescriptorReader +import org.jitsi.rtp.rtp.header_extensions.Av1TemplateDependencyStructure +import org.jitsi.utils.logging2.LoggerImpl +import org.jitsi.utils.logging2.getClassForLogging +import java.time.Instant + +internal class Av1DDQualityFilterTest : ShouldSpec() { + init { + context("A non-scalable stream") { + should("be entirely projected") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = SingleLayerFrameGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 0) + + testGenerator(generator, filter, targetIndex) { _, result -> + result.accept shouldBe true + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + } + context("A temporally scalable stream") { + should("be entirely projected when TL2 is requested") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = TemporallyScaledFrameGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 2) + + testGenerator(generator, filter, targetIndex) { _, result -> + result.accept shouldBe true + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + should("project only the base temporal layer when targeted") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = TemporallyScaledFrameGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 0) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe (f.frameInfo!!.temporalId == 0) + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + } + should("project only the intermediate temporal layer when targeted") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = TemporallyScaledFrameGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 1) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe (f.frameInfo!!.temporalId <= 1) + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + } + should("be able to switch the targeted layers, without a keyframe") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = TemporallyScaledFrameGenerator(av1FrameMaps) + val targetIndex1 = Av1DDRtpLayerDesc.getIndex(0, 0) + + testGenerator(generator, filter, targetIndex1, numFrames = 500) { f, result -> + result.accept shouldBe (f.frameInfo!!.temporalId == 0) + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + val targetIndex2 = Av1DDRtpLayerDesc.getIndex(0, 2) + + testGenerator(generator, filter, targetIndex2) { _, result -> + result.accept shouldBe true + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + } + context("A spatially scalable stream") { + should("be entirely projected when SL2/TL2 is requested (L3T3)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = SVCFrameGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 3 * 2 + 2) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe true + result.mark shouldBe (f.frameInfo!!.spatialId == 2) + filter.needsKeyframe shouldBe false + } + } + should("be able to be shaped to SL0/TL2 (L3T3)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = SVCFrameGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 2) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe (f.frameInfo!!.spatialId == 0) + if (result.accept) { + result.mark shouldBe (f.frameInfo!!.spatialId == 0) + filter.needsKeyframe shouldBe false + } + } + } + should("be able to be shaped to SL1/TL2 (L3T3)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = SVCFrameGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 3 * 1 + 2) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe (f.frameInfo!!.spatialId <= 1) + if (result.accept) { + result.mark shouldBe (f.frameInfo!!.spatialId == 1) + filter.needsKeyframe shouldBe false + } + } + } + should("be able to be shaped to SL2/TL0 (L3T3)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = SVCFrameGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 3 * 2 + 0) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe (f.frameInfo!!.temporalId == 0) + if (result.accept) { + result.mark shouldBe (f.frameInfo!!.spatialId == 2) + filter.needsKeyframe shouldBe false + } + } + } + should("be able to switch spatial layers (L3T3)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = SVCFrameGenerator(av1FrameMaps) + + /* Start by sending spatial layer 0. */ + val targetIndex1 = Av1DDRtpLayerDesc.getIndex(0, 2) + + testGenerator(generator, filter, targetIndex1, numFrames = 1200) { f, result -> + result.accept shouldBe (f.frameInfo!!.spatialId == 0) + if (result.accept) { + result.mark shouldBe (f.frameInfo!!.spatialId == 0) + filter.needsKeyframe shouldBe false + } + } + + /* Switch to spatial layer 2. Need a keyframe. */ + val targetIndex2 = Av1DDRtpLayerDesc.getIndex(0, 3 * 2 + 2) + var sawKeyframe = false + testGenerator(generator, filter, targetIndex2, numFrames = 1200) { f, result -> + if (f.isKeyframe) sawKeyframe = true + result.accept shouldBe if (!sawKeyframe) (f.frameInfo!!.spatialId == 0) else true + if (result.accept) { + result.mark shouldBe if (!sawKeyframe) { + (f.frameInfo!!.spatialId == 0) + } else { + (f.frameInfo!!.spatialId == 2) + } + filter.needsKeyframe shouldBe (!sawKeyframe) + } + } + + /* Switch to spatial layer 1. For SVC, dropping down in spatial layers can happen immediately. */ + val targetIndex3 = Av1DDRtpLayerDesc.getIndex(0, 3 * 1 + 2) + testGenerator(generator, filter, targetIndex3) { f, result -> + result.accept shouldBe (f.frameInfo!!.spatialId <= 1) + if (result.accept) { + result.mark shouldBe (f.frameInfo!!.spatialId == 1) + filter.needsKeyframe shouldBe false + } + } + } + } + context("A K-SVC spatially scalable stream") { + should("be able to be shaped to SL2/TL2 (L3T3_KEY)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = KSVCFrameGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 3 * 2 + 2) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe ( + f.frameInfo!!.spatialId == 2 || !f.frameInfo!!.hasInterPictureDependency() + ) + result.mark shouldBe (f.frameInfo!!.spatialId == 2) + filter.needsKeyframe shouldBe false + } + } + should("be able to be shaped to SL0/TL2 (L3T3_KEY)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = KSVCFrameGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 2) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe (f.frameInfo!!.spatialId == 0) + if (result.accept) { + result.mark shouldBe (f.frameInfo!!.spatialId == 0) + filter.needsKeyframe shouldBe false + } + } + } + should("be able to be shaped to SL1/TL2 (L3T3_KEY)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = KSVCFrameGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 3 * 1 + 2) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe ( + f.frameInfo!!.spatialId == 1 || ( + f.frameInfo!!.spatialId == 0 && !f.frameInfo!!.hasInterPictureDependency() + ) + ) + if (result.accept) { + result.mark shouldBe (f.frameInfo!!.spatialId == 1) + filter.needsKeyframe shouldBe false + } + } + } + should("be able to be shaped to SL2/TL0 (L3T3_KEY)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = KSVCFrameGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 3 * 2 + 0) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe ( + f.frameInfo!!.temporalId == 0 && ( + f.frameInfo!!.spatialId == 2 || !f.frameInfo!!.hasInterPictureDependency() + ) + ) + if (result.accept) { + result.mark shouldBe (f.frameInfo!!.spatialId == 2) + filter.needsKeyframe shouldBe false + } + } + } + should("be able to switch spatial layers (L3T3_KEY)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = KSVCFrameGenerator(av1FrameMaps) + + /* Start by sending spatial layer 0. */ + val targetIndex1 = Av1DDRtpLayerDesc.getIndex(0, 2) + + testGenerator(generator, filter, targetIndex1, numFrames = 1200) { f, result -> + result.accept shouldBe (f.frameInfo!!.spatialId == 0) + if (result.accept) { + result.mark shouldBe (f.frameInfo!!.spatialId == 0) + filter.needsKeyframe shouldBe false + } + } + + /* Switch to spatial layer 2. Need a keyframe. */ + val targetIndex2 = Av1DDRtpLayerDesc.getIndex(0, 3 * 2 + 2) + var sawKeyframe = false + testGenerator(generator, filter, targetIndex2, numFrames = 1200) { f, result -> + if (f.isKeyframe) sawKeyframe = true + result.accept shouldBe if (!sawKeyframe) { + (f.frameInfo!!.spatialId == 0) + } else { + (f.frameInfo!!.spatialId == 2 || !f.frameInfo!!.hasInterPictureDependency()) + } + if (result.accept) { + result.mark shouldBe if (!sawKeyframe) { + (f.frameInfo!!.spatialId == 0) + } else { + (f.frameInfo!!.spatialId == 2) + } + filter.needsKeyframe shouldBe (!sawKeyframe) + } + } + + /* Switch to spatial layer 1. For K-SVC, dropping down in spatial layers needs a keyframe. */ + val targetIndex3 = Av1DDRtpLayerDesc.getIndex(0, 3 * 1 + 2) + sawKeyframe = false + testGenerator(generator, filter, targetIndex3) { f, result -> + if (f.isKeyframe) sawKeyframe = true + result.accept shouldBe if (!sawKeyframe) { + (f.frameInfo!!.spatialId == 2 || !f.frameInfo!!.hasInterPictureDependency()) + } else { + ( + f.frameInfo!!.spatialId == 1 || ( + f.frameInfo!!.spatialId == 0 && !f.frameInfo!!.hasInterPictureDependency() + ) + ) + } + if (result.accept) { + result.mark shouldBe if (!sawKeyframe) { + (f.frameInfo!!.spatialId == 2) + } else { + (f.frameInfo!!.spatialId == 1) + } + filter.needsKeyframe shouldBe (!sawKeyframe) + } + } + } + } + context("A K-SVC spatially scalable stream with a temporal shift") { + should("be able to be shaped to SL1/TL1 (L2S2_KEY_SHIFT)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = KSVCShiftFrameGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 2 * 1 + 1) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe ( + f.frameInfo!!.spatialId == 1 || !f.frameInfo!!.hasInterPictureDependency() + ) + result.mark shouldBe (f.frameInfo!!.spatialId == 1) + filter.needsKeyframe shouldBe false + } + } + should("be able to be shaped to SL0/TL1 (L2S2_KEY_SHIFT)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = KSVCShiftFrameGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 1) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe (f.frameInfo!!.spatialId == 0) + if (result.accept) { + result.mark shouldBe (f.frameInfo!!.spatialId == 0) + filter.needsKeyframe shouldBe false + } + } + } + should("be able to be shaped to SL1/TL0 (L2S2_KEY_SHIFT)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = KSVCShiftFrameGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 2 * 1 + 0) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe ( + f.frameInfo!!.temporalId == 0 && ( + f.frameInfo!!.spatialId == 1 || !f.frameInfo!!.hasInterPictureDependency() + ) + ) + if (result.accept) { + result.mark shouldBe (f.frameInfo!!.spatialId == 1) + filter.needsKeyframe shouldBe false + } + } + } + should("be able to switch spatial layers (L2S2_KEY_SHIFT)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = KSVCShiftFrameGenerator(av1FrameMaps) + + /* Start by sending spatial layer 0. */ + val targetIndex1 = Av1DDRtpLayerDesc.getIndex(0, 1) + + testGenerator(generator, filter, targetIndex1, numFrames = 1200) { f, result -> + result.accept shouldBe (f.frameInfo!!.spatialId == 0) + if (result.accept) { + result.mark shouldBe (f.frameInfo!!.spatialId == 0) + filter.needsKeyframe shouldBe false + } + } + + /* Switch to spatial layer 1. Need a keyframe. */ + val targetIndex2 = Av1DDRtpLayerDesc.getIndex(0, 2 * 1 + 1) + var sawKeyframe = false + testGenerator(generator, filter, targetIndex2, numFrames = 1200) { f, result -> + if (f.isKeyframe) sawKeyframe = true + result.accept shouldBe if (!sawKeyframe) { + (f.frameInfo!!.spatialId == 0) + } else { + (f.frameInfo!!.spatialId == 1 || !f.frameInfo!!.hasInterPictureDependency()) + } + if (result.accept) { + result.mark shouldBe if (!sawKeyframe) { + (f.frameInfo!!.spatialId == 0) + } else { + (f.frameInfo!!.spatialId == 1) + } + filter.needsKeyframe shouldBe (!sawKeyframe) + } + } + + /* Switch back to spatial layer 0. For K-SVC, dropping down in spatial layers needs a keyframe. */ + val targetIndex3 = Av1DDRtpLayerDesc.getIndex(0, 1) + sawKeyframe = false + testGenerator(generator, filter, targetIndex3) { f, result -> + if (f.isKeyframe) sawKeyframe = true + result.accept shouldBe if (!sawKeyframe) { + (f.frameInfo!!.spatialId == 1 || !f.frameInfo!!.hasInterPictureDependency()) + } else { + (f.frameInfo!!.spatialId == 0) + } + if (result.accept) { + result.mark shouldBe if (!sawKeyframe) { + (f.frameInfo!!.spatialId == 1) + } else { + (f.frameInfo!!.spatialId == 0) + } + filter.needsKeyframe shouldBe (!sawKeyframe) + } + } + } + } + context("A single-encoding simulcast stream") { + should("project all of layer 2 when when SL2/TL2 is requested (S3T3)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = SingleEncodingSimulcastGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 3 * 2 + 2) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe (f.frameInfo!!.spatialId == 2) + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + } + should("be able to be shaped to SL0/TL2 (S3T3)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = SingleEncodingSimulcastGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 2) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe (f.frameInfo!!.spatialId == 0) + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + } + should("be able to be shaped to SL1/TL2 (S3T3)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = SingleEncodingSimulcastGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 3 * 1 + 2) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe (f.frameInfo!!.spatialId == 1) + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + } + should("be able to be shaped to SL2/TL0 (S3T3)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = SingleEncodingSimulcastGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 3 * 2 + 0) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe (f.frameInfo!!.spatialId == 2 && f.frameInfo!!.temporalId == 0) + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + } + should("be able to switch spatial layers (S3T3)") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = SingleEncodingSimulcastGenerator(av1FrameMaps) + + /* Start by sending spatial layer 0. */ + val targetIndex1 = Av1DDRtpLayerDesc.getIndex(0, 2) + + testGenerator(generator, filter, targetIndex1, numFrames = 1200) { f, result -> + result.accept shouldBe (f.frameInfo!!.spatialId == 0) + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + + /* Switch to spatial layer 2. Need a keyframe. */ + val targetIndex2 = Av1DDRtpLayerDesc.getIndex(0, 3 * 2 + 2) + var sawKeyframe = false + testGenerator(generator, filter, targetIndex2, numFrames = 1200) { f, result -> + if (f.isKeyframe) sawKeyframe = true + result.accept shouldBe if (!sawKeyframe) { + (f.frameInfo!!.spatialId == 0) + } else { + (f.frameInfo!!.spatialId == 2) + } + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe (!sawKeyframe) + } + } + + /* Switch to spatial layer 1. Need a keyframe. */ + val targetIndex3 = Av1DDRtpLayerDesc.getIndex(0, 3 * 1 + 2) + sawKeyframe = false + testGenerator(generator, filter, targetIndex3) { f, result -> + if (f.isKeyframe) sawKeyframe = true + result.accept shouldBe if (!sawKeyframe) { + (f.frameInfo!!.spatialId == 2) + } else { + (f.frameInfo!!.spatialId == 1) + } + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe (!sawKeyframe) + } + } + } + } + context("A multi-encoding simulcast stream") { + should("project all of encoding 2 when when Enc 2/TL2 is requested") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = MultiEncodingSimulcastGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(2, 2) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe (f.ssrc == 2L || f.isKeyframe) + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + } + should("be able to be shaped to Enc 0/TL2") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = MultiEncodingSimulcastGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(0, 2) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe (f.ssrc == 0L) + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + } + should("be able to be shaped to Enc 1/TL2") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = MultiEncodingSimulcastGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(1, 2) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe (f.ssrc == 1L || (f.isKeyframe && f.ssrc == 0L)) + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + } + should("be able to be shaped to Enc 2/TL0") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = MultiEncodingSimulcastGenerator(av1FrameMaps) + val targetIndex = Av1DDRtpLayerDesc.getIndex(2, 0) + + testGenerator(generator, filter, targetIndex) { f, result -> + result.accept shouldBe ((f.ssrc == 2L || f.isKeyframe) && f.frameInfo!!.temporalId == 0) + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + } + should("be able to switch encodings") { + val av1FrameMaps = HashMap() + + val filter = Av1DDQualityFilter(av1FrameMaps, logger) + val generator = MultiEncodingSimulcastGenerator(av1FrameMaps) + + /* Start by sending encoding 0. */ + val targetIndex1 = Av1DDRtpLayerDesc.getIndex(0, 2) + + testGenerator(generator, filter, targetIndex1, numFrames = 1200) { f, result -> + result.accept shouldBe (f.ssrc == 0L) + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe false + } + } + + /* Switch to encoding 2. Need a keyframe. */ + val targetIndex2 = Av1DDRtpLayerDesc.getIndex(2, 2) + var sawKeyframe = false + testGenerator(generator, filter, targetIndex2, numFrames = 1200) { f, result -> + if (f.isKeyframe) sawKeyframe = true + result.accept shouldBe if (!sawKeyframe) { + (f.ssrc == 0L) + } else { + (f.ssrc == 2L || f.isKeyframe) + } + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe (!sawKeyframe) + } + } + + /* Switch to encoding 1. Need a keyframe. */ + val targetIndex3 = Av1DDRtpLayerDesc.getIndex(1, 2) + sawKeyframe = false + testGenerator(generator, filter, targetIndex3) { f, result -> + if (f.isKeyframe) sawKeyframe = true + result.accept shouldBe if (!sawKeyframe) { + // We don't send discardable frames for the DT while there's a pending encoding downswitch + (f.ssrc == 2L && f.frameInfo!!.temporalId != 2) + } else { + (f.ssrc == 1L || (f.ssrc == 0L && f.isKeyframe)) + } + if (result.accept) { + result.mark shouldBe true + filter.needsKeyframe shouldBe (!sawKeyframe) + } + } + } + } + } + + private fun testGenerator( + g: FrameGenerator, + filter: Av1DDQualityFilter, + targetIndex: Int, + numFrames: Int = Int.MAX_VALUE, + evaluator: (Av1DDFrame, Av1DDQualityFilter.AcceptResult) -> Unit + ) { + var lastTs = -1L + var ms = -1L + var frames = 0 + while (g.hasNext() && frames < numFrames) { + val f = g.next() + + ms = if (f.timestamp != lastTs) { + f.timestamp / 90 + } else { + ms + 1 + } + lastTs = f.timestamp + + val result = filter.acceptFrame( + frame = f, + externalTargetIndex = targetIndex, + incomingEncoding = f.ssrc.toInt(), + receivedTime = Instant.ofEpochMilli(ms) + ) + f.isAccepted = result.accept + evaluator(f, result) + frames++ + } + } + + companion object { + val logger = LoggerImpl(getClassForLogging(this::class.java).name) + } +} + +private abstract class FrameGenerator : Iterator + +private open class DDBasedGenerator( + val av1FrameMaps: HashMap, + val keyframeInterval: Int, + val keyframeTemplates: Array, + val normalTemplates: Array, + ddHex: String +) : FrameGenerator() { + private var frameCount = 0 + private val structure: Av1TemplateDependencyStructure + + init { + val dd = DatatypeConverter.parseHexBinary(ddHex) + structure = Av1DependencyDescriptorReader(dd, 0, dd.size).parse(null).structure + } + + override fun hasNext(): Boolean = frameCount < TOTAL_FRAMES + + protected open fun isKeyframe(keyCycle: Int) = keyCycle == 0 + + override fun next(): Av1DDFrame { + val tCycle = frameCount % normalTemplates.size + val keyCycle = frameCount % keyframeInterval + + val templateId = if (keyCycle < keyframeTemplates.size) { + keyframeTemplates[tCycle] + } else { + normalTemplates[tCycle] + } + + val f = Av1DDFrame( + ssrc = 0, + timestamp = frameCount * 3000L, + earliestKnownSequenceNumber = frameCount, + latestKnownSequenceNumber = frameCount, + seenStartOfFrame = true, + seenEndOfFrame = true, + seenMarker = true, + frameInfo = structure.templateInfo[templateId], + // Will be less than 0xffff + frameNumber = frameCount, + index = frameCount, + templateId = templateId, + structure = structure, + activeDecodeTargets = null, + isKeyframe = isKeyframe(keyCycle), + rawDependencyDescriptor = null + ) + av1FrameMaps.getOrPut(f.ssrc) { Av1DDFrameMap(Av1DDQualityFilterTest.logger) }.insertFrame(f) + frameCount++ + return f + } + + companion object { + private const val TOTAL_FRAMES = 10000 + } +} + +/** Generate a non-scalable AV1 stream, with a single keyframe at the start. */ +private class SingleLayerFrameGenerator(av1FrameMaps: HashMap) : DDBasedGenerator( + av1FrameMaps, + 10000, + arrayOf(0), + arrayOf(1), + "80000180003a410180ef808680" +) + +/** Generate a temporally-scaled series of AV1 frames, with a single keyframe at the start. */ +private class TemporallyScaledFrameGenerator(av1FrameMaps: HashMap) : DDBasedGenerator( + av1FrameMaps, + 10000, + arrayOf(0), + arrayOf(1, 3, 2, 4), + "800001800214eaa860414d141020842701df010d" +) + +/** Generate a spatially-scaled series of AV1 frames (L3T3), with full spatial dependencies and periodic keyframes. */ +private class SVCFrameGenerator(av1FrameMaps: HashMap) : DDBasedGenerator( + av1FrameMaps, + 144, + arrayOf(1, 6, 11), + arrayOf(0, 5, 10, 3, 8, 13, 2, 7, 12, 4, 9, 14), + "d0013481e81485214eafffaaaa863cf0430c10c302afc0aaa0063c00430010c002a000a800060000" + + "40001d954926e082b04a0941b820ac1282503157f974000ca864330e222222eca8655304224230ec" + + "a87753013f00b3027f016704ff02cf" +) + +/** Generate a spatially-scaled series of AV1 frames (L3T3), with keyframe spatial dependencies and periodic + * keyframes. */ +private class KSVCFrameGenerator(av1FrameMaps: HashMap) : DDBasedGenerator( + av1FrameMaps, + 144, + arrayOf(0, 5, 10), + arrayOf(1, 6, 11, 3, 8, 13, 2, 7, 12, 4, 9, 14), + "8f008581e81485214eaaaaa8000600004000100002aa80a8000600004000100002a000a80006000040" + + "0016d549241b5524906d54923157e001974ca864330e222396eca8655304224390eca87753013f00b3027f016704ff02cf" +) + +/** Generate a spatially-scaled series of AV1 frames (L2T2), with keyframe spatial dependencies and periodic + * keyframes, with temporal structures shifted. */ +/* Note that as of Chrome 111, L3T3_KEY_SHIFT is not supported yet, so we're testing L2T2_KEY_SHIFT instead. */ +private class KSVCShiftFrameGenerator(av1FrameMaps: HashMap) : DDBasedGenerator( + av1FrameMaps, + 144, + arrayOf(0, 4, 1), + arrayOf(2, 6, 3, 5), + "8700ed80e3061eaa82804028280514d14134518010a091889a09409fc059c13fc0b3c0" +) + +/** Generate a single-stream temporally-scaled simulcast (S3T3) series of AV1 frames, with periodic keyframes. */ +private class SingleEncodingSimulcastGenerator(av1FrameMaps: HashMap) : DDBasedGenerator( + av1FrameMaps, + 144, + arrayOf(1, 6, 11), + arrayOf(0, 5, 10, 3, 8, 13, 2, 7, 12, 4, 9, 14), + "c1000180081485214ea000a8000600004000100002a000a8000600004000100002a000a8000600004" + + "0001d954926caa493655248c55fe5d00032a190cc38e58803b2a1954c10e10843b2a1dd4c01dc010803bc0218077c0434" +) { + // All frames of the initial picture get the DD structure attached + override fun isKeyframe(keyCycle: Int) = keyCycle < keyframeTemplates.size +} + +private class MultiEncodingSimulcastGenerator(val av1FrameMaps: HashMap) : FrameGenerator() { + private var frameCount = 0 + + override fun hasNext(): Boolean = frameCount < TOTAL_FRAMES + + override fun next(): Av1DDFrame { + val pictureCount = frameCount / NUM_ENCODINGS + val encoding = frameCount % NUM_ENCODINGS + val tCycle = pictureCount % normalTemplates.size + val keyCycle = pictureCount % KEYFRAME_INTERVAL + + val templateId = if (keyCycle < keyframeTemplates.size) { + keyframeTemplates[tCycle] + } else { + normalTemplates[tCycle] + } + + val keyframePicture = keyCycle == 0 + + val f = Av1DDFrame( + ssrc = encoding.toLong(), + timestamp = pictureCount * 3000L, + earliestKnownSequenceNumber = pictureCount, + latestKnownSequenceNumber = pictureCount, + seenStartOfFrame = true, + seenEndOfFrame = true, + seenMarker = true, + frameInfo = structure.templateInfo[templateId], + // Will be less than 0xffff + frameNumber = pictureCount, + index = pictureCount, + templateId = templateId, + structure = structure, + activeDecodeTargets = null, + isKeyframe = keyframePicture, + rawDependencyDescriptor = null + ) + av1FrameMaps.getOrPut(f.ssrc) { Av1DDFrameMap(Av1DDQualityFilterTest.logger) }.insertFrame(f) + frameCount++ + return f + } + + companion object { + private const val TOTAL_FRAMES = 10000 + private const val KEYFRAME_INTERVAL = 144 + private const val NUM_ENCODINGS = 3 + private val keyframeTemplates = arrayOf(0) + private val normalTemplates = arrayOf(1, 3, 2, 4) + + private val structure: Av1TemplateDependencyStructure + + init { + val dd = DatatypeConverter.parseHexBinary("800001800214eaa860414d141020842701df010d") + structure = Av1DependencyDescriptorReader(dd, 0, dd.size).parse(null).structure + } + } +} diff --git a/jvb/src/test/kotlin/org/jitsi/videobridge/cc/vp9/Vp9AdaptiveSourceProjectionTest.kt b/jvb/src/test/kotlin/org/jitsi/videobridge/cc/vp9/Vp9AdaptiveSourceProjectionTest.kt index 5c89fd6c36..1eaac25854 100644 --- a/jvb/src/test/kotlin/org/jitsi/videobridge/cc/vp9/Vp9AdaptiveSourceProjectionTest.kt +++ b/jvb/src/test/kotlin/org/jitsi/videobridge/cc/vp9/Vp9AdaptiveSourceProjectionTest.kt @@ -23,13 +23,10 @@ import org.jitsi.nlj.RtpLayerDesc.Companion.getTidFromIndex import org.jitsi.nlj.codec.vpx.VpxUtils.Companion.applyExtendedPictureIdDelta import org.jitsi.nlj.codec.vpx.VpxUtils.Companion.applyTl0PicIdxDelta import org.jitsi.nlj.codec.vpx.VpxUtils.Companion.getExtendedPictureIdDelta -import org.jitsi.nlj.format.PayloadType -import org.jitsi.nlj.format.Vp9PayloadType import org.jitsi.nlj.rtp.codec.vp9.Vp9Packet import org.jitsi.nlj.util.Rfc3711IndexTracker import org.jitsi.rtp.rtcp.RtcpSrPacket import org.jitsi.rtp.rtcp.RtcpSrPacketBuilder -import org.jitsi.rtp.rtcp.SenderInfoBuilder import org.jitsi.rtp.rtp.RtpPacket import org.jitsi.rtp.util.RtpUtils import org.jitsi.rtp.util.isNewerThan @@ -46,18 +43,11 @@ import java.time.Duration import java.time.Instant import java.util.Random import java.util.TreeMap -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.CopyOnWriteArraySet import javax.xml.bind.DatatypeConverter import kotlin.collections.ArrayList class Vp9AdaptiveSourceProjectionTest { private val logger: Logger = LoggerImpl(javaClass.name) - private val payloadType: PayloadType = Vp9PayloadType( - 96.toByte(), - ConcurrentHashMap(), - CopyOnWriteArraySet() - ) @Test fun singlePacketProjectionTest() { @@ -66,7 +56,6 @@ class Vp9AdaptiveSourceProjectionTest { val initialState = RtpState(1, 10000, 1000000) val context = Vp9AdaptiveSourceProjectionContext( diagnosticContext, - payloadType, initialState, logger ) @@ -74,13 +63,7 @@ class Vp9AdaptiveSourceProjectionTest { val packetInfo = generator.nextPacket() val packet = packetInfo.packetAs() val targetIndex = getIndex(eid = 0, sid = 0, tid = 0) - Assert.assertTrue( - context.accept( - packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertTrue(context.accept(packetInfo, 0, targetIndex)) context.rewriteRtp(packetInfo) Assert.assertEquals(10001, packet.sequenceNumber) Assert.assertEquals(1003000, packet.timestamp) @@ -95,7 +78,6 @@ class Vp9AdaptiveSourceProjectionTest { val initialState = RtpState(1, 10000, 1000000) val context = Vp9AdaptiveSourceProjectionContext( diagnosticContext, - payloadType, initialState, logger ) @@ -110,7 +92,7 @@ class Vp9AdaptiveSourceProjectionTest { val packet = packetInfo.packetAs() val accepted = context.accept( packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), + 0, targetIndex ) if (!packet.hasLayerIndices) { @@ -146,7 +128,7 @@ class Vp9AdaptiveSourceProjectionTest { } } - private class ProjectedPacket internal constructor( + private class ProjectedPacket constructor( val packet: Vp9Packet, val origSeq: Int, val extOrigSeq: Int, @@ -176,7 +158,6 @@ class Vp9AdaptiveSourceProjectionTest { var orderedCount = initialOrderedCount - 1 val context = Vp9AdaptiveSourceProjectionContext( diagnosticContext, - payloadType, initialState, logger ) @@ -193,11 +174,7 @@ class Vp9AdaptiveSourceProjectionTest { if (latestSeq isOlderThan origSeq) { latestSeq = origSeq } - val accepted = context.accept( - packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), - targetIndex - ) + val accepted = context.accept(packetInfo, 0, targetIndex) val oldestValidSeq: Int = RtpUtils.applySequenceNumberDelta( latestSeq, @@ -501,42 +478,20 @@ class Vp9AdaptiveSourceProjectionTest { val initialState = RtpState(1, 10000, 1000000) val context = Vp9AdaptiveSourceProjectionContext( diagnosticContext, - payloadType, initialState, logger ) val firstPacketInfo = generator.nextPacket() - val firstPacket = firstPacketInfo.packetAs() val targetIndex = getIndex(eid = 0, sid = 0, tid = 2) for (i in 0..2) { val packetInfo = generator.nextPacket() - val packet = packetInfo.packetAs() - Assert.assertFalse( - context.accept( - packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertFalse(context.accept(packetInfo, 0, targetIndex)) } - Assert.assertTrue( - context.accept( - firstPacketInfo, - getIndex(0, firstPacket.spatialLayerIndex, firstPacket.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertTrue(context.accept(firstPacketInfo, 0, targetIndex)) context.rewriteRtp(firstPacketInfo) for (i in 0..9995) { val packetInfo = generator.nextPacket() - val packet = packetInfo.packetAs() - Assert.assertTrue( - context.accept( - packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertTrue(context.accept(packetInfo, 0, targetIndex)) context.rewriteRtp(packetInfo) } } @@ -549,53 +504,24 @@ class Vp9AdaptiveSourceProjectionTest { val initialState = RtpState(1, 10000, 1000000) val context = Vp9AdaptiveSourceProjectionContext( diagnosticContext, - payloadType, initialState, logger ) val firstPacketInfo = generator.nextPacket() - val firstPacket = firstPacketInfo.packetAs() val targetIndex = getIndex(eid = 0, sid = 0, tid = 2) for (i in 0..3) { val packetInfo = generator.nextPacket() - val packet = packetInfo.packetAs() - Assert.assertFalse( - context.accept( - packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertFalse(context.accept(packetInfo, 0, targetIndex)) } - Assert.assertFalse( - context.accept( - firstPacketInfo, - getIndex(0, firstPacket.spatialLayerIndex, firstPacket.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertFalse(context.accept(firstPacketInfo, 0, targetIndex)) for (i in 0..9) { val packetInfo = generator.nextPacket() - val packet = packetInfo.packetAs() - Assert.assertFalse( - context.accept( - packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertFalse(context.accept(packetInfo, 0, targetIndex)) } generator.requestKeyframe() for (i in 0..9995) { val packetInfo = generator.nextPacket() - val packet = packetInfo.packetAs() - Assert.assertTrue( - context.accept( - packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertTrue(context.accept(packetInfo, 0, targetIndex)) context.rewriteRtp(packetInfo) } } @@ -608,7 +534,6 @@ class Vp9AdaptiveSourceProjectionTest { val initialState = RtpState(1, 10000, 1000000) val context = Vp9AdaptiveSourceProjectionContext( diagnosticContext, - payloadType, initialState, logger ) @@ -619,38 +544,19 @@ class Vp9AdaptiveSourceProjectionTest { for (i in 0..10) { val packetInfo = generator.nextPacket() val packet = packetInfo.packetAs() - Assert.assertTrue( - context.accept( - packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertTrue(context.accept(packetInfo, 0, targetIndex)) context.rewriteRtp(packetInfo) Assert.assertTrue(packet.sequenceNumber > 10001) lowestSeq = minOf(lowestSeq, packet.sequenceNumber) } - Assert.assertTrue( - context.accept( - firstPacketInfo, - getIndex(0, firstPacket.spatialLayerIndex, firstPacket.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertTrue(context.accept(firstPacketInfo, 0, targetIndex)) context.rewriteRtp(firstPacketInfo) Assert.assertEquals(lowestSeq - 1, firstPacket.sequenceNumber) for (i in 0..9980) { val packetInfo = generator.nextPacket() - val packet = packetInfo.packetAs() - Assert.assertTrue( - context.accept( - packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertTrue(context.accept(packetInfo, 0, targetIndex)) context.rewriteRtp(packetInfo) } } @@ -665,7 +571,6 @@ class Vp9AdaptiveSourceProjectionTest { val initialState = RtpState(1, 10000, 1000000) val context = Vp9AdaptiveSourceProjectionContext( diagnosticContext, - payloadType, initialState, logger ) @@ -675,22 +580,9 @@ class Vp9AdaptiveSourceProjectionTest { for (i in 0..9999) { val packetInfo1 = generator1.nextPacket() val packet1 = packetInfo1.packetAs() - Assert.assertTrue( - context.accept( - packetInfo1, - getIndex(1, packet1.spatialLayerIndex, packet1.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertTrue(context.accept(packetInfo1, 1, targetIndex)) val packetInfo2 = generator2.nextPacket() - val packet2 = packetInfo2.packetAs() - Assert.assertFalse( - context.accept( - packetInfo2, - getIndex(0, packet2.spatialLayerIndex, packet2.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertFalse(context.accept(packetInfo2, 0, targetIndex)) context.rewriteRtp(packetInfo1) Assert.assertEquals(expectedSeq, packet1.sequenceNumber) Assert.assertEquals(expectedTs, packet1.timestamp) @@ -711,7 +603,6 @@ class Vp9AdaptiveSourceProjectionTest { val initialState = RtpState(1, 10000, 1000000) val context = Vp9AdaptiveSourceProjectionContext( diagnosticContext, - payloadType, initialState, logger ) @@ -729,35 +620,14 @@ class Vp9AdaptiveSourceProjectionTest { if (packet1.isStartOfFrame && packet1.temporalLayerIndex == 0) { expectedTl0PicIdx = applyTl0PicIdxDelta(expectedTl0PicIdx, 1) } - Assert.assertTrue( - context.accept( - packetInfo1, - getIndex( - 0, - packet1.spatialLayerIndex, - packet1.temporalLayerIndex - ), - targetIndex - ) - ) + Assert.assertTrue(context.accept(packetInfo1, 0, targetIndex)) context.rewriteRtp(packetInfo1) Assert.assertTrue(context.rewriteRtcp(srPacket1)) Assert.assertEquals(packet1.ssrc, srPacket1.senderSsrc) Assert.assertEquals(packet1.timestamp, srPacket1.senderInfo.rtpTimestamp) val srPacket2 = generator2.srPacket val packetInfo2 = generator2.nextPacket() - val packet2 = packetInfo2.packetAs() - Assert.assertFalse( - context.accept( - packetInfo2, - getIndex( - 1, - packet2.spatialLayerIndex, - packet2.temporalLayerIndex - ), - targetIndex - ) - ) + Assert.assertFalse(context.accept(packetInfo2, 1, targetIndex)) Assert.assertFalse(context.rewriteRtcp(srPacket2)) Assert.assertEquals(expectedSeq, packet1.sequenceNumber) Assert.assertEquals(expectedTs, packet1.timestamp) @@ -779,27 +649,14 @@ class Vp9AdaptiveSourceProjectionTest { if (packet1.isStartOfFrame && packet1.temporalLayerIndex == 0) { expectedTl0PicIdx = applyTl0PicIdxDelta(expectedTl0PicIdx, 1) } - Assert.assertTrue( - context.accept( - packetInfo1, - getIndex(0, packet1.spatialLayerIndex, packet1.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertTrue(context.accept(packetInfo1, 0, targetIndex)) context.rewriteRtp(packetInfo1) Assert.assertTrue(context.rewriteRtcp(srPacket1)) Assert.assertEquals(packet1.ssrc, srPacket1.senderSsrc) Assert.assertEquals(packet1.timestamp, srPacket1.senderInfo.rtpTimestamp) val srPacket2 = generator2.srPacket val packetInfo2 = generator2.nextPacket() - val packet2 = packetInfo2.packetAs() - Assert.assertFalse( - context.accept( - packetInfo2, - getIndex(1, packet2.spatialLayerIndex, packet2.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertFalse(context.accept(packetInfo2, 1, targetIndex)) Assert.assertFalse(context.rewriteRtcp(srPacket2)) Assert.assertEquals(expectedSeq, packet1.sequenceNumber) Assert.assertEquals(expectedTs, packet1.timestamp) @@ -824,14 +681,7 @@ class Vp9AdaptiveSourceProjectionTest { } /* We will cut off the layer 0 keyframe after 1 packet, once we see the layer 1 keyframe. */ - Assert.assertEquals( - i == 0, - context.accept( - packetInfo1, - getIndex(0, packet1.spatialLayerIndex, packet1.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertEquals(i == 0, context.accept(packetInfo1, 0, targetIndex)) Assert.assertEquals(i == 0, context.rewriteRtcp(srPacket1)) if (i == 0) { context.rewriteRtp(packetInfo1) @@ -844,13 +694,7 @@ class Vp9AdaptiveSourceProjectionTest { if (packet2.isStartOfFrame && packet2.temporalLayerIndex == 0) { expectedTl0PicIdx = applyTl0PicIdxDelta(expectedTl0PicIdx, 1) } - Assert.assertTrue( - context.accept( - packetInfo2, - getIndex(1, packet2.spatialLayerIndex, packet2.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertTrue(context.accept(packetInfo2, 1, targetIndex)) context.rewriteRtp(packetInfo2) Assert.assertTrue(context.rewriteRtcp(srPacket2)) Assert.assertEquals(packet2.ssrc, srPacket2.senderSsrc) @@ -883,7 +727,6 @@ class Vp9AdaptiveSourceProjectionTest { val initialState = RtpState(1, 10000, 1000000) val context = Vp9AdaptiveSourceProjectionContext( diagnosticContext, - payloadType, initialState, logger ) @@ -897,11 +740,7 @@ class Vp9AdaptiveSourceProjectionTest { for (i in 0..9999) { val packetInfo = generator.nextPacket() val packet = packetInfo.packetAs() - val accepted = context.accept( - packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), - targetIndex - ) + val accepted = context.accept(packetInfo, 0, targetIndex) if (packet.isStartOfFrame && packet.temporalLayerIndex == 0) { expectedTl0PicIdx = applyTl0PicIdxDelta(expectedTl0PicIdx, 1) } @@ -942,7 +781,6 @@ class Vp9AdaptiveSourceProjectionTest { val initialState = RtpState(1, 10000, 1000000) val context = Vp9AdaptiveSourceProjectionContext( diagnosticContext, - payloadType, initialState, logger ) @@ -957,7 +795,7 @@ class Vp9AdaptiveSourceProjectionTest { val packet = packetInfo.packetAs() val accepted = context.accept( packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), + 0, targetIndex ) if (packet.isStartOfFrame && packet.temporalLayerIndex == 0) { @@ -995,13 +833,7 @@ class Vp9AdaptiveSourceProjectionTest { packetInfo = generator.nextPacket() packet = packetInfo.packetAs() } while (packet.temporalLayerIndex > targetIndex) - Assert.assertTrue( - context.accept( - packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), - targetIndex - ) - ) + Assert.assertTrue(context.accept(packetInfo, 0, targetIndex)) context.rewriteRtp(packetInfo) /* Allow any values after a gap. */ @@ -1018,7 +850,7 @@ class Vp9AdaptiveSourceProjectionTest { packet = packetInfo.packetAs() val accepted = context.accept( packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), + 0, targetIndex ) if (packet.isStartOfFrame && packet.temporalLayerIndex == 0) { @@ -1074,7 +906,6 @@ class Vp9AdaptiveSourceProjectionTest { val initialState = RtpState(1, 10000, 1000000) val context = Vp9AdaptiveSourceProjectionContext( diagnosticContext, - payloadType, initialState, logger ) @@ -1096,7 +927,7 @@ class Vp9AdaptiveSourceProjectionTest { packet = packetInfo.packetAs() val accepted = context.accept( packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), + 0, targetIndex ) if (packet.isStartOfFrame && packet.temporalLayerIndex == 0) { @@ -1136,7 +967,7 @@ class Vp9AdaptiveSourceProjectionTest { packet = packetInfo.packetAs() val accepted = context.accept( packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), + 0, targetIndex ) Assert.assertTrue(accepted) @@ -1154,7 +985,7 @@ class Vp9AdaptiveSourceProjectionTest { packet = packetInfo.packetAs() val accepted = context.accept( packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), + 0, RtpLayerDesc.SUSPENDED_INDEX ) Assert.assertFalse(accepted) @@ -1170,7 +1001,7 @@ class Vp9AdaptiveSourceProjectionTest { packet = packetInfo.packetAs() val accepted = context.accept( packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), + 0, targetIndex ) Assert.assertFalse(accepted) @@ -1187,7 +1018,7 @@ class Vp9AdaptiveSourceProjectionTest { packet = packetInfo.packetAs() val accepted = context.accept( packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), + 0, targetIndex ) Assert.assertFalse(accepted) @@ -1202,7 +1033,7 @@ class Vp9AdaptiveSourceProjectionTest { packet = packetInfo.packetAs() val accepted = context.accept( packetInfo, - getIndex(0, packet.spatialLayerIndex, packet.temporalLayerIndex), + 0, targetIndex ) if (packet.isStartOfFrame && packet.temporalLayerIndex == 0) { @@ -1561,7 +1392,7 @@ class Vp9AdaptiveSourceProjectionTest { val srPacketBuilder = RtcpSrPacketBuilder() srPacketBuilder.rtcpHeader.senderSsrc = ssrc val siBuilder = srPacketBuilder.senderInfo - setSIBuilderNtp(srPacketBuilder.senderInfo, receivedTime.toEpochMilli()) + siBuilder.setNtpFromJavaTime(receivedTime.toEpochMilli()) siBuilder.rtpTimestamp = ts siBuilder.sendersOctetCount = packetCount.toLong() siBuilder.sendersOctetCount = octetCount.toLong() @@ -1594,16 +1425,6 @@ class Vp9AdaptiveSourceProjectionTest { // Dummy payload data "000000" ) - - /* TODO: move this to jitsi-rtp */ - const val JAVA_TO_NTP_EPOCH_OFFSET_SECS = 2208988800L - - fun setSIBuilderNtp(siBuilder: SenderInfoBuilder, wallTime: Long) { - val wallSecs = wallTime / 1000 - val wallMs = wallTime % 1000 - siBuilder.ntpTimestampMsw = wallSecs + JAVA_TO_NTP_EPOCH_OFFSET_SECS - siBuilder.ntpTimestampLsw = wallMs * (1L shl 32) / 1000 - } } } } diff --git a/jvb/src/test/kotlin/org/jitsi/videobridge/cc/vp9/Vp9QualityFilterTest.kt b/jvb/src/test/kotlin/org/jitsi/videobridge/cc/vp9/Vp9QualityFilterTest.kt index d32cb07ed5..f2cb5a7ae7 100644 --- a/jvb/src/test/kotlin/org/jitsi/videobridge/cc/vp9/Vp9QualityFilterTest.kt +++ b/jvb/src/test/kotlin/org/jitsi/videobridge/cc/vp9/Vp9QualityFilterTest.kt @@ -423,12 +423,10 @@ internal class Vp9QualityFilterTest : ShouldSpec() { } lastTs = f.timestamp - val packetIndex = RtpLayerDesc.getIndex(f.ssrc.toInt(), f.spatialLayer, f.temporalLayer) - val result = filter.acceptFrame( frame = f, + incomingEncoding = f.ssrc.toInt(), externalTargetIndex = targetIndex, - incomingIndex = packetIndex, receivedTime = Instant.ofEpochMilli(ms) ) evaluator(f, result) @@ -445,7 +443,7 @@ private abstract class FrameGenerator : Iterator /** Generate a non-scalable series of VP9 frames, with a single keyframe at the start. */ private class SingleLayerFrameGenerator : FrameGenerator() { - private val totalPictures = 1000 + private val totalPictures = 10000 private var pictureCount = 0 override fun hasNext(): Boolean = pictureCount < totalPictures @@ -478,7 +476,7 @@ private class SingleLayerFrameGenerator : FrameGenerator() { /** Generate a temporally-scaled series of VP9 frames, with a single keyframe at the start. */ private class TemporallyScaledFrameGenerator : FrameGenerator() { - private val totalPictures = 1000 + private val totalPictures = 10000 private var pictureCount = 0 private var tl0Count = -1 @@ -521,7 +519,7 @@ private class TemporallyScaledFrameGenerator : FrameGenerator() { /** Generate a spatially-scaled series of VP9 frames, with full spatial dependencies and periodic keyframes. */ private class SVCFrameGenerator : FrameGenerator() { - private val totalPictures = 1000 + private val totalPictures = 10000 private var pictureCount = 0 private var frameCount = 0 private var tl0Count = -1 @@ -574,7 +572,7 @@ private class SVCFrameGenerator : FrameGenerator() { /** Generate a spatially-scaled series of VP9 frames, with K-SVC spatial dependencies and periodic keyframes. */ private class KSVCFrameGenerator : FrameGenerator() { - private val totalPictures = 1000 + private val totalPictures = 10000 private var pictureCount = 0 private var frameCount = 0 private var tl0Count = -1 diff --git a/rtp/pom.xml b/rtp/pom.xml index f03c567fca..001ead4e8c 100644 --- a/rtp/pom.xml +++ b/rtp/pom.xml @@ -55,6 +55,12 @@ 3.0.3 test + + jakarta.xml.bind + jakarta.xml.bind-api + 4.0.0 + test + diff --git a/rtp/spotbugs-exclude.xml b/rtp/spotbugs-exclude.xml index f7cf298ceb..7327eba7c3 100644 --- a/rtp/spotbugs-exclude.xml +++ b/rtp/spotbugs-exclude.xml @@ -18,6 +18,8 @@ + + diff --git a/rtp/src/main/kotlin/org/jitsi/rtp/rtcp/RtcpSrPacket.kt b/rtp/src/main/kotlin/org/jitsi/rtp/rtcp/RtcpSrPacket.kt index 750f92e590..b821f0250c 100644 --- a/rtp/src/main/kotlin/org/jitsi/rtp/rtcp/RtcpSrPacket.kt +++ b/rtp/src/main/kotlin/org/jitsi/rtp/rtcp/RtcpSrPacket.kt @@ -77,6 +77,12 @@ data class SenderInfoBuilder( var sendersPacketCount: Long = -1, var sendersOctetCount: Long = -1 ) { + fun setNtpFromJavaTime(javaTime: Long) { + val wallSecs = javaTime / 1000 + val wallMs = javaTime % 1000 + ntpTimestampMsw = wallSecs + JAVA_TO_NTP_EPOCH_OFFSET_SECS + ntpTimestampLsw = wallMs * (1L shl 32) / 1000 + } fun writeTo(buf: ByteArray, offset: Int) { SenderInfoParser.setNtpTimestampMsw(buf, offset, ntpTimestampMsw) @@ -85,6 +91,10 @@ data class SenderInfoBuilder( SenderInfoParser.setSendersPacketCount(buf, offset, sendersPacketCount) SenderInfoParser.setSendersOctetCount(buf, offset, sendersOctetCount) } + + companion object { + const val JAVA_TO_NTP_EPOCH_OFFSET_SECS = 2208988800L + } } /** diff --git a/rtp/src/main/kotlin/org/jitsi/rtp/rtp/RtpPacket.kt b/rtp/src/main/kotlin/org/jitsi/rtp/rtp/RtpPacket.kt index 6741a65f8b..f0246eca04 100644 --- a/rtp/src/main/kotlin/org/jitsi/rtp/rtp/RtpPacket.kt +++ b/rtp/src/main/kotlin/org/jitsi/rtp/rtp/RtpPacket.kt @@ -432,8 +432,24 @@ open class RtpPacket( val dataLengthBytes: Int val totalLengthBytes: Int + + fun clone(): HeaderExtension = StandaloneHeaderExtension(this) + } + + @SuppressFBWarnings("CN_IMPLEMENTS_CLONE_BUT_NOT_CLONEABLE") + class StandaloneHeaderExtension(ext: HeaderExtension) : HeaderExtension { + override val buffer: ByteArray = ByteArray(ext.dataLengthBytes).also { + System.arraycopy(ext.buffer, ext.dataOffset, it, 0, ext.dataLengthBytes) + } + override val dataOffset = 0 + override var id = ext.id + override val dataLengthBytes: Int + get() = buffer.size + override val totalLengthBytes: Int + get() = buffer.size } + @SuppressFBWarnings("CN_IMPLEMENTS_CLONE_BUT_NOT_CLONEABLE") inner class EncodedHeaderExtension : HeaderExtension { private var currExtOffset: Int = 0 private var currExtLength: Int = 0 @@ -471,7 +487,7 @@ open class RtpPacket( } @SuppressFBWarnings( - value = ["EI_EXPOSE_REP"], + value = ["EI_EXPOSE_REP", "CN_IMPLEMENTS_CLONE_BUT_NOT_CLONEABLE"], justification = "We intentionally expose the internal buffer." ) inner class PendingHeaderExtension(override var id: Int, override val dataLengthBytes: Int) : HeaderExtension { diff --git a/rtp/src/main/kotlin/org/jitsi/rtp/rtp/header_extensions/Av1DependencyDescriptorHeaderExtension.kt b/rtp/src/main/kotlin/org/jitsi/rtp/rtp/header_extensions/Av1DependencyDescriptorHeaderExtension.kt new file mode 100644 index 0000000000..679c6e3e83 --- /dev/null +++ b/rtp/src/main/kotlin/org/jitsi/rtp/rtp/header_extensions/Av1DependencyDescriptorHeaderExtension.kt @@ -0,0 +1,902 @@ +/* + * Copyright @ 2018 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.rtp.rtp.header_extensions + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings +import org.jitsi.rtp.rtp.RtpPacket +import org.jitsi.rtp.util.BitReader +import org.jitsi.rtp.util.BitWriter +import org.jitsi.utils.OrderedJsonObject +import org.json.simple.JSONAware + +/** + * The subset of the fields of an AV1 Dependency Descriptor that can be parsed statelessly. + */ +@SuppressFBWarnings("CN_IMPLEMENTS_CLONE_BUT_NOT_CLONEABLE") +open class Av1DependencyDescriptorStatelessSubset( + val startOfFrame: Boolean, + val endOfFrame: Boolean, + var frameDependencyTemplateId: Int, + var frameNumber: Int, + + val newTemplateDependencyStructure: Av1TemplateDependencyStructure?, +) { + open fun clone(): Av1DependencyDescriptorStatelessSubset { + return Av1DependencyDescriptorStatelessSubset( + startOfFrame = startOfFrame, + endOfFrame = endOfFrame, + frameDependencyTemplateId = frameDependencyTemplateId, + frameNumber = frameNumber, + newTemplateDependencyStructure = newTemplateDependencyStructure?.clone() + ) + } +} + +/** + * The AV1 Dependency Descriptor header extension, as defined in https://aomediacodec.github.io/av1-rtp-spec/#appendix + */ +@SuppressFBWarnings("CN_IMPLEMENTS_CLONE_BUT_NOT_CLONEABLE") +class Av1DependencyDescriptorHeaderExtension( + startOfFrame: Boolean, + endOfFrame: Boolean, + frameDependencyTemplateId: Int, + frameNumber: Int, + + newTemplateDependencyStructure: Av1TemplateDependencyStructure?, + + var activeDecodeTargetsBitmask: Int?, + + val customDtis: List?, + val customFdiffs: List?, + val customChains: List?, + + val structure: Av1TemplateDependencyStructure +) : Av1DependencyDescriptorStatelessSubset( + startOfFrame, + endOfFrame, + frameDependencyTemplateId, + frameNumber, + newTemplateDependencyStructure +), + JSONAware { + val frameInfo: FrameInfo by lazy { + val templateIndex = (frameDependencyTemplateId + 64 - structure.templateIdOffset) % 64 + if (templateIndex >= structure.templateCount) { + val maxTemplate = (structure.templateIdOffset + structure.templateCount - 1) % 64 + throw Av1DependencyException( + "Invalid template ID $frameDependencyTemplateId. " + + "Should be in range ${structure.templateIdOffset} .. $maxTemplate. " + + "Missed a keyframe?" + ) + } + val templateVal = structure.templateInfo[templateIndex] + + FrameInfo( + spatialId = templateVal.spatialId, + temporalId = templateVal.temporalId, + dti = customDtis ?: templateVal.dti, + fdiff = customFdiffs ?: templateVal.fdiff, + chains = customChains ?: templateVal.chains + ) + } + + val encodedLength: Int + get() = (unpaddedLengthBits + 7) / 8 + + private val unpaddedLengthBits: Int + get() { + var length = 24 + if (newTemplateDependencyStructure != null || + activeDecodeTargetsBitmask != null || + customDtis != null || + customFdiffs != null || + customChains != null + ) { + length += 5 + } + if (newTemplateDependencyStructure != null) { + length += newTemplateDependencyStructure.unpaddedLengthBits + } + if (activeDecodeTargetsBitmask != null && + ( + newTemplateDependencyStructure == null || + activeDecodeTargetsBitmask != ((1 shl newTemplateDependencyStructure.decodeTargetCount) - 1) + ) + ) { + length += structure.decodeTargetCount + } + if (customDtis != null) { + length += 2 * structure.decodeTargetCount + } + if (customFdiffs != null) { + customFdiffs.forEach { + length += 2 + it.bitsForFdiff() + } + length += 2 + } + if (customChains != null) { + length += 8 * customChains.size + } + + return length + } + + override fun clone(): Av1DependencyDescriptorHeaderExtension { + val structureCopy = structure.clone() + val newStructure = if (newTemplateDependencyStructure == null) null else structureCopy + return Av1DependencyDescriptorHeaderExtension( + startOfFrame, + endOfFrame, + frameDependencyTemplateId, + frameNumber, + newStructure, + activeDecodeTargetsBitmask, + // These values are not mutable so it's safe to copy them by reference + customDtis, + customFdiffs, + customChains, + structureCopy + ) + } + + fun write(ext: RtpPacket.HeaderExtension) = write(ext.buffer, ext.dataOffset, ext.dataLengthBytes) + + fun write(buffer: ByteArray, offset: Int, length: Int) { + check(length <= encodedLength) { + "Cannot write AV1 DD to buffer: buffer length $length must be at least $encodedLength" + } + val writer = BitWriter(buffer, offset, length) + + writeMandatoryDescriptorFields(writer) + + if (newTemplateDependencyStructure != null || + activeDecodeTargetsBitmask != null || + customDtis != null || + customFdiffs != null || + customChains != null + ) { + writeOptionalDescriptorFields(writer) + writePadding(writer) + } else { + check(length == 3) { + "AV1 DD without optional descriptors must be 3 bytes in length" + } + } + } + + private fun writeMandatoryDescriptorFields(writer: BitWriter) { + writer.writeBit(startOfFrame) + writer.writeBit(endOfFrame) + writer.writeBits(6, frameDependencyTemplateId) + writer.writeBits(16, frameNumber) + } + + private fun writeOptionalDescriptorFields(writer: BitWriter) { + val templateDependencyStructurePresent = newTemplateDependencyStructure != null + val activeDecodeTargetsPresent = activeDecodeTargetsBitmask != null && + ( + newTemplateDependencyStructure == null || + activeDecodeTargetsBitmask != ((1 shl newTemplateDependencyStructure.decodeTargetCount) - 1) + ) + + val customDtisFlag = customDtis != null + val customFdiffsFlag = customFdiffs != null + val customChainsFlag = customChains != null + + writer.writeBit(templateDependencyStructurePresent) + writer.writeBit(activeDecodeTargetsPresent) + writer.writeBit(customDtisFlag) + writer.writeBit(customFdiffsFlag) + writer.writeBit(customChainsFlag) + + if (templateDependencyStructurePresent) { + newTemplateDependencyStructure!!.write(writer) + } + + if (activeDecodeTargetsPresent) { + writeActiveDecodeTargets(writer) + } + + if (customDtisFlag) { + writeFrameDtis(writer) + } + + if (customFdiffsFlag) { + writeFrameFdiffs(writer) + } + + if (customChainsFlag) { + writeFrameChains(writer) + } + } + + private fun writeActiveDecodeTargets(writer: BitWriter) { + writer.writeBits(structure.decodeTargetCount, activeDecodeTargetsBitmask!!) + } + + private fun writeFrameDtis(writer: BitWriter) { + customDtis!!.forEach { dti -> + writer.writeBits(2, dti.dti) + } + } + + private fun writeFrameFdiffs(writer: BitWriter) { + customFdiffs!!.forEach { fdiff -> + val bits = fdiff.bitsForFdiff() + writer.writeBits(2, bits / 4) + writer.writeBits(bits, fdiff - 1) + } + writer.writeBits(2, 0) + } + + private fun writeFrameChains(writer: BitWriter) { + customChains!!.forEach { chain -> + writer.writeBits(8, chain) + } + } + + private fun writePadding(writer: BitWriter) { + writer.writeBits(writer.remainingBits, 0) + } + + override fun toJSONString(): String { + return OrderedJsonObject().apply { + put("startOfFrame", startOfFrame) + put("endOfFrame", endOfFrame) + put("frameDependencyTemplateId", frameDependencyTemplateId) + put("frameNumber", frameNumber) + newTemplateDependencyStructure?.let { put("templateStructure", it) } + customDtis?.let { put("customDTIs", it) } + customFdiffs?.let { put("customFdiffs", it) } + customChains?.let { put("customChains", it) } + }.toJSONString() + } + + override fun toString(): String = toJSONString() +} + +fun Int.bitsForFdiff() = when { + this <= 0x10 -> 4 + this <= 0x100 -> 8 + this <= 0x1000 -> 12 + else -> throw IllegalArgumentException("Invalid FDiff value $this") +} + +/** + * The template information about a stream described by AV1 dependency descriptors. This is carried in the + * first packet of a codec video sequence (i.e. the first packet of a keyframe), and is necessary to interpret + * dependency descriptors carried in subsequent packets of the sequence. + */ +@SuppressFBWarnings("CN_IMPLEMENTS_CLONE_BUT_NOT_CLONEABLE") +class Av1TemplateDependencyStructure( + var templateIdOffset: Int, + val templateInfo: List, + val decodeTargetInfo: List, + val maxRenderResolutions: List, + val maxSpatialId: Int, + val maxTemporalId: Int +) : JSONAware { + val templateCount + get() = templateInfo.size + + val decodeTargetCount + get() = decodeTargetInfo.size + + val chainCount: Int = + templateInfo.first().chains.size + + init { + check(templateInfo.all { it.chains.size == chainCount }) { + "Templates have inconsistent chain sizes" + } + check(templateInfo.all { it.temporalId <= maxTemporalId }) { + "Incorrect maxTemporalId" + } + check(maxRenderResolutions.isEmpty() || maxRenderResolutions.size == maxSpatialId + 1) { + "Non-zero number of render resolutions does not match maxSpatialId" + } + check(templateInfo.all { it.spatialId <= maxSpatialId }) { + "Incorrect maxSpatialId" + } + } + + val unpaddedLengthBits: Int + get() { + var length = 6 // Template ID offset + + length += 5 // DT Count - 1 + + length += templateCount * 2 // templateLayers + length += templateCount * decodeTargetCount * 2 // TemplateDTIs + templateInfo.forEach { + length += it.fdiffCnt * 5 + 1 // TemplateFDiffs + } + // TemplateChains + length += nsBits(decodeTargetCount + 1, chainCount) + if (chainCount > 0) { + decodeTargetInfo.forEach { + length += nsBits(chainCount, it.protectedBy) + } + length += templateCount * chainCount * 4 + } + length += 1 // ResolutionsPresent + length += maxRenderResolutions.size * 32 // RenderResolutions + + return length + } + + fun clone(): Av1TemplateDependencyStructure { + return Av1TemplateDependencyStructure( + templateIdOffset, + // These objects are not mutable so it's safe to copy them by reference + templateInfo, + decodeTargetInfo, + maxRenderResolutions, + maxSpatialId, + maxTemporalId + ) + } + + fun write(writer: BitWriter) { + writer.writeBits(6, templateIdOffset) + + writer.writeBits(5, decodeTargetCount - 1) + + writeTemplateLayers(writer) + writeTemplateDtis(writer) + writeTemplateFdiffs(writer) + writeTemplateChains(writer) + + writeRenderResolutions(writer) + } + + private fun writeTemplateLayers(writer: BitWriter) { + check(templateInfo[0].spatialId == 0 && templateInfo[0].temporalId == 0) { + "First template must have spatial and temporal IDs 0/0, but found " + + "${templateInfo[0].spatialId}/${templateInfo[0].temporalId}" + } + for (templateNum in 1 until templateInfo.size) { + val layerIdc = when { + templateInfo[templateNum].spatialId == templateInfo[templateNum - 1].spatialId && + templateInfo[templateNum].temporalId == templateInfo[templateNum - 1].temporalId -> + 0 + templateInfo[templateNum].spatialId == templateInfo[templateNum - 1].spatialId && + templateInfo[templateNum].temporalId == templateInfo[templateNum - 1].temporalId + 1 -> + 1 + templateInfo[templateNum].spatialId == templateInfo[templateNum - 1].spatialId + 1 && + templateInfo[templateNum].temporalId == 0 -> + 2 + else -> + throw IllegalStateException( + "Template $templateNum with spatial and temporal IDs " + + "${templateInfo[templateNum].spatialId}/${templateInfo[templateNum].temporalId} " + + "cannot follow template ${templateNum - 1} with spatial and temporal IDs " + + "${templateInfo[templateNum - 1].spatialId}/${templateInfo[templateNum - 1].temporalId}." + ) + } + writer.writeBits(2, layerIdc) + } + writer.writeBits(2, 3) + } + + private fun writeTemplateDtis(writer: BitWriter) { + templateInfo.forEach { t -> + t.dti.forEach { dti -> + writer.writeBits(2, dti.dti) + } + } + } + + private fun writeTemplateFdiffs(writer: BitWriter) { + templateInfo.forEach { t -> + t.fdiff.forEach { fdiff -> + writer.writeBit(true) + writer.writeBits(4, fdiff - 1) + } + writer.writeBit(false) + } + } + + private fun writeTemplateChains(writer: BitWriter) { + writer.writeNs(decodeTargetCount + 1, chainCount) + decodeTargetInfo.forEach { + writer.writeNs(chainCount, it.protectedBy) + } + templateInfo.forEach { t -> + t.chains.forEach { chain -> + writer.writeBits(4, chain) + } + } + } + + private fun writeRenderResolutions(writer: BitWriter) { + if (maxRenderResolutions.isEmpty()) { + writer.writeBit(false) + } else { + writer.writeBit(true) + maxRenderResolutions.forEach { r -> + writer.writeBits(16, r.width - 1) + writer.writeBits(16, r.height - 1) + } + } + } + + /** Return whether, in this structure, it's possible to switch from DT [fromDt] to DT [toDt] + * without a keyframe. + * Note this makes certain assumptions about the encoding structure. + */ + fun canSwitchWithoutKeyframe(fromDt: Int, toDt: Int): Boolean = templateInfo.any { + it.hasInterPictureDependency() && it.dti[fromDt] != DTI.NOT_PRESENT && it.dti[toDt] == DTI.SWITCH + } + + /** Given that we are sending packets for a given DT, return a decodeTargetBitmask corresponding to all DTs + * contained in that DT. + */ + fun getDtBitmaskForDt(dt: Int): Int { + var mask = (1 shl decodeTargetCount) - 1 + templateInfo.forEach { frameInfo -> + frameInfo.dti.forEachIndexed { i, dti -> + if (frameInfo.dti[dt] == DTI.NOT_PRESENT && dti != DTI.NOT_PRESENT) { + mask = mask and (1 shl i).inv() + } + } + } + return mask + } + + override fun toJSONString(): String { + return OrderedJsonObject().apply { + put("templateIdOffset", templateIdOffset) + put("templateInfo", templateInfo.toIndexedMap()) + put("decodeTargetInfo", decodeTargetInfo.toIndexedMap()) + if (maxRenderResolutions.isNotEmpty()) { + put("maxRenderResolutions", maxRenderResolutions.toIndexedMap()) + } + put("maxSpatialId", maxSpatialId) + put("maxTemporalId", maxTemporalId) + }.toJSONString() + } + + override fun toString() = toJSONString() +} + +fun nsBits(n: Int, v: Int): Int { + require(n > 0) + if (n == 1) return 0 + var w = 0 + var x = n + while (x != 0) { + x = x shr 1 + w++ + } + val m = (1 shl w) - n + if (v < m) return w - 1 + return w +} + +class Av1DependencyDescriptorReader( + buffer: ByteArray, + offset: Int, + val length: Int, +) { + private var startOfFrame = false + private var endOfFrame = false + private var frameDependencyTemplateId = 0 + private var frameNumber = 0 + + private var customDtis: List? = null + private var customFdiffs: List? = null + private var customChains: List? = null + + private var localTemplateDependencyStructure: Av1TemplateDependencyStructure? = null + private var templateDependencyStructure: Av1TemplateDependencyStructure? = null + + private var activeDecodeTargetsBitmask: Int? = null + + private val reader = BitReader(buffer, offset, length) + + constructor(ext: RtpPacket.HeaderExtension) : + this(ext.buffer, ext.dataOffset, ext.dataLengthBytes) + + /** Parse those parts of the dependency descriptor that can be parsed statelessly, i.e. without an external + * template dependency structure. The returned object will not be a complete representation of the + * dependency descriptor, because some fields need the external structure to be parseable. + */ + fun parseStateless(): Av1DependencyDescriptorStatelessSubset { + reset() + readMandatoryDescriptorFields() + + if (length > 3) { + val templateDependencyStructurePresent = reader.bitAsBoolean() + + /* activeDecodeTargetsPresent, customDtisFlag, customFdiffsFlag, and customChainsFlag; + * none of these fields are parseable statelessly. + */ + reader.skipBits(4) + + if (templateDependencyStructurePresent) { + localTemplateDependencyStructure = readTemplateDependencyStructure() + } + } + return Av1DependencyDescriptorStatelessSubset( + startOfFrame, + endOfFrame, + frameDependencyTemplateId, + frameNumber, + localTemplateDependencyStructure, + ) + } + + /** Parse the dependency descriptor in the context of [dep], the currently-applicable template dependency + * structure.*/ + fun parse(dep: Av1TemplateDependencyStructure?): Av1DependencyDescriptorHeaderExtension { + reset() + readMandatoryDescriptorFields() + if (length > 3) { + readExtendedDescriptorFields(dep) + } else { + if (dep == null) { + throw Av1DependencyException("No external dependency structure specified for non-first packet") + } + templateDependencyStructure = dep + } + return Av1DependencyDescriptorHeaderExtension( + startOfFrame, + endOfFrame, + frameDependencyTemplateId, + frameNumber, + localTemplateDependencyStructure, + activeDecodeTargetsBitmask, + customDtis, + customFdiffs, + customChains, + templateDependencyStructure!! + ) + } + + private fun reset() = reader.reset() + + private fun readMandatoryDescriptorFields() { + startOfFrame = reader.bitAsBoolean() + endOfFrame = reader.bitAsBoolean() + frameDependencyTemplateId = reader.bits(6) + frameNumber = reader.bits(16) + } + + private fun readExtendedDescriptorFields(dep: Av1TemplateDependencyStructure?) { + val templateDependencyStructurePresent = reader.bitAsBoolean() + val activeDecodeTargetsPresent = reader.bitAsBoolean() + val customDtisFlag = reader.bitAsBoolean() + val customFdiffsFlag = reader.bitAsBoolean() + val customChainsFlag = reader.bitAsBoolean() + + if (templateDependencyStructurePresent) { + localTemplateDependencyStructure = readTemplateDependencyStructure() + templateDependencyStructure = localTemplateDependencyStructure + } else { + if (dep == null) { + throw Av1DependencyException("No external dependency structure specified for non-first packet") + } + templateDependencyStructure = dep + } + if (activeDecodeTargetsPresent) { + activeDecodeTargetsBitmask = reader.bits(templateDependencyStructure!!.decodeTargetCount) + } else if (templateDependencyStructurePresent) { + activeDecodeTargetsBitmask = (1 shl templateDependencyStructure!!.decodeTargetCount) - 1 + } + + if (customDtisFlag) { + customDtis = readFrameDtis() + } + if (customFdiffsFlag) { + customFdiffs = readFrameFdiffs() + } + if (customChainsFlag) { + customChains = readFrameChains() + } + } + + /* Data for template dependency structure */ + private var templateIdOffset: Int = 0 + private val templateInfo = mutableListOf() + private val decodeTargetInfo = mutableListOf() + private val maxRenderResolutions = mutableListOf() + + private var dtCnt = 0 + + private fun resetDependencyStructureInfo() { + /* These fields are assembled incrementally when parsing a dependency structure; reset them + * in case we're running a parser more than once. + */ + templateCnt = 0 + templateInfo.clear() + decodeTargetInfo.clear() + maxRenderResolutions.clear() + } + + private fun readTemplateDependencyStructure(): Av1TemplateDependencyStructure { + resetDependencyStructureInfo() + + templateIdOffset = reader.bits(6) + + val dtCntMinusOne = reader.bits(5) + dtCnt = dtCntMinusOne + 1 + + readTemplateLayers() + readTemplateDtis() + readTemplateFdiffs() + readTemplateChains() + readDecodeTargetLayers() + + val resolutionsPresent = reader.bitAsBoolean() + + if (resolutionsPresent) { + readRenderResolutions() + } + + return Av1TemplateDependencyStructure( + templateIdOffset, + templateInfo.toList(), + decodeTargetInfo.toList(), + maxRenderResolutions.toList(), + maxSpatialId, + maxTemporalId + ) + } + + private var templateCnt = 0 + private var maxSpatialId = 0 + private var maxTemporalId = 0 + + @SuppressFBWarnings( + value = ["SF_SWITCH_NO_DEFAULT"], + justification = "Artifact of generated Kotlin code." + ) + private fun readTemplateLayers() { + var temporalId = 0 + var spatialId = 0 + + var nextLayerIdc: Int + do { + templateInfo.add(templateCnt, TemplateFrameInfo(spatialId, temporalId)) + templateCnt++ + nextLayerIdc = reader.bits(2) + if (nextLayerIdc == 1) { + temporalId++ + if (maxTemporalId < temporalId) { + maxTemporalId = temporalId + } + } else if (nextLayerIdc == 2) { + temporalId = 0 + spatialId++ + } + } while (nextLayerIdc != 3) + + check(templateInfo.size == templateCnt) + + maxSpatialId = spatialId + } + + private fun readTemplateDtis() { + for (templateIndex in 0 until templateCnt) { + for (dtIndex in 0 until dtCnt) { + templateInfo[templateIndex].dti.add(dtIndex, DTI.fromInt(reader.bits(2))) + } + } + } + + private fun readFrameDtis(): List { + return List(templateDependencyStructure!!.decodeTargetCount) { + DTI.fromInt(reader.bits(2)) + } + } + + private fun readTemplateFdiffs() { + for (templateIndex in 0 until templateCnt) { + var fdiffCnt = 0 + var fdiffFollowsFlag = reader.bitAsBoolean() + while (fdiffFollowsFlag) { + val fdiffMinusOne = reader.bits(4) + templateInfo[templateIndex].fdiff.add(fdiffCnt, fdiffMinusOne + 1) + fdiffCnt++ + fdiffFollowsFlag = reader.bitAsBoolean() + } + check(fdiffCnt == templateInfo[templateIndex].fdiffCnt) + } + } + + private fun readFrameFdiffs(): List { + return buildList { + var nextFdiffSize = reader.bits(2) + while (nextFdiffSize != 0) { + val fdiffMinus1 = reader.bits(4 * nextFdiffSize) + add(fdiffMinus1 + 1) + nextFdiffSize = reader.bits(2) + } + } + } + + private fun readTemplateChains() { + val chainCount = reader.ns(dtCnt + 1) + if (chainCount != 0) { + for (dtIndex in 0 until dtCnt) { + decodeTargetInfo.add(dtIndex, DecodeTargetInfo(reader.ns(chainCount))) + } + for (templateIndex in 0 until templateCnt) { + for (chainIndex in 0 until chainCount) { + templateInfo[templateIndex].chains.add(chainIndex, reader.bits(4)) + } + check(templateInfo[templateIndex].chains.size == chainCount) + } + } + } + + private fun readFrameChains(): List { + return List(templateDependencyStructure!!.chainCount) { + reader.bits(8) + } + } + + private fun readDecodeTargetLayers() { + for (dtIndex in 0 until dtCnt) { + var spatialId = 0 + var temporalId = 0 + for (templateIndex in 0 until templateCnt) { + if (templateInfo[templateIndex].dti[dtIndex] != DTI.NOT_PRESENT) { + if (templateInfo[templateIndex].spatialId > spatialId) { + spatialId = templateInfo[templateIndex].spatialId + } + if (templateInfo[templateIndex].temporalId > temporalId) { + temporalId = templateInfo[templateIndex].temporalId + } + } + } + decodeTargetInfo[dtIndex].spatialId = spatialId + decodeTargetInfo[dtIndex].temporalId = temporalId + } + check(decodeTargetInfo.size == dtCnt) + } + + private fun readRenderResolutions() { + for (spatialId in 0..maxSpatialId) { + val widthMinus1 = reader.bits(16) + val heightMinus1 = reader.bits(16) + maxRenderResolutions.add(spatialId, Resolution(widthMinus1 + 1, heightMinus1 + 1)) + } + } +} + +open class FrameInfo( + val spatialId: Int, + val temporalId: Int, + open val dti: List, + open val fdiff: List, + open val chains: List +) : JSONAware { + val fdiffCnt + get() = fdiff.size + + override fun equals(other: Any?): Boolean { + if (other !is FrameInfo) { + return false + } + return other.spatialId == spatialId && + other.temporalId == temporalId && + other.dti == dti && + other.fdiff == fdiff && + other.chains == chains + } + + override fun hashCode(): Int { + var result = spatialId + result = 31 * result + temporalId + result = 31 * result + dti.hashCode() + result = 31 * result + fdiff.hashCode() + result = 31 * result + chains.hashCode() + return result + } + + override fun toString(): String { + return "spatialId=$spatialId, temporalId=$temporalId, dti=$dti, fdiff=$fdiff, chains=$chains" + } + + override fun toJSONString(): String { + return OrderedJsonObject().apply { + put("spatialId", spatialId) + put("temporalId", temporalId) + put("dti", dti) + put("fdiff", fdiff) + put("chains", chains) + }.toJSONString() + } + + /** Whether the frame has a dependency on a frame earlier than this "picture", the other frames of this + * temporal moment. If it doesn't, it's probably part of a keyframe, and not part of the regular structure. + * Note this makes assumptions about the scalability structure. + */ + fun hasInterPictureDependency(): Boolean = fdiff.any { it > spatialId } + + val dtisPresent: List + get() = dti.withIndex().filter { (_, dti) -> dti != DTI.NOT_PRESENT }.map { (i, _) -> i } +} + +/* The only thing this changes from its parent class is to make the lists mutable, so the parent equals() is fine. */ +@SuppressFBWarnings("EQ_DOESNT_OVERRIDE_EQUALS") +class TemplateFrameInfo( + spatialId: Int, + temporalId: Int, + override val dti: MutableList = mutableListOf(), + override val fdiff: MutableList = mutableListOf(), + override val chains: MutableList = mutableListOf() +) : FrameInfo(spatialId, temporalId, dti, fdiff, chains) + +class DecodeTargetInfo( + val protectedBy: Int +) : JSONAware { + /** Todo: only want to be able to set these from the constructor */ + var spatialId: Int = -1 + var temporalId: Int = -1 + + override fun toJSONString(): String { + return OrderedJsonObject().apply { + put("protectedBy", protectedBy) + put("spatialId", spatialId) + put("temporalId", temporalId) + }.toJSONString() + } +} + +data class Resolution( + val width: Int, + val height: Int +) : JSONAware { + override fun toJSONString(): String { + return OrderedJsonObject().apply { + put("width", width) + put("height", height) + }.toJSONString() + } +} + +/** Decode target indication */ +enum class DTI(val dti: Int) { + NOT_PRESENT(0), + DISCARDABLE(1), + SWITCH(2), + REQUIRED(3); + + companion object { + private val map = DTI.values().associateBy(DTI::dti) + fun fromInt(type: Int) = map[type] ?: throw java.lang.IllegalArgumentException("Bad DTI $type") + } + + fun toShortString(): String { + return when (this) { + NOT_PRESENT -> "N" + DISCARDABLE -> "D" + SWITCH -> "S" + REQUIRED -> "R" + } + } +} + +fun List.toShortString(): String { + return joinToString(separator = "") { it.toShortString() } +} + +class Av1DependencyException(msg: String) : RuntimeException(msg) + +fun List.toIndexedMap(): Map = mapIndexed { i, t -> i to t }.toMap() diff --git a/rtp/src/main/kotlin/org/jitsi/rtp/util/BitReader.kt b/rtp/src/main/kotlin/org/jitsi/rtp/util/BitReader.kt new file mode 100644 index 0000000000..d141b7aed4 --- /dev/null +++ b/rtp/src/main/kotlin/org/jitsi/rtp/util/BitReader.kt @@ -0,0 +1,105 @@ +/* + * Copyright @ 2018 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jitsi.rtp.util + +import kotlin.experimental.and + +/** + * Read individual bits, and unaligned sets of bits, from a [ByteArray], with an incrementing offset. + */ +/* TODO: put this in jitsi-utils? */ +class BitReader(val buf: ByteArray, private val byteOffset: Int = 0, private val byteLength: Int = buf.size) { + private var offset = byteOffset * 8 + private val byteBound = byteOffset + byteLength + + /** Read a single bit from the buffer, as a boolean, incrementing the offset. */ + fun bitAsBoolean(): Boolean { + val byteIdx = offset / 8 + val bitIdx = offset % 8 + check(byteIdx < byteBound) { + "offset $offset ($byteIdx/$bitIdx) invalid in buffer of length $byteLength after offset $byteOffset" + } + val byte = buf[byteIdx] + val mask = (1 shl (7 - bitIdx)).toByte() + offset++ + + return (byte and mask) != 0.toByte() + } + + /** Read a single bit from the buffer, as an integer, incrementing the offset. */ + fun bit() = if (bitAsBoolean()) 1 else 0 + + /** Read [n] bits from the buffer, returning them as an unsigned integer. */ + fun bits(n: Int): Int { + require(n < Int.SIZE_BITS) + + var ret = 0 + + /* TODO: optimize this */ + repeat(n) { + ret = ret shl 1 + ret = ret or bit() + } + + return ret + } + + /** Read [n] bits from the buffer, returning them as an unsigned long. */ + fun bitsLong(n: Int): Long { + require(n < Long.SIZE_BITS) + + var ret = 0L + + /* TODO: optimize this */ + repeat(n) { + ret = ret shl 1 + ret = ret or bit().toLong() + } + + return ret + } + + /** Skip forward [n] bits in the buffer. */ + fun skipBits(n: Int) { + offset += n + } + + /** Read a non-symmetric unsigned integer with max *value* [n] from the buffer. + * (Note: *not* the number of bits.) + * See https://aomediacodec.github.io/av1-rtp-spec/#a82-syntax + */ + fun ns(n: Int): Int { + var w = 0 + var x = n + while (x != 0) { + x = x shr 1 + w++ + } + val m = (1 shl w) - n + val v = bits(w - 1) + if (v < m) { + return v + } + val extraBit = bit() + return (v shl 1) - m + extraBit + } + + /** Reset the reader to the beginning of the buffer */ + fun reset() { + offset = byteOffset * 8 + } +} diff --git a/rtp/src/main/kotlin/org/jitsi/rtp/util/BitWriter.kt b/rtp/src/main/kotlin/org/jitsi/rtp/util/BitWriter.kt new file mode 100644 index 0000000000..9601d0db6f --- /dev/null +++ b/rtp/src/main/kotlin/org/jitsi/rtp/util/BitWriter.kt @@ -0,0 +1,73 @@ +/* + * Copyright @ 2018 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jitsi.rtp.util + +import java.util.* +import kotlin.experimental.or + +/** + * Write individual bits, and unaligned sets of bits, to a [ByteArray], with an incrementing offset. + */ +class BitWriter(val buf: ByteArray, val byteOffset: Int = 0, private val byteLength: Int = buf.size) { + private var offset = byteOffset * 8 + private val byteBound = byteOffset + byteLength + + init { + Arrays.fill(buf, byteOffset, byteBound, 0) + } + + fun writeBit(value: Boolean) { + val byteIdx = offset / 8 + val bitIdx = offset % 8 + check(byteIdx < byteBound) { + "offset $offset ($byteIdx/$bitIdx) invalid in buffer of length $byteLength after offset $byteOffset" + } + + if (value) { + buf[byteIdx] = buf[byteIdx] or (1 shl (7 - bitIdx)).toByte() + } + offset++ + } + + fun writeBits(bits: Int, value: Int) { + check(value < (1 shl bits)) { + "value $value cannot be represented in $bits bits" + } + repeat(bits) { i -> + writeBit((value and (1 shl (bits - i - 1))) != 0) + } + } + + fun writeNs(n: Int, v: Int) { + if (n == 1) return + var w = 0 + var x = n + while (x != 0) { + x = x shr 1 + w++ + } + val m = (1 shl w) - n + if (v < m) { + writeBits(w - 1, v) + } else { + writeBits(w, v + m) + } + } + + val remainingBits + get() = byteBound * 8 - offset +} diff --git a/rtp/src/test/kotlin/org/jitsi/rtp/rtp/header_extensions/Av1DependencyDescriptorHeaderExtensionTest.kt b/rtp/src/test/kotlin/org/jitsi/rtp/rtp/header_extensions/Av1DependencyDescriptorHeaderExtensionTest.kt new file mode 100644 index 0000000000..4a69bb9c93 --- /dev/null +++ b/rtp/src/test/kotlin/org/jitsi/rtp/rtp/header_extensions/Av1DependencyDescriptorHeaderExtensionTest.kt @@ -0,0 +1,549 @@ +/* + * Copyright @ 2018 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.rtp.rtp.header_extensions + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.assertions.withClue +import io.kotest.core.spec.IsolationMode +import io.kotest.core.spec.style.ShouldSpec +import io.kotest.matchers.shouldBe +import io.kotest.matchers.shouldNotBe +import jakarta.xml.bind.DatatypeConverter.parseHexBinary + +@SuppressFBWarnings( + value = ["NP_NULL_ON_SOME_PATH_FROM_RETURN_VALUE"], + justification = "Use of pointer after shouldNotBeNull test." +) +class Av1DependencyDescriptorHeaderExtensionTest : ShouldSpec() { + override fun isolationMode(): IsolationMode = IsolationMode.InstancePerLeaf + + /* Headers generated by Chrome 110 sending AV1 in its default configuration - L1T1 */ + val descL1T1 = parseHexBinary("80000180003a410180ef808680") + + val shortDesc = parseHexBinary("400001") + + val descL1T3 = parseHexBinary("800001800214eaa860414d141020842701df010d") + + /* Header generated by Chrome 112 sending AV1 with L3T3 set. */ + val descL3T3 = parseHexBinary( + "d0013481e81485214eafffaaaa863cf0430c10c302afc0aaa0063c00430010c002a000a800060000" + + "40001d954926e082b04a0941b820ac1282503157f974000ca864330e222222eca8655304224230ec" + + "a87753013f00b3027f016704ff02cf" + ) + + val midDescScalable = parseHexBinary("d10146401c") + + val midDescScalable2 = parseHexBinary("c203ce581d30100000") + val longForMid2 = parseHexBinary( + "8003ca80081485214eaaaaa8000600004000100002aa80a8000600004000100002a000a8000600004" + + "00016d549241b5524906d54923157e001974ca864330e222396eca8655304224390eca87753013f00b3027f016704ff02cf" + ) + + val descS3T3 = parseHexBinary( + "c1000180081485214ea000a8000600004000100002a000a8000600004000100002a000a8000600004" + + "0001d954926caa493655248c55fe5d00032a190cc38e58803b2a1954c10e10843b2a1dd4c01dc010803bc0218077c0434" + ) + + val descL3T3Key = parseHexBinary( + "8f008581e81485214eaaaaa8000600004000100002aa80a8000600004000100002a000a80006000040" + + "0016d549241b5524906d54923157e001974ca864330e222396eca8655304224390eca87753013f00b3027f016704ff02cf" + ) + + /* As of Chrome version 111, it doesn't support L3T3_KEY_SHIFT, but it does support L2T2_KEY_SHIFT, so test that. */ + val descL2T2KeyShift = parseHexBinary( + "8700ed80e3061eaa82804028280514d14134518010a091889a09409fc059c13fc0b3c0" + ) + + init { + context("AV1 Dependency Descriptors") { + context("a descriptor with a single-layer dependency structure") { + val ld1r = Av1DependencyDescriptorReader(descL1T1, 0, descL1T1.size) + val ld1 = ld1r.parse(null) + should("be parsed properly") { + ld1.startOfFrame shouldBe true + ld1.endOfFrame shouldBe false + ld1.frameNumber shouldBe 0x0001 + + val structure = ld1.newTemplateDependencyStructure + structure shouldNotBe null + structure!!.decodeTargetCount shouldBe 1 + structure.maxTemporalId shouldBe 0 + structure.maxSpatialId shouldBe 0 + } + should("be parseable statelessly") { + val ld1s = ld1r.parseStateless() + ld1s.startOfFrame shouldBe true + ld1s.endOfFrame shouldBe false + ld1s.frameNumber shouldBe 0x0001 + + val structure = ld1s.newTemplateDependencyStructure + structure shouldNotBe null + structure!!.decodeTargetCount shouldBe 1 + } + should("calculate correct frame info") { + val ld1i = ld1.frameInfo + ld1i.spatialId shouldBe 0 + ld1i.temporalId shouldBe 0 + } + should("Calculate its own length properly") { + ld1.encodedLength shouldBe descL1T1.size + } + should("Be re-encoded to the same bytes") { + val buf = ByteArray(ld1.encodedLength) + ld1.write(buf, 0, buf.size) + buf shouldBe descL1T1 + } + } + context("a descriptor with a scalable dependency structure") { + val ldsr = Av1DependencyDescriptorReader(descL3T3, 0, descL3T3.size) + val lds = ldsr.parse(null) + should("be parsed properly") { + lds.startOfFrame shouldBe true + lds.endOfFrame shouldBe true + lds.frameNumber shouldBe 0x0134 + lds.activeDecodeTargetsBitmask shouldBe 0x1ff + + val structure = lds.newTemplateDependencyStructure + structure shouldNotBe null + structure!!.decodeTargetCount shouldBe 9 + structure.maxTemporalId shouldBe 2 + structure.maxSpatialId shouldBe 2 + } + should("calculate correct frame info") { + val ldsi = lds.frameInfo + ldsi.spatialId shouldBe 0 + ldsi.temporalId shouldBe 0 + } + should("calculate correctly whether layer switching needs keyframes") { + val structure = lds.newTemplateDependencyStructure!! + for (fromS in 0..2) { + for (fromT in 0..2) { + val fromDT = 3 * fromS + fromT + for (toS in 0..2) { + for (toT in 0..2) { + val toDT = 3 * toS + toT + /* With this structure you can switch down spatial layers, or to other temporal + * layers within the same spatial layer, without a keyframe; but switching up + * spatial layers needs a keyframe. + */ + withClue("from DT $fromDT to DT $toDT") { + if (fromS >= toS) { + structure.canSwitchWithoutKeyframe( + fromDt = fromDT, + toDt = toDT + ) shouldBe true + } else { + structure.canSwitchWithoutKeyframe( + fromDt = fromDT, + toDt = toDT + ) shouldBe false + } + } + } + } + } + } + } + should("calculate DTI bitmasks corresponding to a given DT") { + val structure = lds.newTemplateDependencyStructure!! + structure.getDtBitmaskForDt(0) shouldBe 0b000000001 + structure.getDtBitmaskForDt(1) shouldBe 0b000000011 + structure.getDtBitmaskForDt(2) shouldBe 0b000000111 + structure.getDtBitmaskForDt(3) shouldBe 0b000001001 + structure.getDtBitmaskForDt(4) shouldBe 0b000011011 + structure.getDtBitmaskForDt(5) shouldBe 0b000111111 + structure.getDtBitmaskForDt(6) shouldBe 0b001001001 + structure.getDtBitmaskForDt(7) shouldBe 0b011011011 + structure.getDtBitmaskForDt(8) shouldBe 0b111111111 + } + should("Calculate its own length properly") { + lds.encodedLength shouldBe descL3T3.size + } + should("Be re-encoded to the same bytes") { + val buf = ByteArray(lds.encodedLength) + lds.write(buf, 0, buf.size) + buf shouldBe descL3T3 + } + } + context("a descriptor with a K-SVC dependency structure") { + val ldsr = Av1DependencyDescriptorReader(descL3T3Key, 0, descL3T3Key.size) + val lds = ldsr.parse(null) + should("be parsed properly") { + lds.startOfFrame shouldBe true + lds.endOfFrame shouldBe false + lds.frameNumber shouldBe 0x0085 + lds.activeDecodeTargetsBitmask shouldBe 0x1ff + + val structure = lds.newTemplateDependencyStructure + structure shouldNotBe null + structure!!.decodeTargetCount shouldBe 9 + structure.maxTemporalId shouldBe 2 + structure.maxSpatialId shouldBe 2 + } + should("calculate correct frame info") { + val ldsi = lds.frameInfo + ldsi.spatialId shouldBe 0 + ldsi.temporalId shouldBe 0 + } + should("calculate correctly whether layer switching needs keyframes") { + val structure = lds.newTemplateDependencyStructure!! + for (fromS in 0..2) { + for (fromT in 0..2) { + val fromDT = 3 * fromS + fromT + for (toS in 0..2) { + for (toT in 0..2) { + val toDT = 3 * toS + toT + /* With this structure you can switch to other temporal + * layers within the same spatial layer, without a keyframe; but switching + * spatial layers needs a keyframe. + */ + withClue("from DT $fromDT to DT $toDT") { + if (fromS == toS) { + structure.canSwitchWithoutKeyframe( + fromDt = fromDT, + toDt = toDT + ) shouldBe true + } else { + structure.canSwitchWithoutKeyframe( + fromDt = fromDT, + toDt = toDT + ) shouldBe false + } + } + } + } + } + } + } + should("calculate DTI bitmasks corresponding to a given DT") { + val structure = lds.newTemplateDependencyStructure!! + structure.getDtBitmaskForDt(0) shouldBe 0b000000001 + structure.getDtBitmaskForDt(1) shouldBe 0b000000011 + structure.getDtBitmaskForDt(2) shouldBe 0b000000111 + structure.getDtBitmaskForDt(3) shouldBe 0b000001000 + structure.getDtBitmaskForDt(4) shouldBe 0b000011000 + structure.getDtBitmaskForDt(5) shouldBe 0b000111000 + structure.getDtBitmaskForDt(6) shouldBe 0b001000000 + structure.getDtBitmaskForDt(7) shouldBe 0b011000000 + structure.getDtBitmaskForDt(8) shouldBe 0b111000000 + } + should("Calculate its own length properly") { + lds.encodedLength shouldBe descL3T3Key.size + } + should("Be re-encoded to the same bytes") { + val buf = ByteArray(lds.encodedLength) + lds.write(buf, 0, buf.size) + buf shouldBe descL3T3Key + } + } + context("a descriptor with a K-SVC dependency structure with key shift") { + val ldsr = Av1DependencyDescriptorReader(descL2T2KeyShift, 0, descL2T2KeyShift.size) + val lds = ldsr.parse(null) + should("be parsed properly") { + lds.startOfFrame shouldBe true + lds.endOfFrame shouldBe false + lds.frameNumber shouldBe 0x00ed + lds.activeDecodeTargetsBitmask shouldBe 0x0f + + val structure = lds.newTemplateDependencyStructure + structure shouldNotBe null + structure!!.decodeTargetCount shouldBe 4 + structure.maxTemporalId shouldBe 1 + structure.maxSpatialId shouldBe 1 + } + should("calculate correct frame info") { + val ldsi = lds.frameInfo + ldsi.spatialId shouldBe 0 + ldsi.temporalId shouldBe 0 + } + should("calculate correctly whether layer switching needs keyframes") { + val structure = lds.newTemplateDependencyStructure!! + for (fromS in 0..1) { + for (fromT in 0..1) { + val fromDT = 2 * fromS + fromT + for (toS in 0..1) { + for (toT in 0..1) { + val toDT = 2 * toS + toT + /* With this structure you can switch to other temporal + * layers within the same spatial layer, without a keyframe; but switching + * spatial layers needs a keyframe. + */ + withClue("from DT $fromDT to DT $toDT") { + if (fromS == toS) { + structure.canSwitchWithoutKeyframe( + fromDt = fromDT, + toDt = toDT + ) shouldBe true + } else { + structure.canSwitchWithoutKeyframe( + fromDt = fromDT, + toDt = toDT + ) shouldBe false + } + } + } + } + } + } + } + should("calculate DTI bitmasks corresponding to a given DT") { + val structure = lds.newTemplateDependencyStructure!! + structure.getDtBitmaskForDt(0) shouldBe 0b0001 + structure.getDtBitmaskForDt(1) shouldBe 0b0011 + structure.getDtBitmaskForDt(2) shouldBe 0b0100 + structure.getDtBitmaskForDt(3) shouldBe 0b1100 + } + should("Calculate its own length properly") { + lds.encodedLength shouldBe descL2T2KeyShift.size + } + should("Be re-encoded to the same bytes") { + val buf = ByteArray(lds.encodedLength) + lds.write(buf, 0, buf.size) + buf shouldBe descL2T2KeyShift + } + } + context("a descriptor following the dependency structure, specifying decode targets") { + val ldsr = Av1DependencyDescriptorReader(descL3T3, 0, descL3T3.size) + val lds = ldsr.parse(null) + val mdsr = Av1DependencyDescriptorReader(midDescScalable, 0, midDescScalable.size) + val mds = mdsr.parse(lds.newTemplateDependencyStructure) + + should("be parsed properly") { + mds.startOfFrame shouldBe true + mds.endOfFrame shouldBe true + mds.frameNumber shouldBe 0x0146 + + mds.newTemplateDependencyStructure shouldBe null + mds.activeDecodeTargetsBitmask shouldBe 0x7 + } + should("calculate correct frame info") { + val mdsi = mds.frameInfo + mdsi.spatialId shouldBe 0 + mdsi.temporalId shouldBe 1 + } + should("Calculate its own length properly") { + mds.encodedLength shouldBe midDescScalable.size + } + should("Be re-encoded to the same bytes") { + val buf = ByteArray(mds.encodedLength) + mds.write(buf, 0, buf.size) + buf shouldBe midDescScalable + } + } + context("another such descriptor") { + val ldsr = Av1DependencyDescriptorReader(longForMid2, 0, longForMid2.size) + val lds = ldsr.parse(null) + val mdsr = Av1DependencyDescriptorReader(midDescScalable2, 0, midDescScalable2.size) + val mds = mdsr.parse(lds.newTemplateDependencyStructure) + + should("be parsed properly") { + mds.startOfFrame shouldBe true + mds.endOfFrame shouldBe true + mds.frameNumber shouldBe 0x03ce + + mds.newTemplateDependencyStructure shouldBe null + mds.activeDecodeTargetsBitmask shouldBe 0x7 + } + should("calculate correct frame info") { + val mdsi = mds.frameInfo + mdsi.spatialId shouldBe 0 + mdsi.temporalId shouldBe 1 + } + should("Calculate its own length properly") { + mds.encodedLength shouldBe midDescScalable2.size + } + should("Be re-encoded to the same bytes") { + val buf = ByteArray(mds.encodedLength) + mds.write(buf, 0, buf.size) + buf shouldBe midDescScalable2 + } + } + context("a descriptor without a dependency structure") { + val mdsr = Av1DependencyDescriptorReader(midDescScalable, 0, midDescScalable.size) + should("be parseable as the stateless subset") { + val mds = mdsr.parseStateless() + + mds.startOfFrame shouldBe true + mds.endOfFrame shouldBe true + mds.frameNumber shouldBe 0x0146 + + mds.newTemplateDependencyStructure shouldBe null + } + should("fail to parse if the dependency structure is not present") { + shouldThrow { + mdsr.parse(null) + } + } + } + context("a descriptor without extended fields") { + val ld1r = Av1DependencyDescriptorReader(descL1T1, 0, descL1T1.size) + val ld1 = ld1r.parse(null) + val sd1r = Av1DependencyDescriptorReader(shortDesc, 0, shortDesc.size) + val sd1 = sd1r.parse(ld1.newTemplateDependencyStructure) + + should("be parsed properly") { + sd1.startOfFrame shouldBe false + sd1.endOfFrame shouldBe true + sd1.frameNumber shouldBe 0x0001 + + sd1.newTemplateDependencyStructure shouldBe null + sd1.activeDecodeTargetsBitmask shouldBe null + } + should("calculate correct frame info") { + val sd1i = sd1.frameInfo + sd1i.spatialId shouldBe 0 + sd1i.temporalId shouldBe 0 + } + should("Calculate its own length properly") { + sd1.encodedLength shouldBe shortDesc.size + } + should("Be re-encoded to the same bytes") { + val buf = ByteArray(sd1.encodedLength) + sd1.write(buf, 0, buf.size) + buf shouldBe shortDesc + } + } + context("A descriptor with a Temporal-only scalability structure ") { + val ldsr = Av1DependencyDescriptorReader(descL1T3, 0, descL1T3.size) + val lds = ldsr.parse(null) + should("be parsed properly") { + lds.startOfFrame shouldBe true + lds.endOfFrame shouldBe false + lds.frameNumber shouldBe 0x0001 + lds.activeDecodeTargetsBitmask shouldBe 0x7 + + val structure = lds.newTemplateDependencyStructure + structure shouldNotBe null + structure!!.decodeTargetCount shouldBe 3 + structure.maxTemporalId shouldBe 2 + structure.maxSpatialId shouldBe 0 + } + should("calculate correct frame info") { + val ldsi = lds.frameInfo + ldsi.spatialId shouldBe 0 + ldsi.temporalId shouldBe 0 + } + should("calculate correctly whether layer switching needs keyframes") { + val structure = lds.newTemplateDependencyStructure!! + val fromS = 0 + for (fromT in 0..2) { + val fromDT = 3 * fromS + fromT + val toS = 0 + for (toT in 0..2) { + val toDT = 3 * toS + toT + /* With this structure you can switch down spatial layers, or to other temporal + * layers within the same spatial layer, without a keyframe; but switching up + * spatial layers needs a keyframe. + */ + withClue("from DT $fromDT to DT $toDT") { + structure.canSwitchWithoutKeyframe( + fromDt = fromDT, + toDt = toDT + ) shouldBe true + } + } + } + } + should("calculate DTI bitmasks corresponding to a given DT") { + val structure = lds.newTemplateDependencyStructure!! + structure.getDtBitmaskForDt(0) shouldBe 0b001 + structure.getDtBitmaskForDt(1) shouldBe 0b011 + structure.getDtBitmaskForDt(2) shouldBe 0b111 + } + should("Calculate its own length properly") { + lds.encodedLength shouldBe descL1T3.size + } + should("Be re-encoded to the same bytes") { + val buf = ByteArray(lds.encodedLength) + lds.write(buf, 0, buf.size) + buf shouldBe descL1T3 + } + } + context("a descriptor with a simulcast dependency structure") { + val ldsr = Av1DependencyDescriptorReader(descS3T3, 0, descS3T3.size) + val lds = ldsr.parse(null) + should("be parsed properly") { + lds.startOfFrame shouldBe true + lds.endOfFrame shouldBe true + lds.frameNumber shouldBe 0x0001 + lds.activeDecodeTargetsBitmask shouldBe 0x1ff + + val structure = lds.newTemplateDependencyStructure + structure shouldNotBe null + structure!!.decodeTargetCount shouldBe 9 + structure.maxTemporalId shouldBe 2 + structure.maxSpatialId shouldBe 2 + } + should("calculate correct frame info") { + val ldsi = lds.frameInfo + ldsi.spatialId shouldBe 0 + ldsi.temporalId shouldBe 0 + } + should("calculate correctly whether layer switching needs keyframes") { + val structure = lds.newTemplateDependencyStructure!! + for (fromS in 0..2) { + for (fromT in 0..2) { + val fromDT = 3 * fromS + fromT + for (toS in 0..2) { + for (toT in 0..2) { + val toDT = 3 * toS + toT + /* With this structure you can switch to other temporal + * layers within the same spatial layer, without a keyframe; but switching + * spatial layers needs a keyframe. + */ + withClue("from DT $fromDT to DT $toDT") { + if (fromS == toS) { + structure.canSwitchWithoutKeyframe( + fromDt = fromDT, + toDt = toDT + ) shouldBe true + } else { + structure.canSwitchWithoutKeyframe( + fromDt = fromDT, + toDt = toDT + ) shouldBe false + } + } + } + } + } + } + } + should("calculate DTI bitmasks corresponding to a given DT") { + val structure = lds.newTemplateDependencyStructure!! + structure.getDtBitmaskForDt(0) shouldBe 0b000000001 + structure.getDtBitmaskForDt(1) shouldBe 0b000000011 + structure.getDtBitmaskForDt(2) shouldBe 0b000000111 + structure.getDtBitmaskForDt(3) shouldBe 0b000001000 + structure.getDtBitmaskForDt(4) shouldBe 0b000011000 + structure.getDtBitmaskForDt(5) shouldBe 0b000111000 + structure.getDtBitmaskForDt(6) shouldBe 0b001000000 + structure.getDtBitmaskForDt(7) shouldBe 0b011000000 + structure.getDtBitmaskForDt(8) shouldBe 0b111000000 + } + should("Calculate its own length properly") { + lds.encodedLength shouldBe descS3T3.size + } + should("Be re-encoded to the same bytes") { + val buf = ByteArray(lds.encodedLength) + lds.write(buf, 0, buf.size) + buf shouldBe descS3T3 + } + } + } + } +} diff --git a/rtp/src/test/kotlin/org/jitsi/rtp/rtp/header_extensions/DumpAv1DependencyDescriptor.kt b/rtp/src/test/kotlin/org/jitsi/rtp/rtp/header_extensions/DumpAv1DependencyDescriptor.kt new file mode 100644 index 0000000000..b488ad9a20 --- /dev/null +++ b/rtp/src/test/kotlin/org/jitsi/rtp/rtp/header_extensions/DumpAv1DependencyDescriptor.kt @@ -0,0 +1,36 @@ +/* + * Copyright @ 2018 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.rtp.rtp.header_extensions + +import jakarta.xml.bind.DatatypeConverter.parseHexBinary + +fun main(args: Array) { + var structure: Av1TemplateDependencyStructure? = null + var line: String? + while (readLine().also { line = it } != null) { + try { + val descBinary = parseHexBinary(line) + val reader = Av1DependencyDescriptorReader(descBinary, 0, descBinary.size) + val desc = reader.parse(structure) + desc.newTemplateDependencyStructure?.let { structure = it } + println(desc.toJSONString()) + val frameInfo = desc.frameInfo + println(frameInfo) + } catch (e: Exception) { + println(e.message) + } + } +}