diff --git a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt index 5a9e90f..cebdcba 100644 --- a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt @@ -25,23 +25,6 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT */ abstract fun Map.toTreapMapOrNull(): AbstractTreapMap? - /** - Converts the given Map to a AbstractTreapMap of the same type as 'this'. May copy the map. - */ - fun Map.toTreapMapIfNotEmpty(): AbstractTreapMap? = - toTreapMapOrNull() ?: when { - isEmpty() -> null - else -> { - val i = entries.iterator() - var m: AbstractTreapMap = i.next().let { (k, v) -> new(k, v) } - while (i.hasNext()) { - val (k, v) = i.next() - m = m.put(k, v) - } - m - } - } - /** Given a map, calls the supplied `action` if the collection is a Treap of the same type as this Treap, otherwise calls `fallback.` Used to implement optimized operations over two compatible Treaps, with a fallback when @@ -64,7 +47,7 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT /** Converts the supplied map key to a TreapKey appropriate to this type of AbstractTreapMap (sorted vs. hashed) */ - abstract fun K.toTreapKey(): TreapKey + abstract fun K.toTreapKey(): TreapKey? /** Does this node contain an entry with the given map key? @@ -127,24 +110,21 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT override fun isEmpty(): Boolean = false override fun containsKey(key: K) = - self.find(key.toTreapKey())?.shallowContainsKey(key) ?: false + key.toTreapKey()?.let { self.find(it) }?.shallowContainsKey(key) ?: false override fun containsValue(value: V) = values.contains(value) override fun get(key: K): V? = - self.find(key.toTreapKey())?.shallowGetValue(key) - - override fun put(key: K, value: V): AbstractTreapMap = - self.add(new(key, value)) + key.toTreapKey()?.let { self.find(it) }?.shallowGetValue(key) override fun putAll(m: Map): TreapMap = m.entries.fold(this as TreapMap) { t, e -> t.put(e.key, e.value) } - override fun remove(key: K): TreapMap = - self.remove(key.toTreapKey(), key) ?: clear() + override fun remove(key: K): TreapMap = + key.toTreapKey()?.let { self.remove(it, key) ?: clear() } ?: this override fun remove(key: K, value: V): TreapMap = - self.removeEntry(key.toTreapKey(), key, value) ?: clear() + key.toTreapKey()?.let { self.removeEntry(it, key, value) ?: clear() } ?: this override fun clear(): TreapMap = treapMapOf() @@ -295,7 +275,13 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT ``` */ override fun updateEntry(key: K, value: U, merger: (V?, U) -> V?): TreapMap { - return self.updateEntry(key.toTreapKey().precompute(), key, value, merger, ::new) ?: clear() + val treapKey = key.toTreapKey()?.precompute() + return if (treapKey == null) { + // The key is not compatible with this map type, so it's definitely not in the map. + merger(null, value)?.let { put(key, it) } ?: this + } else { + self.updateEntry(treapKey, key, value, merger, ::new) ?: clear() + } } /** @@ -305,47 +291,59 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT override fun zip(m: Map) = sequence>> { fun Iterator.nextOrNull() = if (hasNext()) { next() } else { null } - // Iterate over the two maps' treap sequences. We ensure that each sequence uses the same key ordering, by - // converting `m` to a TreapMap of this map's type, if necessary. Note that we can't use entrySequence, because - // HashTreapMap's entrySequence is only partially ordered. - val thisIt = asTreapSequence().iterator() - val that: Treap? = m.toTreapMapIfNotEmpty() - val thatIt = that?.asTreapSequence()?.iterator() - - var thisCurrent = thisIt.nextOrNull() - var thatCurrent = thatIt?.nextOrNull() - - while (thisCurrent != null && thatIt != null && thatCurrent != null) { - val c = thisCurrent.compareKeyTo(thatCurrent) - when { - c < 0 -> { - yieldAll(thisCurrent.shallowZipThisOnly()) - thisCurrent = thisIt.nextOrNull() - } - c > 0 -> { - yieldAll(thatCurrent.shallowZipThatOnly()) - thatCurrent = thatIt.nextOrNull() + val sequences = getTreapSequencesIfSameType(m) + if (sequences != null) { + // Fast case for when the maps are the same type + val thisIt = sequences.first.iterator() + val thatIt = sequences.second.iterator() + + var thisCurrent = thisIt.nextOrNull() + var thatCurrent = thatIt.nextOrNull() + + while (thisCurrent != null && thatCurrent != null) { + val c = thisCurrent.compareKeyTo(thatCurrent) + when { + c < 0 -> { + yieldAll(thisCurrent.shallowZipThisOnly()) + thisCurrent = thisIt.nextOrNull() + } + c > 0 -> { + yieldAll(thatCurrent.shallowZipThatOnly()) + thatCurrent = thatIt.nextOrNull() + } + else -> { + yieldAll(thisCurrent.shallowZip(thatCurrent)) + thisCurrent = thisIt.nextOrNull() + thatCurrent = thatIt.nextOrNull() + } } - else -> { - yieldAll(thisCurrent.shallowZip(thatCurrent)) - thisCurrent = thisIt.nextOrNull() - thatCurrent = thatIt.nextOrNull() + } + while (thisCurrent != null) { + yieldAll(thisCurrent.shallowZipThisOnly()) + thisCurrent = thisIt.nextOrNull() + } + while (thatCurrent != null) { + yieldAll(thatCurrent.shallowZipThatOnly()) + thatCurrent = thatIt.nextOrNull() + } + } else { + // Slower fallback for maps of different types + for ((k, v) in entries) { + yield(MapEntry(k, v to m[k])) + } + for ((k, v) in m.entries) { + if (k !in this@AbstractTreapMap) { + yield(MapEntry(k, null to v)) } } } - while (thisCurrent != null) { - yieldAll(thisCurrent.shallowZipThisOnly()) - thisCurrent = thisIt.nextOrNull() - } - while (thatIt != null && thatCurrent != null) { - yieldAll(thatCurrent.shallowZipThatOnly()) - thatCurrent = thatIt.nextOrNull() - } } private fun shallowZipThisOnly() = shallowEntrySequence().map { MapEntry(it.key, it.value to null) } private fun shallowZipThatOnly() = shallowEntrySequence().map { MapEntry(it.key, null to it.value) } protected abstract fun shallowZip(that: S): Sequence>> + protected abstract fun getTreapSequencesIfSameType(that: Map): Pair, Sequence>? + override fun mapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R = notForking(self) { mapReduceImpl(map, reduce) } diff --git a/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt index 52fe41d..4900c4e 100644 --- a/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt @@ -39,7 +39,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet> /** Converts the supplied set element to a TreapKey appropriate to this type of AbstractTreapSet (sorted vs. hashed) */ - abstract fun E.toTreapKey(): TreapKey + abstract fun E.toTreapKey(): TreapKey? /** Does this node contain the element? @@ -92,7 +92,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet> } override fun contains(element: E): Boolean = - self.find(element.toTreapKey())?.shallowContains(element) ?: false + element.toTreapKey()?.let { self.find(it) }?.shallowContains(element) ?: false override fun containsAll(elements: Collection): Boolean = elements.useAsTreap( { elementsTreap -> self.containsAllKeys(elementsTreap) }, @@ -110,7 +110,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet> ) override fun remove(element: E): TreapSet = - self.remove(element.toTreapKey(), element) ?: clear() + element.toTreapKey()?.let { self.remove(it, element) ?: clear() } ?: this override fun removeAll(elements: Collection): TreapSet = elements.useAsTreap( { elementsTreap -> (self difference elementsTreap) ?: clear() }, @@ -134,7 +134,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet> ) override fun findEqual(element: E): E? = - self.find(element.toTreapKey())?.shallowFindEqual(element) + element.toTreapKey()?.let { self.find(it) }?.shallowFindEqual(element) @Suppress("UNCHECKED_CAST") override fun single(): E = getSingleElement() ?: when { diff --git a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt index 5975ae1..f022933 100644 --- a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt @@ -61,10 +61,8 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap = when (key) { - is PrefersHashTreap -> HashTreapMap(key, value) - is Comparable<*> -> - SortedTreapMap>, V>(key as Comparable>, value) as TreapMap - else -> HashTreapMap(key, value) + !is Comparable<*>?, is PrefersHashTreap -> HashTreapMap(key, value) + else -> SortedTreapMap(key, value) } @Suppress("UNCHECKED_CAST") diff --git a/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt index 3628343..4b880b9 100644 --- a/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt @@ -27,12 +27,9 @@ internal class EmptyTreapSet<@Treapable E> private constructor() : TreapSet, override fun mapReduce(map: (E) -> R, reduce: (R, R) -> R): R? = null override fun parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R? = null - @Suppress("Treapability", "UNCHECKED_CAST") override fun add(element: E): TreapSet = when (element) { - is PrefersHashTreap -> HashTreapSet(element) - is Comparable<*> -> - SortedTreapSet>>(element as Comparable>) as TreapSet - else -> HashTreapSet(element) + !is Comparable<*>?, is PrefersHashTreap -> HashTreapSet(element) + else -> SortedTreapSet(element as E) } @Suppress("UNCHECKED_CAST") @@ -42,10 +39,10 @@ internal class EmptyTreapSet<@Treapable E> private constructor() : TreapSet, elements is PersistentSet.Builder<*> -> elements.build() as TreapSet else -> elements.fold(this as TreapSet) { set, element -> set.add(element) } } - + companion object { private val instance = EmptyTreapSet() @Suppress("UNCHECKED_CAST") operator fun <@Treapable E> invoke(): EmptyTreapSet = instance as EmptyTreapSet } -} \ No newline at end of file +} diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt index 5666af8..abddbed 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt @@ -21,9 +21,11 @@ internal class HashTreapMap<@Treapable K, V>( override fun hashCode() = computeHashCode() - override fun K.toTreapKey() = TreapKey.Hashed.FromKey(this) + override fun K.toTreapKey() = TreapKey.Hashed.fromKey(this) override fun new(key: K, value: V): HashTreapMap = HashTreapMap(key, value) + override fun put(key: K, value: V): TreapMap = self.add(new(key, value)) + @Suppress("UNCHECKED_CAST") override fun Map.toTreapMapOrNull() = this as? HashTreapMap @@ -88,6 +90,15 @@ internal class HashTreapMap<@Treapable K, V>( return false } + protected override fun getTreapSequencesIfSameType( + that: Map + ): Pair>, Sequence>>? { + @Suppress("UNCHECKED_CAST") + return (that as? HashTreapMap)?.let { + this.asTreapSequence() to it.asTreapSequence() + } + } + override fun shallowZip(that: HashTreapMap): Sequence>> = sequence { forEachPair { yield(MapEntry(it.key, it.value to that.shallowGetValue(it.key))) diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt index 8346603..fd819f0 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt @@ -17,7 +17,7 @@ internal class HashTreapSet<@Treapable E>( override fun hashCode(): Int = computeHashCode() - override fun E.toTreapKey() = TreapKey.Hashed.FromKey(this) + override fun E.toTreapKey() = TreapKey.Hashed.fromKey(this) override fun new(element: E): HashTreapSet = HashTreapSet(element) override fun add(element: E): TreapSet = self.add(new(element)) @@ -45,7 +45,7 @@ internal class HashTreapSet<@Treapable E>( return count } - override fun copyWith(left: HashTreapSet?, right: HashTreapSet?): HashTreapSet = + override fun copyWith(left: HashTreapSet?, right: HashTreapSet?): HashTreapSet = HashTreapSet(element, next, left, right) fun withElement(element: E) = when { diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt index e447172..ae13992 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt @@ -6,26 +6,33 @@ import kotlinx.collections.immutable.PersistentMap A TreapMap specific to Comparable keys. Iterates in the order defined by the objects. We store one element per Treap node, with the map key itself as the Treap key, and an additional `value` field */ -internal class SortedTreapMap<@Treapable K : Comparable, V>( +internal class SortedTreapMap<@Treapable K, V>( val key: K, val value: V, left: SortedTreapMap? = null, right: SortedTreapMap? = null ) : AbstractTreapMap>(left, right), TreapKey.Sorted { + init { check(key is Comparable<*>?) { "SortedTreapMap keys must be Comparable" } } + override fun hashCode() = computeHashCode() - override fun K.toTreapKey() = TreapKey.Sorted.FromKey(this) + override fun K.toTreapKey() = TreapKey.Sorted.fromKey(this) override fun new(key: K, value: V): SortedTreapMap = SortedTreapMap(key, value) + override fun put(key: K, value: V): TreapMap = when (key) { + !is Comparable<*>?, is PrefersHashTreap -> HashTreapMap(key, value) + this + else -> self.add(new(key, value)) + } + @Suppress("UNCHECKED_CAST") override fun Map.toTreapMapOrNull() = this as? SortedTreapMap ?: (this as? PersistentMap.Builder)?.build() as? SortedTreapMap override fun getShallowMerger(merger: (K, V?, V?) -> V?): (SortedTreapMap?, SortedTreapMap?) -> SortedTreapMap? = { t1, t2 -> - val k = t1?.key ?: t2?.key as K + val k = (t1 ?: t2)!!.key val v1 = t1?.value val v2 = t2?.value val v = merger(k, v1, v2) @@ -37,6 +44,15 @@ internal class SortedTreapMap<@Treapable K : Comparable, V>( } } + protected override fun getTreapSequencesIfSameType( + that: Map + ): Pair>, Sequence>>? { + @Suppress("UNCHECKED_CAST") + return (that as? SortedTreapMap)?.let { + this.asTreapSequence() to it.asTreapSequence() + } + } + override fun shallowZip(that: SortedTreapMap): Sequence>> = sequenceOf(MapEntry(this.key, this.value to that.value)) @@ -84,8 +100,9 @@ internal class SortedTreapMap<@Treapable K : Comparable, V>( } fun floorEntry(key: K): Map.Entry? { - val cmp = key.compareTo(this.key) + val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this) return when { + cmp == null -> null cmp < 0 -> left?.floorEntry(key) cmp > 0 -> right?.floorEntry(key) ?: this.asEntry() else -> this.asEntry() @@ -93,8 +110,9 @@ internal class SortedTreapMap<@Treapable K : Comparable, V>( } fun ceilingEntry(key: K): Map.Entry? { - val cmp = key.compareTo(this.key) + val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this) return when { + cmp == null -> null cmp < 0 -> left?.ceilingEntry(key) ?: this.asEntry() cmp > 0 -> right?.ceilingEntry(key) else -> this.asEntry() @@ -102,16 +120,18 @@ internal class SortedTreapMap<@Treapable K : Comparable, V>( } fun lowerEntry(key: K): Map.Entry? { - val cmp = key.compareTo(this.key) + val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this) return when { + cmp == null -> null cmp > 0 -> right?.lowerEntry(key) ?: this.asEntry() else -> left?.lowerEntry(key) } } fun higherEntry(key: K): Map.Entry? { - val cmp = key.compareTo(this.key) + val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this) return when { + cmp == null -> null cmp < 0 -> left?.higherEntry(key) ?: this.asEntry() else -> right?.higherEntry(key) } diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt index 40f31c9..d683b71 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt @@ -6,19 +6,21 @@ import kotlinx.collections.immutable.PersistentSet A TreapSet specific to Comparable elements. Iterates in the order defined by the objects. We store one element per Treap node, with the element itself as the Treap key. */ -internal class SortedTreapSet<@Treapable E : Comparable>( +internal class SortedTreapSet<@Treapable E>( override val treapKey: E, left: SortedTreapSet? = null, right: SortedTreapSet? = null ) : AbstractTreapSet>(left, right), TreapKey.Sorted { + init { check(treapKey is Comparable<*>?) { "SortedTreapSet elements must be Comparable" } } + override fun hashCode(): Int = computeHashCode() - override fun E.toTreapKey() = TreapKey.Sorted.FromKey(this) + override fun E.toTreapKey() = TreapKey.Sorted.fromKey(this) override fun new(element: E): SortedTreapSet = SortedTreapSet(element) - override fun add(element: E): TreapSet = when { - element is PrefersHashTreap -> HashTreapSet(element as E) + this + override fun add(element: E): TreapSet = when(element) { + !is Comparable<*>?, is PrefersHashTreap -> HashTreapSet(element) + this else -> self.add(new(element)) } diff --git a/collect/src/main/kotlin/com/certora/collect/TreapKey.kt b/collect/src/main/kotlin/com/certora/collect/TreapKey.kt index 2c501da..f5b454d 100644 --- a/collect/src/main/kotlin/com/certora/collect/TreapKey.kt +++ b/collect/src/main/kotlin/com/certora/collect/TreapKey.kt @@ -64,17 +64,33 @@ internal interface TreapKey<@Treapable K> { /** A TreapKey whose underlying key implements Comparable. This allows us to sort the Treap naturally. */ - interface Sorted<@Treapable K : Comparable> : TreapKey { + interface Sorted<@Treapable K> : TreapKey { abstract override val treapKey: K - // Note that we must never compare a Hashed key with a Sorted key. We'd check that here, but this is extremely - // perf-critical code. - override fun compareKeyTo(that: TreapKey) = this.treapKey.compareTo(that.treapKey) + override fun compareKeyTo(that: TreapKey): Int { + val thisTreapKey = this.treapKey + val thatTreapKey = that.treapKey + return when { + thisTreapKey === thatTreapKey -> 0 + thisTreapKey == null -> -1 + thatTreapKey == null -> 1 + else -> { + @Suppress("UNCHECKED_CAST") + (thisTreapKey as Comparable).compareTo(thatTreapKey) + } + } + } - override fun precompute() = FromKey(treapKey) + override fun precompute() = fromKey(treapKey)!! - class FromKey<@Treapable K : Comparable>(override val treapKey: K) : Sorted { - override val treapPriority = super.treapPriority // precompute the priority + companion object { + fun <@Treapable K> fromKey(key: K): Sorted? = when (key) { + is Comparable<*>? -> object : Sorted { + override val treapKey = key + override val treapPriority = super.treapPriority // precompute the priority + } + else -> null + } } } @@ -98,11 +114,14 @@ internal interface TreapKey<@Treapable K> { } } - override fun precompute() = FromKey(treapKey) + override fun precompute() = fromKey(treapKey) - class FromKey<@Treapable K>(override val treapKey: K) : Hashed { - override val treapKeyHashCode = treapKey.hashCode() // precompute the hash code - override val treapPriority = super.treapPriority // precompute the priority + companion object { + fun <@Treapable K> fromKey(key: K): Hashed = object : Hashed { + override val treapKey = key + override val treapKeyHashCode = treapKey.hashCode() // precompute the hash code + override val treapPriority = super.treapPriority // precompute the priority + } } } } diff --git a/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt b/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt index 0eb8e4e..28b4bc2 100644 --- a/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt +++ b/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt @@ -191,6 +191,70 @@ abstract class TreapMapTest { } } + @Test + fun addNullMapAtEnd() { + val s = makeMap() + s.putAll(treapMapOf(makeKey(0) to null)) + s.putAll(treapMapOf(null to null)) + assertEquals(s, mapOf(null to null, makeKey(0) to null)) + } + + @Test + fun addNullMapAtStart() { + val s = makeMap() + s.putAll(treapMapOf(null to null)) + s.putAll(treapMapOf(makeKey(0) to null)) + assertEquals(s, mapOf(null to null, makeKey(0) to null)) + } + + @Test + fun addHashKeyAtEnd() { + val s = makeMap() + s.put(makeKey(0), null) + s.put(null, null) + assertEquals(s, mapOf(null to null, makeKey(0) to null)) + } + + @Test + fun addNullKeyAtStart() { + val s = makeMap() + s.put(null, null) + s.put(makeKey(0), null) + assertEquals(s, mapOf(null to null, makeKey(0) to null)) + } + + @Test + fun addHashKeyAtStart() { + val s = makeMap() + s.put(HashTestKey(0), null) + s.put(makeKey(0), null) + assertEquals(s, mapOf(HashTestKey(0) to null, makeKey(0) to null)) + } + + @Test + fun addHashMapAtEnd() { + val s = makeMap() + s.putAll(treapMapOf(makeKey(0) to null)) + s.putAll(treapMapOf(HashTestKey(0) to null)) + assertEquals(s, mapOf(HashTestKey(0) to null, makeKey(0) to null)) + } + + @Test + fun addHashMapAtStart() { + val s = makeMap() + s.putAll(treapMapOf(HashTestKey(0) to null)) + s.putAll(treapMapOf(makeKey(0) to null)) + assertEquals(s, mapOf(HashTestKey(0) to null, makeKey(0) to null)) + } + + @Test + fun addNullKeyAtEnd() { + val s = makeMap() + s.put(makeKey(0), null) + s.put(HashTestKey(0), null) + assertEquals(s, mapOf(HashTestKey(0) to null, makeKey(0) to null)) + } + @Test fun copyConstructorEmpty() { val empty = mapOf() diff --git a/collect/src/test/kotlin/com/certora/collect/TreapSetTest.kt b/collect/src/test/kotlin/com/certora/collect/TreapSetTest.kt index 04bb170..f3af578 100644 --- a/collect/src/test/kotlin/com/certora/collect/TreapSetTest.kt +++ b/collect/src/test/kotlin/com/certora/collect/TreapSetTest.kt @@ -187,6 +187,108 @@ abstract class TreapSetTest { } } + @Test + fun addNullSetAtEnd() { + val s = makeSet() + s.addAll(treapSetOf(makeKey(0))) + s.addAll(treapSetOf(null)) + assertEquals(s, setOf(null, makeKey(0))) + } + + @Test + fun addNullSetAtStart() { + val s = makeSet() + s.addAll(treapSetOf(null)) + s.addAll(treapSetOf(makeKey(0))) + assertEquals(s, setOf(null, makeKey(0))) + } + + @Test + fun addNullElementAtEnd() { + val s = makeSet() + s.add(makeKey(0)) + s.add(null) + assertEquals(s, setOf(null, makeKey(0))) + } + + @Test + fun addNullElementAtStart() { + val s = makeSet() + s.add(null) + s.add(makeKey(0)) + assertEquals(s, setOf(null, makeKey(0))) + } + + @Test + fun addHashSetAtEnd() { + val s = makeSet() + s.addAll(treapSetOf(makeKey(0))) + s.addAll(treapSetOf(HashTestKey(0))) + assertEquals(s, setOf(HashTestKey(0), makeKey(0))) + } + + @Test + fun addHashSetAtStart() { + val s = makeSet() + s.addAll(treapSetOf(HashTestKey(0))) + s.addAll(treapSetOf(makeKey(0))) + assertEquals(s, setOf(HashTestKey(0), makeKey(0))) + } + + @Test + fun addHashElementAtEnd() { + val s = makeSet() + s.add(makeKey(0)) + s.add(HashTestKey(0)) + assertEquals(s, setOf(HashTestKey(0), makeKey(0))) + } + + @Test + fun addHashElementAtStart() { + val s = makeSet() + s.add(HashTestKey(0)) + s.add(makeKey(0)) + assertEquals(s, setOf(HashTestKey(0), makeKey(0))) + } + + @Test + fun nullQueries() { + val withNull = makeSet() + withNull.add(makeKey(0)) + withNull.add(null) + + val withoutNull = makeSet() + withNull.add(makeKey(0)) + + assertTrue(null in withNull) + assertFalse(null in withoutNull) + + assertTrue(withNull.containsAll(treapSetOf(null))) + assertFalse(withoutNull.containsAll(treapSetOf(null))) + + assertTrue(withNull.remove(null)) + assertFalse(withoutNull.remove(null)) + } + + @Test + fun hashQueries() { + val withHash = makeSet() + withHash.add(makeKey(0)) + withHash.add(HashTestKey(0)) + + val withoutHash = makeSet() + withHash.add(makeKey(0)) + + assertTrue(HashTestKey(0) in withHash) + assertFalse(HashTestKey(0) in withoutHash) + + assertTrue(withHash.containsAll(treapSetOf(HashTestKey(0)))) + assertFalse(withoutHash.containsAll(treapSetOf(HashTestKey(0)))) + + assertTrue(withHash.remove(HashTestKey(0))) + assertFalse(withoutHash.remove(HashTestKey(0))) + } + @Test fun retainAll() { val b = makeBaseline()