diff --git a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt index cebdcba..b0dcf0b 100644 --- a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt @@ -67,6 +67,8 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT Applies a merge function to all entries in this Treap node. */ abstract fun getShallowMerger(merger: (K, V?, V?) -> V?): (S?, S?) -> S? + abstract fun getShallowUnionMerger(merger: (K, V, V) -> V): (S, S) -> S + abstract fun getShallowIntersectMerger(merger: (K, V, V) -> V): (S, S) -> S? private fun containsEntry(entry: Map.Entry): Boolean { val key = entry.key @@ -152,6 +154,52 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT override operator fun iterator() = entrySequence().map { it.value }.iterator() } + override fun union(m: Map, merger: (K, V, V) -> V): TreapMap = + m.useAsTreap( + { otherTreap -> self.unionWith(otherTreap, getShallowUnionMerger(merger)) ?: clear() }, + { fallbackUnion(m, merger) } + ) + + override fun parallelUnion(m: Map, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap = + m.useAsTreap( + { otherTreap -> self.parallelUnionWith(otherTreap, parallelThresholdLog2, getShallowUnionMerger(merger)) ?: clear() }, + { fallbackUnion(m, merger) } + ) + + private fun fallbackUnion(m: Map, merger: (K, V, V) -> V): TreapMap { + var newThis = this as TreapMap + for ((k, v) in m.entries) { + if (k in this) { + newThis = newThis + (k to merger(k, this[k]!!, v)) + } else { + newThis = newThis + (k to v) + } + } + return newThis + } + + override fun intersect(m: Map, merger: (K, V, V) -> V): TreapMap = + m.useAsTreap( + { otherTreap -> self.intersectWith(otherTreap, getShallowIntersectMerger(merger)) ?: clear() }, + { fallbackIntersect(m, merger) } + ) + + override fun parallelIntersect(m: Map, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap = + m.useAsTreap( + { otherTreap -> self.parallelIntersectWith(otherTreap, parallelThresholdLog2, getShallowIntersectMerger(merger)) ?: clear() }, + { fallbackIntersect(m, merger) } + ) + + private fun fallbackIntersect(m: Map, merger: (K, V, V) -> V): TreapMap { + var newThis = clear() + for ((k, v) in m.entries) { + if (k in this) { + newThis = newThis + (k to merger(k, this[k]!!, v)) + } + } + return newThis + } + /** Merges the entries in `m` with the entries in this AbstractTreapMap, applying the "merger" function to get the new values for each key. @@ -491,3 +539,107 @@ private fun <@Treapable K, V, @Treapable S : AbstractTreapMap> S?.merge return newThis?.with(newLeft, newRight) ?: (newLeft join newRight) } +internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap> S?.unionWith( + that: S?, + shallowUnion: (S, S) -> S +): S? = + notForking(this to that) { + unionWithImpl(that, shallowUnion) + } + +internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap> S?.parallelUnionWith( + that: S?, + parallelThresholdLog2: Int, + shallowUnion: (S, S) -> S +): S? = + maybeForking( + this to that, + { + it.first.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1) && + it.second.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1) + } + ) { + unionWithImpl(that, shallowUnion) + } + +context(ThresholdForker>) +private fun <@Treapable K, V, @Treapable S : AbstractTreapMap> S?.unionWithImpl( + that: S?, + shallowUnion: (S, S) -> S +): S? { + val (newLeft, newRight, newThis) = when { + this == null -> return that + that == null -> return this + this.comparePriorityTo(that) >= 0 -> { + val thatSplit = that.split(this) + fork( + this to that, + { this.left.unionWithImpl(thatSplit.left, shallowUnion) }, + { this.right.unionWithImpl(thatSplit.right, shallowUnion) }, + { thatSplit.duplicate?.let { shallowUnion(this, it) } ?: this } + ) + } + else -> { + val thisSplit = this.split(that) + fork( + this to that, + { thisSplit.left.unionWithImpl(that.left, shallowUnion) }, + { thisSplit.right.unionWithImpl(that.right, shallowUnion) }, + { thisSplit.duplicate?.let { shallowUnion(it, that) } ?: that } + ) + } + } + return newThis.with(newLeft, newRight) +} + +internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap> S?.intersectWith( + that: S?, + shallowIntersect: (S, S) -> S? +): S? = + notForking(this to that) { + intersectWithImpl(that, shallowIntersect) + } + +internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap> S?.parallelIntersectWith( + that: S?, + parallelThresholdLog2: Int, + shallowIntersect: (S, S) -> S? +): S? = + maybeForking( + this to that, + { + it.first.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1) && + it.second.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1) + } + ) { + intersectWithImpl(that, shallowIntersect) + } + +context(ThresholdForker>) +private fun <@Treapable K, V, @Treapable S : AbstractTreapMap> S?.intersectWithImpl( + that: S?, + shallowIntersect: (S, S) -> S? +): S? { + val (newLeft, newRight, newThis) = when { + this == null || that == null -> return null + this.comparePriorityTo(that) >= 0 -> { + val thatSplit = that.split(this) + fork( + this to that, + { this.left.intersectWithImpl(thatSplit.left, shallowIntersect) }, + { this.right.intersectWithImpl(thatSplit.right, shallowIntersect) }, + { thatSplit.duplicate?.let { shallowIntersect(this, it) } } + ) + } + else -> { + val thisSplit = this.split(that) + fork( + this to that, + { thisSplit.left.intersectWithImpl(that.left, shallowIntersect) }, + { thisSplit.right.intersectWithImpl(that.right, shallowIntersect) }, + { thisSplit.duplicate?.let { shallowIntersect(it, that) } } + ) + } + } + return newThis?.with(newLeft, newRight) ?: (newLeft join newRight) +} diff --git a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt index bdb0862..a99cc4a 100644 --- a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt @@ -41,6 +41,12 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap put(key, v) } + override fun union(m: Map, merger: (K, V, V) -> V): TreapMap = putAll(m) + override fun parallelUnion(m: Map, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap = putAll(m) + + override fun intersect(m: Map, merger: (K, V, V) -> V): TreapMap = this + override fun parallelIntersect(m: Map, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap = this + override fun merge(m: Map, merger: (K, V?, V?) -> V?): TreapMap { var map: TreapMap = this for ((key, value) in m) { diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt index 081dee2..cea3982 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt @@ -66,6 +66,48 @@ internal class HashTreapMap<@Treapable K, V>( } } + override fun getShallowUnionMerger( + merger: (K, V, V) -> V + ): (HashTreapMap, HashTreapMap) -> HashTreapMap = { t1, t2 -> + var newPairs: KeyValuePairList.More? = null + t1.forEachPair { (k, v1) -> + val v = if (t2.shallowContainsKey(k)) { + @Suppress("UNCHECKED_CAST") + merger(k, v1, t2.shallowGetValue(k) as V) + } else { + v1 + } + newPairs = KeyValuePairList.More(k, v, newPairs) + } + t2.forEachPair { (k, v2) -> + if (!t1.shallowContainsKey(k)) { + newPairs = KeyValuePairList.More(k, v2, newPairs) + } + } + newPairs!!.let { firstPair -> + val newNode = HashTreapMap(firstPair.key, firstPair.value, firstPair.next, t1.left, t1.right) + if (newNode.shallowEquals(t1)) { t1 } else { newNode } + } + } + + override fun getShallowIntersectMerger( + merger: (K, V, V) -> V + ): (HashTreapMap, HashTreapMap) -> HashTreapMap? = { t1, t2 -> + var newPairs: KeyValuePairList.More? = null + t1.forEachPair { (k, v1) -> + if (t2.shallowContainsKey(k)) { + @Suppress("UNCHECKED_CAST") + val v2 = t2.shallowGetValue(k) as V + val v = merger(k, v1, v2) + newPairs = KeyValuePairList.More(k, v, newPairs) + } + } + newPairs?.let { firstPair -> + val newNode = HashTreapMap(firstPair.key, firstPair.value, firstPair.next, t1.left, t1.right) + if (newNode.shallowEquals(t1)) { t1 } else { newNode } + } + } + private inline fun KeyValuePairList?.forEachPair(action: (KeyValuePairList) -> Unit) { var current = this while (current != null) { diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt index 2d0cdac..0c3617a 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt @@ -33,6 +33,20 @@ internal class SortedTreapMap<@Treapable K, V>( override fun arbitraryOrNull(): Map.Entry? = MapEntry(key, value) + override fun getShallowUnionMerger( + merger: (K, V, V) -> V + ): (SortedTreapMap, SortedTreapMap) -> SortedTreapMap = { t1, t2 -> + val v = merger(t1.key, t1.value, t2.value) + SortedTreapMap(t1.key, v, t1.left, t1.right) + } + + override fun getShallowIntersectMerger( + merger: (K, V, V) -> V + ): (SortedTreapMap, SortedTreapMap) -> SortedTreapMap? = { t1, t2 -> + val v = merger(t1.key, t1.value, t2.value) + SortedTreapMap(t1.key, v, t1.left, t1.right) + } + override fun getShallowMerger(merger: (K, V?, V?) -> V?): (SortedTreapMap?, SortedTreapMap?) -> SortedTreapMap? = { t1, t2 -> val k = (t1 ?: t2)!!.key val v1 = t1?.value diff --git a/collect/src/main/kotlin/com/certora/collect/Treap.kt b/collect/src/main/kotlin/com/certora/collect/Treap.kt index 6580046..7a456ea 100644 --- a/collect/src/main/kotlin/com/certora/collect/Treap.kt +++ b/collect/src/main/kotlin/com/certora/collect/Treap.kt @@ -174,7 +174,9 @@ internal fun <@Treapable T, S : Treap> Treap?.split(key: TreapKey } } } -internal class Split<@Treapable T, S : Treap>(var left: S?, var right: S?, var duplicate: S?) +internal class Split<@Treapable T, S : Treap>(var left: S?, var right: S?, var duplicate: S?) { + override fun toString(): String = "Split(left=$left, right=$right, duplicate=$duplicate)" +} /** diff --git a/collect/src/main/kotlin/com/certora/collect/TreapMap.kt b/collect/src/main/kotlin/com/certora/collect/TreapMap.kt index 6c5dc0e..7931468 100644 --- a/collect/src/main/kotlin/com/certora/collect/TreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/TreapMap.kt @@ -28,13 +28,92 @@ public sealed interface TreapMap : PersistentMap { */ public fun arbitraryOrNull(): Map.Entry? + /** + Calls [action] for each entry in the map. + + Traverses the treap without allocating temporary storage, which may be more efficient than `entries.forEach`. + */ public fun forEachEntry(action: (Map.Entry) -> Unit): Unit + /** + Produces a new map containing the keys from this map and another map [m]. + + If a key is present in just one of the maps, the resulting map will contain the key with the corresponding value + from that map. If a key is present in both maps, [merger] is called with the key, the value from this map, and + the value from [m], in that order, and the returned value will appear in the resulting map. + */ + public fun union( + m: Map, + merger: (K, V, V) -> V + ): TreapMap + + /** + Produces a new map containing the keys from this map and another map [m]. + + If a key is present in just one of the maps, the resulting map will contain the key with the corresponding value + from that map. If a key is present in both maps, [merger] is called with the key, the value from this map, and + the value from [m], in that order, and the returned value will appear in the resulting map. + + Merge operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2. + */ + public fun parallelUnion( + m: Map, + parallelThresholdLog2: Int = 4, + merger: (K, V, V) -> V + ): TreapMap + + /** + Produces a new map containing the keys that are present in both this map and another map [m]. + + For each key, the resulting map will contain the key with the value returned by [merger], which is called with + the key, the value from this map, and the value from [m], in that order. + */ + public fun intersect( + m: Map, + merger: (K, V, V) -> V + ): TreapMap + + /** + Produces a new map containing the keys that are present in both this map and another map [m]. + + For each key, the resulting map will contain the key with the value returned by [merger], which is called with + the key, the value from this map, and the value from [m], in that order. + + Merge operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2. + */ + public fun parallelIntersect( + m: Map, + parallelThresholdLog2: Int = 4, + merger: (K, V, V) -> V + ): TreapMap + + /** + Produces a new [TreapMap] with updated entries, by applying supplied [merger] to each entry of this map and + another map [m]. + + The [merger] function is called for each key that is present in either map, with the key, the value from this + map, and the value from [m], in that order, as arguments. If the key is not present in one of the maps, the + corresponding [merger] argument will be `null`. + + If the [merger] function returns null, the key is not added to the resulting map. + */ public fun merge( m: Map, merger: (K, V?, V?) -> V? ): TreapMap + /** + Produces a new [TreapMap] with updated entries, by applying supplied [merger] to each entry of this map and + another map [m]. + + The [merger] function is called for each key that is present in either map, with the key, the value from this + map, and the value from [m], in that order, as arguments. If the key is not present in one of the maps, the + corresponding [merger] argument will be `null`. + + If the [merger] function returns null, the key is not added to the resulting map. + + Merge operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2. + */ public fun parallelMerge( m: Map, parallelThresholdLog2: Int = 4, @@ -68,12 +147,23 @@ public sealed interface TreapMap : PersistentMap { transform: (K, V) -> R? ): TreapMap + /** + Produces a new [TreapMap] with the entry for the specified [key] updated via [merger]. + + [merger] is called with the current value for the key (or null if the key is absent), and supplied [value] + argument. If the [merger] function returns null, the key will be absent from the resulting map. Otherwise + the resulting map will contain the key with the value returned by the [merger] function. + */ public fun updateEntry( key: K, value: U, merger: (V?, U) -> V? ): TreapMap + /** + Produces a sequence from the entries of this map and another map. For each key, the result is an entry mapping + the key to a pair of values. Each value may be null, if the key is not present in the corresponding map. + */ public fun zip( m: Map ): Sequence>> diff --git a/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt b/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt index 3bece70..e130f6d 100644 --- a/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt +++ b/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt @@ -423,6 +423,66 @@ abstract class TreapMapTest { m1.merge(m2, merger3)) } + @Test + fun union() { + assertEquals(testMapOf(), testMapOf().union(testMapOf()) { _, a, _ -> a }) + assertEquals(testMapOf(1 to 2), testMapOf(1 to 2).union(testMapOf()) { _, a, _ -> a }) + assertEquals(testMapOf(1 to 2), testMapOf().union(testMapOf(1 to 2)) { _, a, _ -> a }) + assertEquals( + testMapOf(1 to 2, 2 to 3, 3 to 4), + testMapOf(1 to 2, 2 to 3).union(testMapOf(2 to 3, 3 to 4)) { _, a, _ -> a } + ) + + val m1 = testMapOf(2 to 2, 3 to 3) + val m2 = testMapOf(3 to 4) + assertEquals( + mapOf(2 to 2, 3 to 3), + m1.union(m2) { _, a, _ -> a } + ) + assertEquals( + mapOf(2 to 2, 3 to 4), + m2.union(m1) { _, a, _ -> a } + ) + assertEquals( + mapOf(2 to 2, 3 to 4), + m1.union(m2) { _, _, b -> b } + ) + assertEquals( + mapOf(2 to 2, 3 to 3), + m2.union(m1) { _, _, b -> b } + ) + } + + @Test + fun intersect() { + assertEquals(testMapOf(), testMapOf().intersect(testMapOf()) { _, a, _ -> a }) + assertEquals(testMapOf(), testMapOf(1 to 2).intersect(testMapOf()) { _, a, _ -> a }) + assertEquals(testMapOf(), testMapOf().intersect(testMapOf(1 to 2)) { _, a, _ -> a }) + assertEquals( + testMapOf(2 to 3), + testMapOf(1 to 2, 2 to 3).intersect(testMapOf(2 to 3, 3 to 4)) { _, a, _ -> a } + ) + + val m1 = testMapOf(2 to 2, 3 to 3) + val m2 = testMapOf(3 to 4) + assertEquals( + mapOf(3 to 3), + m1.intersect(m2) { _, a, _ -> a } + ) + assertEquals( + mapOf(3 to 4), + m2.intersect(m1) { _, a, _ -> a } + ) + assertEquals( + mapOf(3 to 4), + m1.intersect(m2) { _, _, b -> b } + ) + assertEquals( + mapOf(3 to 3), + m2.intersect(m1) { _, _, b -> b } + ) + } + @Test fun zip() { assertEquals(