diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index 897b2e13a6..eeaabec3e0 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "errors" "fmt" + "math" "os" "strings" "sync" @@ -32,6 +33,9 @@ type CassandraOnlineStore struct { session *gocql.Session config *registry.RepoConfig + + // The number of keys to include in a single CQL query for retrieval from the database + keyBatchSize int } type CassandraConfig struct { @@ -43,6 +47,7 @@ type CassandraConfig struct { loadBalancingPolicy gocql.HostSelectionPolicy connectionTimeoutMillis int64 requestTimeoutMillis int64 + keyBatchSize int } func parseStringField(config map[string]any, fieldName string, defaultValue string) (string, error) { @@ -155,6 +160,13 @@ func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig, } cassandraConfig.requestTimeoutMillis = int64(requestTimeoutMillis.(float64)) + keyBatchSize, ok := onlineStoreConfig["key_batch_size"] + if !ok { + keyBatchSize = 10.0 + log.Warn().Msg("key_batch_size not specified, defaulting to batches of size 10") + } + cassandraConfig.keyBatchSize = int(keyBatchSize.(float64)) + return &cassandraConfig, nil } @@ -175,8 +187,9 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online store.clusterConfigs.PoolConfig.HostSelectionPolicy = cassandraConfig.loadBalancingPolicy - if cassandraConfig.username != "" && cassandraConfig.password != "" { - log.Warn().Msg("username/password not defined, will not be using authentication") + if cassandraConfig.username == "" || cassandraConfig.password == "" { + log.Warn().Msg("username and/or password not defined, will not be using authentication") + } else { store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{ Username: cassandraConfig.username, Password: cassandraConfig.password, @@ -202,6 +215,16 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online return nil, fmt.Errorf("unable to connect to the ScyllaDB database") } store.session = createdSession + + if cassandraConfig.keyBatchSize <= 0 || cassandraConfig.keyBatchSize > 100 { + return nil, fmt.Errorf("key_batch_size must be greater than zero and less than 100") + } else if cassandraConfig.keyBatchSize == 1 { + log.Info().Msg("key batching is disabled") + } else { + log.Info().Msgf("key batching is enabled with a batch size of %d", cassandraConfig.keyBatchSize) + } + store.keyBatchSize = cassandraConfig.keyBatchSize + return &store, nil } @@ -209,7 +232,7 @@ func (c *CassandraOnlineStore) getFqTableName(tableName string) string { return fmt.Sprintf(`"%s"."%s_%s"`, c.clusterConfigs.Keyspace, c.project, tableName) } -func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string) string { +func (c *CassandraOnlineStore) getSingleKeyCQLStatement(tableName string, featureNames []string) string { // this prevents fetching unnecessary features quotedFeatureNames := make([]string, len(featureNames)) for i, featureName := range featureNames { @@ -223,6 +246,26 @@ func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames [] ) } +func (c *CassandraOnlineStore) getMultiKeyCQLStatement(tableName string, featureNames []string, nkeys int) string { + // this prevents fetching unnecessary features + quotedFeatureNames := make([]string, len(featureNames)) + for i, featureName := range featureNames { + quotedFeatureNames[i] = fmt.Sprintf(`'%s'`, featureName) + } + + keyPlaceholders := make([]string, nkeys) + for i := 0; i < nkeys; i++ { + keyPlaceholders[i] = "?" + } + + return fmt.Sprintf( + `SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" IN (%s) AND "feature_name" IN (%s)`, + tableName, + strings.Join(keyPlaceholders, ","), + strings.Join(quotedFeatureNames, ","), + ) +} + func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.EntityKey) ([]any, map[string]int, error) { cassandraKeys := make([]any, len(entityKeys)) cassandraKeyToEntityIndex := make(map[string]int) @@ -237,7 +280,8 @@ func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.Enti } return cassandraKeys, cassandraKeyToEntityIndex, nil } -func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { + +func (c *CassandraOnlineStore) UnbatchedKeysOnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { uniqueNames := make(map[string]int32) for _, fvName := range featureViewNames { uniqueNames[fvName] = 0 @@ -265,7 +309,7 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ // Prepare the query tableName := c.getFqTableName(featureViewName) - cqlStatement := c.getCQLStatement(tableName, featureNames) + cqlStatement := c.getSingleKeyCQLStatement(tableName, featureNames) var waitGroup sync.WaitGroup waitGroup.Add(len(serializedEntityKeys)) @@ -372,6 +416,156 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ return results, nil } +func (c *CassandraOnlineStore) BatchedKeysOnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { + uniqueNames := make(map[string]int32) + for _, fvName := range featureViewNames { + uniqueNames[fvName] = 0 + } + if len(uniqueNames) != 1 { + return nil, fmt.Errorf("rejecting OnlineRead as more than 1 feature view was tried to be read at once") + } + + serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys) + + if err != nil { + return nil, fmt.Errorf("error when serializing entity keys for Cassandra") + } + results := make([][]FeatureData, len(entityKeys)) + for i := range results { + results[i] = make([]FeatureData, len(featureNames)) + } + + featureNamesToIdx := make(map[string]int) + for idx, name := range featureNames { + featureNamesToIdx[name] = idx + } + + featureViewName := featureViewNames[0] + + // Prepare the query + tableName := c.getFqTableName(featureViewName) + + // Key batching + nKeys := len(serializedEntityKeys) + batchSize := c.keyBatchSize + nBatches := int(math.Ceil(float64(nKeys) / float64(batchSize))) + + batches := make([][]any, nBatches) + nAssigned := 0 + for i := 0; i < nBatches; i++ { + thisBatchSize := int(math.Min(float64(batchSize), float64(nKeys-nAssigned))) + nAssigned += thisBatchSize + batches[i] = make([]any, thisBatchSize) + for j := 0; j < thisBatchSize; j++ { + batches[i][j] = serializedEntityKeys[i*batchSize+j] + } + } + + var waitGroup sync.WaitGroup + waitGroup.Add(nBatches) + + errorsChannel := make(chan error, nBatches) + var prevBatchLength int + var cqlStatement string + for _, batch := range batches { + go func(keyBatch []any) { + defer waitGroup.Done() + + // this caches the previous batch query if it had the same number of keys + if len(keyBatch) != prevBatchLength { + cqlStatement = c.getMultiKeyCQLStatement(tableName, featureNames, len(keyBatch)) + } + + iter := c.session.Query(cqlStatement, keyBatch...).WithContext(ctx).Iter() + + scanner := iter.Scanner() + var entityKey string + var featureName string + var eventTs time.Time + var valueStr []byte + var deserializedValue types.Value + // key 1: entityKey - key 2: featureName + batchFeatures := make(map[string]map[string]FeatureData) + for scanner.Next() { + err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr) + if err != nil { + errorsChannel <- errors.New("could not read row in query for (entity key, feature name, value, event ts)") + return + } + if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil { + errorsChannel <- errors.New("error converting parsed Cassandra Value to types.Value") + return + } + + if deserializedValue.Val != nil { + if batchFeatures[entityKey] == nil { + batchFeatures[entityKey] = make(map[string]FeatureData) + } + batchFeatures[entityKey][featureName] = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featureName, + }, + Timestamp: timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}, + Value: types.Value{ + Val: deserializedValue.Val, + }, + } + } + } + + if err := scanner.Err(); err != nil { + errorsChannel <- errors.New("failed to scan features: " + err.Error()) + return + } + + for _, serializedEntityKey := range keyBatch { + for _, featName := range featureNames { + keyString := serializedEntityKey.(string) + featureData, ok := batchFeatures[keyString][featName] + if !ok { + featureData = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featName, + }, + Value: types.Value{ + Val: &types.Value_NullVal{ + NullVal: types.Null_NULL, + }, + }, + } + } + results[serializedEntityKeyToIndex[keyString]][featureNamesToIdx[featName]] = featureData + } + } + }(batch) + } + // wait until all concurrent single-key queries are done + waitGroup.Wait() + close(errorsChannel) + + var collectedErrors []error + for err := range errorsChannel { + if err != nil { + collectedErrors = append(collectedErrors, err) + } + } + if len(collectedErrors) > 0 { + return nil, errors.Join(collectedErrors...) + } + + return results, nil +} + +func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { + if c.keyBatchSize == 1 { + return c.UnbatchedKeysOnlineRead(ctx, entityKeys, featureViewNames, featureNames) + } else { + return c.BatchedKeysOnlineRead(ctx, entityKeys, featureViewNames, featureNames) + } +} + func (c *CassandraOnlineStore) Destruct() { c.session.Close() } diff --git a/go/internal/feast/onlinestore/cassandraonlinestore_test.go b/go/internal/feast/onlinestore/cassandraonlinestore_test.go index 67a9eea548..19c53506b3 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore_test.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore_test.go @@ -60,17 +60,28 @@ func TestGetFqTableName(t *testing.T) { assert.Equal(t, `"scylladb"."dummy_project_dummy_fv"`, fqTableName) } -func TestGetCQLStatement(t *testing.T) { +func TestGetSingleKeyCQLStatement(t *testing.T) { store := CassandraOnlineStore{} fqTableName := `"scylladb"."dummy_project_dummy_fv"` - cqlStatement := store.getCQLStatement(fqTableName, []string{"feat1", "feat2"}) + cqlStatement := store.getSingleKeyCQLStatement(fqTableName, []string{"feat1", "feat2"}) assert.Equal(t, `SELECT "entity_key", "feature_name", "event_ts", "value" FROM "scylladb"."dummy_project_dummy_fv" WHERE "entity_key" = ? AND "feature_name" IN ('feat1','feat2')`, cqlStatement, ) } +func TestGetMultiKeyCQLStatement(t *testing.T) { + store := CassandraOnlineStore{} + fqTableName := `"scylladb"."dummy_project_dummy_fv"` + + cqlStatement := store.getMultiKeyCQLStatement(fqTableName, []string{"feat1", "feat2"}, 5) + assert.Equal(t, + `SELECT "entity_key", "feature_name", "event_ts", "value" FROM "scylladb"."dummy_project_dummy_fv" WHERE "entity_key" IN (?,?,?,?,?) AND "feature_name" IN ('feat1','feat2')`, + cqlStatement, + ) +} + func TestOnlineRead_RejectsDifferentFeatureViewsInSameRead(t *testing.T) { store := CassandraOnlineStore{} _, err := store.OnlineRead(context.TODO(), nil, []string{"fv1", "fv2"}, []string{"feat1", "feat2"}) diff --git a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py index 1998de464a..b3c3d955f3 100644 --- a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py +++ b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py @@ -162,6 +162,9 @@ class CassandraOnlineStoreConfig(FeastConfigBaseModel): key_ttl_seconds: Optional[StrictInt] = None """Default TTL (in seconds) to apply to all tables if not specified in FeatureView. Value 0 or None means No TTL.""" + key_batch_size: Optional[StrictInt] = 10 + """In Go Feature Server, this configuration is used to query tables with multiple keys at a time using IN clause based on the size specified. Value 1 means key batching is disabled. Valid values are 1 to 100.""" + class CassandraLoadBalancingPolicy(FeastConfigBaseModel): """ Configuration block related to the Cluster's load-balancing policy.