diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java index 95000bb42..ec46b6aa2 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java @@ -197,8 +197,16 @@ public Uni> postCommand( } // TODO: refactor this code to be cleaner so it assigns on one line EmbeddingProvider embeddingProvider = null; - final VectorConfig.VectorizeConfig vectorizeConfig = - schemaObject.vectorConfig().vectorizeConfig(); + VectorConfig vectorConfig = schemaObject.vectorConfig(); + final VectorConfig.ColumnVectorDefinition columnVectorDefinition = + vectorConfig.columnVectorDefinitions() == null + || vectorConfig.columnVectorDefinitions().isEmpty() + ? null + : vectorConfig.columnVectorDefinitions().get(0); + final VectorConfig.ColumnVectorDefinition.VectorizeConfig vectorizeConfig = + columnVectorDefinition != null + ? columnVectorDefinition.vectorizeConfig() + : null; if (vectorizeConfig != null) { embeddingProvider = embeddingProviderFactory.getConfiguration( @@ -206,7 +214,7 @@ public Uni> postCommand( dataApiRequestInfo.getCassandraToken(), vectorizeConfig.provider(), vectorizeConfig.modelName(), - schemaObject.vectorConfig().vectorSize(), + columnVectorDefinition.vectorSize(), vectorizeConfig.parameters(), vectorizeConfig.authentication(), command.getClass().getSimpleName()); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/exception/SchemaException.java b/src/main/java/io/stargate/sgv2/jsonapi/exception/SchemaException.java index b2a5355c3..ff9ec460f 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/exception/SchemaException.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/exception/SchemaException.java @@ -17,6 +17,8 @@ public enum Code implements ErrorCode { COLUMN_DEFINITION_MISSING, COLUMN_TYPE_INCORRECT, COLUMN_TYPE_UNSUPPORTED, + INVALID_CONFIGURATION, + INVALID_VECTORIZE_CONFIGURATION, LIST_TYPE_INCORRECT_DEFINITION, MAP_TYPE_INCORRECT_DEFINITION, MISSING_PRIMARY_KEYS, diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/NamespaceCache.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/NamespaceCache.java index 3e0439575..f106b87ce 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/NamespaceCache.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/NamespaceCache.java @@ -103,8 +103,7 @@ private Uni loadSchemaObject( optionalTable.get(), objectMapper); } - // 04-Sep-2024, tatu: Used to check that API Tables enabled; no longer checked here - return new TableSchemaObject(table); + return TableSchemaObject.from(table, objectMapper); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/SchemaObject.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/SchemaObject.java index bd147a509..239a67d73 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/SchemaObject.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/SchemaObject.java @@ -29,8 +29,8 @@ public SchemaObjectName name() { } /** - * Subclasses must always return an instance of VectorConfig, if there is no vector config they - * should return VectorConfig.notEnabledVectorConfig() + * Subclasses must always return VectorConfig, if there is no vector config they should return + * VectorConfig.notEnabledVectorConfig(). * * @return */ diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/TableSchemaObject.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/TableSchemaObject.java index 3e4ef7d62..38c07e5f5 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/TableSchemaObject.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/TableSchemaObject.java @@ -1,22 +1,118 @@ package io.stargate.sgv2.jsonapi.service.cqldriver.executor; +import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.data.ByteUtils; +import com.datastax.oss.driver.api.core.metadata.schema.ColumnMetadata; +import com.datastax.oss.driver.api.core.metadata.schema.IndexMetadata; import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; +import com.datastax.oss.driver.api.core.type.VectorType; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.stargate.sgv2.jsonapi.exception.SchemaException; +import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; public class TableSchemaObject extends TableBasedSchemaObject { public static final SchemaObjectType TYPE = SchemaObjectType.TABLE; - public TableSchemaObject(TableMetadata tableMetadata) { + private final VectorConfig vectorConfig; + + private TableSchemaObject(TableMetadata tableMetadata, VectorConfig vectorConfig) { super(TYPE, tableMetadata); + this.vectorConfig = vectorConfig; } @Override public VectorConfig vectorConfig() { - return VectorConfig.notEnabledVectorConfig(); + return vectorConfig; } @Override public IndexUsage newIndexUsage() { return IndexUsage.NO_OP; } + + /** + * Get table schema object from table metadata + * + * @param tableMetadata + * @param objectMapper + * @return + */ + public static TableSchemaObject from(TableMetadata tableMetadata, ObjectMapper objectMapper) { + Map extensions = + (Map) + tableMetadata.getOptions().get(CqlIdentifier.fromInternal("extensions")); + String vectorizeJson = null; + if (extensions != null) { + ByteBuffer vectorizeBuffer = + (ByteBuffer) extensions.get("com.datastax.data-api.vectorize-config"); + vectorizeJson = + vectorizeBuffer != null + ? new String(ByteUtils.getArray(vectorizeBuffer.duplicate()), StandardCharsets.UTF_8) + : null; + } + Map vectorizeConfigMap = + new HashMap<>(); + if (vectorizeJson != null) { + try { + JsonNode vectorizeByColumns = objectMapper.readTree(vectorizeJson); + Iterator> it = vectorizeByColumns.fields(); + while (it.hasNext()) { + Map.Entry entry = it.next(); + try { + VectorConfig.ColumnVectorDefinition.VectorizeConfig vectorizeConfig = + objectMapper.treeToValue( + entry.getValue(), VectorConfig.ColumnVectorDefinition.VectorizeConfig.class); + vectorizeConfigMap.put(entry.getKey(), vectorizeConfig); + } catch (JsonProcessingException | IllegalArgumentException e) { + throw SchemaException.Code.INVALID_VECTORIZE_CONFIGURATION.get( + Map.of("field", entry.getKey())); + } + } + } catch (JsonProcessingException e) { + throw SchemaException.Code.INVALID_CONFIGURATION.get(); + } + } + VectorConfig vectorConfig; + List columnVectorDefinitions = new ArrayList<>(); + for (Map.Entry column : tableMetadata.getColumns().entrySet()) { + if (column.getValue().getType() instanceof VectorType vectorType) { + final Optional index = tableMetadata.getIndex(column.getKey()); + SimilarityFunction similarityFunction = SimilarityFunction.COSINE; + if (index.isPresent()) { + final IndexMetadata indexMetadata = index.get(); + final Map indexOptions = indexMetadata.getOptions(); + final String similarityFunctionValue = indexOptions.get("similarity_function"); + if (similarityFunctionValue != null) { + similarityFunction = SimilarityFunction.fromString(similarityFunctionValue); + } + } + int dimension = vectorType.getDimensions(); + VectorConfig.ColumnVectorDefinition columnVectorDefinition = + new VectorConfig.ColumnVectorDefinition( + column.getKey().asInternal(), + dimension, + similarityFunction, + vectorizeConfigMap.get(column.getKey().asInternal())); + columnVectorDefinitions.add(columnVectorDefinition); + } + } + if (columnVectorDefinitions.isEmpty()) { + vectorConfig = VectorConfig.notEnabledVectorConfig(); + } else { + vectorConfig = new VectorConfig(true, Collections.unmodifiableList(columnVectorDefinitions)); + } + return new TableSchemaObject(tableMetadata, vectorConfig); + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorConfig.java index dfb52b543..229d6d0de 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorConfig.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorConfig.java @@ -2,66 +2,106 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import java.util.List; import java.util.Map; /** * incorporates vectorizeConfig into vectorConfig * - * @param vectorEnabled - * @param vectorSize - * @param similarityFunction - * @param vectorizeConfig + * @param vectorEnabled - If vector field is available for the table + * @param columnVectorDefinitions - List of columnVectorDefinitions each with respect to a + * column/field */ public record VectorConfig( - boolean vectorEnabled, - int vectorSize, - SimilarityFunction similarityFunction, - VectorizeConfig vectorizeConfig) { + boolean vectorEnabled, List columnVectorDefinitions) { // TODO: this is an immutable record, this can be singleton // TODO: Remove the use of NULL for the objects like vectorizeConfig public static VectorConfig notEnabledVectorConfig() { - return new VectorConfig(false, -1, null, null); + return new VectorConfig(false, null); } - // convert a vector jsonNode from table comment to vectorConfig - public static VectorConfig fromJson(JsonNode jsonNode, ObjectMapper objectMapper) { - // dimension, similarityFunction, must exist - int dimension = jsonNode.get("dimension").asInt(); - SimilarityFunction similarityFunction = - SimilarityFunction.fromString(jsonNode.get("metric").asText()); + /** + * Configuration for a column, In case of collection this will be of size one + * + * @param fieldName + * @param vectorSize + * @param similarityFunction + * @param vectorizeConfig + */ + public record ColumnVectorDefinition( + String fieldName, + int vectorSize, + SimilarityFunction similarityFunction, + VectorizeConfig vectorizeConfig) { - VectorizeConfig vectorizeConfig = null; - // construct vectorizeConfig - JsonNode vectorizeServiceNode = jsonNode.get("service"); - if (vectorizeServiceNode != null) { - // provider, modelName, must exist - String provider = vectorizeServiceNode.get("provider").asText(); - String modelName = vectorizeServiceNode.get("modelName").asText(); - // construct VectorizeConfig.authentication, can be null - JsonNode vectorizeServiceAuthenticationNode = vectorizeServiceNode.get("authentication"); - Map vectorizeServiceAuthentication = - vectorizeServiceAuthenticationNode == null - ? null - : objectMapper.convertValue(vectorizeServiceAuthenticationNode, Map.class); - // construct VectorizeConfig.parameters, can be null - JsonNode vectorizeServiceParameterNode = vectorizeServiceNode.get("parameters"); - Map vectorizeServiceParameter = - vectorizeServiceParameterNode == null - ? null - : objectMapper.convertValue(vectorizeServiceParameterNode, Map.class); - vectorizeConfig = - new VectorizeConfig( - provider, modelName, vectorizeServiceAuthentication, vectorizeServiceParameter); + // convert a vector jsonNode from comment option to vectorConfig, used for collection + public static ColumnVectorDefinition fromJson(JsonNode jsonNode, ObjectMapper objectMapper) { + // dimension, similarityFunction, must exist + int dimension = jsonNode.get("dimension").asInt(); + SimilarityFunction similarityFunction = + SimilarityFunction.fromString(jsonNode.get("metric").asText()); + + return fromJson( + DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, + dimension, + similarityFunction, + jsonNode, + objectMapper); } - return new VectorConfig(true, dimension, similarityFunction, vectorizeConfig); - } + // convert a vector jsonNode from table extension to vectorConfig, used for tables + public static ColumnVectorDefinition fromJson( + String fieldName, + int dimension, + SimilarityFunction similarityFunction, + JsonNode jsonNode, + ObjectMapper objectMapper) { + VectorizeConfig vectorizeConfig = null; + // construct vectorizeConfig + JsonNode vectorizeServiceNode = jsonNode.get("service"); + if (vectorizeServiceNode != null) { + vectorizeConfig = VectorizeConfig.fromJson(vectorizeServiceNode, objectMapper); + } + return new ColumnVectorDefinition(fieldName, dimension, similarityFunction, vectorizeConfig); + } - public record VectorizeConfig( - String provider, - String modelName, - Map authentication, - Map parameters) {} + /** + * Represent the vectorize configuration defined for a column + * + * @param provider + * @param modelName + * @param authentication + * @param parameters + */ + public record VectorizeConfig( + String provider, + String modelName, + Map authentication, + Map parameters) { + + protected static VectorizeConfig fromJson( + JsonNode vectorizeServiceNode, ObjectMapper objectMapper) { + // provider, modelName, must exist + String provider = vectorizeServiceNode.get("provider").asText(); + String modelName = vectorizeServiceNode.get("modelName").asText(); + // construct VectorizeConfig.authentication, can be null + JsonNode vectorizeServiceAuthenticationNode = vectorizeServiceNode.get("authentication"); + Map vectorizeServiceAuthentication = + vectorizeServiceAuthenticationNode == null + ? null + : objectMapper.convertValue(vectorizeServiceAuthenticationNode, Map.class); + // construct VectorizeConfig.parameters, can be null + JsonNode vectorizeServiceParameterNode = vectorizeServiceNode.get("parameters"); + Map vectorizeServiceParameter = + vectorizeServiceParameterNode == null + ? null + : objectMapper.convertValue(vectorizeServiceParameterNode, Map.class); + return new VectorizeConfig( + provider, modelName, vectorizeServiceAuthentication, vectorizeServiceParameter); + } + } + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java index 0557ae329..1eca64f39 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java @@ -14,6 +14,7 @@ import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; import java.util.ArrayList; import java.util.HashMap; @@ -110,11 +111,16 @@ public Uni vectorize(List documents) { .onItem() .transform( vectorData -> { + final VectorConfig vectorConfig = schemaObject.vectorConfig(); + // This will be the first element for collection + final VectorConfig.ColumnVectorDefinition collectionVectorDefinition = + vectorConfig.columnVectorDefinitions().get(0); + // check if we get back the same number of vectors that we asked for if (vectorData.size() != vectorizeTexts.size()) { throw EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( "Embedding provider '%s' didn't return the expected number of embeddings. Expect: '%d'. Actual: '%d'", - schemaObject.vectorConfig().vectorizeConfig().provider(), + collectionVectorDefinition.vectorizeConfig().provider(), vectorizeTexts.size(), vectorData.size()); } @@ -125,11 +131,11 @@ public Uni vectorize(List documents) { JsonNode document = documents.get(position); float[] vector = vectorData.get(vectorPosition); // check if all vectors have the expected size - if (vector.length != schemaObject.vectorConfig().vectorSize()) { + if (vector.length != collectionVectorDefinition.vectorSize()) { throw EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( "Embedding provider '%s' did not return expected embedding length. Expect: '%d'. Actual: '%d'", - schemaObject.vectorConfig().vectorizeConfig().provider(), - schemaObject.vectorConfig().vectorSize(), + collectionVectorDefinition.vectorizeConfig().provider(), + collectionVectorDefinition.vectorSize(), vector.length); } final ArrayNode arrayNode = nodeFactory.arrayNode(vector.length); @@ -168,13 +174,17 @@ public Uni vectorize(String vectorizeContent) { .onItem() .transform( vectorData -> { + final VectorConfig vectorConfig = schemaObject.vectorConfig(); + // This will be the first element for collection + final VectorConfig.ColumnVectorDefinition collectionVectorDefinition = + vectorConfig.columnVectorDefinitions().get(0); float[] vector = vectorData.get(0); // check if vector have the expected size - if (vector.length != schemaObject.vectorConfig().vectorSize()) { + if (vector.length != collectionVectorDefinition.vectorSize()) { throw EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( "Embedding provider '%s' did not return expected embedding length. Expect: '%d'. Actual: '%d'", - schemaObject.vectorConfig().vectorizeConfig().provider(), - schemaObject.vectorConfig().vectorSize(), + collectionVectorDefinition.vectorizeConfig().provider(), + collectionVectorDefinition.vectorSize(), vector.length); } return vector; @@ -211,12 +221,16 @@ public Uni vectorize(SortClause sortClause) { .transform( vectorData -> { float[] vector = vectorData.get(0); + final VectorConfig vectorConfig = schemaObject.vectorConfig(); + // This will be the first element for collection + final VectorConfig.ColumnVectorDefinition collectionVectorDefinition = + vectorConfig.columnVectorDefinitions().get(0); // check if vector have the expected size - if (vector.length != schemaObject.vectorConfig().vectorSize()) { + if (vector.length != collectionVectorDefinition.vectorSize()) { throw EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( "Embedding provider '%s' did not return expected embedding length. Expect: '%d'. Actual: '%d'", - schemaObject.vectorConfig().vectorizeConfig().provider(), - schemaObject.vectorConfig().vectorSize(), + collectionVectorDefinition.vectorizeConfig().provider(), + collectionVectorDefinition.vectorSize(), vector.length); } sortExpressions.clear(); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/processor/MeteredCommandProcessor.java b/src/main/java/io/stargate/sgv2/jsonapi/service/processor/MeteredCommandProcessor.java index a2d2a581f..5cb882f59 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/processor/MeteredCommandProcessor.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/processor/MeteredCommandProcessor.java @@ -243,6 +243,7 @@ private Tags getCustomTags( result.errors().get(0).fieldsForMetricsTag().getOrDefault("errorCode", UNKNOWN_VALUE); errorCodeTag = Tag.of(jsonApiMetricsConfig.errorCode(), errorCode); } + Tag vectorEnabled = commandContext.schemaObject().vectorConfig().vectorEnabled() ? Tag.of(jsonApiMetricsConfig.vectorEnabled(), "true") diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateTableCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateTableCommandResolver.java index 7f25c46ed..b117cbcad 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateTableCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateTableCommandResolver.java @@ -11,6 +11,7 @@ import io.stargate.sgv2.jsonapi.config.OperationsConfig; import io.stargate.sgv2.jsonapi.exception.SchemaException; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.KeyspaceSchemaObject; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; import io.stargate.sgv2.jsonapi.service.operation.*; import io.stargate.sgv2.jsonapi.service.operation.Operation; import io.stargate.sgv2.jsonapi.service.operation.tables.CreateTableAttemptBuilder; @@ -39,7 +40,7 @@ public Operation resolveKeyspaceCommand( .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getApiDataType())); List partitionKeys = Arrays.stream(command.definition().primaryKey().keys()).toList(); - Map vectorizeConfigMap = + Map vectorizeConfigMap = command.definition().columns().entrySet().stream() .filter( e -> @@ -50,9 +51,15 @@ public Operation resolveKeyspaceCommand( Map.Entry::getKey, e -> { ComplexTypes.VectorType vectorType = ((ComplexTypes.VectorType) e.getValue()); - final VectorizeConfig vectorConfig = vectorType.getVectorConfig(); - validateVectorize.validateService(vectorConfig, vectorType.getDimension()); - return vectorConfig; + final VectorizeConfig vectorizeConfig = vectorType.getVectorConfig(); + validateVectorize.validateService(vectorizeConfig, vectorType.getDimension()); + VectorConfig.ColumnVectorDefinition.VectorizeConfig dbVectorConfig = + new VectorConfig.ColumnVectorDefinition.VectorizeConfig( + vectorizeConfig.provider(), + vectorizeConfig.modelName(), + vectorizeConfig.authentication(), + vectorizeConfig.parameters()); + return dbVectorConfig; })); if (partitionKeys.isEmpty()) { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java index 3c054d51c..2fe10c8e6 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java @@ -17,6 +17,7 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.*; import io.stargate.sgv2.jsonapi.service.projection.IndexingProjector; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -220,7 +221,14 @@ private static CollectionSchemaObject createCollectionSettings( collectionName, tableMetadata, IdConfig.defaultIdConfig(), - new VectorConfig(true, vectorSize, function, null), + new VectorConfig( + true, + List.of( + new VectorConfig.ColumnVectorDefinition( + DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, + vectorSize, + function, + null))), null); } else { return new CollectionSchemaObject( @@ -272,6 +280,22 @@ private static CollectionSchemaObject createCollectionSettings( } } + // convert a vector jsonNode from cql table comment to vectorConfig, used for collection + private static VectorConfig.ColumnVectorDefinition fromJson( + JsonNode jsonNode, ObjectMapper objectMapper) { + // dimension, similarityFunction, must exist + int dimension = jsonNode.get("dimension").asInt(); + SimilarityFunction similarityFunction = + SimilarityFunction.fromString(jsonNode.get("metric").asText()); + + return VectorConfig.ColumnVectorDefinition.fromJson( + DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, + dimension, + similarityFunction, + jsonNode, + objectMapper); + } + public static CreateCollectionCommand collectionSettingToCreateCollectionCommand( CollectionSchemaObject collectionSetting) { @@ -279,25 +303,28 @@ public static CreateCollectionCommand collectionSettingToCreateCollectionCommand CreateCollectionCommand.Options options = null; CreateCollectionCommand.Options.VectorSearchConfig vectorSearchConfig = null; CreateCollectionCommand.Options.IndexingConfig indexingConfig = null; - // populate the vectorSearchConfig - if (collectionSetting.vectorConfig().vectorEnabled()) { + // populate the vectorSearchConfig, Default will be the index 0 since there is only one vector + // column supported for collection + final VectorConfig vectorConfig = collectionSetting.vectorConfig(); + if (vectorConfig.vectorEnabled()) { + // This will be size 1 for collection + VectorConfig.ColumnVectorDefinition vectorConfigColumn = + vectorConfig.columnVectorDefinitions().get(0); VectorizeConfig vectorizeConfig = null; - if (collectionSetting.vectorConfig().vectorizeConfig() != null) { - Map authentication = - collectionSetting.vectorConfig().vectorizeConfig().authentication(); - Map parameters = - collectionSetting.vectorConfig().vectorizeConfig().parameters(); + if (vectorConfigColumn.vectorizeConfig() != null) { + Map authentication = vectorConfigColumn.vectorizeConfig().authentication(); + Map parameters = vectorConfigColumn.vectorizeConfig().parameters(); vectorizeConfig = new VectorizeConfig( - collectionSetting.vectorConfig().vectorizeConfig().provider(), - collectionSetting.vectorConfig().vectorizeConfig().modelName(), + vectorConfigColumn.vectorizeConfig().provider(), + vectorConfigColumn.vectorizeConfig().modelName(), authentication == null ? null : Map.copyOf(authentication), parameters == null ? null : Map.copyOf(parameters)); } vectorSearchConfig = new CreateCollectionCommand.Options.VectorSearchConfig( - collectionSetting.vectorConfig().vectorSize(), - collectionSetting.vectorConfig().similarityFunction().name().toLowerCase(), + vectorConfigColumn.vectorSize(), + vectorConfigColumn.similarityFunction().name().toLowerCase(), vectorizeConfig); } // populate the indexingConfig @@ -331,11 +358,11 @@ public CollectionIndexingConfig indexingConfig() { // TODO: these helper functions break encapsulation for very little benefit public SimilarityFunction similarityFunction() { - return vectorConfig().similarityFunction(); + return vectorConfig().columnVectorDefinitions().get(0).similarityFunction(); } public boolean isVectorEnabled() { - return vectorConfig() != null && vectorConfig().vectorEnabled(); + return vectorConfig().vectorEnabled(); } // TODO: the overrides below were auto added when migrating from a record to a class, not sure diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV0Reader.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV0Reader.java index 0cf67b0be..753827934 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV0Reader.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV0Reader.java @@ -3,9 +3,11 @@ import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.config.constants.TableCommentConstants; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import java.util.List; /** * schema_version 0 is before we introduce schema_version into the C* table comment of data api @@ -26,7 +28,15 @@ public CollectionSchemaObject readCollectionSettings( int vectorSize, SimilarityFunction function) { - VectorConfig vectorConfig = new VectorConfig(vectorEnabled, vectorSize, function, null); + VectorConfig vectorConfig = + new VectorConfig( + vectorEnabled, + List.of( + new VectorConfig.ColumnVectorDefinition( + DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, + vectorSize, + function, + null))); CollectionIndexingConfig indexingConfig = null; JsonNode indexing = commentConfigNode.path(TableCommentConstants.COLLECTION_INDEXING_KEY); if (!indexing.isMissingNode()) { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV1Reader.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV1Reader.java index de0c3a7e6..9d577a9e9 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV1Reader.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV1Reader.java @@ -5,6 +5,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.stargate.sgv2.jsonapi.config.constants.TableCommentConstants; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; +import java.util.List; /** * schema_version 1 sample: @@ -25,7 +26,9 @@ public CollectionSchemaObject readCollectionSettings( VectorConfig vectorConfig = VectorConfig.notEnabledVectorConfig(); JsonNode vector = collectionOptionsNode.path(TableCommentConstants.COLLECTION_VECTOR_KEY); if (!vector.isMissingNode()) { - vectorConfig = VectorConfig.fromJson(vector, objectMapper); + VectorConfig.ColumnVectorDefinition columnVectorDefinition = + VectorConfig.ColumnVectorDefinition.fromJson(vector, objectMapper); + vectorConfig = new VectorConfig(true, List.of(columnVectorDefinition)); } // construct collectionSettings IndexingConfig CollectionIndexingConfig indexingConfig = null; diff --git a/src/main/resources/errors.yaml b/src/main/resources/errors.yaml index 7ab61d80d..60414daf2 100644 --- a/src/main/resources/errors.yaml +++ b/src/main/resources/errors.yaml @@ -371,6 +371,18 @@ request-errors: "modelName": "NV-Embed-QA" } } + - scope: SCHEMA + code: INVALID_CONFIGURATION + title: Unable to parse configuration, schema invalid. + body: |- + Unable to parse configuration, schema invalid. + + - scope: SCHEMA + code: INVALID_VECTORIZE_CONFIGURATION + title: Unable to parse vectorize configuration, schema invalid. + body: |- + Unable to parse vectorize configuration, schema invalid for field ${field}. + # ================================================================================================================ # Server Errors diff --git a/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java b/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java index f13ab7a04..e990ea1f9 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java @@ -1,11 +1,13 @@ package io.stargate.sgv2.jsonapi; import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; +import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.config.feature.ApiFeatures; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.*; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import io.stargate.sgv2.jsonapi.service.schema.collections.IdConfig; +import java.util.List; import org.apache.commons.lang3.RandomStringUtils; /** @@ -34,7 +36,14 @@ public final class TestConstants { SCHEMA_OBJECT_NAME, null, IdConfig.defaultIdConfig(), - new VectorConfig(true, -1, SimilarityFunction.COSINE, null), + new VectorConfig( + true, + List.of( + new VectorConfig.ColumnVectorDefinition( + DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, + -1, + SimilarityFunction.COSINE, + null))), null); public static final KeyspaceSchemaObject KEYSPACE_SCHEMA_OBJECT = diff --git a/src/test/java/io/stargate/sgv2/jsonapi/fixtures/CqlFixture.java b/src/test/java/io/stargate/sgv2/jsonapi/fixtures/CqlFixture.java index 28d0e3b66..6a267769e 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/fixtures/CqlFixture.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/fixtures/CqlFixture.java @@ -1,6 +1,7 @@ package io.stargate.sgv2.jsonapi.fixtures; import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; +import com.fasterxml.jackson.databind.ObjectMapper; import io.stargate.sgv2.jsonapi.fixtures.data.DefaultData; import io.stargate.sgv2.jsonapi.fixtures.data.FixtureData; import io.stargate.sgv2.jsonapi.fixtures.identifiers.BaseFixtureIdentifiers; @@ -62,7 +63,7 @@ public CqlFixture( this.cqlData = cqlData; this.tableFixture = tableFixture; this.tableMetadata = tableFixture.tableMetadata(identifiers); - this.tableSchemaObject = new TableSchemaObject(tableMetadata); + this.tableSchemaObject = TableSchemaObject.from(tableMetadata, new ObjectMapper()); } public FixtureIdentifiers identifiers() { diff --git a/src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/SchemaObjectTestData.java b/src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/SchemaObjectTestData.java index b13fde1ef..2dbfaebf3 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/SchemaObjectTestData.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/SchemaObjectTestData.java @@ -1,5 +1,6 @@ package io.stargate.sgv2.jsonapi.fixtures.testdata; +import com.fasterxml.jackson.databind.ObjectMapper; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; public class SchemaObjectTestData extends TestDataSuplier { @@ -9,6 +10,6 @@ public SchemaObjectTestData(TestData testData) { } public TableSchemaObject emptyTableSchemaObject() { - return new TableSchemaObject(testData.tableMetadata().empty()); + return TableSchemaObject.from(testData.tableMetadata().empty(), new ObjectMapper()); } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/WhereAnalyzerTestData.java b/src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/WhereAnalyzerTestData.java index c734859c1..531388e26 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/WhereAnalyzerTestData.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/WhereAnalyzerTestData.java @@ -10,6 +10,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; +import com.fasterxml.jackson.databind.ObjectMapper; import io.stargate.sgv2.jsonapi.exception.FilterException; import io.stargate.sgv2.jsonapi.exception.WarningException; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; @@ -58,8 +59,9 @@ public WhereAnalyzerFixture( this.message = message; this.tableMetadata = tableMetadata; - this.analyzer = new WhereCQLClauseAnalyzer(new TableSchemaObject(tableMetadata)); - this.tableSchemaObject = new TableSchemaObject(tableMetadata); + this.analyzer = + new WhereCQLClauseAnalyzer(TableSchemaObject.from(tableMetadata, new ObjectMapper())); + this.tableSchemaObject = TableSchemaObject.from(tableMetadata, new ObjectMapper()); this.expression = new LogicalExpressionTestData.ExpressionBuilder<>(this, expression, tableMetadata); } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/DefaultDriverExceptionHandlerTestData.java b/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/DefaultDriverExceptionHandlerTestData.java index d80645866..4db187633 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/DefaultDriverExceptionHandlerTestData.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/DefaultDriverExceptionHandlerTestData.java @@ -2,6 +2,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.internal.core.metadata.schema.DefaultTableMetadata; +import com.fasterxml.jackson.databind.ObjectMapper; import java.util.List; import java.util.Map; import java.util.UUID; @@ -32,6 +33,6 @@ public DefaultDriverExceptionHandlerTestData() { Map.of(), Map.of(), Map.of()); - TABLE_SCHEMA_OBJECT = new TableSchemaObject(tableMetadata); + TABLE_SCHEMA_OBJECT = TableSchemaObject.from(tableMetadata, new ObjectMapper()); } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java index 1ac981fc7..e61fd4510 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java @@ -13,6 +13,7 @@ import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; +import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; @@ -232,9 +233,13 @@ public void testWithUnmatchedVectorSize() { IdConfig.defaultIdConfig(), new VectorConfig( true, - 4, - SimilarityFunction.COSINE, - new VectorConfig.VectorizeConfig("custom", "custom", null, null)), + List.of( + new VectorConfig.ColumnVectorDefinition( + DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, + 4, + SimilarityFunction.COSINE, + new VectorConfig.ColumnVectorDefinition.VectorizeConfig( + "custom", "custom", null, null)))), null); List documents = new ArrayList<>(); for (int i = 0; i < 2; i++) { diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java index ac27f0f20..df37dbd63 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java @@ -4,6 +4,7 @@ import io.stargate.sgv2.jsonapi.TestConstants; import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; +import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; @@ -21,9 +22,13 @@ public class TestEmbeddingProvider extends EmbeddingProvider { IdConfig.defaultIdConfig(), new VectorConfig( true, - 3, - SimilarityFunction.COSINE, - new VectorConfig.VectorizeConfig("custom", "custom", null, null)), + List.of( + new VectorConfig.ColumnVectorDefinition( + DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, + 3, + SimilarityFunction.COSINE, + new VectorConfig.ColumnVectorDefinition.VectorizeConfig( + "custom", "custom", null, null)))), null), new TestEmbeddingProvider(), "testCommand", diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/FindCollectionOperationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/FindCollectionOperationTest.java index 76c1d8154..7cab15ccb 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/FindCollectionOperationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/FindCollectionOperationTest.java @@ -25,6 +25,7 @@ import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; import io.stargate.sgv2.jsonapi.api.model.command.CommandResult; import io.stargate.sgv2.jsonapi.api.model.command.CommandStatus; +import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.exception.mappers.ThrowableToErrorMapper; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.QueryExecutor; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; @@ -82,7 +83,14 @@ public void init() { SCHEMA_OBJECT_NAME, null, IdConfig.defaultIdConfig(), - new VectorConfig(true, -1, SimilarityFunction.COSINE, null), + new VectorConfig( + true, + List.of( + new VectorConfig.ColumnVectorDefinition( + DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, + -1, + SimilarityFunction.COSINE, + null))), null), null, "testCommand", diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/InsertCollectionOperationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/InsertCollectionOperationTest.java index 1ad31e4b2..d9f88307b 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/InsertCollectionOperationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/InsertCollectionOperationTest.java @@ -20,6 +20,7 @@ import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; import io.stargate.sgv2.jsonapi.api.model.command.CommandResult; import io.stargate.sgv2.jsonapi.api.model.command.CommandStatus; +import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.QueryExecutor; @@ -98,7 +99,14 @@ public void init() { SCHEMA_OBJECT_NAME, null, IdConfig.defaultIdConfig(), - new VectorConfig(true, -1, SimilarityFunction.COSINE, null), + new VectorConfig( + true, + List.of( + new VectorConfig.ColumnVectorDefinition( + DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, + -1, + SimilarityFunction.COSINE, + null))), null), null, "testCommand", diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/ReadAndUpdateCollectionOperationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/ReadAndUpdateCollectionOperationTest.java index 570d48b86..a0b782b37 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/ReadAndUpdateCollectionOperationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/ReadAndUpdateCollectionOperationTest.java @@ -23,6 +23,7 @@ import io.stargate.sgv2.jsonapi.api.model.command.CommandStatus; import io.stargate.sgv2.jsonapi.api.model.command.clause.update.UpdateClause; import io.stargate.sgv2.jsonapi.api.model.command.clause.update.UpdateOperator; +import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.QueryExecutor; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; import io.stargate.sgv2.jsonapi.service.cqldriver.serializer.CQLBindValues; @@ -117,7 +118,14 @@ public void init() { SCHEMA_OBJECT_NAME, null, IdConfig.defaultIdConfig(), - new VectorConfig(true, -1, SimilarityFunction.COSINE, null), + new VectorConfig( + true, + List.of( + new VectorConfig.ColumnVectorDefinition( + DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, + -1, + SimilarityFunction.COSINE, + null))), null), null, "testCommand", diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/CommandResolverWithVectorizerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/CommandResolverWithVectorizerTest.java index e1a7de40d..503b7b863 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/CommandResolverWithVectorizerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/CommandResolverWithVectorizerTest.java @@ -19,6 +19,7 @@ import io.stargate.sgv2.jsonapi.api.model.command.impl.UpdateOneCommand; import io.stargate.sgv2.jsonapi.api.request.DataApiRequestInfo; import io.stargate.sgv2.jsonapi.config.OperationsConfig; +import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.config.feature.ApiFeatures; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; @@ -41,6 +42,7 @@ import io.stargate.sgv2.jsonapi.service.updater.DocumentUpdater; import io.stargate.sgv2.jsonapi.testresource.NoGlobalResourcesTestProfile; import jakarta.inject.Inject; +import java.util.List; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -83,7 +85,14 @@ class Resolve { new SchemaObjectName(KEYSPACE_NAME, COLLECTION_NAME), null, IdConfig.defaultIdConfig(), - new VectorConfig(true, -1, SimilarityFunction.COSINE, null), + new VectorConfig( + true, + List.of( + new VectorConfig.ColumnVectorDefinition( + DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, + -1, + SimilarityFunction.COSINE, + null))), null), null, null,