Skip to content

Commit

Permalink
TreapMap union and intersection (#19)
Browse files Browse the repository at this point in the history
`TreapMap.merge` is a very flexible way to merge two `TreapMap`s, but it
has a couple of drawbacks:

1) It always visits every entry in both maps 
2) It doesn't allow `null` values to be added to the result map (because
`merge` treats `null` as meaning "remove this entry")

The first issue can be a major performance problem if you're merging two
large maps with a relatively small key intersection, and you only need
to apply custom merge logic to the keys in the intersection. This turns
out to be very common.

To address this, we add new methods to `TreapMap`:

```kotlin
    public fun union(
        m: Map<K, V>,
        merger: (K, V, V) -> V
    ): TreapMap<K, V>

    public fun intersect(
        m: Map<K, V>,
        merger: (K, V, V) -> V
    ): TreapMap<K, V>
```

These apply the `merger` function only to the intersection of the two
maps' keys, and don't permit `merger` to discard entries. `union`
preserves all non-intersecting entries from both maps, while `intersect`
discards them.

We also add parallel versions of both functions, following the example
of `merge`/`parallelMerge`.

See Certora/EVMVerifier#6819 for an example of
where this can be used to get some significant performance benefits.
  • Loading branch information
ericeil authored Dec 19, 2024
1 parent a5f3410 commit 38a54f5
Show file tree
Hide file tree
Showing 7 changed files with 367 additions and 1 deletion.
152 changes: 152 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
Applies a merge function to all entries in this Treap node.
*/
abstract fun getShallowMerger(merger: (K, V?, V?) -> V?): (S?, S?) -> S?
abstract fun getShallowUnionMerger(merger: (K, V, V) -> V): (S, S) -> S
abstract fun getShallowIntersectMerger(merger: (K, V, V) -> V): (S, S) -> S?

private fun containsEntry(entry: Map.Entry<K, V>): Boolean {
val key = entry.key
Expand Down Expand Up @@ -152,6 +154,52 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
override operator fun iterator() = entrySequence().map { it.value }.iterator()
}

override fun union(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> =
m.useAsTreap(
{ otherTreap -> self.unionWith(otherTreap, getShallowUnionMerger(merger)) ?: clear() },
{ fallbackUnion(m, merger) }
)

override fun parallelUnion(m: Map<K, V>, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap<K, V> =
m.useAsTreap(
{ otherTreap -> self.parallelUnionWith(otherTreap, parallelThresholdLog2, getShallowUnionMerger(merger)) ?: clear() },
{ fallbackUnion(m, merger) }
)

private fun fallbackUnion(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> {
var newThis = this as TreapMap<K, V>
for ((k, v) in m.entries) {
if (k in this) {
newThis = newThis + (k to merger(k, this[k]!!, v))
} else {
newThis = newThis + (k to v)
}
}
return newThis
}

override fun intersect(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> =
m.useAsTreap(
{ otherTreap -> self.intersectWith(otherTreap, getShallowIntersectMerger(merger)) ?: clear() },
{ fallbackIntersect(m, merger) }
)

override fun parallelIntersect(m: Map<K, V>, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap<K, V> =
m.useAsTreap(
{ otherTreap -> self.parallelIntersectWith(otherTreap, parallelThresholdLog2, getShallowIntersectMerger(merger)) ?: clear() },
{ fallbackIntersect(m, merger) }
)

private fun fallbackIntersect(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> {
var newThis = clear()
for ((k, v) in m.entries) {
if (k in this) {
newThis = newThis + (k to merger(k, this[k]!!, v))
}
}
return newThis
}

/**
Merges the entries in `m` with the entries in this AbstractTreapMap, applying the "merger" function to get the
new values for each key.
Expand Down Expand Up @@ -491,3 +539,107 @@ private fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.merge
return newThis?.with(newLeft, newRight) ?: (newLeft join newRight)
}

internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.unionWith(
that: S?,
shallowUnion: (S, S) -> S
): S? =
notForking(this to that) {
unionWithImpl(that, shallowUnion)
}

internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.parallelUnionWith(
that: S?,
parallelThresholdLog2: Int,
shallowUnion: (S, S) -> S
): S? =
maybeForking(
this to that,
{
it.first.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1) &&
it.second.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1)
}
) {
unionWithImpl(that, shallowUnion)
}

context(ThresholdForker<Pair<S?, S?>>)
private fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.unionWithImpl(
that: S?,
shallowUnion: (S, S) -> S
): S? {
val (newLeft, newRight, newThis) = when {
this == null -> return that
that == null -> return this
this.comparePriorityTo(that) >= 0 -> {
val thatSplit = that.split(this)
fork(
this to that,
{ this.left.unionWithImpl(thatSplit.left, shallowUnion) },
{ this.right.unionWithImpl(thatSplit.right, shallowUnion) },
{ thatSplit.duplicate?.let { shallowUnion(this, it) } ?: this }
)
}
else -> {
val thisSplit = this.split(that)
fork(
this to that,
{ thisSplit.left.unionWithImpl(that.left, shallowUnion) },
{ thisSplit.right.unionWithImpl(that.right, shallowUnion) },
{ thisSplit.duplicate?.let { shallowUnion(it, that) } ?: that }
)
}
}
return newThis.with(newLeft, newRight)
}

internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.intersectWith(
that: S?,
shallowIntersect: (S, S) -> S?
): S? =
notForking(this to that) {
intersectWithImpl(that, shallowIntersect)
}

internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.parallelIntersectWith(
that: S?,
parallelThresholdLog2: Int,
shallowIntersect: (S, S) -> S?
): S? =
maybeForking(
this to that,
{
it.first.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1) &&
it.second.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1)
}
) {
intersectWithImpl(that, shallowIntersect)
}

context(ThresholdForker<Pair<S?, S?>>)
private fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.intersectWithImpl(
that: S?,
shallowIntersect: (S, S) -> S?
): S? {
val (newLeft, newRight, newThis) = when {
this == null || that == null -> return null
this.comparePriorityTo(that) >= 0 -> {
val thatSplit = that.split(this)
fork(
this to that,
{ this.left.intersectWithImpl(thatSplit.left, shallowIntersect) },
{ this.right.intersectWithImpl(thatSplit.right, shallowIntersect) },
{ thatSplit.duplicate?.let { shallowIntersect(this, it) } }
)
}
else -> {
val thisSplit = this.split(that)
fork(
this to that,
{ thisSplit.left.intersectWithImpl(that.left, shallowIntersect) },
{ thisSplit.right.intersectWithImpl(that.right, shallowIntersect) },
{ thisSplit.duplicate?.let { shallowIntersect(it, that) } }
)
}
}
return newThis?.with(newLeft, newRight) ?: (newLeft join newRight)
}
6 changes: 6 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K
else -> put(key, v)
}

override fun union(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> = putAll(m)
override fun parallelUnion(m: Map<K, V>, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap<K, V> = putAll(m)

override fun intersect(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> = this
override fun parallelIntersect(m: Map<K, V>, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap<K, V> = this

override fun merge(m: Map<K, V>, merger: (K, V?, V?) -> V?): TreapMap<K, V> {
var map: TreapMap<K, V> = this
for ((key, value) in m) {
Expand Down
42 changes: 42 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,48 @@ internal class HashTreapMap<@Treapable K, V>(
}
}

override fun getShallowUnionMerger(
merger: (K, V, V) -> V
): (HashTreapMap<K, V>, HashTreapMap<K, V>) -> HashTreapMap<K, V> = { t1, t2 ->
var newPairs: KeyValuePairList.More<K, V>? = null
t1.forEachPair { (k, v1) ->
val v = if (t2.shallowContainsKey(k)) {
@Suppress("UNCHECKED_CAST")
merger(k, v1, t2.shallowGetValue(k) as V)
} else {
v1
}
newPairs = KeyValuePairList.More(k, v, newPairs)
}
t2.forEachPair { (k, v2) ->
if (!t1.shallowContainsKey(k)) {
newPairs = KeyValuePairList.More(k, v2, newPairs)
}
}
newPairs!!.let { firstPair ->
val newNode = HashTreapMap(firstPair.key, firstPair.value, firstPair.next, t1.left, t1.right)
if (newNode.shallowEquals(t1)) { t1 } else { newNode }
}
}

override fun getShallowIntersectMerger(
merger: (K, V, V) -> V
): (HashTreapMap<K, V>, HashTreapMap<K, V>) -> HashTreapMap<K, V>? = { t1, t2 ->
var newPairs: KeyValuePairList.More<K, V>? = null
t1.forEachPair { (k, v1) ->
if (t2.shallowContainsKey(k)) {
@Suppress("UNCHECKED_CAST")
val v2 = t2.shallowGetValue(k) as V
val v = merger(k, v1, v2)
newPairs = KeyValuePairList.More(k, v, newPairs)
}
}
newPairs?.let { firstPair ->
val newNode = HashTreapMap(firstPair.key, firstPair.value, firstPair.next, t1.left, t1.right)
if (newNode.shallowEquals(t1)) { t1 } else { newNode }
}
}

private inline fun KeyValuePairList<K, V>?.forEachPair(action: (KeyValuePairList<K, V>) -> Unit) {
var current = this
while (current != null) {
Expand Down
14 changes: 14 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ internal class SortedTreapMap<@Treapable K, V>(

override fun arbitraryOrNull(): Map.Entry<K, V>? = MapEntry(key, value)

override fun getShallowUnionMerger(
merger: (K, V, V) -> V
): (SortedTreapMap<K, V>, SortedTreapMap<K, V>) -> SortedTreapMap<K, V> = { t1, t2 ->
val v = merger(t1.key, t1.value, t2.value)
SortedTreapMap(t1.key, v, t1.left, t1.right)
}

override fun getShallowIntersectMerger(
merger: (K, V, V) -> V
): (SortedTreapMap<K, V>, SortedTreapMap<K, V>) -> SortedTreapMap<K, V>? = { t1, t2 ->
val v = merger(t1.key, t1.value, t2.value)
SortedTreapMap(t1.key, v, t1.left, t1.right)
}

override fun getShallowMerger(merger: (K, V?, V?) -> V?): (SortedTreapMap<K, V>?, SortedTreapMap<K, V>?) -> SortedTreapMap<K, V>? = { t1, t2 ->
val k = (t1 ?: t2)!!.key
val v1 = t1?.value
Expand Down
4 changes: 3 additions & 1 deletion collect/src/main/kotlin/com/certora/collect/Treap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ internal fun <@Treapable T, S : Treap<T, S>> Treap<T, S>?.split(key: TreapKey<T>
}
}
}
internal class Split<@Treapable T, S : Treap<T, S>>(var left: S?, var right: S?, var duplicate: S?)
internal class Split<@Treapable T, S : Treap<T, S>>(var left: S?, var right: S?, var duplicate: S?) {
override fun toString(): String = "Split(left=$left, right=$right, duplicate=$duplicate)"
}


/**
Expand Down
90 changes: 90 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/TreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,92 @@ public sealed interface TreapMap<K, V> : PersistentMap<K, V> {
*/
public fun arbitraryOrNull(): Map.Entry<K, V>?

/**
Calls [action] for each entry in the map.
Traverses the treap without allocating temporary storage, which may be more efficient than `entries.forEach`.
*/
public fun forEachEntry(action: (Map.Entry<K, V>) -> Unit): Unit

/**
Produces a new map containing the keys from this map and another map [m].
If a key is present in just one of the maps, the resulting map will contain the key with the corresponding value
from that map. If a key is present in both maps, [merger] is called with the key, the value from this map, and
the value from [m], in that order, and the returned value will appear in the resulting map.
*/
public fun union(
m: Map<K, V>,
merger: (K, V, V) -> V
): TreapMap<K, V>

/**
Produces a new map containing the keys from this map and another map [m].
If a key is present in just one of the maps, the resulting map will contain the key with the corresponding value
from that map. If a key is present in both maps, [merger] is called with the key, the value from this map, and
the value from [m], in that order, and the returned value will appear in the resulting map.
Merge operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2.
*/
public fun parallelUnion(
m: Map<K, V>,
parallelThresholdLog2: Int = 4,
merger: (K, V, V) -> V
): TreapMap<K, V>

/**
Produces a new map containing the keys that are present in both this map and another map [m].
For each key, the resulting map will contain the key with the value returned by [merger], which is called with
the key, the value from this map, and the value from [m], in that order.
*/
public fun intersect(
m: Map<K, V>,
merger: (K, V, V) -> V
): TreapMap<K, V>

/**
Produces a new map containing the keys that are present in both this map and another map [m].
For each key, the resulting map will contain the key with the value returned by [merger], which is called with
the key, the value from this map, and the value from [m], in that order.
Merge operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2.
*/
public fun parallelIntersect(
m: Map<K, V>,
parallelThresholdLog2: Int = 4,
merger: (K, V, V) -> V
): TreapMap<K, V>

/**
Produces a new [TreapMap] with updated entries, by applying supplied [merger] to each entry of this map and
another map [m].
The [merger] function is called for each key that is present in either map, with the key, the value from this
map, and the value from [m], in that order, as arguments. If the key is not present in one of the maps, the
corresponding [merger] argument will be `null`.
If the [merger] function returns null, the key is not added to the resulting map.
*/
public fun merge(
m: Map<K, V>,
merger: (K, V?, V?) -> V?
): TreapMap<K, V>

/**
Produces a new [TreapMap] with updated entries, by applying supplied [merger] to each entry of this map and
another map [m].
The [merger] function is called for each key that is present in either map, with the key, the value from this
map, and the value from [m], in that order, as arguments. If the key is not present in one of the maps, the
corresponding [merger] argument will be `null`.
If the [merger] function returns null, the key is not added to the resulting map.
Merge operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2.
*/
public fun parallelMerge(
m: Map<K, V>,
parallelThresholdLog2: Int = 4,
Expand Down Expand Up @@ -68,12 +147,23 @@ public sealed interface TreapMap<K, V> : PersistentMap<K, V> {
transform: (K, V) -> R?
): TreapMap<K, R>

/**
Produces a new [TreapMap] with the entry for the specified [key] updated via [merger].
[merger] is called with the current value for the key (or null if the key is absent), and supplied [value]
argument. If the [merger] function returns null, the key will be absent from the resulting map. Otherwise
the resulting map will contain the key with the value returned by the [merger] function.
*/
public fun <U> updateEntry(
key: K,
value: U,
merger: (V?, U) -> V?
): TreapMap<K, V>

/**
Produces a sequence from the entries of this map and another map. For each key, the result is an entry mapping
the key to a pair of values. Each value may be null, if the key is not present in the corresponding map.
*/
public fun zip(
m: Map<out K, V>
): Sequence<Map.Entry<K, Pair<V?, V?>>>
Expand Down
Loading

0 comments on commit 38a54f5

Please sign in to comment.