diff --git a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt index 2cbf207..b0dcf0b 100644 --- a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt @@ -8,8 +8,11 @@ import kotlinx.collections.immutable.ImmutableSet Base class for TreapMap implementations. Provides the Map operations; derived classes deal with type-specific behavior such as hash collisions. See [Treap] for an overview of all of this. */ -internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractTreapMap> -: TreapMap, Treap() { +internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractTreapMap>( + left: S?, + right: S? +) : TreapMap, Treap(left, right) { + /** Derived classes override to create an apropriate node containing the given entry. */ @@ -136,7 +139,13 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT override fun iterator() = entrySequence().iterator() } - abstract override val keys: TreapSet + override val keys: ImmutableSet + get() = object: AbstractSet(), ImmutableSet { + override val size get() = this@AbstractTreapMap.size + override fun isEmpty() = this@AbstractTreapMap.isEmpty() + override operator fun contains(element: K) = containsKey(element) + override operator fun iterator() = entrySequence().map { it.key }.iterator() + } override val values: ImmutableCollection get() = object: AbstractCollection(), ImmutableCollection { diff --git a/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt index 2895c15..4900c4e 100644 --- a/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt @@ -6,7 +6,10 @@ import com.certora.forkjoin.* Base class for TreapSet implementations. Provides the Set operations; derived classes deal with type-specific behavior such as hash collisions. See `Treap` for an overview of all of this. */ -internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet> : TreapSet, Treap() { +internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet>( + left: S?, + right: S? +) : TreapSet, Treap(left, right) { /** Derived classes override to create an apropriate node containing the given element. */ diff --git a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt index d877184..a99cc4a 100644 --- a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt @@ -66,7 +66,7 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap> get() = persistentSetOf>() - override val keys: TreapSet get() = treapSetOf() + override val keys: ImmutableSet get() = persistentSetOf() override val values: ImmutableCollection get() = persistentSetOf() @Suppress("Treapability", "UNCHECKED_CAST") diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt index 37a9052..cea3982 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt @@ -15,9 +15,9 @@ internal class HashTreapMap<@Treapable K, V>( override val key: K, override val value: V, override val next: KeyValuePairList.More? = null, - override val left: HashTreapMap? = null, - override val right: HashTreapMap? = null -) : AbstractTreapMap>(), TreapKey.Hashed, KeyValuePairList { + left: HashTreapMap? = null, + right: HashTreapMap? = null +) : AbstractTreapMap>(left, right), TreapKey.Hashed, KeyValuePairList { override fun hashCode() = computeHashCode() @@ -350,30 +350,12 @@ internal class HashTreapMap<@Treapable K, V>( forEachPair { (k, v) -> action(MapEntry(k, v)) } right?.forEachEntry(action) } - - override val keys get() = KeySet(this) - - class KeySet<@Treapable K>( - private val map: HashTreapMap - ) : AbstractHashTreapSet() { - override fun hashCode() = super.hashCode() - - override val element get() = map.key - override val left get() = map.left?.keys - override val right get() = map.right?.keys - override val next get() = map.next?.let { More(it) } - - class More(val mapMore: KeyValuePairList.More) : ElementList, java.io.Serializable { - override val element get() = mapMore.key - override val next get() = mapMore.next?.let { More(it) } - } - } } internal interface KeyValuePairList { abstract val key: K abstract val value: V - abstract val next: KeyValuePairList? + abstract val next: More? operator fun component1() = key operator fun component2() = value diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt index 45a1c08..987e503 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt @@ -8,8 +8,12 @@ package com.certora.collect This is just a simple linked list, so operations on it are either O(N) or O(N^2), but collisions are assumed to be rare enough that these lists will be very small - usually just one element. */ -internal abstract class AbstractHashTreapSet<@Treapable E> - : AbstractTreapSet>(), TreapKey.Hashed, ElementList { +internal class HashTreapSet<@Treapable E>( + override val element: E, + override val next: ElementList.More? = null, + left: HashTreapSet? = null, + right: HashTreapSet? = null +) : AbstractTreapSet>(left, right), TreapKey.Hashed, ElementList { override fun hashCode(): Int = computeHashCode() @@ -41,7 +45,7 @@ internal abstract class AbstractHashTreapSet<@Treapable E> return count } - override fun copyWith(left: AbstractHashTreapSet?, right: AbstractHashTreapSet?) = + override fun copyWith(left: HashTreapSet?, right: HashTreapSet?): HashTreapSet = HashTreapSet(element, next, left, right) fun withElement(element: E) = when { @@ -49,7 +53,7 @@ internal abstract class AbstractHashTreapSet<@Treapable E> else -> HashTreapSet(this.element, ElementList.More(element, this.next), this.left, this.right) } - override fun shallowEquals(that: AbstractHashTreapSet): Boolean { + override fun shallowEquals(that: HashTreapSet): Boolean { forEachNodeElement { if (!that.shallowContains(it)) { return false @@ -80,7 +84,7 @@ internal abstract class AbstractHashTreapSet<@Treapable E> return false } - override fun shallowContainsAll(elements: AbstractHashTreapSet): Boolean { + override fun shallowContainsAll(elements: HashTreapSet): Boolean { elements.forEachNodeElement { if (!this.shallowContains(it)) { return false @@ -89,7 +93,7 @@ internal abstract class AbstractHashTreapSet<@Treapable E> return true } - override fun shallowContainsAny(elements: AbstractHashTreapSet): Boolean { + override fun shallowContainsAny(elements: HashTreapSet): Boolean { elements.forEachNodeElement { if (this.shallowContains(it)) { return true @@ -98,13 +102,13 @@ internal abstract class AbstractHashTreapSet<@Treapable E> return false } - override fun shallowAdd(that: AbstractHashTreapSet): AbstractHashTreapSet { + override fun shallowAdd(that: HashTreapSet): HashTreapSet { // add is only called with a single element check (that.next == null) { "add with multiple elements?" } return this.withElement(that.element) } - override fun shallowUnion(that: AbstractHashTreapSet): AbstractHashTreapSet { + override fun shallowUnion(that: HashTreapSet): HashTreapSet { var result = this that.forEachNodeElement { result = result.withElement(it) @@ -112,7 +116,7 @@ internal abstract class AbstractHashTreapSet<@Treapable E> return result } - override fun shallowDifference(that: AbstractHashTreapSet): AbstractHashTreapSet? { + override fun shallowDifference(that: HashTreapSet): HashTreapSet? { // Fast path for the most common case if (this.next == null) { if (that.shallowContains(this.element)) { @@ -122,7 +126,7 @@ internal abstract class AbstractHashTreapSet<@Treapable E> } } - var result: AbstractHashTreapSet? = null + var result: HashTreapSet? = null var changed = false this.forEachNodeElement { if (!that.shallowContains(it)) { @@ -139,7 +143,7 @@ internal abstract class AbstractHashTreapSet<@Treapable E> } } - override fun shallowIntersect(that: AbstractHashTreapSet): AbstractHashTreapSet? { + override fun shallowIntersect(that: HashTreapSet): HashTreapSet? { // Fast path for the most common case if (this.next == null) { if (that.shallowContains(this.element)) { @@ -149,7 +153,7 @@ internal abstract class AbstractHashTreapSet<@Treapable E> } } - var result: AbstractHashTreapSet? = null + var result: HashTreapSet? = null var changed = false this.forEachNodeElement { if (that.shallowContains(it)) { @@ -166,7 +170,7 @@ internal abstract class AbstractHashTreapSet<@Treapable E> } } - override fun shallowRemove(element: E): AbstractHashTreapSet? { + override fun shallowRemove(element: E): HashTreapSet? { // Fast path for the most common case if (this.next == null) { if (this.element == element) { @@ -176,7 +180,7 @@ internal abstract class AbstractHashTreapSet<@Treapable E> } } - var result: AbstractHashTreapSet? = null + var result: HashTreapSet? = null var changed = false this.forEachNodeElement { if (it != element) { @@ -193,8 +197,8 @@ internal abstract class AbstractHashTreapSet<@Treapable E> } } - override fun shallowRemoveAll(predicate: (E) -> Boolean): AbstractHashTreapSet? { - var result: AbstractHashTreapSet? = null + override fun shallowRemoveAll(predicate: (E) -> Boolean): HashTreapSet? { + var result: HashTreapSet? = null var removed = false this.forEachNodeElement { if (predicate(it)) { @@ -249,17 +253,10 @@ internal abstract class AbstractHashTreapSet<@Treapable E> internal interface ElementList { val element: E - val next: ElementList? + val next: More? class More( override val element: E, - override val next: ElementList? + override val next: More? ) : ElementList, java.io.Serializable } - -internal class HashTreapSet<@Treapable E>( - override val element: E, - override val next: ElementList? = null, - override val left: AbstractHashTreapSet? = null, - override val right: AbstractHashTreapSet? = null -) : AbstractHashTreapSet() diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt index 73362fc..0c3617a 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt @@ -9,9 +9,11 @@ import kotlinx.collections.immutable.PersistentMap internal class SortedTreapMap<@Treapable K, V>( val key: K, val value: V, - override val left: SortedTreapMap? = null, - override val right: SortedTreapMap? = null -) : AbstractTreapMap>(), TreapKey.Sorted { + left: SortedTreapMap? = null, + right: SortedTreapMap? = null +) : AbstractTreapMap>(left, right), TreapKey.Sorted { + + init { check(key is Comparable<*>?) { "SortedTreapMap keys must be Comparable" } } override fun hashCode() = computeHashCode() @@ -161,16 +163,4 @@ internal class SortedTreapMap<@Treapable K, V>( action(this.asEntry()) right?.forEachEntry(action) } - - override val keys get() = KeySet(this) - - class KeySet<@Treapable K>( - private val map: SortedTreapMap - ) : AbstractSortedTreapSet() { - override fun hashCode() = super.hashCode() - - override val treapKey get() = map.key - override val left get() = map.left?.keys - override val right get() = map.right?.keys - } } diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt index 518f049..5704f0a 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt @@ -6,8 +6,13 @@ import kotlinx.collections.immutable.PersistentSet A TreapSet specific to Comparable elements. Iterates in the order defined by the objects. We store one element per Treap node, with the element itself as the Treap key. */ -internal abstract class AbstractSortedTreapSet<@Treapable E> - : AbstractTreapSet>(), TreapKey.Sorted { +internal class SortedTreapSet<@Treapable E>( + override val treapKey: E, + left: SortedTreapSet? = null, + right: SortedTreapSet? = null +) : AbstractTreapSet>(left, right), TreapKey.Sorted { + + init { check(treapKey is Comparable<*>?) { "SortedTreapSet elements must be Comparable" } } override fun hashCode(): Int = computeHashCode() @@ -26,36 +31,29 @@ internal abstract class AbstractSortedTreapSet<@Treapable E> override val self get() = this override fun iterator(): Iterator = this.asTreapSequence().map { it.treapKey }.iterator() - override fun shallowEquals(that: AbstractSortedTreapSet): Boolean = this.compareKeyTo(that) == 0 + override fun shallowEquals(that: SortedTreapSet): Boolean = this.compareKeyTo(that) == 0 override val shallowSize: Int get() = 1 - override fun copyWith(left: AbstractSortedTreapSet?, right: AbstractSortedTreapSet?) = - SortedTreapSet(treapKey, left, right) + override fun copyWith(left: SortedTreapSet?, right: SortedTreapSet?): SortedTreapSet = SortedTreapSet(treapKey, left, right) // Since these are only called for Treap nodes with the same key, and each of our nodes stores a single element, // these are trivial. override fun shallowContains(element: E) = true - override fun shallowContainsAll(elements: AbstractSortedTreapSet) = true - override fun shallowContainsAny(elements: AbstractSortedTreapSet) = true + override fun shallowContainsAll(elements: SortedTreapSet) = true + override fun shallowContainsAny(elements: SortedTreapSet) = true override fun shallowFindEqual(element: E) = treapKey.takeIf { it == element } - override fun shallowAdd(that: AbstractSortedTreapSet) = this - override fun shallowUnion(that: AbstractSortedTreapSet) = this - override fun shallowDifference(that: AbstractSortedTreapSet) = null - override fun shallowIntersect(that: AbstractSortedTreapSet) = this - override fun shallowRemove(element: E) = null - override fun shallowRemoveAll(predicate: (E) -> Boolean) = this.takeIf { !predicate(treapKey) } - override fun shallowComputeHashCode() = treapKey.hashCode() - override fun shallowGetSingleElement() = treapKey - override fun arbitraryOrNull() = treapKey - override fun shallowForEach(action: (element: E) -> Unit) { action(treapKey) } - override fun shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R) = map(treapKey) + override fun shallowAdd(that: SortedTreapSet) = this + override fun shallowUnion(that: SortedTreapSet) = this + override fun shallowDifference(that: SortedTreapSet) = null + override fun shallowIntersect(that: SortedTreapSet) = this + override fun shallowRemove(element: E): SortedTreapSet? = null + override fun shallowRemoveAll(predicate: (E) -> Boolean): SortedTreapSet? = this.takeIf { !predicate(treapKey) } + override fun shallowComputeHashCode(): Int = treapKey.hashCode() + override fun shallowGetSingleElement(): E = treapKey + override fun arbitraryOrNull(): E? = treapKey + override fun shallowForEach(action: (element: E) -> Unit): Unit { action(treapKey) } + override fun shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R = map(treapKey) override fun containsAny(predicate: (E) -> Boolean): Boolean = predicate(treapKey) || left?.containsAny(predicate) == true || right?.containsAny(predicate) == true } - -internal class SortedTreapSet<@Treapable E>( - override val treapKey: E, - override val left: AbstractSortedTreapSet? = null, - override val right: AbstractSortedTreapSet? = null -) : AbstractSortedTreapSet() diff --git a/collect/src/main/kotlin/com/certora/collect/Treap.kt b/collect/src/main/kotlin/com/certora/collect/Treap.kt index f10b171..7a456ea 100644 --- a/collect/src/main/kotlin/com/certora/collect/Treap.kt +++ b/collect/src/main/kotlin/com/certora/collect/Treap.kt @@ -65,9 +65,11 @@ package com.certora.collect - Treaps impose special requirements on keys if they are serialized. See [Treapable]. */ -internal abstract class Treap<@Treapable T, S : Treap> : TreapKey, java.io.Serializable { - abstract val left: S? - abstract val right: S? +internal abstract class Treap<@Treapable T, S : Treap>( + @JvmField val left: S?, + @JvmField val right: S? +) : TreapKey, java.io.Serializable { + abstract val self: S /** @@ -133,11 +135,11 @@ internal abstract class Treap<@Treapable T, S : Treap> : TreapKey, java case any hash functions have changed since this Treap was serialized. */ protected fun readResolve(): Any? { - left?.let { left -> + if (left != null) { check(left.compareKeyTo(this) < 0) { "Treap key comparison logic changed: ${left.treapKey} >= ${this.treapKey}" } check(left.comparePriorityTo(this) < 0) { "Treap key priority hash logic changed: ${left.treapKey} >= ${this.treapKey} "} } - right?.let { right -> + if (right != null) { check(right.compareKeyTo(this) > 0) { "Treap key comparison logic changed: ${right.treapKey} <= ${this.treapKey}" } check(right.comparePriorityTo(this) < 0) { "Treap key priority hash logic changed: ${right.treapKey} >= ${this.treapKey} "} } diff --git a/collect/src/main/kotlin/com/certora/collect/TreapMap.kt b/collect/src/main/kotlin/com/certora/collect/TreapMap.kt index d7b0d5f..7931468 100644 --- a/collect/src/main/kotlin/com/certora/collect/TreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/TreapMap.kt @@ -13,8 +13,6 @@ public sealed interface TreapMap : PersistentMap { override fun putAll(m: Map): TreapMap override fun clear(): TreapMap - override val keys: TreapSet - /** A [PersistentMap.Builder] that produces a [TreapMap]. */