Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SchemaObject changes to support multiple vector configs #1499

Merged
merged 20 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
42d0fb5
Made the SchemaObject store and return List of VectorConfig. This is …
maheshrajamani Oct 5, 2024
c016640
Changes for TableSchemaObject to deserialize vector and vectorize inf…
maheshrajamani Oct 7, 2024
d590679
Changed the vector config to unmodifiable list
maheshrajamani Oct 7, 2024
aa656ef
Fixed the comments
maheshrajamani Oct 7, 2024
02bc38e
Fixed the comments
maheshrajamani Oct 7, 2024
24da975
Use $vectorize field name from constant `DocumentConstants.Fields.VEC…
maheshrajamani Oct 7, 2024
ecabca3
Use $vectorize field name from constant `DocumentConstants.Fields.VEC…
maheshrajamani Oct 7, 2024
bc30acb
Use $vectorize field name from constant `DocumentConstants.Fields.VEC…
maheshrajamani Oct 7, 2024
550ce05
Updated the code based on review.
maheshrajamani Oct 7, 2024
0a24bd0
Merge branch 'main' into schema-outject-multiple-vectorize
maheshrajamani Oct 7, 2024
5220f57
Merge branch 'main' of github.com:stargate/data-api into schema-outje…
maheshrajamani Oct 7, 2024
2683c66
Fixed the vectorize config deserializer
maheshrajamani Oct 7, 2024
e7e92cb
Changes based on review
maheshrajamani Oct 7, 2024
78ba059
Merge branch 'schema-outject-multiple-vectorize' of github.com:starga…
maheshrajamani Oct 7, 2024
366a503
Merge branch 'main' into schema-outject-multiple-vectorize
maheshrajamani Oct 7, 2024
27f3bef
Resolve the merge compile error
maheshrajamani Oct 7, 2024
11b589e
Resolve the merge compile error
maheshrajamani Oct 7, 2024
6ad43dc
Resolve the merge compile error
maheshrajamani Oct 7, 2024
dd2f2cb
Formatted the file
maheshrajamani Oct 7, 2024
5e45661
Fix for IT
maheshrajamani Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,15 @@ public Uni<RestResponse<CommandResult>> postCommand(
// TODO: refactor this code to be cleaner so it assigns on one line
EmbeddingProvider embeddingProvider = null;
final VectorConfig.VectorizeConfig vectorizeConfig =
schemaObject.vectorConfig().vectorizeConfig();
schemaObject.vectorConfigs().get(0).vectorizeConfig();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is changed to get(0) because collections have one config. Will need to change this when we work on vectorize support for tables.

if (vectorizeConfig != null) {
embeddingProvider =
embeddingProviderFactory.getConfiguration(
dataApiRequestInfo.getTenantId(),
dataApiRequestInfo.getCassandraToken(),
vectorizeConfig.provider(),
vectorizeConfig.modelName(),
schemaObject.vectorConfig().vectorSize(),
schemaObject.vectorConfigs().get(0).vectorSize(),
vectorizeConfig.parameters(),
vectorizeConfig.authentication(),
command.getClass().getSimpleName());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.stargate.sgv2.jsonapi.service.cqldriver.executor;

import java.util.List;

public class DatabaseSchemaObject extends SchemaObject {

public static final SchemaObjectType TYPE = SchemaObjectType.DATABASE;
Expand All @@ -9,8 +11,8 @@ public DatabaseSchemaObject() {
}

@Override
public VectorConfig vectorConfig() {
return VectorConfig.notEnabledVectorConfig();
public List<VectorConfig> vectorConfigs() {
return List.of(VectorConfig.notEnabledVectorConfig());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.stargate.sgv2.jsonapi.service.cqldriver.executor;

import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject;
import java.util.List;

public class KeyspaceSchemaObject extends SchemaObject {

Expand Down Expand Up @@ -39,8 +40,8 @@ public static KeyspaceSchemaObject fromSchemaObject(TableSchemaObject table) {
}

@Override
public VectorConfig vectorConfig() {
return VectorConfig.notEnabledVectorConfig();
public List<VectorConfig> vectorConfigs() {
return List.of(VectorConfig.notEnabledVectorConfig());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ private Uni<SchemaObject> loadSchemaObject(
}

// 04-Sep-2024, tatu: Used to check that API Tables enabled; no longer checked here
maheshrajamani marked this conversation as resolved.
Show resolved Hide resolved
return new TableSchemaObject(table);
return TableSchemaObject.getTableSettings(table, objectMapper);
});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.stargate.sgv2.jsonapi.service.cqldriver.executor;

import java.util.List;

/** A Collection or Table the command works on */
public abstract class SchemaObject {

Expand Down Expand Up @@ -29,12 +31,13 @@ public SchemaObjectName name() {
}

/**
* Subclasses must always return an instance of VectorConfig, if there is no vector config they
* should return VectorConfig.notEnabledVectorConfig()
* Subclasses must always return List of VectorConfig, if there is no vector config they should
* return VectorConfig.notEnabledVectorConfig(). This needs to be a list to support tables with
* multiple vector columns.
*
* @return
*/
public abstract VectorConfig vectorConfig();
public abstract List<VectorConfig> vectorConfigs();

/**
* Call to get an instance of the appropriate {@link IndexUsage} for this schema object
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,98 @@
package io.stargate.sgv2.jsonapi.service.cqldriver.executor;

import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.data.ByteUtils;
import com.datastax.oss.driver.api.core.metadata.schema.ColumnMetadata;
import com.datastax.oss.driver.api.core.metadata.schema.IndexMetadata;
import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata;
import com.datastax.oss.driver.api.core.type.VectorType;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class TableSchemaObject extends TableBasedSchemaObject {

public static final SchemaObjectType TYPE = SchemaObjectType.TABLE;

public TableSchemaObject(TableMetadata tableMetadata) {
private final List<VectorConfig> vectorConfigs;

private TableSchemaObject(TableMetadata tableMetadata, List<VectorConfig> vectorConfigs) {
super(TYPE, tableMetadata);
this.vectorConfigs = vectorConfigs;
}

@Override
public VectorConfig vectorConfig() {
return VectorConfig.notEnabledVectorConfig();
public List<VectorConfig> vectorConfigs() {
return vectorConfigs;
}

@Override
public IndexUsage newIndexUsage() {
return IndexUsage.NO_OP;
}

/**
* Get table schema object from table metadata
*
* @param tableMetadata
* @param objectMapper
* @return
*/
public static TableSchemaObject getTableSettings(
maheshrajamani marked this conversation as resolved.
Show resolved Hide resolved
TableMetadata tableMetadata, ObjectMapper objectMapper) {
Map<String, String> extensions =
(Map<String, String>)
tableMetadata.getOptions().get(CqlIdentifier.fromInternal("extensions"));
String vectorize = extensions != null ? extensions.get("vectorize") : null;
Map<String, VectorConfig.VectorizeConfig> resultMap = new HashMap<>();
if (vectorize != null) {
String vectorizeJson = new String(ByteUtils.fromHexString(vectorize).array());
maheshrajamani marked this conversation as resolved.
Show resolved Hide resolved
maheshrajamani marked this conversation as resolved.
Show resolved Hide resolved
// Define the TypeReference for Map<String, VectorConfig.VectorizeConfig>
TypeReference<Map<String, VectorConfig.VectorizeConfig>> typeRef =
new TypeReference<Map<String, VectorConfig.VectorizeConfig>>() {};

// Convert JSON string to Map
try {
resultMap = objectMapper.readValue(vectorizeJson, typeRef);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
maheshrajamani marked this conversation as resolved.
Show resolved Hide resolved
}
}

List<VectorConfig> vectorConfigs = new ArrayList<>();
for (Map.Entry<CqlIdentifier, ColumnMetadata> column : tableMetadata.getColumns().entrySet()) {
if (column.getValue().getType() instanceof VectorType vectorType) {
final Optional<IndexMetadata> index = tableMetadata.getIndex(column.getKey());
SimilarityFunction similarityFunction = SimilarityFunction.COSINE;
if (index.isPresent()) {
final IndexMetadata indexMetadata = index.get();
final Map<String, String> indexOptions = indexMetadata.getOptions();
final String similarityFunctionValue = indexOptions.get("similarity_function");
if (similarityFunctionValue != null) {
similarityFunction = SimilarityFunction.fromString(similarityFunctionValue);
tatu-at-datastax marked this conversation as resolved.
Show resolved Hide resolved
}
}
int dimension = vectorType.getDimensions();
VectorConfig vectorConfig =
new VectorConfig(
true,
column.getKey().asInternal(),
dimension,
similarityFunction,
resultMap.get(column.getKey().asInternal()));
vectorConfigs.add(vectorConfig);
maheshrajamani marked this conversation as resolved.
Show resolved Hide resolved
}
}
if (vectorConfigs.isEmpty()) {
vectorConfigs.add(VectorConfig.notEnabledVectorConfig());
}
return new TableSchemaObject(tableMetadata, Collections.unmodifiableList(vectorConfigs));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import java.util.Map;

Expand All @@ -15,27 +16,56 @@
*/
public record VectorConfig(
boolean vectorEnabled,
String fieldName,
maheshrajamani marked this conversation as resolved.
Show resolved Hide resolved
int vectorSize,
SimilarityFunction similarityFunction,
VectorizeConfig vectorizeConfig) {

// TODO: this is an immutable record, this can be singleton
// TODO: Remove the use of NULL for the objects like vectorizeConfig
public static VectorConfig notEnabledVectorConfig() {
return new VectorConfig(false, -1, null, null);
return new VectorConfig(false, null, -1, null, null);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the semantics of null vectorField? I assume we need a placeholder of some kind... ?
Put another way, how are we using this placeholder (how is it checked and so on).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gave null if the there is no vector field set for the table or collection

}

// convert a vector jsonNode from table comment to vectorConfig
// convert a vector jsonNode from table comment to vectorConfig, used for collection
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... not used for tables? So only collection? (if so, comment should state that)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it means "used when reading config for Collections (not Tables)"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The table there denotes cql table. fwiw I will just say comment to avoid confusion

public static VectorConfig fromJson(JsonNode jsonNode, ObjectMapper objectMapper) {
// dimension, similarityFunction, must exist
int dimension = jsonNode.get("dimension").asInt();
SimilarityFunction similarityFunction =
SimilarityFunction.fromString(jsonNode.get("metric").asText());

return fromJson(
DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD,
dimension,
similarityFunction,
jsonNode,
objectMapper);
}

// convert a vector jsonNode from table extension to vectorConfig, used for tables
public static VectorConfig fromJson(
String fieldName,
int dimension,
SimilarityFunction similarityFunction,
JsonNode jsonNode,
ObjectMapper objectMapper) {
VectorizeConfig vectorizeConfig = null;
// construct vectorizeConfig
JsonNode vectorizeServiceNode = jsonNode.get("service");
if (vectorizeServiceNode != null) {
vectorizeConfig = VectorizeConfig.fromJson(vectorizeServiceNode, objectMapper);
}
return new VectorConfig(true, fieldName, dimension, similarityFunction, vectorizeConfig);
}

public record VectorizeConfig(
maheshrajamani marked this conversation as resolved.
Show resolved Hide resolved
String provider,
String modelName,
Map<String, String> authentication,
Map<String, Object> parameters) {

protected static VectorizeConfig fromJson(
JsonNode vectorizeServiceNode, ObjectMapper objectMapper) {
// provider, modelName, must exist
String provider = vectorizeServiceNode.get("provider").asText();
String modelName = vectorizeServiceNode.get("modelName").asText();
Expand All @@ -51,17 +81,8 @@ public static VectorConfig fromJson(JsonNode jsonNode, ObjectMapper objectMapper
vectorizeServiceParameterNode == null
? null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could use Map.of() unless we want to use null as marker of some kind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't understand this recommendation.

: objectMapper.convertValue(vectorizeServiceParameterNode, Map.class);
vectorizeConfig =
new VectorizeConfig(
provider, modelName, vectorizeServiceAuthentication, vectorizeServiceParameter);
return new VectorizeConfig(
provider, modelName, vectorizeServiceAuthentication, vectorizeServiceParameter);
}

return new VectorConfig(true, dimension, similarityFunction, vectorizeConfig);
}

public record VectorizeConfig(
String provider,
String modelName,
Map<String, String> authentication,
Map<String, Object> parameters) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1;
import io.stargate.sgv2.jsonapi.exception.JsonApiException;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig;
import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider;
import java.util.ArrayList;
import java.util.HashMap;
Expand All @@ -30,6 +31,7 @@ public class DataVectorizer {
private final JsonNodeFactory nodeFactory;
private final EmbeddingCredentials embeddingCredentials;
private final SchemaObject schemaObject;
private final VectorConfig vectorConfig;

/**
* Constructor
Expand All @@ -49,6 +51,9 @@ public DataVectorizer(
this.nodeFactory = nodeFactory;
this.embeddingCredentials = embeddingCredentials;
this.schemaObject = schemaObject;
// This is getting element at 0 since only one vector is stored in the schema.
// This logic needs to be changed to handle multiple vectors columns for tables,
vectorConfig = schemaObject.vectorConfigs().get(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should class Javadoc probably mention if it is only being used for Collections?
I assume it is because only single vector is defined as per comment here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class has all field names hard coded for collections, this needs to be adapted for api tables.

}

/**
Expand Down Expand Up @@ -114,7 +119,7 @@ public Uni<Boolean> vectorize(List<JsonNode> documents) {
if (vectorData.size() != vectorizeTexts.size()) {
throw EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException(
"Embedding provider '%s' didn't return the expected number of embeddings. Expect: '%d'. Actual: '%d'",
schemaObject.vectorConfig().vectorizeConfig().provider(),
vectorConfig.vectorizeConfig().provider(),
vectorizeTexts.size(),
vectorData.size());
}
Expand All @@ -125,11 +130,11 @@ public Uni<Boolean> vectorize(List<JsonNode> documents) {
JsonNode document = documents.get(position);
float[] vector = vectorData.get(vectorPosition);
// check if all vectors have the expected size
if (vector.length != schemaObject.vectorConfig().vectorSize()) {
if (vector.length != vectorConfig.vectorSize()) {
throw EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException(
"Embedding provider '%s' did not return expected embedding length. Expect: '%d'. Actual: '%d'",
schemaObject.vectorConfig().vectorizeConfig().provider(),
schemaObject.vectorConfig().vectorSize(),
vectorConfig.vectorizeConfig().provider(),
vectorConfig.vectorSize(),
vector.length);
}
final ArrayNode arrayNode = nodeFactory.arrayNode(vector.length);
Expand Down Expand Up @@ -170,11 +175,11 @@ public Uni<float[]> vectorize(String vectorizeContent) {
vectorData -> {
float[] vector = vectorData.get(0);
// check if vector have the expected size
if (vector.length != schemaObject.vectorConfig().vectorSize()) {
if (vector.length != vectorConfig.vectorSize()) {
throw EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException(
"Embedding provider '%s' did not return expected embedding length. Expect: '%d'. Actual: '%d'",
schemaObject.vectorConfig().vectorizeConfig().provider(),
schemaObject.vectorConfig().vectorSize(),
vectorConfig.vectorizeConfig().provider(),
vectorConfig.vectorSize(),
vector.length);
}
return vector;
Expand Down Expand Up @@ -212,11 +217,11 @@ public Uni<Boolean> vectorize(SortClause sortClause) {
vectorData -> {
float[] vector = vectorData.get(0);
// check if vector have the expected size
if (vector.length != schemaObject.vectorConfig().vectorSize()) {
if (vector.length != vectorConfig.vectorSize()) {
throw EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException(
"Embedding provider '%s' did not return expected embedding length. Expect: '%d'. Actual: '%d'",
schemaObject.vectorConfig().vectorizeConfig().provider(),
schemaObject.vectorConfig().vectorSize(),
vectorConfig.vectorizeConfig().provider(),
vectorConfig.vectorSize(),
vector.length);
}
sortExpressions.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import io.stargate.sgv2.jsonapi.config.CommandLevelLoggingConfig;
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Produces;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -243,8 +245,12 @@ private <T extends SchemaObject> Tags getCustomTags(
result.errors().get(0).fieldsForMetricsTag().getOrDefault("errorCode", UNKNOWN_VALUE);
errorCodeTag = Tag.of(jsonApiMetricsConfig.errorCode(), errorCode);
}
final Optional<VectorConfig> first =
commandContext.schemaObject().vectorConfigs().stream()
.filter(a -> a.vectorEnabled())
.findFirst();
Tag vectorEnabled =
commandContext.schemaObject().vectorConfig().vectorEnabled()
first.isPresent()
? Tag.of(jsonApiMetricsConfig.vectorEnabled(), "true")
: Tag.of(jsonApiMetricsConfig.vectorEnabled(), "false");
JsonApiMetricsConfig.SortType sortType = getVectorTypeTag(command);
Expand Down
Loading