From 38a54f595f87a3468b3cc4553b249c84ede19dc4 Mon Sep 17 00:00:00 2001 From: Eric Eilebrecht Date: Thu, 19 Dec 2024 11:47:36 -0800 Subject: [PATCH] TreapMap union and intersection (#19) `TreapMap.merge` is a very flexible way to merge two `TreapMap`s, but it has a couple of drawbacks: 1) It always visits every entry in both maps 2) It doesn't allow `null` values to be added to the result map (because `merge` treats `null` as meaning "remove this entry") The first issue can be a major performance problem if you're merging two large maps with a relatively small key intersection, and you only need to apply custom merge logic to the keys in the intersection. This turns out to be very common. To address this, we add new methods to `TreapMap`: ```kotlin public fun union( m: Map, merger: (K, V, V) -> V ): TreapMap public fun intersect( m: Map, merger: (K, V, V) -> V ): TreapMap ``` These apply the `merger` function only to the intersection of the two maps' keys, and don't permit `merger` to discard entries. `union` preserves all non-intersecting entries from both maps, while `intersect` discards them. We also add parallel versions of both functions, following the example of `merge`/`parallelMerge`. See https://github.com/Certora/EVMVerifier/pull/6819 for an example of where this can be used to get some significant performance benefits. --- .../com/certora/collect/AbstractTreapMap.kt | 152 ++++++++++++++++++ .../com/certora/collect/EmptyTreapMap.kt | 6 + .../com/certora/collect/HashTreapMap.kt | 42 +++++ .../com/certora/collect/SortedTreapMap.kt | 14 ++ .../main/kotlin/com/certora/collect/Treap.kt | 4 +- .../kotlin/com/certora/collect/TreapMap.kt | 90 +++++++++++ .../com/certora/collect/TreapMapTest.kt | 60 +++++++ 7 files changed, 367 insertions(+), 1 deletion(-) diff --git a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt index cebdcba..b0dcf0b 100644 --- a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt @@ -67,6 +67,8 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT Applies a merge function to all entries in this Treap node. */ abstract fun getShallowMerger(merger: (K, V?, V?) -> V?): (S?, S?) -> S? + abstract fun getShallowUnionMerger(merger: (K, V, V) -> V): (S, S) -> S + abstract fun getShallowIntersectMerger(merger: (K, V, V) -> V): (S, S) -> S? private fun containsEntry(entry: Map.Entry): Boolean { val key = entry.key @@ -152,6 +154,52 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT override operator fun iterator() = entrySequence().map { it.value }.iterator() } + override fun union(m: Map, merger: (K, V, V) -> V): TreapMap = + m.useAsTreap( + { otherTreap -> self.unionWith(otherTreap, getShallowUnionMerger(merger)) ?: clear() }, + { fallbackUnion(m, merger) } + ) + + override fun parallelUnion(m: Map, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap = + m.useAsTreap( + { otherTreap -> self.parallelUnionWith(otherTreap, parallelThresholdLog2, getShallowUnionMerger(merger)) ?: clear() }, + { fallbackUnion(m, merger) } + ) + + private fun fallbackUnion(m: Map, merger: (K, V, V) -> V): TreapMap { + var newThis = this as TreapMap + for ((k, v) in m.entries) { + if (k in this) { + newThis = newThis + (k to merger(k, this[k]!!, v)) + } else { + newThis = newThis + (k to v) + } + } + return newThis + } + + override fun intersect(m: Map, merger: (K, V, V) -> V): TreapMap = + m.useAsTreap( + { otherTreap -> self.intersectWith(otherTreap, getShallowIntersectMerger(merger)) ?: clear() }, + { fallbackIntersect(m, merger) } + ) + + override fun parallelIntersect(m: Map, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap = + m.useAsTreap( + { otherTreap -> self.parallelIntersectWith(otherTreap, parallelThresholdLog2, getShallowIntersectMerger(merger)) ?: clear() }, + { fallbackIntersect(m, merger) } + ) + + private fun fallbackIntersect(m: Map, merger: (K, V, V) -> V): TreapMap { + var newThis = clear() + for ((k, v) in m.entries) { + if (k in this) { + newThis = newThis + (k to merger(k, this[k]!!, v)) + } + } + return newThis + } + /** Merges the entries in `m` with the entries in this AbstractTreapMap, applying the "merger" function to get the new values for each key. @@ -491,3 +539,107 @@ private fun <@Treapable K, V, @Treapable S : AbstractTreapMap> S?.merge return newThis?.with(newLeft, newRight) ?: (newLeft join newRight) } +internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap> S?.unionWith( + that: S?, + shallowUnion: (S, S) -> S +): S? = + notForking(this to that) { + unionWithImpl(that, shallowUnion) + } + +internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap> S?.parallelUnionWith( + that: S?, + parallelThresholdLog2: Int, + shallowUnion: (S, S) -> S +): S? = + maybeForking( + this to that, + { + it.first.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1) && + it.second.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1) + } + ) { + unionWithImpl(that, shallowUnion) + } + +context(ThresholdForker>) +private fun <@Treapable K, V, @Treapable S : AbstractTreapMap> S?.unionWithImpl( + that: S?, + shallowUnion: (S, S) -> S +): S? { + val (newLeft, newRight, newThis) = when { + this == null -> return that + that == null -> return this + this.comparePriorityTo(that) >= 0 -> { + val thatSplit = that.split(this) + fork( + this to that, + { this.left.unionWithImpl(thatSplit.left, shallowUnion) }, + { this.right.unionWithImpl(thatSplit.right, shallowUnion) }, + { thatSplit.duplicate?.let { shallowUnion(this, it) } ?: this } + ) + } + else -> { + val thisSplit = this.split(that) + fork( + this to that, + { thisSplit.left.unionWithImpl(that.left, shallowUnion) }, + { thisSplit.right.unionWithImpl(that.right, shallowUnion) }, + { thisSplit.duplicate?.let { shallowUnion(it, that) } ?: that } + ) + } + } + return newThis.with(newLeft, newRight) +} + +internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap> S?.intersectWith( + that: S?, + shallowIntersect: (S, S) -> S? +): S? = + notForking(this to that) { + intersectWithImpl(that, shallowIntersect) + } + +internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap> S?.parallelIntersectWith( + that: S?, + parallelThresholdLog2: Int, + shallowIntersect: (S, S) -> S? +): S? = + maybeForking( + this to that, + { + it.first.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1) && + it.second.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1) + } + ) { + intersectWithImpl(that, shallowIntersect) + } + +context(ThresholdForker>) +private fun <@Treapable K, V, @Treapable S : AbstractTreapMap> S?.intersectWithImpl( + that: S?, + shallowIntersect: (S, S) -> S? +): S? { + val (newLeft, newRight, newThis) = when { + this == null || that == null -> return null + this.comparePriorityTo(that) >= 0 -> { + val thatSplit = that.split(this) + fork( + this to that, + { this.left.intersectWithImpl(thatSplit.left, shallowIntersect) }, + { this.right.intersectWithImpl(thatSplit.right, shallowIntersect) }, + { thatSplit.duplicate?.let { shallowIntersect(this, it) } } + ) + } + else -> { + val thisSplit = this.split(that) + fork( + this to that, + { thisSplit.left.intersectWithImpl(that.left, shallowIntersect) }, + { thisSplit.right.intersectWithImpl(that.right, shallowIntersect) }, + { thisSplit.duplicate?.let { shallowIntersect(it, that) } } + ) + } + } + return newThis?.with(newLeft, newRight) ?: (newLeft join newRight) +} diff --git a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt index bdb0862..a99cc4a 100644 --- a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt @@ -41,6 +41,12 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap put(key, v) } + override fun union(m: Map, merger: (K, V, V) -> V): TreapMap = putAll(m) + override fun parallelUnion(m: Map, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap = putAll(m) + + override fun intersect(m: Map, merger: (K, V, V) -> V): TreapMap = this + override fun parallelIntersect(m: Map, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap = this + override fun merge(m: Map, merger: (K, V?, V?) -> V?): TreapMap { var map: TreapMap = this for ((key, value) in m) { diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt index 081dee2..cea3982 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt @@ -66,6 +66,48 @@ internal class HashTreapMap<@Treapable K, V>( } } + override fun getShallowUnionMerger( + merger: (K, V, V) -> V + ): (HashTreapMap, HashTreapMap) -> HashTreapMap = { t1, t2 -> + var newPairs: KeyValuePairList.More? = null + t1.forEachPair { (k, v1) -> + val v = if (t2.shallowContainsKey(k)) { + @Suppress("UNCHECKED_CAST") + merger(k, v1, t2.shallowGetValue(k) as V) + } else { + v1 + } + newPairs = KeyValuePairList.More(k, v, newPairs) + } + t2.forEachPair { (k, v2) -> + if (!t1.shallowContainsKey(k)) { + newPairs = KeyValuePairList.More(k, v2, newPairs) + } + } + newPairs!!.let { firstPair -> + val newNode = HashTreapMap(firstPair.key, firstPair.value, firstPair.next, t1.left, t1.right) + if (newNode.shallowEquals(t1)) { t1 } else { newNode } + } + } + + override fun getShallowIntersectMerger( + merger: (K, V, V) -> V + ): (HashTreapMap, HashTreapMap) -> HashTreapMap? = { t1, t2 -> + var newPairs: KeyValuePairList.More? = null + t1.forEachPair { (k, v1) -> + if (t2.shallowContainsKey(k)) { + @Suppress("UNCHECKED_CAST") + val v2 = t2.shallowGetValue(k) as V + val v = merger(k, v1, v2) + newPairs = KeyValuePairList.More(k, v, newPairs) + } + } + newPairs?.let { firstPair -> + val newNode = HashTreapMap(firstPair.key, firstPair.value, firstPair.next, t1.left, t1.right) + if (newNode.shallowEquals(t1)) { t1 } else { newNode } + } + } + private inline fun KeyValuePairList?.forEachPair(action: (KeyValuePairList) -> Unit) { var current = this while (current != null) { diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt index 2d0cdac..0c3617a 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt @@ -33,6 +33,20 @@ internal class SortedTreapMap<@Treapable K, V>( override fun arbitraryOrNull(): Map.Entry? = MapEntry(key, value) + override fun getShallowUnionMerger( + merger: (K, V, V) -> V + ): (SortedTreapMap, SortedTreapMap) -> SortedTreapMap = { t1, t2 -> + val v = merger(t1.key, t1.value, t2.value) + SortedTreapMap(t1.key, v, t1.left, t1.right) + } + + override fun getShallowIntersectMerger( + merger: (K, V, V) -> V + ): (SortedTreapMap, SortedTreapMap) -> SortedTreapMap? = { t1, t2 -> + val v = merger(t1.key, t1.value, t2.value) + SortedTreapMap(t1.key, v, t1.left, t1.right) + } + override fun getShallowMerger(merger: (K, V?, V?) -> V?): (SortedTreapMap?, SortedTreapMap?) -> SortedTreapMap? = { t1, t2 -> val k = (t1 ?: t2)!!.key val v1 = t1?.value diff --git a/collect/src/main/kotlin/com/certora/collect/Treap.kt b/collect/src/main/kotlin/com/certora/collect/Treap.kt index 6580046..7a456ea 100644 --- a/collect/src/main/kotlin/com/certora/collect/Treap.kt +++ b/collect/src/main/kotlin/com/certora/collect/Treap.kt @@ -174,7 +174,9 @@ internal fun <@Treapable T, S : Treap> Treap?.split(key: TreapKey } } } -internal class Split<@Treapable T, S : Treap>(var left: S?, var right: S?, var duplicate: S?) +internal class Split<@Treapable T, S : Treap>(var left: S?, var right: S?, var duplicate: S?) { + override fun toString(): String = "Split(left=$left, right=$right, duplicate=$duplicate)" +} /** diff --git a/collect/src/main/kotlin/com/certora/collect/TreapMap.kt b/collect/src/main/kotlin/com/certora/collect/TreapMap.kt index 6c5dc0e..7931468 100644 --- a/collect/src/main/kotlin/com/certora/collect/TreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/TreapMap.kt @@ -28,13 +28,92 @@ public sealed interface TreapMap : PersistentMap { */ public fun arbitraryOrNull(): Map.Entry? + /** + Calls [action] for each entry in the map. + + Traverses the treap without allocating temporary storage, which may be more efficient than `entries.forEach`. + */ public fun forEachEntry(action: (Map.Entry) -> Unit): Unit + /** + Produces a new map containing the keys from this map and another map [m]. + + If a key is present in just one of the maps, the resulting map will contain the key with the corresponding value + from that map. If a key is present in both maps, [merger] is called with the key, the value from this map, and + the value from [m], in that order, and the returned value will appear in the resulting map. + */ + public fun union( + m: Map, + merger: (K, V, V) -> V + ): TreapMap + + /** + Produces a new map containing the keys from this map and another map [m]. + + If a key is present in just one of the maps, the resulting map will contain the key with the corresponding value + from that map. If a key is present in both maps, [merger] is called with the key, the value from this map, and + the value from [m], in that order, and the returned value will appear in the resulting map. + + Merge operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2. + */ + public fun parallelUnion( + m: Map, + parallelThresholdLog2: Int = 4, + merger: (K, V, V) -> V + ): TreapMap + + /** + Produces a new map containing the keys that are present in both this map and another map [m]. + + For each key, the resulting map will contain the key with the value returned by [merger], which is called with + the key, the value from this map, and the value from [m], in that order. + */ + public fun intersect( + m: Map, + merger: (K, V, V) -> V + ): TreapMap + + /** + Produces a new map containing the keys that are present in both this map and another map [m]. + + For each key, the resulting map will contain the key with the value returned by [merger], which is called with + the key, the value from this map, and the value from [m], in that order. + + Merge operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2. + */ + public fun parallelIntersect( + m: Map, + parallelThresholdLog2: Int = 4, + merger: (K, V, V) -> V + ): TreapMap + + /** + Produces a new [TreapMap] with updated entries, by applying supplied [merger] to each entry of this map and + another map [m]. + + The [merger] function is called for each key that is present in either map, with the key, the value from this + map, and the value from [m], in that order, as arguments. If the key is not present in one of the maps, the + corresponding [merger] argument will be `null`. + + If the [merger] function returns null, the key is not added to the resulting map. + */ public fun merge( m: Map, merger: (K, V?, V?) -> V? ): TreapMap + /** + Produces a new [TreapMap] with updated entries, by applying supplied [merger] to each entry of this map and + another map [m]. + + The [merger] function is called for each key that is present in either map, with the key, the value from this + map, and the value from [m], in that order, as arguments. If the key is not present in one of the maps, the + corresponding [merger] argument will be `null`. + + If the [merger] function returns null, the key is not added to the resulting map. + + Merge operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2. + */ public fun parallelMerge( m: Map, parallelThresholdLog2: Int = 4, @@ -68,12 +147,23 @@ public sealed interface TreapMap : PersistentMap { transform: (K, V) -> R? ): TreapMap + /** + Produces a new [TreapMap] with the entry for the specified [key] updated via [merger]. + + [merger] is called with the current value for the key (or null if the key is absent), and supplied [value] + argument. If the [merger] function returns null, the key will be absent from the resulting map. Otherwise + the resulting map will contain the key with the value returned by the [merger] function. + */ public fun updateEntry( key: K, value: U, merger: (V?, U) -> V? ): TreapMap + /** + Produces a sequence from the entries of this map and another map. For each key, the result is an entry mapping + the key to a pair of values. Each value may be null, if the key is not present in the corresponding map. + */ public fun zip( m: Map ): Sequence>> diff --git a/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt b/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt index 3bece70..e130f6d 100644 --- a/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt +++ b/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt @@ -423,6 +423,66 @@ abstract class TreapMapTest { m1.merge(m2, merger3)) } + @Test + fun union() { + assertEquals(testMapOf(), testMapOf().union(testMapOf()) { _, a, _ -> a }) + assertEquals(testMapOf(1 to 2), testMapOf(1 to 2).union(testMapOf()) { _, a, _ -> a }) + assertEquals(testMapOf(1 to 2), testMapOf().union(testMapOf(1 to 2)) { _, a, _ -> a }) + assertEquals( + testMapOf(1 to 2, 2 to 3, 3 to 4), + testMapOf(1 to 2, 2 to 3).union(testMapOf(2 to 3, 3 to 4)) { _, a, _ -> a } + ) + + val m1 = testMapOf(2 to 2, 3 to 3) + val m2 = testMapOf(3 to 4) + assertEquals( + mapOf(2 to 2, 3 to 3), + m1.union(m2) { _, a, _ -> a } + ) + assertEquals( + mapOf(2 to 2, 3 to 4), + m2.union(m1) { _, a, _ -> a } + ) + assertEquals( + mapOf(2 to 2, 3 to 4), + m1.union(m2) { _, _, b -> b } + ) + assertEquals( + mapOf(2 to 2, 3 to 3), + m2.union(m1) { _, _, b -> b } + ) + } + + @Test + fun intersect() { + assertEquals(testMapOf(), testMapOf().intersect(testMapOf()) { _, a, _ -> a }) + assertEquals(testMapOf(), testMapOf(1 to 2).intersect(testMapOf()) { _, a, _ -> a }) + assertEquals(testMapOf(), testMapOf().intersect(testMapOf(1 to 2)) { _, a, _ -> a }) + assertEquals( + testMapOf(2 to 3), + testMapOf(1 to 2, 2 to 3).intersect(testMapOf(2 to 3, 3 to 4)) { _, a, _ -> a } + ) + + val m1 = testMapOf(2 to 2, 3 to 3) + val m2 = testMapOf(3 to 4) + assertEquals( + mapOf(3 to 3), + m1.intersect(m2) { _, a, _ -> a } + ) + assertEquals( + mapOf(3 to 4), + m2.intersect(m1) { _, a, _ -> a } + ) + assertEquals( + mapOf(3 to 4), + m1.intersect(m2) { _, _, b -> b } + ) + assertEquals( + mapOf(3 to 3), + m2.intersect(m1) { _, _, b -> b } + ) + } + @Test fun zip() { assertEquals(