Skip to content

Commit

Permalink
Refactor client, and add test
Browse files Browse the repository at this point in the history
Makes it so the client will not deadlock if the channel is not closed, and closes missing code coverage
  • Loading branch information
jaredoconnell committed Jul 29, 2024
1 parent 1ff05a7 commit f627790
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 17 deletions.
50 changes: 33 additions & 17 deletions atp/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package atp

import (
"context"
"fmt"
"github.com/fxamacker/cbor/v2"
"go.arcalot.io/log/v2"
Expand Down Expand Up @@ -36,7 +37,13 @@ type Client interface {
// ReadSchema reads the schema from the ATP server.
ReadSchema() (*schema.SchemaSchema, error)
// Execute executes a step with a given context and returns the resulting output. Assumes you called ReadSchema first.
Execute(input schema.Input, receivedSignals <-chan schema.Input, emittedSignals chan<- schema.Input) ExecutionResult
// Params:
// - input: The step input for the run.
// - signalsToStep: A channel to send signals from the client to the plugin.
// - signalsFromStep: A channel to receive signals from the plugin to the client.
// It is recommended to close the signalsToStep channel when either Execute is done or it is known that no more signals
// will be sent to the plugin.
Execute(input schema.Input, signalsToStep <-chan schema.Input, signalsFromStep chan<- schema.Input) ExecutionResult
Close() error
Encoder() *cbor.Encoder
Decoder() *cbor.Decoder
Expand Down Expand Up @@ -66,20 +73,22 @@ func NewClientWithLogger(
if logger == nil {
logger = log.NewLogger(log.LevelDebug, log.NewNOOPLogger())
}
ctx, cancel := context.WithCancel(context.Background())
return &client{
-1, // unknown
channel,
decMode,
logger,
decMode.NewDecoder(channel),
cbor.NewEncoder(channel),
make(chan bool, 5), // Buffer to prevent deadlocks
make([]schema.Input, 0),
make(map[string]*executionEntry),
make(map[string]chan<- schema.Input),
sync.Mutex{},
false,
false,
ctx,
cancel,
sync.WaitGroup{},
}
}
Expand All @@ -99,18 +108,19 @@ type executionEntry struct {

type client struct {
atpVersion int64
channel ClientChannel
rawChannels ClientChannel
decMode cbor.DecMode
logger log.Logger
decoder *cbor.Decoder
encoder *cbor.Encoder
doneChannel chan bool
runningSteps []schema.Input
runningStepResultEntries map[string]*executionEntry // Run ID to results
runningStepEmittedSignalChannels map[string]chan<- schema.Input // Run ID to channel of signals emitted from steps
mutex sync.Mutex
readLoopRunning bool
readLoopRunning bool // To prevent duplicate loops across multiple step executions.
done bool
context context.Context
cancelFunc context.CancelFunc
wg sync.WaitGroup // For the read loop.
}

Expand Down Expand Up @@ -165,8 +175,8 @@ func (c *client) validateVersion(serverVersion int64) error {

func (c *client) Execute(
stepData schema.Input,
receivedSignals <-chan schema.Input,
emittedSignals chan<- schema.Input,
signalsToStep <-chan schema.Input,
signalsFromStep chan<- schema.Input,
) ExecutionResult {
c.logger.Debugf("Executing plugin step %s/%s...", stepData.RunID, stepData.ID)
if len(stepData.RunID) == 0 {
Expand All @@ -177,20 +187,20 @@ func (c *client) Execute(
StepID: stepData.ID,
Config: stepData.InputData,
}
cborReader := c.decMode.NewDecoder(c.channel)
cborReader := c.decMode.NewDecoder(c.rawChannels)
if c.atpVersion > 1 {
// Wrap it in a runtime message.
workStartMsg = RuntimeMessage{RunID: stepData.RunID, MessageID: MessageTypeWorkStart, MessageData: workStartMsg}
// Handle signals to the step
if receivedSignals != nil {
if signalsToStep != nil {
c.wg.Add(1)
go func() {
defer c.wg.Done()
c.executeWriteLoop(stepData.RunID, receivedSignals)
c.executeWriteLoop(stepData.RunID, signalsToStep)
}()
}
// Setup channels for ATP v2
err := c.prepareResultChannels(cborReader, stepData, emittedSignals)
err := c.prepareResultChannels(cborReader, stepData, signalsFromStep)
if err != nil {
return NewErrorExecutionResult(err)
}
Expand All @@ -206,6 +216,7 @@ func (c *client) Execute(

// Close Tells the client that it's done, and can stop listening for more requests.
func (c *client) Close() error {
c.cancelFunc()
c.mutex.Lock()
if c.done {
c.mutex.Unlock()
Expand Down Expand Up @@ -271,14 +282,12 @@ func (c *client) getRunningStepIDs() string {
// Listen for received signals, and send them over ATP if available.
func (c *client) executeWriteLoop(
runID string,
receivedSignals <-chan schema.Input,
signalsToStep <-chan schema.Input,
) {
c.mutex.Lock()
if c.done {
c.mutex.Unlock()
// Close() was called, so exit now.
// Failure to exit now may result in this receivedSignals channel not getting
// closed, resulting in this function hanging.
c.logger.Warningf(
"write called loop for run ID %q on done client; skipping receive loop",
runID,
Expand All @@ -289,9 +298,16 @@ func (c *client) executeWriteLoop(

// Looped select that gets signals
for {
signal, ok := <-receivedSignals
if !ok {
c.logger.Infof("ATP signal loop done")
var signal schema.Input
var ok bool
select {
case signal, ok = <-signalsToStep:
if !ok {
c.logger.Debugf("ATP signal loop done; channel closed")
return
}
case <-c.context.Done():
c.logger.Debugf("ATP signal loop exited; context closed")
return
}
c.logger.Debugf("Sending signal with ID '%s' to step with run ID '%s'", signal.ID, signal.RunID)
Expand Down
60 changes: 60 additions & 0 deletions atp/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,66 @@ func TestProtocol_Client_Execute(t *testing.T) {
wg.Wait()
}

func TestProtocol_Client_Execute_With_Signals(t *testing.T) {
testExecuteWithChannels(true, t)
}

func TestProtocol_Client_Execute_With_Signals_Unclosed(t *testing.T) {
testExecuteWithChannels(false, t)
}

func testExecuteWithChannels(closeChannel bool, t *testing.T) {
// Client ReadSchema and Execute happy path with signal handlers passed
// into the Execute call.
ctx, cancel := context.WithCancel(context.Background())
wg := &sync.WaitGroup{}
wg.Add(2)
stdinReader, stdinWriter := io.Pipe()
stdoutReader, stdoutWriter := io.Pipe()

go func() {
defer wg.Done()
errors := atp.RunATPServer(
ctx,
stdinReader,
stdoutWriter,
helloWorldSchema,
)
assert.Equals(t, len(errors), 0)
}()

go func() {
defer wg.Done()
cli := atp.NewClientWithLogger(channel{
Reader: stdoutReader,
Writer: stdinWriter,
Context: nil,
cancel: cancel,
}, log.NewTestLogger(t))

_, err := cli.ReadSchema()
assert.NoError(t, err)
toStepChan := make(chan schema.Input)
fromStepChan := make(chan schema.Input)

result := cli.Execute(
schema.Input{
RunID: t.Name(),
ID: "hello-world",
InputData: map[string]any{"name": "Arca Lot"},
}, toStepChan, fromStepChan)
if closeChannel {
close(toStepChan)
}
assert.NoError(t, cli.Close())
assert.NoError(t, result.Error)
assert.Equals(t, result.OutputID, "success")
assert.Equals(t, result.OutputData.(map[any]any)["message"].(string), "Hello, Arca Lot!")
}()

wg.Wait()
}

//nolint:funlen
func TestProtocol_Client_ATP_v1(t *testing.T) {
// Client ReadSchema and Execute atp v1 happy path.
Expand Down

0 comments on commit f627790

Please sign in to comment.