Skip to content

Commit

Permalink
Add forEachEntry and arbitraryOrNull (#17)
Browse files Browse the repository at this point in the history
- `forEachEntry`: Fast enumeration of map entries via a simple recursive
walk
- `arbitraryOrNull`: Quickly get a single arbitrary element/entry from a
set/map.
  • Loading branch information
ericeil authored Dec 9, 2024
1 parent 01ac23a commit 889dfe2
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 11 deletions.
4 changes: 4 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K
override fun remove(key: K): TreapMap<K, V> = this
override fun remove(key: K, value: V): TreapMap<K, V> = this

override fun arbitraryOrNull(): Map.Entry<K, V>? = null

override fun forEachEntry(action: (Map.Entry<K, V>) -> Unit): Unit {}

override fun <R : Any> updateValues(
transform: (K, V) -> R?
): TreapMap<K, R> = treapMapOf()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ internal class EmptyTreapSet<@Treapable E> private constructor() : TreapSet<E>,
override fun retainAll(elements: Collection<E>): TreapSet<E> = this
override fun single(): E = throw NoSuchElementException("Empty set.")
override fun singleOrNull(): E? = null
override fun arbitraryOrNull(): E? = null
override fun <R : Any> mapReduce(map: (E) -> R, reduce: (R, R) -> R): R? = null
override fun <R : Any> parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R? = null

Expand Down
8 changes: 8 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ internal class HashTreapMap<@Treapable K, V>(
this as? HashTreapMap<K, V>
?: (this as? PersistentMap.Builder<K, V>)?.build() as? HashTreapMap<K, V>

override fun arbitraryOrNull(): Map.Entry<K, V>? = MapEntry(key, value)

override fun getShallowMerger(merger: (K, V?, V?) -> V?): (HashTreapMap<K, V>?, HashTreapMap<K, V>?) -> HashTreapMap<K, V>? = { t1, t2 ->
var newPairs: KeyValuePairList.More<K, V>? = null
t1?.forEachPair { (k, v1) ->
Expand Down Expand Up @@ -300,6 +302,12 @@ internal class HashTreapMap<@Treapable K, V>(
}
return result!!
}

override fun forEachEntry(action: (Map.Entry<K, V>) -> Unit) {
left?.forEachEntry(action)
forEachPair { (k, v) -> action(MapEntry(k, v)) }
right?.forEachEntry(action)
}
}

internal interface KeyValuePairList<K, V> {
Expand Down
2 changes: 2 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ internal class HashTreapSet<@Treapable E>(

override fun shallowGetSingleElement(): E? = element.takeIf { next == null }

override fun arbitraryOrNull(): E? = element

override fun <R : Any> shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R {
var result: R? = null
forEachNodeElement {
Expand Down
8 changes: 8 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ internal class SortedTreapMap<@Treapable K, V>(
this as? SortedTreapMap<K, V>
?: (this as? PersistentMap.Builder<K, V>)?.build() as? SortedTreapMap<K, V>

override fun arbitraryOrNull(): Map.Entry<K, V>? = MapEntry(key, value)

override fun getShallowMerger(merger: (K, V?, V?) -> V?): (SortedTreapMap<K, V>?, SortedTreapMap<K, V>?) -> SortedTreapMap<K, V>? = { t1, t2 ->
val k = (t1 ?: t2)!!.key
val v1 = t1?.value
Expand Down Expand Up @@ -141,4 +143,10 @@ internal class SortedTreapMap<@Treapable K, V>(
fun lastEntry(): Map.Entry<K, V>? = right?.lastEntry() ?: this.asEntry()

override fun <R : Any> shallowMapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R = map(key, value)

override fun forEachEntry(action: (Map.Entry<K, V>) -> Unit) {
left?.forEachEntry(action)
action(this.asEntry())
right?.forEachEntry(action)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ internal class SortedTreapSet<@Treapable E>(
override fun shallowRemoveAll(predicate: (E) -> Boolean): SortedTreapSet<E>? = this.takeIf { !predicate(treapKey) }
override fun shallowComputeHashCode(): Int = treapKey.hashCode()
override fun shallowGetSingleElement(): E = treapKey
override fun arbitraryOrNull(): E? = treapKey
override fun shallowForEach(action: (element: E) -> Unit): Unit { action(treapKey) }
override fun <R : Any> shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R = map(treapKey)
}
7 changes: 7 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/TreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ public sealed interface TreapMap<K, V> : PersistentMap<K, V> {
@Suppress("Treapability")
override fun builder(): Builder<K, @UnsafeVariance V> = TreapMapBuilder(this)

/**
Returns an arbitrary entry from the map, or null if the map is empty.
*/
public fun arbitraryOrNull(): Map.Entry<K, V>?

public fun forEachEntry(action: (Map.Entry<K, V>) -> Unit): Unit

public fun merge(
m: Map<K, V>,
merger: (K, V?, V?) -> V?
Expand Down
5 changes: 5 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/TreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ public sealed interface TreapSet<out T> : PersistentSet<T> {
*/
public fun singleOrNull(): T?

/**
Returns an arbitrary element from the set, or null if the set is empty.
*/
public fun arbitraryOrNull(): T?

/**
If this set contains an element that compares equal to the specified [element], returns that element instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import kotlinx.serialization.DeserializationStrategy
/** Tests for [HashTreapMap]. */
class HashTreapMapTest: TreapMapTest() {
override fun makeKey(value: Int, code: Int) = HashTestKey(value, code)
override fun makeMap(): MutableMap<TestKey?, Any?> = treapMapOf<TestKey?, Any?>().builder()
override fun makeMap(): TreapMap.Builder<TestKey?, Any?> = treapMapOf<TestKey?, Any?>().builder()
override fun makeBaseline(): MutableMap<TestKey?, Any?> = HashMap()
override fun makeMap(other: Map<TestKey?,Any?>): MutableMap<TestKey?, Any?> = makeMap().apply { putAll(other) }
override fun makeMap(other: Map<TestKey?,Any?>): TreapMap.Builder<TestKey?, Any?> = makeMap().apply { putAll(other) }
override fun makeBaseline(other: Map<TestKey?,Any?>): MutableMap<TestKey?, Any?> = HashMap(other)
override fun makeMapOfInts(): TreapMap<Int?, Int?> = treapMapOf<Int?, Int?>()
override fun makeMapOfInts(other: Map<Int?, Int?>) = makeMapOfInts().apply { putAll(other) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import java.util.TreeMap
class SortedTreapMapTest: TreapMapTest() {
override fun makeKey(value: Int, code: Int) = ComparableTestKey(value, code)
override val allowNullKeys = false
override fun makeMap(): MutableMap<TestKey?, Any?> = treapMapOf<ComparableTestKey, Any?>().builder() as MutableMap<TestKey?, Any?>
override fun makeMap(): TreapMap.Builder<TestKey?, Any?> = treapMapOf<ComparableTestKey, Any?>().builder() as TreapMap.Builder<TestKey?, Any?>
override fun makeBaseline(): MutableMap<TestKey?, Any?> = TreeMap()
override fun makeMap(other: Map<TestKey?,Any?>): MutableMap<TestKey?, Any?> = makeMap().apply { putAll(other) }
override fun makeMap(other: Map<TestKey?,Any?>): TreapMap.Builder<TestKey?, Any?> = makeMap().apply { putAll(other) }
override fun makeBaseline(other: Map<TestKey?,Any?>): MutableMap<TestKey?, Any?> = TreeMap(other)
override fun makeMapOfInts(): TreapMap<Int?, Int?> = treapMapOf<Int, Int?>() as TreapMap<Int?, Int?>
override fun makeMapOfInts(other: Map<Int?, Int?>) = makeMapOfInts().apply { putAll(other) }
Expand Down
45 changes: 40 additions & 5 deletions collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ abstract class TreapMapTest {
abstract fun makeKey(value: Int, code: Int = value.hashCode()): TestKey

open val allowNullKeys = true
abstract fun makeMap(): MutableMap<TestKey?, Any?>
abstract fun makeMap(): TreapMap.Builder<TestKey?, Any?>
abstract fun makeBaseline(): MutableMap<TestKey?, Any?>
abstract fun makeMap(other: Map<TestKey?,Any?>): MutableMap<TestKey?, Any?>
abstract fun makeMap(other: Map<TestKey?,Any?>): TreapMap.Builder<TestKey?, Any?>
abstract fun makeBaseline(other: Map<TestKey?,Any?>): MutableMap<TestKey?, Any?>

open fun assertOrderedIteration(expected: Iterator<*>, actual: Iterator<*>) { }

fun assertVeryEqual(expected: Map<*,*>, actual: Map<*,*>) {
fun assertVeryEqual(expected: Map<*,*>, actual: TreapMap<*,*>) {
assertEquals(expected, actual)
assertTrue(actual.equals(expected))
assertEquals(expected.hashCode(), actual.hashCode())
Expand All @@ -41,9 +41,27 @@ abstract class TreapMapTest {
val actualValues = actual.values
assertEquals(expectedValues.size, actualValues.size)
assertOrderedIteration(expectedValues.iterator(), actualValues.iterator())

val actualForEachEntries = mutableListOf<Map.Entry<*, *>>()
actual.forEach {
actualForEachEntries += it
}
assertOrderedIteration(expected.entries.iterator(), actualForEachEntries.iterator())

val actualForEachEntryEntries = mutableListOf<Map.Entry<*, *>>()
actual.forEachEntry {
actualForEachEntryEntries += it
}
assertOrderedIteration(expected.entries.iterator(), actualForEachEntryEntries.iterator())

assertEquals(actualForEachEntries, actualForEachEntryEntries)
}

fun assertVeryEqual(expected: Map<*, *>, actual: TreapMap.Builder<*, *>) {
assertVeryEqual(expected, actual.build())
}

fun <TResult> assertEqualMutation(baseline: MutableMap<TestKey?, Any?>, map: MutableMap<TestKey?,Any?>, action: MutableMap<TestKey?,Any?>.() -> TResult) {
fun <TResult> assertEqualMutation(baseline: MutableMap<TestKey?, Any?>, map: TreapMap.Builder<TestKey?,Any?>, action: MutableMap<TestKey?,Any?>.() -> TResult) {
assertEqualResult(baseline, map, action)
assertVeryEqual(baseline, map)
}
Expand Down Expand Up @@ -350,7 +368,7 @@ abstract class TreapMapTest {
@Suppress("UNCHECKED_CAST")
val db = Json.decodeFromString(getBaseDeserializer()!!, bs) as Map<TestKey?, Any?>
@Suppress("UNCHECKED_CAST")
val dm = Json.decodeFromString(getDeserializer()!!, ms) as Map<TestKey?, Any?>
val dm = Json.decodeFromString(getDeserializer()!!, ms) as TreapMap<TestKey?, Any?>

assertVeryEqual(db, dm)
}
Expand Down Expand Up @@ -432,4 +450,21 @@ abstract class TreapMapTest {
testMapOf(1 to 2, 2 to 3).zip(testMapOf(1 to 3, 2 to 4)).toSet()
)
}

@Test
fun arbitraryOrNull() {
val m = makeMap()
assertNull(m.build().arbitraryOrNull())

m[makeKey(1, 1)] = 1
assertEquals(1, m.build().arbitraryOrNull()!!.value)

m[makeKey(2, 1)] = 2
assertTrue(m.build().arbitraryOrNull()!!.value in 1..2)

for (it in 3..100) {
m[makeKey(it)] = it
}
assertTrue(m.build().arbitraryOrNull()!!.value in 1..100)
}
}
21 changes: 19 additions & 2 deletions collect/src/test/kotlin/com/certora/collect/TreapSetTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ abstract class TreapSetTest {
abstract fun makeKey(value: Int, code: Int = value.hashCode()): TestKey
open val nullKeysAllowed: Boolean get() = true

fun makeSet(): MutableSet<TestKey?> = treapSetOf<TestKey?>().builder()
fun makeSet(): TreapSet.Builder<TestKey?> = treapSetOf<TestKey?>().builder()
abstract fun makeBaseline(): MutableSet<TestKey?>

fun makeSet(other: Collection<TestKey?>): MutableSet<TestKey?> = makeSet().also { it += other }
fun makeSet(other: Collection<TestKey?>): TreapSet.Builder<TestKey?> = makeSet().also { it += other }
fun makeBaseline(other: Collection<TestKey?>): MutableSet<TestKey?> = makeSet().also { it += other }

open fun assertOrderedIteration(expected: Iterator<*>, actual: Iterator<*>) {}
Expand Down Expand Up @@ -411,5 +411,22 @@ abstract class TreapSetTest {
assertVeryEqual(s, rs)
assertVeryEqual(b, rs)
}

@Test
fun arbitraryOrNull() {
val s = makeSet()
assertNull(s.build().arbitraryOrNull())

s += makeKey(1, 1)
assertEquals(makeKey(1, 1), s.build().arbitraryOrNull())

s += makeKey(2, 1)
assertTrue(s.build().arbitraryOrNull() in (1..2).map { makeKey(it) })

for (it in 3..100) {
s += makeKey(it)
}
assertTrue(s.build().arbitraryOrNull() in (1..100).map { makeKey(it) })
}
}

0 comments on commit 889dfe2

Please sign in to comment.