-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Implement key batching for cassandra online store in go feature server. #165
Changes from 7 commits
118f947
e7f7c0f
2c3f1de
535741b
4f4302f
e4ca438
7f9427e
d7335a5
ba060c7
749eb1c
f36aa4c
285a398
e3d44f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ import ( | |
"encoding/hex" | ||
"errors" | ||
"fmt" | ||
"math" | ||
"os" | ||
"strings" | ||
"sync" | ||
|
@@ -32,6 +33,9 @@ type CassandraOnlineStore struct { | |
session *gocql.Session | ||
|
||
config *registry.RepoConfig | ||
|
||
// The number of keys to include in a single CQL query for retrieval from the database | ||
keyBatchSize int | ||
} | ||
|
||
type CassandraConfig struct { | ||
|
@@ -43,6 +47,7 @@ type CassandraConfig struct { | |
loadBalancingPolicy gocql.HostSelectionPolicy | ||
connectionTimeoutMillis int64 | ||
requestTimeoutMillis int64 | ||
keyBatchSize int | ||
} | ||
|
||
func parseStringField(config map[string]any, fieldName string, defaultValue string) (string, error) { | ||
|
@@ -155,6 +160,13 @@ func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig, | |
} | ||
cassandraConfig.requestTimeoutMillis = int64(requestTimeoutMillis.(float64)) | ||
|
||
keyBatchSize, ok := onlineStoreConfig["key_batch_size"] | ||
if !ok { | ||
keyBatchSize = 5.0 | ||
vanitabhagwat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
log.Warn().Msg("key_batch_size not specified, defaulting to batches of size 5") | ||
vanitabhagwat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
cassandraConfig.keyBatchSize = int(keyBatchSize.(float64)) | ||
|
||
return &cassandraConfig, nil | ||
} | ||
|
||
|
@@ -175,8 +187,9 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online | |
|
||
store.clusterConfigs.PoolConfig.HostSelectionPolicy = cassandraConfig.loadBalancingPolicy | ||
|
||
if cassandraConfig.username != "" && cassandraConfig.password != "" { | ||
log.Warn().Msg("username/password not defined, will not be using authentication") | ||
if cassandraConfig.username == "" || cassandraConfig.password == "" { | ||
log.Warn().Msg("username and/or password not defined, will not be using authentication") | ||
} else { | ||
store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{ | ||
Username: cassandraConfig.username, | ||
Password: cassandraConfig.password, | ||
|
@@ -202,14 +215,24 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online | |
return nil, fmt.Errorf("unable to connect to the ScyllaDB database") | ||
} | ||
store.session = createdSession | ||
|
||
if cassandraConfig.keyBatchSize <= 0 || cassandraConfig.keyBatchSize > 100 { | ||
return nil, fmt.Errorf("key_batch_size must be greater than zero and less than 100") | ||
} else if cassandraConfig.keyBatchSize == 1 { | ||
log.Info().Msg("key batching is disabled") | ||
} else { | ||
log.Info().Msgf("key batching is enabled with a batch size of %d", cassandraConfig.keyBatchSize) | ||
} | ||
store.keyBatchSize = cassandraConfig.keyBatchSize | ||
|
||
return &store, nil | ||
} | ||
|
||
func (c *CassandraOnlineStore) getFqTableName(tableName string) string { | ||
return fmt.Sprintf(`"%s"."%s_%s"`, c.clusterConfigs.Keyspace, c.project, tableName) | ||
} | ||
|
||
func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string) string { | ||
func (c *CassandraOnlineStore) getSingleKeyCQLStatement(tableName string, featureNames []string) string { | ||
// this prevents fetching unnecessary features | ||
quotedFeatureNames := make([]string, len(featureNames)) | ||
for i, featureName := range featureNames { | ||
|
@@ -223,6 +246,26 @@ func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames [] | |
) | ||
} | ||
|
||
func (c *CassandraOnlineStore) getMultiKeyCQLStatement(tableName string, featureNames []string, nkeys int) string { | ||
// this prevents fetching unnecessary features | ||
quotedFeatureNames := make([]string, len(featureNames)) | ||
for i, featureName := range featureNames { | ||
quotedFeatureNames[i] = fmt.Sprintf(`'%s'`, featureName) | ||
} | ||
|
||
keyPlaceholders := make([]string, nkeys) | ||
for i := 0; i < nkeys; i++ { | ||
keyPlaceholders[i] = "?" | ||
} | ||
|
||
return fmt.Sprintf( | ||
`SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" IN (%s) AND "feature_name" IN (%s)`, | ||
tableName, | ||
strings.Join(keyPlaceholders, ","), | ||
strings.Join(quotedFeatureNames, ","), | ||
) | ||
} | ||
|
||
func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.EntityKey) ([]any, map[string]int, error) { | ||
cassandraKeys := make([]any, len(entityKeys)) | ||
cassandraKeyToEntityIndex := make(map[string]int) | ||
|
@@ -237,7 +280,8 @@ func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.Enti | |
} | ||
return cassandraKeys, cassandraKeyToEntityIndex, nil | ||
} | ||
func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { | ||
|
||
func (c *CassandraOnlineStore) UnbatchedKeysOnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { | ||
uniqueNames := make(map[string]int32) | ||
for _, fvName := range featureViewNames { | ||
uniqueNames[fvName] = 0 | ||
|
@@ -265,7 +309,7 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ | |
|
||
// Prepare the query | ||
tableName := c.getFqTableName(featureViewName) | ||
cqlStatement := c.getCQLStatement(tableName, featureNames) | ||
cqlStatement := c.getSingleKeyCQLStatement(tableName, featureNames) | ||
|
||
var waitGroup sync.WaitGroup | ||
waitGroup.Add(len(serializedEntityKeys)) | ||
|
@@ -372,6 +416,156 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ | |
return results, nil | ||
} | ||
|
||
func (c *CassandraOnlineStore) BatchedKeysOnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { | ||
uniqueNames := make(map[string]int32) | ||
for _, fvName := range featureViewNames { | ||
uniqueNames[fvName] = 0 | ||
} | ||
if len(uniqueNames) != 1 { | ||
return nil, fmt.Errorf("rejecting OnlineRead as more than 1 feature view was tried to be read at once") | ||
} | ||
|
||
serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys) | ||
|
||
if err != nil { | ||
return nil, fmt.Errorf("error when serializing entity keys for Cassandra") | ||
} | ||
results := make([][]FeatureData, len(entityKeys)) | ||
for i := range results { | ||
results[i] = make([]FeatureData, len(featureNames)) | ||
} | ||
|
||
featureNamesToIdx := make(map[string]int) | ||
for idx, name := range featureNames { | ||
featureNamesToIdx[name] = idx | ||
} | ||
|
||
featureViewName := featureViewNames[0] | ||
|
||
// Prepare the query | ||
tableName := c.getFqTableName(featureViewName) | ||
|
||
// Key batching | ||
nKeys := len(serializedEntityKeys) | ||
batchSize := c.keyBatchSize | ||
nBatches := int(math.Ceil(float64(nKeys) / float64(batchSize))) | ||
|
||
batches := make([][]any, nBatches) | ||
nAssigned := 0 | ||
for i := 0; i < nBatches; i++ { | ||
thisBatchSize := int(math.Min(float64(batchSize), float64(nKeys-nAssigned))) | ||
nAssigned += thisBatchSize | ||
batches[i] = make([]any, thisBatchSize) | ||
for j := 0; j < thisBatchSize; j++ { | ||
batches[i][j] = serializedEntityKeys[i*batchSize+j] | ||
} | ||
} | ||
|
||
var waitGroup sync.WaitGroup | ||
waitGroup.Add(nBatches) | ||
|
||
errorsChannel := make(chan error, nBatches) | ||
var prevBatchLength int | ||
var cqlStatement string | ||
for _, batch := range batches { | ||
go func(keyBatch []any) { | ||
defer waitGroup.Done() | ||
|
||
// this caches the previous batch query if it had the same number of keys | ||
if len(keyBatch) != prevBatchLength { | ||
cqlStatement = c.getMultiKeyCQLStatement(tableName, featureNames, len(keyBatch)) | ||
} | ||
|
||
iter := c.session.Query(cqlStatement, keyBatch...).WithContext(ctx).Iter() | ||
|
||
scanner := iter.Scanner() | ||
var entityKey string | ||
var featureName string | ||
var eventTs time.Time | ||
var valueStr []byte | ||
var deserializedValue types.Value | ||
// key 1: entityKey - key 2: featureName | ||
batchFeatures := make(map[string]map[string]FeatureData) | ||
for scanner.Next() { | ||
err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr) | ||
if err != nil { | ||
errorsChannel <- errors.New("could not read row in query for (entity key, feature name, value, event ts)") | ||
return | ||
} | ||
if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil { | ||
errorsChannel <- errors.New("error converting parsed Cassandra Value to types.Value") | ||
return | ||
} | ||
|
||
if deserializedValue.Val != nil { | ||
if batchFeatures[entityKey] == nil { | ||
batchFeatures[entityKey] = make(map[string]FeatureData) | ||
} | ||
batchFeatures[entityKey][featureName] = FeatureData{ | ||
Reference: serving.FeatureReferenceV2{ | ||
FeatureViewName: featureViewName, | ||
FeatureName: featureName, | ||
}, | ||
Timestamp: timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}, | ||
Value: types.Value{ | ||
Val: deserializedValue.Val, | ||
}, | ||
} | ||
} | ||
} | ||
|
||
if err := scanner.Err(); err != nil { | ||
errorsChannel <- errors.New("failed to scan features: " + err.Error()) | ||
return | ||
} | ||
|
||
for _, serializedEntityKey := range keyBatch { | ||
for _, featName := range featureNames { | ||
keyString := serializedEntityKey.(string) | ||
featureData, ok := batchFeatures[keyString][featName] | ||
if !ok { | ||
featureData = FeatureData{ | ||
Reference: serving.FeatureReferenceV2{ | ||
FeatureViewName: featureViewName, | ||
FeatureName: featName, | ||
}, | ||
Value: types.Value{ | ||
Val: &types.Value_NullVal{ | ||
NullVal: types.Null_NULL, | ||
}, | ||
}, | ||
} | ||
} | ||
results[serializedEntityKeyToIndex[keyString]][featureNamesToIdx[featName]] = featureData | ||
} | ||
} | ||
}(batch) | ||
} | ||
// wait until all concurrent single-key queries are done | ||
waitGroup.Wait() | ||
close(errorsChannel) | ||
|
||
var collectedErrors []error | ||
for err := range errorsChannel { | ||
if err != nil { | ||
collectedErrors = append(collectedErrors, err) | ||
} | ||
} | ||
if len(collectedErrors) > 0 { | ||
return nil, errors.Join(collectedErrors...) | ||
} | ||
|
||
return results, nil | ||
} | ||
|
||
func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { | ||
if c.keyBatchSize == 1 { | ||
return c.UnbatchedKeysOnlineRead(ctx, entityKeys, featureViewNames, featureNames) | ||
} else { | ||
return c.BatchedKeysOnlineRead(ctx, entityKeys, featureViewNames, featureNames) | ||
} | ||
Comment on lines
+562
to
+566
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need separate implementations? Usually IN clause with 1 value is same as = There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is that, using = avoids query planning overhead even though its not that significant. Also, with the current logic where key batching is done, using it for single key queries would be a overkill? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It shouldn't be. There is so much duplicated code between them at the moment. We may need to refactor a bit to avoid code duplication. |
||
} | ||
|
||
func (c *CassandraOnlineStore) Destruct() { | ||
c.session.Close() | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may need to add
key_batch_size
field in Cassandra Config in Python as well. It may fail to deserialize when feature_store.yaml has this field for materialization.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done