-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SchemaObject changes to support multiple vector configs #1499
Changes from 8 commits
42d0fb5
c016640
d590679
aa656ef
02bc38e
24da975
ecabca3
bc30acb
550ce05
0a24bd0
5220f57
2683c66
e7e92cb
78ba059
366a503
27f3bef
11b589e
6ad43dc
dd2f2cb
5e45661
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,98 @@ | ||
package io.stargate.sgv2.jsonapi.service.cqldriver.executor; | ||
|
||
import com.datastax.oss.driver.api.core.CqlIdentifier; | ||
import com.datastax.oss.driver.api.core.data.ByteUtils; | ||
import com.datastax.oss.driver.api.core.metadata.schema.ColumnMetadata; | ||
import com.datastax.oss.driver.api.core.metadata.schema.IndexMetadata; | ||
import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; | ||
import com.datastax.oss.driver.api.core.type.VectorType; | ||
import com.fasterxml.jackson.core.JsonProcessingException; | ||
import com.fasterxml.jackson.core.type.TypeReference; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; | ||
import java.util.ArrayList; | ||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
|
||
public class TableSchemaObject extends TableBasedSchemaObject { | ||
|
||
public static final SchemaObjectType TYPE = SchemaObjectType.TABLE; | ||
|
||
public TableSchemaObject(TableMetadata tableMetadata) { | ||
private final List<VectorConfig> vectorConfigs; | ||
|
||
private TableSchemaObject(TableMetadata tableMetadata, List<VectorConfig> vectorConfigs) { | ||
super(TYPE, tableMetadata); | ||
this.vectorConfigs = vectorConfigs; | ||
} | ||
|
||
@Override | ||
public VectorConfig vectorConfig() { | ||
return VectorConfig.notEnabledVectorConfig(); | ||
public List<VectorConfig> vectorConfigs() { | ||
return vectorConfigs; | ||
} | ||
|
||
@Override | ||
public IndexUsage newIndexUsage() { | ||
return IndexUsage.NO_OP; | ||
} | ||
|
||
/** | ||
* Get table schema object from table metadata | ||
* | ||
* @param tableMetadata | ||
* @param objectMapper | ||
* @return | ||
*/ | ||
public static TableSchemaObject getTableSettings( | ||
maheshrajamani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
TableMetadata tableMetadata, ObjectMapper objectMapper) { | ||
Map<String, String> extensions = | ||
(Map<String, String>) | ||
tableMetadata.getOptions().get(CqlIdentifier.fromInternal("extensions")); | ||
String vectorize = extensions != null ? extensions.get("vectorize") : null; | ||
Map<String, VectorConfig.VectorizeConfig> resultMap = new HashMap<>(); | ||
if (vectorize != null) { | ||
String vectorizeJson = new String(ByteUtils.fromHexString(vectorize).array()); | ||
maheshrajamani marked this conversation as resolved.
Show resolved
Hide resolved
maheshrajamani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Define the TypeReference for Map<String, VectorConfig.VectorizeConfig> | ||
TypeReference<Map<String, VectorConfig.VectorizeConfig>> typeRef = | ||
new TypeReference<Map<String, VectorConfig.VectorizeConfig>>() {}; | ||
|
||
// Convert JSON string to Map | ||
try { | ||
resultMap = objectMapper.readValue(vectorizeJson, typeRef); | ||
} catch (JsonProcessingException e) { | ||
throw new RuntimeException(e); | ||
maheshrajamani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
|
||
List<VectorConfig> vectorConfigs = new ArrayList<>(); | ||
for (Map.Entry<CqlIdentifier, ColumnMetadata> column : tableMetadata.getColumns().entrySet()) { | ||
if (column.getValue().getType() instanceof VectorType vectorType) { | ||
final Optional<IndexMetadata> index = tableMetadata.getIndex(column.getKey()); | ||
SimilarityFunction similarityFunction = SimilarityFunction.COSINE; | ||
if (index.isPresent()) { | ||
final IndexMetadata indexMetadata = index.get(); | ||
final Map<String, String> indexOptions = indexMetadata.getOptions(); | ||
final String similarityFunctionValue = indexOptions.get("similarity_function"); | ||
if (similarityFunctionValue != null) { | ||
similarityFunction = SimilarityFunction.fromString(similarityFunctionValue); | ||
tatu-at-datastax marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
int dimension = vectorType.getDimensions(); | ||
VectorConfig vectorConfig = | ||
new VectorConfig( | ||
true, | ||
column.getKey().asInternal(), | ||
dimension, | ||
similarityFunction, | ||
resultMap.get(column.getKey().asInternal())); | ||
vectorConfigs.add(vectorConfig); | ||
maheshrajamani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
if (vectorConfigs.isEmpty()) { | ||
vectorConfigs.add(VectorConfig.notEnabledVectorConfig()); | ||
} | ||
return new TableSchemaObject(tableMetadata, Collections.unmodifiableList(vectorConfigs)); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
|
||
import com.fasterxml.jackson.databind.JsonNode; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; | ||
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; | ||
import java.util.Map; | ||
|
||
|
@@ -15,27 +16,56 @@ | |
*/ | ||
public record VectorConfig( | ||
boolean vectorEnabled, | ||
String fieldName, | ||
maheshrajamani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
int vectorSize, | ||
SimilarityFunction similarityFunction, | ||
VectorizeConfig vectorizeConfig) { | ||
|
||
// TODO: this is an immutable record, this can be singleton | ||
// TODO: Remove the use of NULL for the objects like vectorizeConfig | ||
public static VectorConfig notEnabledVectorConfig() { | ||
return new VectorConfig(false, -1, null, null); | ||
return new VectorConfig(false, null, -1, null, null); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What are the semantics of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gave null if the there is no vector field set for the table or collection |
||
} | ||
|
||
// convert a vector jsonNode from table comment to vectorConfig | ||
// convert a vector jsonNode from table comment to vectorConfig, used for collection | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... not used for tables? So only collection? (if so, comment should state that) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it means "used when reading config for Collections (not Tables)" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The table there denotes cql table. fwiw I will just say comment to avoid confusion |
||
public static VectorConfig fromJson(JsonNode jsonNode, ObjectMapper objectMapper) { | ||
// dimension, similarityFunction, must exist | ||
int dimension = jsonNode.get("dimension").asInt(); | ||
SimilarityFunction similarityFunction = | ||
SimilarityFunction.fromString(jsonNode.get("metric").asText()); | ||
|
||
return fromJson( | ||
DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, | ||
dimension, | ||
similarityFunction, | ||
jsonNode, | ||
objectMapper); | ||
} | ||
|
||
// convert a vector jsonNode from table extension to vectorConfig, used for tables | ||
public static VectorConfig fromJson( | ||
String fieldName, | ||
int dimension, | ||
SimilarityFunction similarityFunction, | ||
JsonNode jsonNode, | ||
ObjectMapper objectMapper) { | ||
VectorizeConfig vectorizeConfig = null; | ||
// construct vectorizeConfig | ||
JsonNode vectorizeServiceNode = jsonNode.get("service"); | ||
if (vectorizeServiceNode != null) { | ||
vectorizeConfig = VectorizeConfig.fromJson(vectorizeServiceNode, objectMapper); | ||
} | ||
return new VectorConfig(true, fieldName, dimension, similarityFunction, vectorizeConfig); | ||
} | ||
|
||
public record VectorizeConfig( | ||
maheshrajamani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
String provider, | ||
String modelName, | ||
Map<String, String> authentication, | ||
Map<String, Object> parameters) { | ||
|
||
protected static VectorizeConfig fromJson( | ||
JsonNode vectorizeServiceNode, ObjectMapper objectMapper) { | ||
// provider, modelName, must exist | ||
String provider = vectorizeServiceNode.get("provider").asText(); | ||
String modelName = vectorizeServiceNode.get("modelName").asText(); | ||
|
@@ -51,17 +81,8 @@ public static VectorConfig fromJson(JsonNode jsonNode, ObjectMapper objectMapper | |
vectorizeServiceParameterNode == null | ||
? null | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didn't understand this recommendation. |
||
: objectMapper.convertValue(vectorizeServiceParameterNode, Map.class); | ||
vectorizeConfig = | ||
new VectorizeConfig( | ||
provider, modelName, vectorizeServiceAuthentication, vectorizeServiceParameter); | ||
return new VectorizeConfig( | ||
provider, modelName, vectorizeServiceAuthentication, vectorizeServiceParameter); | ||
} | ||
|
||
return new VectorConfig(true, dimension, similarityFunction, vectorizeConfig); | ||
} | ||
|
||
public record VectorizeConfig( | ||
String provider, | ||
String modelName, | ||
Map<String, String> authentication, | ||
Map<String, Object> parameters) {} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; | ||
import io.stargate.sgv2.jsonapi.exception.JsonApiException; | ||
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject; | ||
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; | ||
import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; | ||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
|
@@ -30,6 +31,7 @@ public class DataVectorizer { | |
private final JsonNodeFactory nodeFactory; | ||
private final EmbeddingCredentials embeddingCredentials; | ||
private final SchemaObject schemaObject; | ||
private final VectorConfig vectorConfig; | ||
|
||
/** | ||
* Constructor | ||
|
@@ -49,6 +51,9 @@ public DataVectorizer( | |
this.nodeFactory = nodeFactory; | ||
this.embeddingCredentials = embeddingCredentials; | ||
this.schemaObject = schemaObject; | ||
// This is getting element at 0 since only one vector is stored in the schema. | ||
// This logic needs to be changed to handle multiple vectors columns for tables, | ||
vectorConfig = schemaObject.vectorConfigs().get(0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should class Javadoc probably mention if it is only being used for Collections? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This class has all field names hard coded for collections, this needs to be adapted for api tables. |
||
} | ||
|
||
/** | ||
|
@@ -114,7 +119,7 @@ public Uni<Boolean> vectorize(List<JsonNode> documents) { | |
if (vectorData.size() != vectorizeTexts.size()) { | ||
throw EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( | ||
"Embedding provider '%s' didn't return the expected number of embeddings. Expect: '%d'. Actual: '%d'", | ||
schemaObject.vectorConfig().vectorizeConfig().provider(), | ||
vectorConfig.vectorizeConfig().provider(), | ||
vectorizeTexts.size(), | ||
vectorData.size()); | ||
} | ||
|
@@ -125,11 +130,11 @@ public Uni<Boolean> vectorize(List<JsonNode> documents) { | |
JsonNode document = documents.get(position); | ||
float[] vector = vectorData.get(vectorPosition); | ||
// check if all vectors have the expected size | ||
if (vector.length != schemaObject.vectorConfig().vectorSize()) { | ||
if (vector.length != vectorConfig.vectorSize()) { | ||
throw EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( | ||
"Embedding provider '%s' did not return expected embedding length. Expect: '%d'. Actual: '%d'", | ||
schemaObject.vectorConfig().vectorizeConfig().provider(), | ||
schemaObject.vectorConfig().vectorSize(), | ||
vectorConfig.vectorizeConfig().provider(), | ||
vectorConfig.vectorSize(), | ||
vector.length); | ||
} | ||
final ArrayNode arrayNode = nodeFactory.arrayNode(vector.length); | ||
|
@@ -170,11 +175,11 @@ public Uni<float[]> vectorize(String vectorizeContent) { | |
vectorData -> { | ||
float[] vector = vectorData.get(0); | ||
// check if vector have the expected size | ||
if (vector.length != schemaObject.vectorConfig().vectorSize()) { | ||
if (vector.length != vectorConfig.vectorSize()) { | ||
throw EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( | ||
"Embedding provider '%s' did not return expected embedding length. Expect: '%d'. Actual: '%d'", | ||
schemaObject.vectorConfig().vectorizeConfig().provider(), | ||
schemaObject.vectorConfig().vectorSize(), | ||
vectorConfig.vectorizeConfig().provider(), | ||
vectorConfig.vectorSize(), | ||
vector.length); | ||
} | ||
return vector; | ||
|
@@ -212,11 +217,11 @@ public Uni<Boolean> vectorize(SortClause sortClause) { | |
vectorData -> { | ||
float[] vector = vectorData.get(0); | ||
// check if vector have the expected size | ||
if (vector.length != schemaObject.vectorConfig().vectorSize()) { | ||
if (vector.length != vectorConfig.vectorSize()) { | ||
throw EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( | ||
"Embedding provider '%s' did not return expected embedding length. Expect: '%d'. Actual: '%d'", | ||
schemaObject.vectorConfig().vectorizeConfig().provider(), | ||
schemaObject.vectorConfig().vectorSize(), | ||
vectorConfig.vectorizeConfig().provider(), | ||
vectorConfig.vectorSize(), | ||
vector.length); | ||
} | ||
sortExpressions.clear(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is changed to get(0) because collections have one config. Will need to change this when we work on vectorize support for tables.