diff --git a/atp/client.go b/atp/client.go index 1272c02..0052ba6 100644 --- a/atp/client.go +++ b/atp/client.go @@ -8,6 +8,7 @@ import ( "io" "strings" "sync" + "time" ) var supportedServerVersions = []int64{1, 3} @@ -72,7 +73,7 @@ func NewClientWithLogger( make(chan bool, 5), // Buffer to prevent deadlocks make([]schema.Input, 0), make(map[string]chan schema.Input), - make(map[string]chan ExecutionResult), + make(map[string]*executionEntry), make(map[string]chan<- schema.Input), sync.Mutex{}, false, @@ -89,6 +90,11 @@ func (c *client) Encoder() *cbor.Encoder { return c.encoder } +type executionEntry struct { + result *ExecutionResult + condition sync.Cond +} + type client struct { atpVersion int64 channel ClientChannel @@ -98,9 +104,9 @@ type client struct { encoder *cbor.Encoder doneChannel chan bool runningSteps []schema.Input - runningSignalReceiveLoops map[string]chan schema.Input // Run ID to channel of signals to steps - runningStepResultChannels map[string]chan ExecutionResult // Run ID to channel of results - runningStepEmittedSignalChannels map[string]chan<- schema.Input // Run ID to channel of signals emitted from steps + runningSignalReceiveLoops map[string]chan schema.Input // Run ID to channel of signals to steps + runningStepResultEntries map[string]*executionEntry // Run ID to results + runningStepEmittedSignalChannels map[string]chan<- schema.Input // Run ID to channel of signals emitted from steps mutex sync.Mutex readLoopRunning bool done bool @@ -175,7 +181,9 @@ func (c *client) Execute( workStartMsg = RuntimeMessage{RunID: stepData.RunID, MessageID: MessageTypeWorkStart, MessageData: workStartMsg} // Handle signals to the step if receivedSignals != nil { + c.wg.Add(1) go func() { + defer c.wg.Done() c.executeWriteLoop(stepData.RunID, receivedSignals) }() } @@ -215,6 +223,7 @@ func (c *client) handleStepComplete(runID string, receivedSignals chan schema.In func (c *client) Close() error { c.mutex.Lock() if c.done { + c.mutex.Unlock() return nil } c.done = true @@ -235,14 +244,41 @@ func (c *client) Close() error { clientDoneMessage{}, }) if err != nil { - return fmt.Errorf("client with steps '%s' failed to write client done message with error: %w", - c.getRunningStepIDs(), err) + // add a timeout to the wait to prevent it from causing a deadlock. + // 5 seconds is arbitrary, but gives it enough time to exit. + waitedGracefully := waitWithTimeout(time.Second*5, &c.wg) + if waitedGracefully { + return fmt.Errorf("client with step '%s' failed to write client done message with error: %w", + c.getRunningStepIDs(), err) + } else { + panic(fmt.Errorf("potential deadlock after client with step '%s' failed to write client done message with error: %w", + c.getRunningStepIDs(), err)) + } } } c.wg.Wait() return nil } +// waitWithTimeout waits for the provided wait group, aborting the wait if +// the provided timeout expires. +// Returns true if the WaitGroup finished, and false if +// it reached the end of the timeout. +func waitWithTimeout(duration time.Duration, wg *sync.WaitGroup) bool { + // Run a goroutine to do the waiting + doneChannel := make(chan bool, 1) + go func() { + defer close(doneChannel) + wg.Wait() + }() + select { + case <-doneChannel: + return true + case <-time.After(duration): + return false + } +} + func (c *client) getRunningStepIDs() string { if len(c.runningSteps) == 0 { return "No running steps" @@ -259,8 +295,19 @@ func (c *client) executeWriteLoop( runID string, receivedSignals chan schema.Input, ) { - // Add the channel to the client so that it can be kept track of c.mutex.Lock() + if c.done { + c.mutex.Unlock() + // Close() was called, so exit now. + // Failure to exit now may result in this receivedSignals channel not getting + // closed, resulting in this function hanging. + c.logger.Warningf( + "write called loop for run ID %q on done client; skipping receive loop", + runID, + ) + return + } + // Add the channel to the client so that it can be kept track of c.runningSignalReceiveLoops[runID] = receivedSignals c.mutex.Unlock() defer func() { @@ -295,45 +342,113 @@ func (c *client) executeWriteLoop( } } -// sendExecutionResult sends the results to the channel, and closes then removes the channels for the -// step results and the signals. +// sendExecutionResult finalizes the result entry for processing by the client's caller, and +// closes then removes the channels for the signals. +// The caller must have the mutex locked while calling this function. func (c *client) sendExecutionResult(runID string, result ExecutionResult) { - c.logger.Debugf("Providing input for run ID '%s'", runID) - c.mutex.Lock() - resultChannel, found := c.runningStepResultChannels[runID] - c.mutex.Unlock() + c.logger.Debugf("Sending results for run ID '%s'", runID) + resultEntry, found := c.runningStepResultEntries[runID] if found { // Send the result - resultChannel <- result - // Close the channel and remove it to detect incorrectly duplicate results. - close(resultChannel) - c.mutex.Lock() - delete(c.runningStepResultChannels, runID) - c.mutex.Unlock() + resultEntry.result = &result + resultEntry.condition.Signal() } else { - c.logger.Errorf("Step result channel not found for run ID '%s'. This is either a bug in the ATP "+ + c.logger.Errorf("Step result entry not found for run ID '%s'. This is either a bug in the ATP "+ "client, or the plugin erroneously sent a second result.", runID) } // Now close the signal channel, since it's invalid to send a signal after the step is complete. - c.mutex.Lock() - defer c.mutex.Unlock() signalChannel, found := c.runningStepEmittedSignalChannels[runID] if !found { - c.logger.Debugf("Could not find signal output channel for run ID '%s'", runID) return } - close(signalChannel) delete(c.runningStepEmittedSignalChannels, runID) + close(signalChannel) } func (c *client) sendErrorToAll(err error) { result := NewErrorExecutionResult(err) - for runID := range c.runningStepResultChannels { + c.mutex.Lock() + for runID := range c.runningStepResultEntries { c.sendExecutionResult(runID, result) } + c.mutex.Unlock() +} + +func (c *client) handleWorkDoneMessage(runtimeMessage DecodedRuntimeMessage) { + var doneMessage WorkDoneMessage + var result ExecutionResult + if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &doneMessage); err != nil { + c.logger.Errorf("Failed to decode work done message (%v) for run ID '%s' ", err, runtimeMessage.RunID) + result = NewErrorExecutionResult(fmt.Errorf("failed to decode work done message (%w)", err)) + } else { + result = c.processWorkDone(runtimeMessage.RunID, doneMessage) + } + c.mutex.Lock() + c.sendExecutionResult(runtimeMessage.RunID, result) + c.mutex.Unlock() +} + +func (c *client) handleSignalMessage(runtimeMessage DecodedRuntimeMessage) { + var signalMessage SignalMessage + if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &signalMessage); err != nil { + c.logger.Errorf("ATP client for run ID '%s' failed to decode signal message: %v", + runtimeMessage.RunID, err) + return + } + c.mutex.Lock() + defer c.mutex.Unlock() // Hold lock until we send to the channel to prevent premature closing of the channel. + signalChannel, found := c.runningStepEmittedSignalChannels[runtimeMessage.RunID] + if !found { + c.logger.Warningf( + "Step with run ID '%s' sent signal '%s'. Ignoring; signal handling is not implemented "+ + "(emittedSignals is nil).", + runtimeMessage.RunID, signalMessage.SignalID) + return + } + c.logger.Debugf("Got signal from step with run ID '%s' with ID '%s'", runtimeMessage.RunID, + signalMessage.SignalID) + signalChannel <- signalMessage.ToInput(runtimeMessage.RunID) +} + +// Returns true if the error is fatal. +func (c *client) handleErrorMessage(runtimeMessage DecodedRuntimeMessage) bool { + var errMessage ErrorMessage + if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &errMessage); err != nil { + c.logger.Errorf("Step with run ID '%s' failed to decode error message: %v", + runtimeMessage.RunID, err) + } + errorMessageStr := errMessage.ToString(runtimeMessage.RunID) + resultMsg := fmt.Errorf("step with run ID %q sent error message: %s", runtimeMessage.RunID, errorMessageStr) + c.logger.Errorf(resultMsg.Error()) + if errMessage.ServerFatal { + c.sendErrorToAll(resultMsg) + return true // It's server fatal, so this is the last message from the server. + } else if errMessage.StepFatal { + if runtimeMessage.RunID == "" { + c.sendErrorToAll(fmt.Errorf("step fatal error missing run id (%w)", resultMsg)) + } else { + c.mutex.Lock() + c.sendExecutionResult(runtimeMessage.RunID, NewErrorExecutionResult(resultMsg)) + c.mutex.Unlock() + } + } + return false +} + +func (c *client) hasEntriesRemaining() bool { + c.mutex.Lock() + defer c.mutex.Unlock() + for _, resultEntry := range c.runningStepResultEntries { + // If any result is nil then we're not done. + // Context: There is a fraction of time when the entry is still in the map + // following completion. It is set to a non-nil value when done. + if resultEntry.result == nil { + return true + } + } + return false } -//nolint:funlen func (c *client) executeReadLoop(cborReader *cbor.Decoder) { defer func() { c.mutex.Lock() @@ -346,66 +461,35 @@ func (c *client) executeReadLoop(cborReader *cbor.Decoder) { var runtimeMessage DecodedRuntimeMessage for { if err := cborReader.Decode(&runtimeMessage); err != nil { - c.logger.Errorf("ATP client for steps '%s' failed to read or decode runtime message: %v", c.getRunningStepIDs(), err) + c.logger.Errorf( + "ATP client for steps '%s' failed to read or decode runtime message: %v", + c.getRunningStepIDs(), + err, + ) // This is fatal since the entire structure of the runtime message is invalid. c.sendErrorToAll(fmt.Errorf("failed to read or decode runtime message (%w)", err)) return } switch runtimeMessage.MessageID { case MessageTypeWorkDone: - var doneMessage WorkDoneMessage - if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &doneMessage); err != nil { - c.logger.Errorf("Failed to decode work done message (%v) for run ID '%s' ", err, runtimeMessage.RunID) - c.sendExecutionResult(runtimeMessage.RunID, NewErrorExecutionResult( - fmt.Errorf("failed to decode work done message (%w)", err))) - } - c.sendExecutionResult(runtimeMessage.RunID, c.processWorkDone(runtimeMessage.RunID, doneMessage)) + c.handleWorkDoneMessage(runtimeMessage) case MessageTypeSignal: - var signalMessage SignalMessage - if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &signalMessage); err != nil { - c.logger.Errorf("ATP client for run ID '%s' failed to decode signal message: %v", - runtimeMessage.RunID, err) - } - signalChannel, found := c.runningStepEmittedSignalChannels[runtimeMessage.RunID] - if !found { - c.logger.Warningf( - "Step with run ID '%s' sent signal '%s'. Ignoring; signal handling is not implemented "+ - "(emittedSignals is nil).", - runtimeMessage.RunID, signalMessage.SignalID) - } else { - c.logger.Debugf("Got signal from step with run ID '%s' with ID '%s'", runtimeMessage.RunID, - signalMessage.SignalID) - signalChannel <- signalMessage.ToInput(runtimeMessage.RunID) - } + c.handleSignalMessage(runtimeMessage) case MessageTypeError: - var errMessage ErrorMessage - if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &errMessage); err != nil { - c.logger.Errorf("Step with run ID '%s' failed to decode error message: %v", - runtimeMessage.RunID, err) - } - c.logger.Errorf("Step with run ID '%s' sent error message: %v", runtimeMessage.RunID, errMessage) - resultMsg := fmt.Errorf("step '%s' sent error message: %s", runtimeMessage.RunID, - errMessage.ToString(runtimeMessage.RunID)) - if errMessage.ServerFatal { - c.sendErrorToAll(resultMsg) - return // It's server fatal, so this is the last message from the server. - } else if errMessage.StepFatal { - if runtimeMessage.RunID == "" { - c.sendErrorToAll(fmt.Errorf("step fatal error missing run id (%w)", resultMsg)) - } else { - c.sendExecutionResult(runtimeMessage.RunID, NewErrorExecutionResult(resultMsg)) - } + if c.handleErrorMessage(runtimeMessage) { + return // Fatal } default: - c.logger.Warningf("Step with run ID '%s' sent unknown message type: %s", runtimeMessage.RunID, - runtimeMessage.MessageID) + c.logger.Warningf( + "Step with run ID '%s' sent unknown message type: %d", + runtimeMessage.RunID, + runtimeMessage.MessageID, + ) } - c.mutex.Lock() - if len(c.runningStepResultChannels) == 0 { - c.mutex.Unlock() - return // Done + // The non-error exit condition is having no more entries remaining. + if !c.hasEntriesRemaining() { + return } - c.mutex.Unlock() } } @@ -441,15 +525,19 @@ func (c *client) prepareResultChannels( stepData schema.Input, emittedSignals chan<- schema.Input, ) error { + c.logger.Debugf("Preparing result channels for step with run ID %q", stepData.RunID) c.mutex.Lock() defer c.mutex.Unlock() - _, existing := c.runningStepResultChannels[stepData.RunID] + _, existing := c.runningStepResultEntries[stepData.RunID] if existing { return fmt.Errorf("duplicate run ID given '%s'", stepData.RunID) } // Set up the signal and step results channels - resultChannel := make(chan ExecutionResult) - c.runningStepResultChannels[stepData.RunID] = resultChannel + resultEntry := executionEntry{ + result: nil, + condition: sync.Cond{L: &c.mutex}, + } + c.runningStepResultEntries[stepData.RunID] = &resultEntry if emittedSignals != nil { c.runningStepEmittedSignalChannels[stepData.RunID] = emittedSignals } @@ -465,28 +553,30 @@ func (c *client) prepareResultChannels( return nil } -// getResultV2 works with the channels that communicate with the RuntimeMessage loop. -func (c *client) getResultV2( - stepData schema.Input, -) ExecutionResult { +// getResultV2 communicates with the RuntimeMessage loop to get the ExecutionResult. +func (c *client) getResultV2(stepData schema.Input) ExecutionResult { c.mutex.Lock() - resultChannel, found := c.runningStepResultChannels[stepData.RunID] - c.mutex.Unlock() + defer c.mutex.Unlock() + resultEntry, found := c.runningStepResultEntries[stepData.RunID] if !found { return NewErrorExecutionResult( - fmt.Errorf("could not find result channel for step with run ID '%s'", - stepData.RunID), + fmt.Errorf("could not find result entry for step with run ID '%s'. Existing entries: %v", + stepData.RunID, c.runningStepResultEntries), ) } - // Wait for the result - result, received := <-resultChannel - if !received { - return NewErrorExecutionResult( - fmt.Errorf("did not receive result from results channel in ATP client for step with run ID '%s'", - stepData.RunID), - ) + if resultEntry.result == nil { + // Wait for the result + resultEntry.condition.Wait() } - return result + if resultEntry.result == nil { + panic(fmt.Errorf("did not receive result from results entry in ATP client for step with run ID '%s'", + stepData.RunID)) + } + // Now that we've received the result for this step, remove it from the list of running steps. + // We do this here because the sender cannot tell when the message has been received, and so + // it cannot tell when it is safe to remove the entry from the map. + delete(c.runningStepResultEntries, stepData.RunID) + return *resultEntry.result } func (c *client) processWorkDone( diff --git a/atp/server.go b/atp/server.go index e723b54..9f3b15d 100644 --- a/atp/server.go +++ b/atp/server.go @@ -8,6 +8,7 @@ import ( "io" "os" "sync" + "time" ) // RunATPServer runs an ArcaflowTransportProtocol server with a given schema. @@ -83,12 +84,22 @@ func initializeATPServerSession( func (s *atpServerSession) sendRuntimeMessage(msgID uint32, runID string, message any) error { s.encoderMutex.Lock() + doneChannel := make(chan error, 1) + go func() { + defer close(doneChannel) + doneChannel <- s.cborStdout.Encode(RuntimeMessage{ + MessageID: msgID, + RunID: runID, + MessageData: message, + }) + }() defer s.encoderMutex.Unlock() - return s.cborStdout.Encode(RuntimeMessage{ - MessageID: msgID, - RunID: runID, - MessageData: message, - }) + select { + case err := <-doneChannel: + return err + case <-time.After(time.Second * 60): + return fmt.Errorf("send timeout exceeded while sending message ID %q for run id %q", msgID, runID) + } } func (s *atpServerSession) handleClosure() []*ServerError {