Skip to content

Commit

Permalink
TreapMap.keys as TreapSet
Browse files Browse the repository at this point in the history
  • Loading branch information
ericeil committed Dec 21, 2024
1 parent 38a54f5 commit 14c54f9
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 77 deletions.
15 changes: 3 additions & 12 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@ import kotlinx.collections.immutable.ImmutableSet
Base class for TreapMap implementations. Provides the Map operations; derived classes deal with type-specific
behavior such as hash collisions. See [Treap] for an overview of all of this.
*/
internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>>(
left: S?,
right: S?
) : TreapMap<K, V>, Treap<K, S>(left, right) {

internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>>
: TreapMap<K, V>, Treap<K, S>() {
/**
Derived classes override to create an apropriate node containing the given entry.
*/
Expand Down Expand Up @@ -139,13 +136,7 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
override fun iterator() = entrySequence().iterator()
}

override val keys: ImmutableSet<K>
get() = object: AbstractSet<K>(), ImmutableSet<K> {
override val size get() = this@AbstractTreapMap.size
override fun isEmpty() = this@AbstractTreapMap.isEmpty()
override operator fun contains(element: K) = containsKey(element)
override operator fun iterator() = entrySequence().map { it.key }.iterator()
}
abstract override val keys: TreapSet<K>

override val values: ImmutableCollection<V>
get() = object: AbstractCollection<V>(), ImmutableCollection<V> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ import com.certora.forkjoin.*
Base class for TreapSet implementations. Provides the Set operations; derived classes deal with type-specific
behavior such as hash collisions. See `Treap` for an overview of all of this.
*/
internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>(
left: S?,
right: S?
) : TreapSet<E>, Treap<E, S>(left, right) {
internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>> : TreapSet<E>, Treap<E, S>() {
/**
Derived classes override to create an apropriate node containing the given element.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K
m.asSequence().map { MapEntry(it.key, null to it.value) }

override val entries: ImmutableSet<Map.Entry<K, V>> get() = persistentSetOf<Map.Entry<K, V>>()
override val keys: ImmutableSet<K> get() = persistentSetOf<K>()
override val keys: TreapSet<K> get() = treapSetOf<K>()
override val values: ImmutableCollection<V> get() = persistentSetOf<V>()

@Suppress("Treapability", "UNCHECKED_CAST")
Expand Down
26 changes: 22 additions & 4 deletions collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ internal class HashTreapMap<@Treapable K, V>(
override val key: K,
override val value: V,
override val next: KeyValuePairList.More<K, V>? = null,
left: HashTreapMap<K, V>? = null,
right: HashTreapMap<K, V>? = null
) : AbstractTreapMap<K, V, HashTreapMap<K, V>>(left, right), TreapKey.Hashed<K>, KeyValuePairList<K, V> {
override val left: HashTreapMap<K, V>? = null,
override val right: HashTreapMap<K, V>? = null
) : AbstractTreapMap<K, V, HashTreapMap<K, V>>(), TreapKey.Hashed<K>, KeyValuePairList<K, V> {

override fun hashCode() = computeHashCode()

Expand Down Expand Up @@ -350,12 +350,30 @@ internal class HashTreapMap<@Treapable K, V>(
forEachPair { (k, v) -> action(MapEntry(k, v)) }
right?.forEachEntry(action)
}

override val keys get() = KeySet(this)

class KeySet<@Treapable K>(
private val map: HashTreapMap<K, *>
) : AbstractHashTreapSet<K>() {
override fun hashCode() = super.hashCode()

override val element get() = map.key
override val left get() = map.left?.keys
override val right get() = map.right?.keys
override val next get() = map.next?.let { More(it) }

class More<K>(val mapMore: KeyValuePairList.More<K, *>) : ElementList<K>, java.io.Serializable {
override val element get() = mapMore.key
override val next get() = mapMore.next?.let { More(it) }
}
}
}

internal interface KeyValuePairList<K, V> {
abstract val key: K
abstract val value: V
abstract val next: More<K, V>?
abstract val next: KeyValuePairList<K, V>?
operator fun component1() = key
operator fun component2() = value

Expand Down
47 changes: 25 additions & 22 deletions collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,8 @@ package com.certora.collect
This is just a simple linked list, so operations on it are either O(N) or O(N^2), but collisions are assumed to be
rare enough that these lists will be very small - usually just one element.
*/
internal class HashTreapSet<@Treapable E>(
override val element: E,
override val next: ElementList.More<E>? = null,
left: HashTreapSet<E>? = null,
right: HashTreapSet<E>? = null
) : AbstractTreapSet<E, HashTreapSet<E>>(left, right), TreapKey.Hashed<E>, ElementList<E> {
internal abstract class AbstractHashTreapSet<@Treapable E>
: AbstractTreapSet<E, AbstractHashTreapSet<E>>(), TreapKey.Hashed<E>, ElementList<E> {

override fun hashCode(): Int = computeHashCode()

Expand Down Expand Up @@ -45,15 +41,15 @@ internal class HashTreapSet<@Treapable E>(
return count
}

override fun copyWith(left: HashTreapSet<E>?, right: HashTreapSet<E>?): HashTreapSet<E> =
override fun copyWith(left: AbstractHashTreapSet<E>?, right: AbstractHashTreapSet<E>?) =
HashTreapSet(element, next, left, right)

fun withElement(element: E) = when {
this.shallowContains(element) -> this
else -> HashTreapSet(this.element, ElementList.More(element, this.next), this.left, this.right)
}

override fun shallowEquals(that: HashTreapSet<E>): Boolean {
override fun shallowEquals(that: AbstractHashTreapSet<E>): Boolean {
forEachNodeElement {
if (!that.shallowContains(it)) {
return false
Expand Down Expand Up @@ -84,7 +80,7 @@ internal class HashTreapSet<@Treapable E>(
return false
}

override fun shallowContainsAll(elements: HashTreapSet<E>): Boolean {
override fun shallowContainsAll(elements: AbstractHashTreapSet<E>): Boolean {
elements.forEachNodeElement {
if (!this.shallowContains(it)) {
return false
Expand All @@ -93,7 +89,7 @@ internal class HashTreapSet<@Treapable E>(
return true
}

override fun shallowContainsAny(elements: HashTreapSet<E>): Boolean {
override fun shallowContainsAny(elements: AbstractHashTreapSet<E>): Boolean {
elements.forEachNodeElement {
if (this.shallowContains(it)) {
return true
Expand All @@ -102,21 +98,21 @@ internal class HashTreapSet<@Treapable E>(
return false
}

override fun shallowAdd(that: HashTreapSet<E>): HashTreapSet<E> {
override fun shallowAdd(that: AbstractHashTreapSet<E>): AbstractHashTreapSet<E> {
// add is only called with a single element
check (that.next == null) { "add with multiple elements?" }
return this.withElement(that.element)
}

override fun shallowUnion(that: HashTreapSet<E>): HashTreapSet<E> {
override fun shallowUnion(that: AbstractHashTreapSet<E>): AbstractHashTreapSet<E> {
var result = this
that.forEachNodeElement {
result = result.withElement(it)
}
return result
}

override fun shallowDifference(that: HashTreapSet<E>): HashTreapSet<E>? {
override fun shallowDifference(that: AbstractHashTreapSet<E>): AbstractHashTreapSet<E>? {
// Fast path for the most common case
if (this.next == null) {
if (that.shallowContains(this.element)) {
Expand All @@ -126,7 +122,7 @@ internal class HashTreapSet<@Treapable E>(
}
}

var result: HashTreapSet<E>? = null
var result: AbstractHashTreapSet<E>? = null
var changed = false
this.forEachNodeElement {
if (!that.shallowContains(it)) {
Expand All @@ -143,7 +139,7 @@ internal class HashTreapSet<@Treapable E>(
}
}

override fun shallowIntersect(that: HashTreapSet<E>): HashTreapSet<E>? {
override fun shallowIntersect(that: AbstractHashTreapSet<E>): AbstractHashTreapSet<E>? {
// Fast path for the most common case
if (this.next == null) {
if (that.shallowContains(this.element)) {
Expand All @@ -153,7 +149,7 @@ internal class HashTreapSet<@Treapable E>(
}
}

var result: HashTreapSet<E>? = null
var result: AbstractHashTreapSet<E>? = null
var changed = false
this.forEachNodeElement {
if (that.shallowContains(it)) {
Expand All @@ -170,7 +166,7 @@ internal class HashTreapSet<@Treapable E>(
}
}

override fun shallowRemove(element: E): HashTreapSet<E>? {
override fun shallowRemove(element: E): AbstractHashTreapSet<E>? {
// Fast path for the most common case
if (this.next == null) {
if (this.element == element) {
Expand All @@ -180,7 +176,7 @@ internal class HashTreapSet<@Treapable E>(
}
}

var result: HashTreapSet<E>? = null
var result: AbstractHashTreapSet<E>? = null
var changed = false
this.forEachNodeElement {
if (it != element) {
Expand All @@ -197,8 +193,8 @@ internal class HashTreapSet<@Treapable E>(
}
}

override fun shallowRemoveAll(predicate: (E) -> Boolean): HashTreapSet<E>? {
var result: HashTreapSet<E>? = null
override fun shallowRemoveAll(predicate: (E) -> Boolean): AbstractHashTreapSet<E>? {
var result: AbstractHashTreapSet<E>? = null
var removed = false
this.forEachNodeElement {
if (predicate(it)) {
Expand Down Expand Up @@ -253,10 +249,17 @@ internal class HashTreapSet<@Treapable E>(

internal interface ElementList<E> {
val element: E
val next: More<E>?
val next: ElementList<E>?

class More<E>(
override val element: E,
override val next: More<E>?
override val next: ElementList<E>?
) : ElementList<E>, java.io.Serializable
}

internal class HashTreapSet<@Treapable E>(
override val element: E,
override val next: ElementList<E>? = null,
override val left: AbstractHashTreapSet<E>? = null,
override val right: AbstractHashTreapSet<E>? = null
) : AbstractHashTreapSet<E>()
20 changes: 15 additions & 5 deletions collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ import kotlinx.collections.immutable.PersistentMap
internal class SortedTreapMap<@Treapable K, V>(
val key: K,
val value: V,
left: SortedTreapMap<K, V>? = null,
right: SortedTreapMap<K, V>? = null
) : AbstractTreapMap<K, V, SortedTreapMap<K, V>>(left, right), TreapKey.Sorted<K> {

init { check(key is Comparable<*>?) { "SortedTreapMap keys must be Comparable" } }
override val left: SortedTreapMap<K, V>? = null,
override val right: SortedTreapMap<K, V>? = null
) : AbstractTreapMap<K, V, SortedTreapMap<K, V>>(), TreapKey.Sorted<K> {

override fun hashCode() = computeHashCode()

Expand Down Expand Up @@ -163,4 +161,16 @@ internal class SortedTreapMap<@Treapable K, V>(
action(this.asEntry())
right?.forEachEntry(action)
}

override val keys get() = KeySet(this)

class KeySet<@Treapable K>(
private val map: SortedTreapMap<K, *>
) : AbstractSortedTreapSet<K>() {
override fun hashCode() = super.hashCode()

override val treapKey get() = map.key
override val left get() = map.left?.keys
override val right get() = map.right?.keys
}
}
46 changes: 24 additions & 22 deletions collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,8 @@ import kotlinx.collections.immutable.PersistentSet
A TreapSet specific to Comparable elements. Iterates in the order defined by the objects. We store one element per
Treap node, with the element itself as the Treap key.
*/
internal class SortedTreapSet<@Treapable E>(
override val treapKey: E,
left: SortedTreapSet<E>? = null,
right: SortedTreapSet<E>? = null
) : AbstractTreapSet<E, SortedTreapSet<E>>(left, right), TreapKey.Sorted<E> {

init { check(treapKey is Comparable<*>?) { "SortedTreapSet elements must be Comparable" } }
internal abstract class AbstractSortedTreapSet<@Treapable E>
: AbstractTreapSet<E, AbstractSortedTreapSet<E>>(), TreapKey.Sorted<E> {

override fun hashCode(): Int = computeHashCode()

Expand All @@ -31,29 +26,36 @@ internal class SortedTreapSet<@Treapable E>(
override val self get() = this
override fun iterator(): Iterator<E> = this.asTreapSequence().map { it.treapKey }.iterator()

override fun shallowEquals(that: SortedTreapSet<E>): Boolean = this.compareKeyTo(that) == 0
override fun shallowEquals(that: AbstractSortedTreapSet<E>): Boolean = this.compareKeyTo(that) == 0
override val shallowSize: Int get() = 1

override fun copyWith(left: SortedTreapSet<E>?, right: SortedTreapSet<E>?): SortedTreapSet<E> = SortedTreapSet(treapKey, left, right)
override fun copyWith(left: AbstractSortedTreapSet<E>?, right: AbstractSortedTreapSet<E>?) =
SortedTreapSet(treapKey, left, right)

// Since these are only called for Treap nodes with the same key, and each of our nodes stores a single element,
// these are trivial.
override fun shallowContains(element: E) = true
override fun shallowContainsAll(elements: SortedTreapSet<E>) = true
override fun shallowContainsAny(elements: SortedTreapSet<E>) = true
override fun shallowContainsAll(elements: AbstractSortedTreapSet<E>) = true
override fun shallowContainsAny(elements: AbstractSortedTreapSet<E>) = true
override fun shallowFindEqual(element: E) = treapKey.takeIf { it == element }
override fun shallowAdd(that: SortedTreapSet<E>) = this
override fun shallowUnion(that: SortedTreapSet<E>) = this
override fun shallowDifference(that: SortedTreapSet<E>) = null
override fun shallowIntersect(that: SortedTreapSet<E>) = this
override fun shallowRemove(element: E): SortedTreapSet<E>? = null
override fun shallowRemoveAll(predicate: (E) -> Boolean): SortedTreapSet<E>? = this.takeIf { !predicate(treapKey) }
override fun shallowComputeHashCode(): Int = treapKey.hashCode()
override fun shallowGetSingleElement(): E = treapKey
override fun arbitraryOrNull(): E? = treapKey
override fun shallowForEach(action: (element: E) -> Unit): Unit { action(treapKey) }
override fun <R : Any> shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R = map(treapKey)
override fun shallowAdd(that: AbstractSortedTreapSet<E>) = this
override fun shallowUnion(that: AbstractSortedTreapSet<E>) = this
override fun shallowDifference(that: AbstractSortedTreapSet<E>) = null
override fun shallowIntersect(that: AbstractSortedTreapSet<E>) = this
override fun shallowRemove(element: E) = null
override fun shallowRemoveAll(predicate: (E) -> Boolean) = this.takeIf { !predicate(treapKey) }
override fun shallowComputeHashCode() = treapKey.hashCode()
override fun shallowGetSingleElement() = treapKey
override fun arbitraryOrNull() = treapKey
override fun shallowForEach(action: (element: E) -> Unit) { action(treapKey) }
override fun <R : Any> shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R) = map(treapKey)

override fun containsAny(predicate: (E) -> Boolean): Boolean =
predicate(treapKey) || left?.containsAny(predicate) == true || right?.containsAny(predicate) == true
}

internal class SortedTreapSet<@Treapable E>(
override val treapKey: E,
override val left: AbstractSortedTreapSet<E>? = null,
override val right: AbstractSortedTreapSet<E>? = null
) : AbstractSortedTreapSet<E>()
12 changes: 5 additions & 7 deletions collect/src/main/kotlin/com/certora/collect/Treap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,9 @@ package com.certora.collect
- Treaps impose special requirements on keys if they are serialized. See [Treapable].
*/
internal abstract class Treap<@Treapable T, S : Treap<T, S>>(
@JvmField val left: S?,
@JvmField val right: S?
) : TreapKey<T>, java.io.Serializable {

internal abstract class Treap<@Treapable T, S : Treap<T, S>> : TreapKey<T>, java.io.Serializable {
abstract val left: S?
abstract val right: S?
abstract val self: S

/**
Expand Down Expand Up @@ -135,11 +133,11 @@ internal abstract class Treap<@Treapable T, S : Treap<T, S>>(
case any hash functions have changed since this Treap was serialized.
*/
protected fun readResolve(): Any? {
if (left != null) {
left?.let { left ->
check(left.compareKeyTo(this) < 0) { "Treap key comparison logic changed: ${left.treapKey} >= ${this.treapKey}" }
check(left.comparePriorityTo(this) < 0) { "Treap key priority hash logic changed: ${left.treapKey} >= ${this.treapKey} "}
}
if (right != null) {
right?.let { right ->
check(right.compareKeyTo(this) > 0) { "Treap key comparison logic changed: ${right.treapKey} <= ${this.treapKey}" }
check(right.comparePriorityTo(this) < 0) { "Treap key priority hash logic changed: ${right.treapKey} >= ${this.treapKey} "}
}
Expand Down
2 changes: 2 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/TreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ public sealed interface TreapMap<K, V> : PersistentMap<K, V> {
override fun putAll(m: Map<out K, @UnsafeVariance V>): TreapMap<K, V>
override fun clear(): TreapMap<K, V>

override val keys: TreapSet<K>

/**
A [PersistentMap.Builder] that produces a [TreapMap].
*/
Expand Down

0 comments on commit 14c54f9

Please sign in to comment.