From ee7a259d8b830a08f15f5440f1f1917744d0b0a2 Mon Sep 17 00:00:00 2001 From: luke-lombardi <33990301+luke-lombardi@users.noreply.github.com> Date: Wed, 11 Dec 2024 17:10:28 -0500 Subject: [PATCH] Fix: Use heartbeats to correct for request token drift (#777) - Uses key events + request heartbeats to correct for token drift after a particular gateway crashed unexpectedly The consequence of this is that if a gateway that was actively handling a request crashes or is forcibly terminated, there will be a 30 delay before any active containers that were handling requests have the token count incremented. This should fix "bricked" containers that had an inaccurately low token count. --- pkg/abstractions/endpoint/buffer.go | 55 +++++++++++++++++++-------- pkg/abstractions/endpoint/endpoint.go | 8 ++-- pkg/abstractions/endpoint/task.go | 7 +++- 3 files changed, 50 insertions(+), 20 deletions(-) diff --git a/pkg/abstractions/endpoint/buffer.go b/pkg/abstractions/endpoint/buffer.go index c6ae83048..101c9ac01 100644 --- a/pkg/abstractions/endpoint/buffer.go +++ b/pkg/abstractions/endpoint/buffer.go @@ -58,6 +58,8 @@ type RequestBuffer struct { availableContainersLock sync.RWMutex maxTokens int isASGI bool + keyEventManager *common.KeyEventManager + keyEventChan chan common.KeyEvent } func NewRequestBuffer( @@ -67,12 +69,13 @@ func NewRequestBuffer( stubId string, size int, containerRepo repository.ContainerRepository, + keyEventManager *common.KeyEventManager, stubConfig *types.StubConfigV1, tailscale *network.Tailscale, tsConfig types.TailscaleConfig, isASGI bool, ) *RequestBuffer { - b := &RequestBuffer{ + rb := &RequestBuffer{ ctx: ctx, rdb: rdb, workspace: workspace, @@ -82,6 +85,8 @@ func NewRequestBuffer( availableContainers: []container{}, availableContainersLock: sync.RWMutex{}, containerRepo: containerRepo, + keyEventManager: keyEventManager, + keyEventChan: make(chan common.KeyEvent), httpClient: &http.Client{}, tailscale: tailscale, tsConfig: tsConfig, @@ -91,13 +96,38 @@ func NewRequestBuffer( if stubConfig.ConcurrentRequests > 1 && isASGI { // Floor is set to the number of workers - b.maxTokens = max(int(stubConfig.ConcurrentRequests), b.maxTokens) + rb.maxTokens = max(int(stubConfig.ConcurrentRequests), rb.maxTokens) } - go b.discoverContainers() - go b.processRequests() + go rb.discoverContainers() + go rb.processRequests() - return b + // Listen for heartbeat key events + go rb.keyEventManager.ListenForPattern(rb.ctx, Keys.endpointRequestHeartbeat(rb.workspace.Name, rb.stubId, "*", "*"), rb.keyEventChan) + go rb.handleHeartbeatEvents() + + return rb +} + +func (rb *RequestBuffer) handleHeartbeatEvents() { + for { + select { + case event := <-rb.keyEventChan: + operation := event.Operation + + switch operation { + case common.KeyOperationSet, common.KeyOperationHSet, common.KeyOperationDel, common.KeyOperationExpire: + // Do nothing + case common.KeyOperationExpired: + if parts := strings.Split(event.Key, ":"); len(parts) >= 2 { + taskId, containerId := parts[len(parts)-2], parts[len(parts)-1] + rb.releaseRequestToken(containerId, taskId) + } + } + case <-rb.ctx.Done(): + return + } + } } func (rb *RequestBuffer) ForwardRequest(ctx echo.Context, task *EndpointTask) error { @@ -294,12 +324,7 @@ func (rb *RequestBuffer) acquireRequestToken(containerId string) error { return nil } -func (rb *RequestBuffer) releaseRequestToken(containerId string) error { - // TODO: if a gateway crashes before releasing the token, it could lead to a drift - // in the count of available request tokens for a particular container. To handle this - // we could move the release logic to the task implementation (e.g. task.Complete), so that - // it handles the release of the token and is not tied to a specific gateway - +func (rb *RequestBuffer) releaseRequestToken(containerId, taskId string) error { tokenKey := Keys.endpointRequestTokens(rb.workspace.Name, rb.stubId, containerId) err := rb.rdb.Incr(rb.ctx, tokenKey).Err() @@ -312,7 +337,7 @@ func (rb *RequestBuffer) releaseRequestToken(containerId string) error { return err } - return nil + return rb.rdb.Del(rb.ctx, Keys.endpointRequestHeartbeat(rb.workspace.Name, rb.stubId, taskId, containerId)).Err() } func (rb *RequestBuffer) getHttpClient(address string) (*http.Client, error) { @@ -492,7 +517,7 @@ func (rb *RequestBuffer) heartBeat(req *request, containerId string) { ticker := time.NewTicker(endpointRequestHeartbeatInterval) defer ticker.Stop() - rb.rdb.Set(rb.ctx, Keys.endpointRequestHeartbeat(rb.workspace.Name, rb.stubId, req.task.msg.TaskId), containerId, endpointRequestHeartbeatInterval) + rb.rdb.Set(rb.ctx, Keys.endpointRequestHeartbeat(rb.workspace.Name, rb.stubId, req.task.msg.TaskId, containerId), 1, endpointRequestHeartbeatInterval) for { select { case <-ctx.Done(): @@ -500,7 +525,7 @@ func (rb *RequestBuffer) heartBeat(req *request, containerId string) { case <-rb.ctx.Done(): return case <-ticker.C: - rb.rdb.Set(rb.ctx, Keys.endpointRequestHeartbeat(rb.workspace.Name, rb.stubId, req.task.msg.TaskId), containerId, endpointRequestHeartbeatInterval) + rb.rdb.Set(rb.ctx, Keys.endpointRequestHeartbeat(rb.workspace.Name, rb.stubId, req.task.msg.TaskId, containerId), 1, endpointRequestHeartbeatInterval) } } } @@ -510,7 +535,7 @@ func (rb *RequestBuffer) afterRequest(req *request, containerId string) { req.done <- true }() - defer rb.releaseRequestToken(containerId) + defer rb.releaseRequestToken(containerId, req.task.msg.TaskId) // Set keep warm lock if rb.stubConfig.KeepWarmSeconds == 0 { diff --git a/pkg/abstractions/endpoint/endpoint.go b/pkg/abstractions/endpoint/endpoint.go index fa711ea4b..f6a772498 100644 --- a/pkg/abstractions/endpoint/endpoint.go +++ b/pkg/abstractions/endpoint/endpoint.go @@ -272,7 +272,7 @@ func (es *HttpEndpointService) getOrCreateEndpointInstance(ctx context.Context, instance.isASGI = true } - instance.buffer = NewRequestBuffer(autoscaledInstance.Ctx, es.rdb, &stub.Workspace, stubId, requestBufferSize, es.containerRepo, stubConfig, es.tailscale, es.config.Tailscale, instance.isASGI) + instance.buffer = NewRequestBuffer(autoscaledInstance.Ctx, es.rdb, &stub.Workspace, stubId, requestBufferSize, es.containerRepo, es.keyEventManager, stubConfig, es.tailscale, es.config.Tailscale, instance.isASGI) // Embed autoscaled instance struct instance.AutoscaledInstance = autoscaledInstance @@ -314,7 +314,7 @@ var ( endpointKeepWarmLock string = "endpoint:%s:%s:keep_warm_lock:%s" endpointInstanceLock string = "endpoint:%s:%s:instance_lock" endpointRequestTokens string = "endpoint:%s:%s:request_tokens:%s" - endpointRequestHeartbeat string = "endpoint:%s:%s:request_heartbeat:%s" + endpointRequestHeartbeat string = "endpoint:%s:%s:request_heartbeat:%s:%s" endpointServeLock string = "endpoint:%s:%s:serve_lock" ) @@ -330,8 +330,8 @@ func (k *keys) endpointRequestTokens(workspaceName, stubId, containerId string) return fmt.Sprintf(endpointRequestTokens, workspaceName, stubId, containerId) } -func (k *keys) endpointRequestHeartbeat(workspaceName, stubId, taskId string) string { - return fmt.Sprintf(endpointRequestHeartbeat, workspaceName, stubId, taskId) +func (k *keys) endpointRequestHeartbeat(workspaceName, stubId, taskId, containerId string) string { + return fmt.Sprintf(endpointRequestHeartbeat, workspaceName, stubId, taskId, containerId) } func (k *keys) endpointServeLock(workspaceName, stubId string) string { diff --git a/pkg/abstractions/endpoint/task.go b/pkg/abstractions/endpoint/task.go index a3d0f7174..1dda9241e 100644 --- a/pkg/abstractions/endpoint/task.go +++ b/pkg/abstractions/endpoint/task.go @@ -74,7 +74,12 @@ func (t *EndpointTask) Cancel(ctx context.Context, reason types.TaskCancellation } func (t *EndpointTask) HeartBeat(ctx context.Context) (bool, error) { - heartbeatKey := Keys.endpointRequestHeartbeat(t.msg.WorkspaceName, t.msg.StubId, t.msg.TaskId) + task, err := t.es.backendRepo.GetTask(ctx, t.msg.TaskId) + if err != nil { + return false, err + } + + heartbeatKey := Keys.endpointRequestHeartbeat(t.msg.WorkspaceName, t.msg.StubId, t.msg.TaskId, task.ContainerId) exists, err := t.es.rdb.Exists(ctx, heartbeatKey).Result() if err != nil { return false, fmt.Errorf("failed to retrieve endpoint heartbeat key <%v>: %w", heartbeatKey, err)