Skip to content

Commit

Permalink
Found a potentially more efficient way to maintain order on Retrievab…
Browse files Browse the repository at this point in the history
…leReader.getAll
  • Loading branch information
ppanopticon committed Feb 6, 2025
1 parent bed5f6a commit 773b5c5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,15 @@ class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, con
}

/* Fetch retrievable ids. */
val retrievables = mutableMapOf<RetrievableId,Retrieved>()
this.connection.getRetrievableReader().getAll(descriptors.keys).forEach { retrievable ->
return this.connection.getRetrievableReader().getAll(descriptors.keys).map { retrievable ->
for ((descriptor, distance) in descriptors[retrievable.id] ?: emptyList()) {
if (query.fetchVector) {
retrievable.addDescriptor(descriptor)
}
retrievable.addAttribute(DistanceAttribute.Local(distance, descriptor.id))
}
retrievables[retrievable.id] = retrievable as Retrieved
retrievable as Retrieved
}

/* Returns new sequence of retrieved objects. */
return descriptors.keys.asSequence().mapNotNull { retrievables[it] }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class RetrievableReader(override val connection: PgVectorConnection): Retrievabl
override fun getAll(ids: Iterable<RetrievableId>): Sequence<Retrievable> = sequence {
try {
val values = ids.map { it }.toTypedArray()
this@RetrievableReader.connection.jdbc.prepareStatement("SELECT * FROM $RETRIEVABLE_ENTITY_NAME WHERE $RETRIEVABLE_ID_COLUMN_NAME = ANY (?)").use { statement ->
this@RetrievableReader.connection.jdbc.prepareStatement("WITH x(ids) AS VALUES(?) SELECT ${RETRIEVABLE_ENTITY_NAME}.* FROM $RETRIEVABLE_ENTITY_NAME, x WHERE $RETRIEVABLE_ID_COLUMN_NAME = ANY (x.ids) ORDER BY array_position(x.ids, ${RETRIEVABLE_ID_COLUMN_NAME})").use { statement ->
statement.setArray(1, this@RetrievableReader.connection.jdbc.createArrayOf("uuid", values))
statement.executeQuery().use { result ->
while (result.next()) {
Expand All @@ -75,7 +75,7 @@ class RetrievableReader(override val connection: PgVectorConnection): Retrievabl
}
}
} catch (e: Exception) {
LOGGER.error(e) { "Failed to check for retrievables due to SQL error." }
LOGGER.error(e) { "Failed to fetch retrievables due to SQL error." }
}
}

Expand Down

0 comments on commit 773b5c5

Please sign in to comment.