diff --git a/pkg/load/pull.go b/pkg/load/pull.go index 4f669190..526c858b 100644 --- a/pkg/load/pull.go +++ b/pkg/load/pull.go @@ -58,7 +58,13 @@ func ImagePullPrivileged(ctx context.Context, dockerapi docker.APIClient, imageN return err } } else { - if err := printPull(ctx, responseBody, logger); err != nil { + ctx, cancel := context.WithCancel(ctx) + defer cancel() // Used to stop the go routine if printPull returns error early. + + msgCh := make(chan Message, 4096) + go decode(ctx, responseBody, msgCh) + err := printPull(ctx, msgCh, logger) + if err != nil { return err } } @@ -156,7 +162,52 @@ type PullProgress struct { Vtx *client.VertexStatus } -func printPull(_ context.Context, rc io.Reader, l progress.SubLogger) error { +type Message struct { + msg *jsonmessage.JSONMessage + err error +} + +// decode reads the body of the response from Docker and decodes it into JSON messages as fast +// as it can. It does not block on the channel and prefers to drop messages if the channel is full +// to prevent Docker from blocking on the pull. +func decode(ctx context.Context, r io.Reader, msgCh chan<- Message) { + defer close(msgCh) + + dec := json.NewDecoder(r) + for { + select { + case <-ctx.Done(): + select { + case msgCh <- Message{err: ctx.Err()}: + default: + } + return + default: + } + + var msg jsonmessage.JSONMessage + if err := dec.Decode(&msg); err != nil { + if err == io.EOF { + return + } + + select { + case msgCh <- Message{err: err}: + default: + } + } + + // If we block here it is possible for Docker to block on the pull. + select { + case msgCh <- Message{msg: &msg}: + default: + } + } +} + +// printPull will convert the messages to useful on screen content. +// we want to read as fast as possible as docker will block if the body buffer becomes too full. +func printPull(ctx context.Context, msgCh <-chan Message, l progress.SubLogger) error { started := map[string]PullProgress{} defer func() { @@ -170,26 +221,29 @@ func printPull(_ context.Context, rc io.Reader, l progress.SubLogger) error { } }() - dec := json.NewDecoder(rc) - var ( - parsedError error - jm jsonmessage.JSONMessage + msg Message + ok bool ) for { - if err := dec.Decode(&jm); err != nil { - if parsedError != nil { - return parsedError + select { + case <-ctx.Done(): + return ctx.Err() + case msg, ok = <-msgCh: + if !ok { + return nil } - if err == io.EOF { - break - } - return err } + if msg.err != nil { + return msg.err + } + + jm := msg.msg + if jm.Error != nil { - parsedError = jm.Error + return jm.Error } if jm.ID == "" { @@ -270,5 +324,4 @@ func printPull(_ context.Context, rc io.Reader, l progress.SubLogger) error { l.SetStatus(st.Vtx) } - return nil }