diff --git a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt index b0dcf0b..2cbf207 100644 --- a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt @@ -8,11 +8,8 @@ import kotlinx.collections.immutable.ImmutableSet Base class for TreapMap implementations. Provides the Map operations; derived classes deal with type-specific behavior such as hash collisions. See [Treap] for an overview of all of this. */ -internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractTreapMap>( - left: S?, - right: S? -) : TreapMap, Treap(left, right) { - +internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractTreapMap> +: TreapMap, Treap() { /** Derived classes override to create an apropriate node containing the given entry. */ @@ -139,13 +136,7 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT override fun iterator() = entrySequence().iterator() } - override val keys: ImmutableSet - get() = object: AbstractSet(), ImmutableSet { - override val size get() = this@AbstractTreapMap.size - override fun isEmpty() = this@AbstractTreapMap.isEmpty() - override operator fun contains(element: K) = containsKey(element) - override operator fun iterator() = entrySequence().map { it.key }.iterator() - } + abstract override val keys: TreapSet override val values: ImmutableCollection get() = object: AbstractCollection(), ImmutableCollection { diff --git a/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt index 4900c4e..2895c15 100644 --- a/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt @@ -6,10 +6,7 @@ import com.certora.forkjoin.* Base class for TreapSet implementations. Provides the Set operations; derived classes deal with type-specific behavior such as hash collisions. See `Treap` for an overview of all of this. */ -internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet>( - left: S?, - right: S? -) : TreapSet, Treap(left, right) { +internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet> : TreapSet, Treap() { /** Derived classes override to create an apropriate node containing the given element. */ diff --git a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt index a99cc4a..d877184 100644 --- a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt @@ -66,7 +66,7 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap> get() = persistentSetOf>() - override val keys: ImmutableSet get() = persistentSetOf() + override val keys: TreapSet get() = treapSetOf() override val values: ImmutableCollection get() = persistentSetOf() @Suppress("Treapability", "UNCHECKED_CAST") diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt index cea3982..37a9052 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt @@ -15,9 +15,9 @@ internal class HashTreapMap<@Treapable K, V>( override val key: K, override val value: V, override val next: KeyValuePairList.More? = null, - left: HashTreapMap? = null, - right: HashTreapMap? = null -) : AbstractTreapMap>(left, right), TreapKey.Hashed, KeyValuePairList { + override val left: HashTreapMap? = null, + override val right: HashTreapMap? = null +) : AbstractTreapMap>(), TreapKey.Hashed, KeyValuePairList { override fun hashCode() = computeHashCode() @@ -350,12 +350,30 @@ internal class HashTreapMap<@Treapable K, V>( forEachPair { (k, v) -> action(MapEntry(k, v)) } right?.forEachEntry(action) } + + override val keys get() = KeySet(this) + + class KeySet<@Treapable K>( + private val map: HashTreapMap + ) : AbstractHashTreapSet() { + override fun hashCode() = super.hashCode() + + override val element get() = map.key + override val left get() = map.left?.keys + override val right get() = map.right?.keys + override val next get() = map.next?.let { More(it) } + + class More(val mapMore: KeyValuePairList.More) : ElementList, java.io.Serializable { + override val element get() = mapMore.key + override val next get() = mapMore.next?.let { More(it) } + } + } } internal interface KeyValuePairList { abstract val key: K abstract val value: V - abstract val next: More? + abstract val next: KeyValuePairList? operator fun component1() = key operator fun component2() = value diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt index 987e503..45a1c08 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt @@ -8,12 +8,8 @@ package com.certora.collect This is just a simple linked list, so operations on it are either O(N) or O(N^2), but collisions are assumed to be rare enough that these lists will be very small - usually just one element. */ -internal class HashTreapSet<@Treapable E>( - override val element: E, - override val next: ElementList.More? = null, - left: HashTreapSet? = null, - right: HashTreapSet? = null -) : AbstractTreapSet>(left, right), TreapKey.Hashed, ElementList { +internal abstract class AbstractHashTreapSet<@Treapable E> + : AbstractTreapSet>(), TreapKey.Hashed, ElementList { override fun hashCode(): Int = computeHashCode() @@ -45,7 +41,7 @@ internal class HashTreapSet<@Treapable E>( return count } - override fun copyWith(left: HashTreapSet?, right: HashTreapSet?): HashTreapSet = + override fun copyWith(left: AbstractHashTreapSet?, right: AbstractHashTreapSet?) = HashTreapSet(element, next, left, right) fun withElement(element: E) = when { @@ -53,7 +49,7 @@ internal class HashTreapSet<@Treapable E>( else -> HashTreapSet(this.element, ElementList.More(element, this.next), this.left, this.right) } - override fun shallowEquals(that: HashTreapSet): Boolean { + override fun shallowEquals(that: AbstractHashTreapSet): Boolean { forEachNodeElement { if (!that.shallowContains(it)) { return false @@ -84,7 +80,7 @@ internal class HashTreapSet<@Treapable E>( return false } - override fun shallowContainsAll(elements: HashTreapSet): Boolean { + override fun shallowContainsAll(elements: AbstractHashTreapSet): Boolean { elements.forEachNodeElement { if (!this.shallowContains(it)) { return false @@ -93,7 +89,7 @@ internal class HashTreapSet<@Treapable E>( return true } - override fun shallowContainsAny(elements: HashTreapSet): Boolean { + override fun shallowContainsAny(elements: AbstractHashTreapSet): Boolean { elements.forEachNodeElement { if (this.shallowContains(it)) { return true @@ -102,13 +98,13 @@ internal class HashTreapSet<@Treapable E>( return false } - override fun shallowAdd(that: HashTreapSet): HashTreapSet { + override fun shallowAdd(that: AbstractHashTreapSet): AbstractHashTreapSet { // add is only called with a single element check (that.next == null) { "add with multiple elements?" } return this.withElement(that.element) } - override fun shallowUnion(that: HashTreapSet): HashTreapSet { + override fun shallowUnion(that: AbstractHashTreapSet): AbstractHashTreapSet { var result = this that.forEachNodeElement { result = result.withElement(it) @@ -116,7 +112,7 @@ internal class HashTreapSet<@Treapable E>( return result } - override fun shallowDifference(that: HashTreapSet): HashTreapSet? { + override fun shallowDifference(that: AbstractHashTreapSet): AbstractHashTreapSet? { // Fast path for the most common case if (this.next == null) { if (that.shallowContains(this.element)) { @@ -126,7 +122,7 @@ internal class HashTreapSet<@Treapable E>( } } - var result: HashTreapSet? = null + var result: AbstractHashTreapSet? = null var changed = false this.forEachNodeElement { if (!that.shallowContains(it)) { @@ -143,7 +139,7 @@ internal class HashTreapSet<@Treapable E>( } } - override fun shallowIntersect(that: HashTreapSet): HashTreapSet? { + override fun shallowIntersect(that: AbstractHashTreapSet): AbstractHashTreapSet? { // Fast path for the most common case if (this.next == null) { if (that.shallowContains(this.element)) { @@ -153,7 +149,7 @@ internal class HashTreapSet<@Treapable E>( } } - var result: HashTreapSet? = null + var result: AbstractHashTreapSet? = null var changed = false this.forEachNodeElement { if (that.shallowContains(it)) { @@ -170,7 +166,7 @@ internal class HashTreapSet<@Treapable E>( } } - override fun shallowRemove(element: E): HashTreapSet? { + override fun shallowRemove(element: E): AbstractHashTreapSet? { // Fast path for the most common case if (this.next == null) { if (this.element == element) { @@ -180,7 +176,7 @@ internal class HashTreapSet<@Treapable E>( } } - var result: HashTreapSet? = null + var result: AbstractHashTreapSet? = null var changed = false this.forEachNodeElement { if (it != element) { @@ -197,8 +193,8 @@ internal class HashTreapSet<@Treapable E>( } } - override fun shallowRemoveAll(predicate: (E) -> Boolean): HashTreapSet? { - var result: HashTreapSet? = null + override fun shallowRemoveAll(predicate: (E) -> Boolean): AbstractHashTreapSet? { + var result: AbstractHashTreapSet? = null var removed = false this.forEachNodeElement { if (predicate(it)) { @@ -253,10 +249,17 @@ internal class HashTreapSet<@Treapable E>( internal interface ElementList { val element: E - val next: More? + val next: ElementList? class More( override val element: E, - override val next: More? + override val next: ElementList? ) : ElementList, java.io.Serializable } + +internal class HashTreapSet<@Treapable E>( + override val element: E, + override val next: ElementList? = null, + override val left: AbstractHashTreapSet? = null, + override val right: AbstractHashTreapSet? = null +) : AbstractHashTreapSet() diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt index 0c3617a..73362fc 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt @@ -9,11 +9,9 @@ import kotlinx.collections.immutable.PersistentMap internal class SortedTreapMap<@Treapable K, V>( val key: K, val value: V, - left: SortedTreapMap? = null, - right: SortedTreapMap? = null -) : AbstractTreapMap>(left, right), TreapKey.Sorted { - - init { check(key is Comparable<*>?) { "SortedTreapMap keys must be Comparable" } } + override val left: SortedTreapMap? = null, + override val right: SortedTreapMap? = null +) : AbstractTreapMap>(), TreapKey.Sorted { override fun hashCode() = computeHashCode() @@ -163,4 +161,16 @@ internal class SortedTreapMap<@Treapable K, V>( action(this.asEntry()) right?.forEachEntry(action) } + + override val keys get() = KeySet(this) + + class KeySet<@Treapable K>( + private val map: SortedTreapMap + ) : AbstractSortedTreapSet() { + override fun hashCode() = super.hashCode() + + override val treapKey get() = map.key + override val left get() = map.left?.keys + override val right get() = map.right?.keys + } } diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt index 5704f0a..518f049 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt @@ -6,13 +6,8 @@ import kotlinx.collections.immutable.PersistentSet A TreapSet specific to Comparable elements. Iterates in the order defined by the objects. We store one element per Treap node, with the element itself as the Treap key. */ -internal class SortedTreapSet<@Treapable E>( - override val treapKey: E, - left: SortedTreapSet? = null, - right: SortedTreapSet? = null -) : AbstractTreapSet>(left, right), TreapKey.Sorted { - - init { check(treapKey is Comparable<*>?) { "SortedTreapSet elements must be Comparable" } } +internal abstract class AbstractSortedTreapSet<@Treapable E> + : AbstractTreapSet>(), TreapKey.Sorted { override fun hashCode(): Int = computeHashCode() @@ -31,29 +26,36 @@ internal class SortedTreapSet<@Treapable E>( override val self get() = this override fun iterator(): Iterator = this.asTreapSequence().map { it.treapKey }.iterator() - override fun shallowEquals(that: SortedTreapSet): Boolean = this.compareKeyTo(that) == 0 + override fun shallowEquals(that: AbstractSortedTreapSet): Boolean = this.compareKeyTo(that) == 0 override val shallowSize: Int get() = 1 - override fun copyWith(left: SortedTreapSet?, right: SortedTreapSet?): SortedTreapSet = SortedTreapSet(treapKey, left, right) + override fun copyWith(left: AbstractSortedTreapSet?, right: AbstractSortedTreapSet?) = + SortedTreapSet(treapKey, left, right) // Since these are only called for Treap nodes with the same key, and each of our nodes stores a single element, // these are trivial. override fun shallowContains(element: E) = true - override fun shallowContainsAll(elements: SortedTreapSet) = true - override fun shallowContainsAny(elements: SortedTreapSet) = true + override fun shallowContainsAll(elements: AbstractSortedTreapSet) = true + override fun shallowContainsAny(elements: AbstractSortedTreapSet) = true override fun shallowFindEqual(element: E) = treapKey.takeIf { it == element } - override fun shallowAdd(that: SortedTreapSet) = this - override fun shallowUnion(that: SortedTreapSet) = this - override fun shallowDifference(that: SortedTreapSet) = null - override fun shallowIntersect(that: SortedTreapSet) = this - override fun shallowRemove(element: E): SortedTreapSet? = null - override fun shallowRemoveAll(predicate: (E) -> Boolean): SortedTreapSet? = this.takeIf { !predicate(treapKey) } - override fun shallowComputeHashCode(): Int = treapKey.hashCode() - override fun shallowGetSingleElement(): E = treapKey - override fun arbitraryOrNull(): E? = treapKey - override fun shallowForEach(action: (element: E) -> Unit): Unit { action(treapKey) } - override fun shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R = map(treapKey) + override fun shallowAdd(that: AbstractSortedTreapSet) = this + override fun shallowUnion(that: AbstractSortedTreapSet) = this + override fun shallowDifference(that: AbstractSortedTreapSet) = null + override fun shallowIntersect(that: AbstractSortedTreapSet) = this + override fun shallowRemove(element: E) = null + override fun shallowRemoveAll(predicate: (E) -> Boolean) = this.takeIf { !predicate(treapKey) } + override fun shallowComputeHashCode() = treapKey.hashCode() + override fun shallowGetSingleElement() = treapKey + override fun arbitraryOrNull() = treapKey + override fun shallowForEach(action: (element: E) -> Unit) { action(treapKey) } + override fun shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R) = map(treapKey) override fun containsAny(predicate: (E) -> Boolean): Boolean = predicate(treapKey) || left?.containsAny(predicate) == true || right?.containsAny(predicate) == true } + +internal class SortedTreapSet<@Treapable E>( + override val treapKey: E, + override val left: AbstractSortedTreapSet? = null, + override val right: AbstractSortedTreapSet? = null +) : AbstractSortedTreapSet() diff --git a/collect/src/main/kotlin/com/certora/collect/Treap.kt b/collect/src/main/kotlin/com/certora/collect/Treap.kt index 7a456ea..f10b171 100644 --- a/collect/src/main/kotlin/com/certora/collect/Treap.kt +++ b/collect/src/main/kotlin/com/certora/collect/Treap.kt @@ -65,11 +65,9 @@ package com.certora.collect - Treaps impose special requirements on keys if they are serialized. See [Treapable]. */ -internal abstract class Treap<@Treapable T, S : Treap>( - @JvmField val left: S?, - @JvmField val right: S? -) : TreapKey, java.io.Serializable { - +internal abstract class Treap<@Treapable T, S : Treap> : TreapKey, java.io.Serializable { + abstract val left: S? + abstract val right: S? abstract val self: S /** @@ -135,11 +133,11 @@ internal abstract class Treap<@Treapable T, S : Treap>( case any hash functions have changed since this Treap was serialized. */ protected fun readResolve(): Any? { - if (left != null) { + left?.let { left -> check(left.compareKeyTo(this) < 0) { "Treap key comparison logic changed: ${left.treapKey} >= ${this.treapKey}" } check(left.comparePriorityTo(this) < 0) { "Treap key priority hash logic changed: ${left.treapKey} >= ${this.treapKey} "} } - if (right != null) { + right?.let { right -> check(right.compareKeyTo(this) > 0) { "Treap key comparison logic changed: ${right.treapKey} <= ${this.treapKey}" } check(right.comparePriorityTo(this) < 0) { "Treap key priority hash logic changed: ${right.treapKey} >= ${this.treapKey} "} } diff --git a/collect/src/main/kotlin/com/certora/collect/TreapMap.kt b/collect/src/main/kotlin/com/certora/collect/TreapMap.kt index 7931468..d7b0d5f 100644 --- a/collect/src/main/kotlin/com/certora/collect/TreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/TreapMap.kt @@ -13,6 +13,8 @@ public sealed interface TreapMap : PersistentMap { override fun putAll(m: Map): TreapMap override fun clear(): TreapMap + override val keys: TreapSet + /** A [PersistentMap.Builder] that produces a [TreapMap]. */