Skip to content

Commit

Permalink
There is now a difference between local and global DistanceAttribute …
Browse files Browse the repository at this point in the history
…and ScoreAttributes.
  • Loading branch information
ppanopticon committed Feb 1, 2025
1 parent e7696b6 commit d00cbf4
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import org.vitrivr.engine.core.model.descriptor.vector.FloatVectorDescriptor
import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.query.proximity.ProximityQuery
import org.vitrivr.engine.core.model.retrievable.attributes.DistanceAttribute
import org.vitrivr.engine.core.model.retrievable.attributes.ScoreAttribute

/**
* [DenseRetriever] implementation for proximity-based retrieval on float vector embeddings.
Expand All @@ -26,16 +25,12 @@ import org.vitrivr.engine.core.model.retrievable.attributes.ScoreAttribute
* @author Fynn Faber
* @version 1.0.0
*/
class DenseRetriever<C : ContentElement<*>>(field: Schema.Field<C, FloatVectorDescriptor>, query: ProximityQuery<*>, context: QueryContext, val correspondence: CorrespondenceFunction) :
AbstractRetriever<C, FloatVectorDescriptor>(field, query, context) {
class DenseRetriever<C : ContentElement<*>>(field: Schema.Field<C, FloatVectorDescriptor>, query: ProximityQuery<*>, context: QueryContext, val correspondence: CorrespondenceFunction) : AbstractRetriever<C, FloatVectorDescriptor>(field, query, context) {
override fun toFlow(scope: CoroutineScope) = flow {
this@DenseRetriever.reader.queryAndJoin(this@DenseRetriever.query).forEach {
val distance = it.filteredAttribute<DistanceAttribute>()
if (distance != null) {
it.addAttribute(this@DenseRetriever.correspondence(distance))
} else {
this@DenseRetriever.logger.warn { "No distance attribute found for descriptor ${it.id}." }
it.addAttribute(ScoreAttribute.Similarity(0.0f))
val distances = it.filteredAttributes<DistanceAttribute>()
for (d in distances) {
it.addAttribute(d)
}
emit(it)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,8 @@ import org.vitrivr.engine.core.model.retrievable.attributes.ScoreAttribute
* @version 1.0.0
*/
class BoundedCorrespondence(private val min: Float = 0.0f, private val max: Float = 1.0f) : CorrespondenceFunction {
override fun invoke(distance: DistanceAttribute): ScoreAttribute.Similarity = ScoreAttribute.Similarity((this.max - distance.distance) / (this.max - this.min))
override fun invoke(distance: DistanceAttribute): ScoreAttribute.Similarity = when(distance) {
is DistanceAttribute.Global -> ScoreAttribute.Similarity((this.max - distance.distance) / (this.max - this.min))
is DistanceAttribute.Local -> ScoreAttribute.Similarity((this.max - distance.distance) / (this.max - this.min), distance.descriptorId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ interface CorrespondenceFunction {
* Computes the score for a given [DistanceAttribute].
*
* @param distance [DistanceAttribute] for which to compute the score.
* @return [ScoreAttribute.Similarity] for the given [DistanceAttribute].
* @return [ScoreAttribute] for the given [DistanceAttribute].
*/
operator fun invoke(distance: DistanceAttribute): ScoreAttribute.Similarity
operator fun invoke(distance: DistanceAttribute): ScoreAttribute
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,8 @@ import org.vitrivr.engine.core.model.retrievable.attributes.ScoreAttribute
* @version 1.0.0
*/
class LinearCorrespondence(private val max: Float) : CorrespondenceFunction {
override fun invoke(distance: DistanceAttribute): ScoreAttribute.Similarity = ScoreAttribute.Similarity(1.0f - (distance.distance / this.max))
override fun invoke(distance: DistanceAttribute): ScoreAttribute.Similarity = when(distance) {
is DistanceAttribute.Global -> ScoreAttribute.Similarity(1.0f - (distance.distance / this.max))
is DistanceAttribute.Local -> ScoreAttribute.Similarity(1.0f - (distance.distance / this.max), distance.descriptorId)
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,33 @@
package org.vitrivr.engine.core.model.retrievable.attributes

import org.vitrivr.engine.core.model.descriptor.Descriptor
import org.vitrivr.engine.core.model.descriptor.DescriptorId
import org.vitrivr.engine.core.model.retrievable.Retrievable
import kotlin.math.min

/**
* A [MergingRetrievableAttribute] that contains a distance value.
* A [DistanceAttribute] that contains a distance value.
*
* @author Luca Rossetto
* @version 1.1.0
*/
data class DistanceAttribute(val distance: Float, val descriptorId: DescriptorId? = null) : MergingRetrievableAttribute {
override fun merge(other: MergingRetrievableAttribute): DistanceAttribute = DistanceAttribute(
min(this.distance, (other as? DistanceAttribute)?.distance ?: Float.POSITIVE_INFINITY)
)
sealed interface DistanceAttribute: RetrievableAttribute {
/** The distance value associated with this [DistanceAttribute]. */
val distance: Float

/**
* A global [DistanceAttribute].
*
* It is used to store a global distance value for a [Retrievable].
*/
data class Global(override val distance: Float): DistanceAttribute, MergingRetrievableAttribute {
override fun merge(other: MergingRetrievableAttribute) = Global(
min(this.distance, (other as? DistanceAttribute)?.distance ?: Float.POSITIVE_INFINITY)
)
}

/**
* A local [DistanceAttribute]. It is used to store a local distance value specific for a [Descriptor].
*/
data class Local(override val distance: Float, val descriptorId: DescriptorId): DistanceAttribute
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.vitrivr.engine.core.model.retrievable.attributes

import org.vitrivr.engine.core.model.descriptor.DescriptorId
import kotlin.math.max

/**
Expand All @@ -11,15 +12,15 @@ import kotlin.math.max
* @author Ralph Gasser
* @version 1.1.0
*/
sealed interface ScoreAttribute : MergingRetrievableAttribute {
sealed interface ScoreAttribute : RetrievableAttribute {

/** The score associated with this [ScoreAttribute]. */
val score: Float

/**
* A similarity score. Strictly bound between 0 and 1.
* A global similarity score. Strictly bound between 0 and 1.
*/
data class Similarity(override val score: Float): ScoreAttribute {
data class Similarity(override val score: Float, val descriptorId: DescriptorId? = null): ScoreAttribute, MergingRetrievableAttribute {
init {
require(score in 0f..1f) { "Similarity score '$score' outside of valid range (0, 1)" }
}
Expand All @@ -30,9 +31,9 @@ sealed interface ScoreAttribute : MergingRetrievableAttribute {
}

/**
* An unbound score. Unbounded and can be any value >= 0.
* A global unbound score. Unbounded and can be any value >= 0.
*/
data class Unbound(override val score: Float): ScoreAttribute {
data class Unbound(override val score: Float, val descriptorId: DescriptorId? = null): ScoreAttribute, MergingRetrievableAttribute {
init {
require(this.score >= 0f) { "Score '$score' outside of valid range (>= 0)." }
}
Expand All @@ -41,3 +42,4 @@ sealed interface ScoreAttribute : MergingRetrievableAttribute {
)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@ internal class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*,
/* Fetch descriptors */
val descriptors = this.connection.client.query(cottontailQuery).asSequence().map { tuple ->
val scoreIndex = tuple.indexOf(DISTANCE_COLUMN_NAME)
tupleToDescriptor(tuple) to if (scoreIndex > -1) {
tuple.asDouble(DISTANCE_COLUMN_NAME)?.let { DistanceAttribute(it.toFloat()) }
val descriptor = tupleToDescriptor(tuple)
descriptor to if (scoreIndex > -1) {
tuple.asDouble(DISTANCE_COLUMN_NAME)?.let { DistanceAttribute.Local(it.toFloat(), descriptor.retrievableId!!) }
} else {
null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class VectorJsonlReader(
return queue.map {
val retrieved = retrievables[it.first.retrievableId]!!
retrieved.addDescriptor(it.first)
retrieved.addAttribute(DistanceAttribute(it.second))
retrieved.addAttribute(DistanceAttribute.Local(it.second, it.first.retrievableId!!))
retrieved as Retrieved
}.asSequence()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.vitrivr.engine.core.model.descriptor.vector.VectorDescriptor.Companio
import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.query.Query
import org.vitrivr.engine.core.model.query.proximity.ProximityQuery
import org.vitrivr.engine.core.model.retrievable.RetrievableId
import org.vitrivr.engine.core.model.retrievable.Retrieved
import org.vitrivr.engine.core.model.retrievable.attributes.DistanceAttribute
import org.vitrivr.engine.database.pgvector.*
Expand All @@ -19,7 +20,7 @@ import java.util.*
* An abstract implementation of a [DescriptorReader] for Cottontail DB.
*
* @author Ralph Gasser
* @version 1.0.0
* @version 1.1.0
*/
class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, connection: PgVectorConnection) : AbstractDescriptorReader<VectorDescriptor<*, *>>(field, connection) {
/**
Expand Down Expand Up @@ -115,32 +116,37 @@ class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, con
* @return [Sequence] of [VectorDescriptor]s.
*/
private fun queryAndJoinProximity(query: ProximityQuery<*>): Sequence<Retrieved> {
val descriptors = mutableListOf<Pair<VectorDescriptor<*, *>, Float>>()
val statement =
"SELECT $DESCRIPTOR_ID_COLUMN_NAME, $RETRIEVABLE_ID_COLUMN_NAME, $VECTOR_ATTRIBUTE_NAME, $VECTOR_ATTRIBUTE_NAME ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME FROM \"${tableName.lowercase()}\" ORDER BY $VECTOR_ATTRIBUTE_NAME ${query.distance.toSql()} ? ${query.order} LIMIT ${query.k}"
val fetched = mutableMapOf<RetrievableId, MutableList<Pair<VectorDescriptor<*,*>,Float>>>()
val statement = "SELECT $DESCRIPTOR_ID_COLUMN_NAME, $RETRIEVABLE_ID_COLUMN_NAME, $VECTOR_ATTRIBUTE_NAME, $VECTOR_ATTRIBUTE_NAME ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME FROM \"${tableName.lowercase()}\" ORDER BY $VECTOR_ATTRIBUTE_NAME ${query.distance.toSql()} ? ${query.order} LIMIT ${query.k}"
this@VectorDescriptorReader.connection.jdbc.prepareStatement(statement).use { stmt ->
stmt.setValue(1, query.value)
stmt.setValue(2, query.value)
stmt.executeQuery().use { result ->
while (result.next()) {
descriptors.add(this@VectorDescriptorReader.rowToDescriptor(result) to result.getFloat(DISTANCE_COLUMN_NAME))
val d = this@VectorDescriptorReader.rowToDescriptor(result)
fetched.compute(d.retrievableId!!) { _, v ->
if (v == null) {
mutableListOf(d to result.getFloat(DISTANCE_COLUMN_NAME))
} else {
v.add(d to result.getFloat(DISTANCE_COLUMN_NAME))
v
}
}
}
}

/* Fetch retrievable ids. */
val retrievables = this.connection.getRetrievableReader().getAll(descriptors.mapNotNull { it.first.retrievableId }.toSet()).map { it.id to it }.toMap()
return descriptors.asSequence().mapNotNull { (descriptor, distance) ->
val retrievable = retrievables[descriptor.retrievableId]
if (retrievable != null) {
return this.connection.getRetrievableReader().getAll(fetched.keys).map { retrievable ->
val descriptors = fetched[retrievable.id] ?: emptyList()
for ((descriptor, distance) in descriptors) {
if (query.fetchVector) {
retrievable.addDescriptor(descriptor)
}
retrievable.addAttribute(DistanceAttribute(distance))
retrievable as Retrieved
} else {
null
retrievable.addAttribute(DistanceAttribute.Local(distance, descriptor.id))
}
retrievable as Retrieved
}

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import org.vitrivr.engine.core.model.retrievable.RetrievableId
import org.vitrivr.engine.core.model.retrievable.attributes.ScoreAttribute
import org.vitrivr.engine.core.operators.Operator
import org.vitrivr.engine.core.operators.general.Aggregator
import java.util.*
import kotlin.math.pow

class WeightedScoreFusion(
Expand Down

0 comments on commit d00cbf4

Please sign in to comment.