Skip to content

Commit

Permalink
Fix type safety issues in sorted collections (#15)
Browse files Browse the repository at this point in the history
The treap-based collections are split into hashed and sorted variants.
The hashed collections are the most versatile; they'll work for any key
type. The sorted collections are specific to `Comparable` types, taking
advantage of the ordering to avoid having to deal with hash collisions.

We don't expose the notion of sorted vs hashed in the public API.
Instead, we automatically create the correct collection type based on
the element type. This mostly works, but has some issues currently. Take
the following for example:

```kotlin
treapSetOf(1, null)
```

This is allowed statically (`treapSetOf` has no restrictions on the
element type), but currently throws a `NullPointerException` at runtime.
The reason is that the first element is `Comparable`, and so we create a
`SortedTreapSet`, which is effecively a `TreapSet<Comparable<T>>` Then,
when we add the `null`, we fail the runtime check that the element is a
`Comparable<T>`, because it is null.

Similarly, if we start with a `Comparable`, and then add a
non-`Comparable`, we also crash:

```kotlin
data class Foo(a: Int)
treapSetOf<Any>(1, Foo(2)) 
```

Where we went wrong is in making `SortedTreapSet` a
`TreapSet<Comparable<T>>`. Conceptually, `SortedTreapSet` should be a
set of arbitrary values - _that just happens to only contain values that
are `Comparable`_.

We fix this by removing the constraint from `SortedTreapSet` (and
`SortedTreapMap`), and adding some runtime checks to switch to the
hashed collections if a non-`Comparable` key is added.

We also allow `null` as a sorted key, which prevents the (quite
expensive) downgrade to the hashed collections just to get nullability.
  • Loading branch information
ericeil authored Jun 23, 2024
1 parent 1bb0beb commit 01ac23a
Show file tree
Hide file tree
Showing 11 changed files with 310 additions and 99 deletions.
116 changes: 57 additions & 59 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,6 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
*/
abstract fun Map<out K, V>.toTreapMapOrNull(): AbstractTreapMap<K, V, S>?

/**
Converts the given Map to a AbstractTreapMap of the same type as 'this'. May copy the map.
*/
fun Map<out K, V>.toTreapMapIfNotEmpty(): AbstractTreapMap<K, V, S>? =
toTreapMapOrNull() ?: when {
isEmpty() -> null
else -> {
val i = entries.iterator()
var m: AbstractTreapMap<K, V, S> = i.next().let { (k, v) -> new(k, v) }
while (i.hasNext()) {
val (k, v) = i.next()
m = m.put(k, v)
}
m
}
}

/**
Given a map, calls the supplied `action` if the collection is a Treap of the same type as this Treap, otherwise
calls `fallback.` Used to implement optimized operations over two compatible Treaps, with a fallback when
Expand All @@ -64,7 +47,7 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
/**
Converts the supplied map key to a TreapKey appropriate to this type of AbstractTreapMap (sorted vs. hashed)
*/
abstract fun K.toTreapKey(): TreapKey<K>
abstract fun K.toTreapKey(): TreapKey<K>?

/**
Does this node contain an entry with the given map key?
Expand Down Expand Up @@ -127,24 +110,21 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
override fun isEmpty(): Boolean = false

override fun containsKey(key: K) =
self.find(key.toTreapKey())?.shallowContainsKey(key) ?: false
key.toTreapKey()?.let { self.find(it) }?.shallowContainsKey(key) ?: false

override fun containsValue(value: V) = values.contains(value)

override fun get(key: K): V? =
self.find(key.toTreapKey())?.shallowGetValue(key)

override fun put(key: K, value: V): AbstractTreapMap<K, V, S> =
self.add(new(key, value))
key.toTreapKey()?.let { self.find(it) }?.shallowGetValue(key)

override fun putAll(m: Map<out K, V>): TreapMap<K, V> =
m.entries.fold(this as TreapMap<K, V>) { t, e -> t.put(e.key, e.value) }

override fun remove(key: K): TreapMap<K, V> =
self.remove(key.toTreapKey(), key) ?: clear()
override fun remove(key: K): TreapMap<K, V> =
key.toTreapKey()?.let { self.remove(it, key) ?: clear() } ?: this

override fun remove(key: K, value: V): TreapMap<K, V> =
self.removeEntry(key.toTreapKey(), key, value) ?: clear()
key.toTreapKey()?.let { self.removeEntry(it, key, value) ?: clear() } ?: this

override fun clear(): TreapMap<K, V> = treapMapOf<K, V>()

Expand Down Expand Up @@ -295,7 +275,13 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
```
*/
override fun <U> updateEntry(key: K, value: U, merger: (V?, U) -> V?): TreapMap<K, V> {
return self.updateEntry(key.toTreapKey().precompute(), key, value, merger, ::new) ?: clear()
val treapKey = key.toTreapKey()?.precompute()
return if (treapKey == null) {
// The key is not compatible with this map type, so it's definitely not in the map.
merger(null, value)?.let { put(key, it) } ?: this
} else {
self.updateEntry(treapKey, key, value, merger, ::new) ?: clear()
}
}

/**
Expand All @@ -305,47 +291,59 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
override fun zip(m: Map<out K, V>) = sequence<Map.Entry<K, Pair<V?, V?>>> {
fun <T> Iterator<T>.nextOrNull() = if (hasNext()) { next() } else { null }

// Iterate over the two maps' treap sequences. We ensure that each sequence uses the same key ordering, by
// converting `m` to a TreapMap of this map's type, if necessary. Note that we can't use entrySequence, because
// HashTreapMap's entrySequence is only partially ordered.
val thisIt = asTreapSequence().iterator()
val that: Treap<K, S>? = m.toTreapMapIfNotEmpty()
val thatIt = that?.asTreapSequence()?.iterator()

var thisCurrent = thisIt.nextOrNull()
var thatCurrent = thatIt?.nextOrNull()

while (thisCurrent != null && thatIt != null && thatCurrent != null) {
val c = thisCurrent.compareKeyTo(thatCurrent)
when {
c < 0 -> {
yieldAll(thisCurrent.shallowZipThisOnly())
thisCurrent = thisIt.nextOrNull()
}
c > 0 -> {
yieldAll(thatCurrent.shallowZipThatOnly())
thatCurrent = thatIt.nextOrNull()
val sequences = getTreapSequencesIfSameType(m)
if (sequences != null) {
// Fast case for when the maps are the same type
val thisIt = sequences.first.iterator()
val thatIt = sequences.second.iterator()

var thisCurrent = thisIt.nextOrNull()
var thatCurrent = thatIt.nextOrNull()

while (thisCurrent != null && thatCurrent != null) {
val c = thisCurrent.compareKeyTo(thatCurrent)
when {
c < 0 -> {
yieldAll(thisCurrent.shallowZipThisOnly())
thisCurrent = thisIt.nextOrNull()
}
c > 0 -> {
yieldAll(thatCurrent.shallowZipThatOnly())
thatCurrent = thatIt.nextOrNull()
}
else -> {
yieldAll(thisCurrent.shallowZip(thatCurrent))
thisCurrent = thisIt.nextOrNull()
thatCurrent = thatIt.nextOrNull()
}
}
else -> {
yieldAll(thisCurrent.shallowZip(thatCurrent))
thisCurrent = thisIt.nextOrNull()
thatCurrent = thatIt.nextOrNull()
}
while (thisCurrent != null) {
yieldAll(thisCurrent.shallowZipThisOnly())
thisCurrent = thisIt.nextOrNull()
}
while (thatCurrent != null) {
yieldAll(thatCurrent.shallowZipThatOnly())
thatCurrent = thatIt.nextOrNull()
}
} else {
// Slower fallback for maps of different types
for ((k, v) in entries) {
yield(MapEntry(k, v to m[k]))
}
for ((k, v) in m.entries) {
if (k !in this@AbstractTreapMap) {
yield(MapEntry(k, null to v))
}
}
}
while (thisCurrent != null) {
yieldAll(thisCurrent.shallowZipThisOnly())
thisCurrent = thisIt.nextOrNull()
}
while (thatIt != null && thatCurrent != null) {
yieldAll(thatCurrent.shallowZipThatOnly())
thatCurrent = thatIt.nextOrNull()
}
}

private fun shallowZipThisOnly() = shallowEntrySequence().map { MapEntry(it.key, it.value to null) }
private fun shallowZipThatOnly() = shallowEntrySequence().map { MapEntry(it.key, null to it.value) }
protected abstract fun shallowZip(that: S): Sequence<Map.Entry<K, Pair<V?, V?>>>
protected abstract fun getTreapSequencesIfSameType(that: Map<out K, V>): Pair<Sequence<S>, Sequence<S>>?


override fun <R : Any> mapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R =
notForking(self) { mapReduceImpl(map, reduce) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
/**
Converts the supplied set element to a TreapKey appropriate to this type of AbstractTreapSet (sorted vs. hashed)
*/
abstract fun E.toTreapKey(): TreapKey<E>
abstract fun E.toTreapKey(): TreapKey<E>?

/**
Does this node contain the element?
Expand Down Expand Up @@ -92,7 +92,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
}

override fun contains(element: E): Boolean =
self.find(element.toTreapKey())?.shallowContains(element) ?: false
element.toTreapKey()?.let { self.find(it) }?.shallowContains(element) ?: false

override fun containsAll(elements: Collection<E>): Boolean = elements.useAsTreap(
{ elementsTreap -> self.containsAllKeys(elementsTreap) },
Expand All @@ -110,7 +110,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
)

override fun remove(element: E): TreapSet<E> =
self.remove(element.toTreapKey(), element) ?: clear()
element.toTreapKey()?.let { self.remove(it, element) ?: clear() } ?: this

override fun removeAll(elements: Collection<E>): TreapSet<E> = elements.useAsTreap(
{ elementsTreap -> (self difference elementsTreap) ?: clear() },
Expand All @@ -134,7 +134,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
)

override fun findEqual(element: E): E? =
self.find(element.toTreapKey())?.shallowFindEqual(element)
element.toTreapKey()?.let { self.find(it) }?.shallowFindEqual(element)

@Suppress("UNCHECKED_CAST")
override fun single(): E = getSingleElement() ?: when {
Expand Down
6 changes: 2 additions & 4 deletions collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,8 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K

@Suppress("Treapability", "UNCHECKED_CAST")
override fun put(key: K, value: V): TreapMap<K, V> = when (key) {
is PrefersHashTreap -> HashTreapMap(key, value)
is Comparable<*> ->
SortedTreapMap<Comparable<Comparable<*>>, V>(key as Comparable<Comparable<*>>, value) as TreapMap<K, V>
else -> HashTreapMap(key, value)
!is Comparable<*>?, is PrefersHashTreap -> HashTreapMap(key, value)
else -> SortedTreapMap(key, value)
}

@Suppress("UNCHECKED_CAST")
Expand Down
11 changes: 4 additions & 7 deletions collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,9 @@ internal class EmptyTreapSet<@Treapable E> private constructor() : TreapSet<E>,
override fun <R : Any> mapReduce(map: (E) -> R, reduce: (R, R) -> R): R? = null
override fun <R : Any> parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R? = null

@Suppress("Treapability", "UNCHECKED_CAST")
override fun add(element: E): TreapSet<E> = when (element) {
is PrefersHashTreap -> HashTreapSet(element)
is Comparable<*> ->
SortedTreapSet<Comparable<Comparable<*>>>(element as Comparable<Comparable<*>>) as TreapSet<E>
else -> HashTreapSet(element)
!is Comparable<*>?, is PrefersHashTreap -> HashTreapSet(element)
else -> SortedTreapSet(element as E)
}

@Suppress("UNCHECKED_CAST")
Expand All @@ -42,10 +39,10 @@ internal class EmptyTreapSet<@Treapable E> private constructor() : TreapSet<E>,
elements is PersistentSet.Builder<*> -> elements.build() as TreapSet<E>
else -> elements.fold(this as TreapSet<E>) { set, element -> set.add(element) }
}

companion object {
private val instance = EmptyTreapSet<Nothing>()
@Suppress("UNCHECKED_CAST")
operator fun <@Treapable E> invoke(): EmptyTreapSet<E> = instance as EmptyTreapSet<E>
}
}
}
13 changes: 12 additions & 1 deletion collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ internal class HashTreapMap<@Treapable K, V>(

override fun hashCode() = computeHashCode()

override fun K.toTreapKey() = TreapKey.Hashed.FromKey(this)
override fun K.toTreapKey() = TreapKey.Hashed.fromKey(this)
override fun new(key: K, value: V): HashTreapMap<K, V> = HashTreapMap(key, value)

override fun put(key: K, value: V): TreapMap<K, V> = self.add(new(key, value))

@Suppress("UNCHECKED_CAST")
override fun Map<out K, V>.toTreapMapOrNull() =
this as? HashTreapMap<K, V>
Expand Down Expand Up @@ -88,6 +90,15 @@ internal class HashTreapMap<@Treapable K, V>(
return false
}

protected override fun getTreapSequencesIfSameType(
that: Map<out K, V>
): Pair<Sequence<HashTreapMap<K, V>>, Sequence<HashTreapMap<K, V>>>? {
@Suppress("UNCHECKED_CAST")
return (that as? HashTreapMap<K, V>)?.let {
this.asTreapSequence() to it.asTreapSequence()
}
}

override fun shallowZip(that: HashTreapMap<K, V>): Sequence<Map.Entry<K, Pair<V?, V?>>> = sequence {
forEachPair {
yield(MapEntry(it.key, it.value to that.shallowGetValue(it.key)))
Expand Down
4 changes: 2 additions & 2 deletions collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ internal class HashTreapSet<@Treapable E>(

override fun hashCode(): Int = computeHashCode()

override fun E.toTreapKey() = TreapKey.Hashed.FromKey(this)
override fun E.toTreapKey() = TreapKey.Hashed.fromKey(this)
override fun new(element: E): HashTreapSet<E> = HashTreapSet(element)

override fun add(element: E): TreapSet<E> = self.add(new(element))
Expand Down Expand Up @@ -45,7 +45,7 @@ internal class HashTreapSet<@Treapable E>(
return count
}

override fun copyWith(left: HashTreapSet<E>?, right: HashTreapSet<E>?): HashTreapSet<E> =
override fun copyWith(left: HashTreapSet<E>?, right: HashTreapSet<E>?): HashTreapSet<E> =
HashTreapSet(element, next, left, right)

fun withElement(element: E) = when {
Expand Down
34 changes: 27 additions & 7 deletions collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,33 @@ import kotlinx.collections.immutable.PersistentMap
A TreapMap specific to Comparable keys. Iterates in the order defined by the objects. We store one element per
Treap node, with the map key itself as the Treap key, and an additional `value` field
*/
internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(
internal class SortedTreapMap<@Treapable K, V>(
val key: K,
val value: V,
left: SortedTreapMap<K, V>? = null,
right: SortedTreapMap<K, V>? = null
) : AbstractTreapMap<K, V, SortedTreapMap<K, V>>(left, right), TreapKey.Sorted<K> {

init { check(key is Comparable<*>?) { "SortedTreapMap keys must be Comparable" } }

override fun hashCode() = computeHashCode()

override fun K.toTreapKey() = TreapKey.Sorted.FromKey(this)
override fun K.toTreapKey() = TreapKey.Sorted.fromKey(this)

override fun new(key: K, value: V): SortedTreapMap<K, V> = SortedTreapMap(key, value)

override fun put(key: K, value: V): TreapMap<K, V> = when (key) {
!is Comparable<*>?, is PrefersHashTreap -> HashTreapMap(key, value) + this
else -> self.add(new(key, value))
}

@Suppress("UNCHECKED_CAST")
override fun Map<out K, V>.toTreapMapOrNull() =
this as? SortedTreapMap<K, V>
?: (this as? PersistentMap.Builder<K, V>)?.build() as? SortedTreapMap<K, V>

override fun getShallowMerger(merger: (K, V?, V?) -> V?): (SortedTreapMap<K, V>?, SortedTreapMap<K, V>?) -> SortedTreapMap<K, V>? = { t1, t2 ->
val k = t1?.key ?: t2?.key as K
val k = (t1 ?: t2)!!.key
val v1 = t1?.value
val v2 = t2?.value
val v = merger(k, v1, v2)
Expand All @@ -37,6 +44,15 @@ internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(
}
}

protected override fun getTreapSequencesIfSameType(
that: Map<out K, V>
): Pair<Sequence<SortedTreapMap<K, V>>, Sequence<SortedTreapMap<K, V>>>? {
@Suppress("UNCHECKED_CAST")
return (that as? SortedTreapMap<K, V>)?.let {
this.asTreapSequence() to it.asTreapSequence()
}
}

override fun shallowZip(that: SortedTreapMap<K, V>): Sequence<Map.Entry<K, Pair<V, V>>> =
sequenceOf(MapEntry(this.key, this.value to that.value))

Expand Down Expand Up @@ -84,34 +100,38 @@ internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(
}

fun floorEntry(key: K): Map.Entry<K, V>? {
val cmp = key.compareTo(this.key)
val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this)
return when {
cmp == null -> null
cmp < 0 -> left?.floorEntry(key)
cmp > 0 -> right?.floorEntry(key) ?: this.asEntry()
else -> this.asEntry()
}
}

fun ceilingEntry(key: K): Map.Entry<K, V>? {
val cmp = key.compareTo(this.key)
val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this)
return when {
cmp == null -> null
cmp < 0 -> left?.ceilingEntry(key) ?: this.asEntry()
cmp > 0 -> right?.ceilingEntry(key)
else -> this.asEntry()
}
}

fun lowerEntry(key: K): Map.Entry<K, V>? {
val cmp = key.compareTo(this.key)
val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this)
return when {
cmp == null -> null
cmp > 0 -> right?.lowerEntry(key) ?: this.asEntry()
else -> left?.lowerEntry(key)
}
}

fun higherEntry(key: K): Map.Entry<K, V>? {
val cmp = key.compareTo(this.key)
val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this)
return when {
cmp == null -> null
cmp < 0 -> left?.higherEntry(key) ?: this.asEntry()
else -> right?.higherEntry(key)
}
Expand Down
Loading

0 comments on commit 01ac23a

Please sign in to comment.