From e7d85d4a8fe70530abf153002fb510ea5fcac4a3 Mon Sep 17 00:00:00 2001 From: Christian Ege Date: Fri, 7 Jun 2024 16:19:47 +0200 Subject: [PATCH 1/8] feat: add a swupdater package --- go.mod | 5 +- go.sum | 14 ++--- pkg/swupdater/swupdater.go | 108 +++++++++++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 11 deletions(-) create mode 100644 pkg/swupdater/swupdater.go diff --git a/go.mod b/go.mod index d1f12aa..5164633 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,9 @@ go 1.21 require ( alexejk.io/go-xmlrpc v0.4.0 + github.com/gorilla/websocket v1.5.1 github.com/spf13/cobra v1.7.0 + github.com/stretchr/testify v1.8.4 ) require ( @@ -12,7 +14,6 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/stretchr/objx v0.5.0 // indirect - github.com/stretchr/testify v1.8.4 // indirect + golang.org/x/net v0.17.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 95a778e..bad1fcc 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,10 @@ alexejk.io/go-xmlrpc v0.4.0 h1:HvaeCuACuF2NBJruG90AJKc5JHRGj9vKxu2ltJntQR4= alexejk.io/go-xmlrpc v0.4.0/go.mod h1:M7f2OzqvZIWrN1LftR4uwW/bLpxrFoQYjWfm4gQB4JA= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -13,16 +14,11 @@ github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/swupdater/swupdater.go b/pkg/swupdater/swupdater.go new file mode 100644 index 0000000..554fc31 --- /dev/null +++ b/pkg/swupdater/swupdater.go @@ -0,0 +1,108 @@ +package swupdater + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "time" + + "github.com/gorilla/websocket" +) + +type SWUpdater struct { + hostName string + port int + path string + urlUpload string + urlStatus string + done chan error +} + +func NewSWUpdater(hostName, path string, port int) *SWUpdater { + return &SWUpdater{ + hostName: hostName, + port: port, + path: path, + urlUpload: fmt.Sprintf("http://%s:%d%s/upload", hostName, port, path), + urlStatus: fmt.Sprintf("ws://%s:%d%s/ws", hostName, port, path), + done: make(chan error), + } +} + +func (s *SWUpdater) upload(image io.Reader, timeout time.Duration) error { + req, err := http.NewRequest("POST", s.urlUpload, image) + if err != nil { + return fmt.Errorf("cannot create request: %w", err) + } + + client := &http.Client{Timeout: timeout} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("cannot upload software image: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("cannot upload software image: status code %d", resp.StatusCode) + } + + return nil +} + +func (s *SWUpdater) waitForFinished() { + c, _, err := websocket.DefaultDialer.Dial(s.urlStatus, nil) + if err != nil { + s.done <- fmt.Errorf("cannot connect to websocket: %w", err) + return + } + defer c.Close() + + for { + _, message, err := c.ReadMessage() + if err != nil { + s.done <- fmt.Errorf("cannot read message from websocket: %w", err) + return + } + + data := make(map[string]string) + err = json.Unmarshal(message, &data) + if err != nil { + continue + } + + if data["type"] != "message" { + continue + } + + if data["text"] == "SWUPDATE successful" { + s.done <- nil + return + } + if data["text"] == "Installation failed" { + s.done <- errors.New("installation failed") + return + } + } +} + +func (s *SWUpdater) Update(image io.Reader, timeout time.Duration) error { + go s.waitForFinished() + go func() { + err := s.upload(image, timeout) + if err != nil { + s.done <- err + } + }() + + select { + case err := <-s.done: + if err != nil { + return err + } + return nil + case <-time.After(timeout): + return errors.New("timeout") + } +} From b86dc1dffd56fea7e523e7c653a9439e90bfaef9 Mon Sep 17 00:00:00 2001 From: Christian Ege Date: Mon, 10 Jun 2024 17:45:04 +0200 Subject: [PATCH 2/8] fix: upload the image as multipart This is quite memory hungry because the whole file is copied to memory first --- cmd/ovp8xx/cmd/swupdate.go | 60 ++++++++++++++++++++++++++++ pkg/swupdater/swupdater.go | 80 +++++++++++++++++++++++++------------- 2 files changed, 113 insertions(+), 27 deletions(-) create mode 100644 cmd/ovp8xx/cmd/swupdate.go diff --git a/cmd/ovp8xx/cmd/swupdate.go b/cmd/ovp8xx/cmd/swupdate.go new file mode 100644 index 0000000..04fbac8 --- /dev/null +++ b/cmd/ovp8xx/cmd/swupdate.go @@ -0,0 +1,60 @@ +/* +Copyright © 2023 Christian Ege +*/ +package cmd + +import ( + "fmt" + "time" + + "github.com/graugans/go-ovp8xx/pkg/swupdater" + "github.com/spf13/cobra" +) + +func swupdateCommand(cmd *cobra.Command, args []string) error { + var err error + host, err := rootCmd.PersistentFlags().GetString("ip") + if err != nil { + return fmt.Errorf("cannot get host: %w", err) + } + + port, err := cmd.Flags().GetUint16("port") + if err != nil { + return fmt.Errorf("cannot get port: %w", err) + } + + filename, err := cmd.Flags().GetString("file") + if err != nil { + return fmt.Errorf("cannot get filename: %w", err) + } + + timeout, err := cmd.Flags().GetDuration("timeout") + if err != nil { + return fmt.Errorf("cannot get timeout: %w", err) + } + + fmt.Printf("Updating firmware on %s:%d with file %s (%v)\n", host, port, filename, timeout) + + swu := swupdater.NewSWUpdater(host, port) + + err = swu.Update(filename, timeout) + if err != nil { + return fmt.Errorf("software update failed: %w", err) + } + + return nil +} + +// swupdateCmd represents the swupdate command +var swupdateCmd = &cobra.Command{ + Use: "swupdate", + Short: "Update the firmware on the device", + RunE: swupdateCommand, +} + +func init() { + rootCmd.AddCommand(swupdateCmd) + swupdateCmd.Flags().String("file", "", "A file conatining the firmware image") + swupdateCmd.Flags().Uint16("port", 8080, "Port number for SWUpdate") + swupdateCmd.Flags().Duration("timeout", 5*time.Minute, "The timeout for the upload") +} diff --git a/pkg/swupdater/swupdater.go b/pkg/swupdater/swupdater.go index 554fc31..10ab820 100644 --- a/pkg/swupdater/swupdater.go +++ b/pkg/swupdater/swupdater.go @@ -1,11 +1,15 @@ package swupdater import ( + "bytes" "encoding/json" "errors" "fmt" "io" + "mime/multipart" "net/http" + "os" + "strconv" "time" "github.com/gorilla/websocket" @@ -13,30 +17,52 @@ import ( type SWUpdater struct { hostName string - port int - path string + port uint16 urlUpload string urlStatus string - done chan error } -func NewSWUpdater(hostName, path string, port int) *SWUpdater { +func NewSWUpdater(hostName string, port uint16) *SWUpdater { return &SWUpdater{ hostName: hostName, port: port, - path: path, - urlUpload: fmt.Sprintf("http://%s:%d%s/upload", hostName, port, path), - urlStatus: fmt.Sprintf("ws://%s:%d%s/ws", hostName, port, path), - done: make(chan error), + urlUpload: fmt.Sprintf("http://%s:%d/upload", hostName, port), + urlStatus: fmt.Sprintf("ws://%s:%d/ws", hostName, port), } } +func (s *SWUpdater) upload(filename string, timeout time.Duration) error { + image, err := os.Open(filename) + if err != nil { + return fmt.Errorf("cannot open file: %w", err) + } + fmt.Printf("Uploading software image to %s\n", s.urlUpload) + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + part, err := writer.CreateFormFile("file", filename) + if err != nil { + return fmt.Errorf("cannot create form file: %w", err) + } -func (s *SWUpdater) upload(image io.Reader, timeout time.Duration) error { - req, err := http.NewRequest("POST", s.urlUpload, image) + _, err = io.Copy(part, image) + if err != nil { + return fmt.Errorf("cannot write to form file: %w", err) + } + + err = writer.Close() + if err != nil { + return fmt.Errorf("cannot close multipart writer: %w", err) + } + + req, err := http.NewRequest("POST", s.urlUpload, bytes.NewReader(body.Bytes())) if err != nil { return fmt.Errorf("cannot create request: %w", err) } + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Content-Length", strconv.Itoa(body.Len())) + client := &http.Client{Timeout: timeout} resp, err := client.Do(req) if err != nil { @@ -51,10 +77,10 @@ func (s *SWUpdater) upload(image io.Reader, timeout time.Duration) error { return nil } -func (s *SWUpdater) waitForFinished() { +func (s *SWUpdater) waitForFinished(done chan error) { c, _, err := websocket.DefaultDialer.Dial(s.urlStatus, nil) if err != nil { - s.done <- fmt.Errorf("cannot connect to websocket: %w", err) + done <- fmt.Errorf("cannot connect to websocket: %w", err) return } defer c.Close() @@ -62,44 +88,44 @@ func (s *SWUpdater) waitForFinished() { for { _, message, err := c.ReadMessage() if err != nil { - s.done <- fmt.Errorf("cannot read message from websocket: %w", err) + done <- fmt.Errorf("cannot read message from websocket: %w", err) return } data := make(map[string]string) err = json.Unmarshal(message, &data) if err != nil { - continue + done <- fmt.Errorf("cannot unmarshal message: %w", err) + return } - + fmt.Println("Raw JSON: ", data) if data["type"] != "message" { continue } if data["text"] == "SWUPDATE successful" { - s.done <- nil + done <- nil return } if data["text"] == "Installation failed" { - s.done <- errors.New("installation failed") + done <- errors.New("installation failed") return } } } -func (s *SWUpdater) Update(image io.Reader, timeout time.Duration) error { - go s.waitForFinished() - go func() { - err := s.upload(image, timeout) - if err != nil { - s.done <- err - } - }() +func (s *SWUpdater) Update(filename string, timeout time.Duration) error { + done := make(chan error) + go s.waitForFinished(done) + err := s.upload(filename, timeout) + if err != nil { + return fmt.Errorf("cannot upload software image: %w", err) + } select { - case err := <-s.done: + case err := <-done: if err != nil { - return err + return fmt.Errorf("update failed: %w", err) } return nil case <-time.After(timeout): From ab78f774cf3b7b1e67d217daf15ff78f58a8115c Mon Sep 17 00:00:00 2001 From: Christian Ege Date: Tue, 11 Jun 2024 09:14:36 +0200 Subject: [PATCH 3/8] fix: only send the basefile name --- pkg/swupdater/swupdater.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/swupdater/swupdater.go b/pkg/swupdater/swupdater.go index 10ab820..0c854c9 100644 --- a/pkg/swupdater/swupdater.go +++ b/pkg/swupdater/swupdater.go @@ -9,6 +9,7 @@ import ( "mime/multipart" "net/http" "os" + "path/filepath" "strconv" "time" @@ -40,7 +41,7 @@ func (s *SWUpdater) upload(filename string, timeout time.Duration) error { body := &bytes.Buffer{} writer := multipart.NewWriter(body) - part, err := writer.CreateFormFile("file", filename) + part, err := writer.CreateFormFile("file", filepath.Base(filename)) if err != nil { return fmt.Errorf("cannot create form file: %w", err) } From eda352bacbd3b2cb26d8b29d5e670ad5b0d2811f Mon Sep 17 00:00:00 2001 From: Christian Ege Date: Thu, 13 Jun 2024 14:22:21 +0200 Subject: [PATCH 4/8] fix: the multipart file upload It looks like that the current SWUpdate in combination with the mongoos Webserver, does not support chunked uploads --- cmd/ovp8xx/cmd/swupdate.go | 8 +++- go.mod | 1 + go.sum | 2 + pkg/swupdater/swupdater.go | 98 +++++++++++++++++++++----------------- 4 files changed, 64 insertions(+), 45 deletions(-) diff --git a/cmd/ovp8xx/cmd/swupdate.go b/cmd/ovp8xx/cmd/swupdate.go index 04fbac8..80a50ce 100644 --- a/cmd/ovp8xx/cmd/swupdate.go +++ b/cmd/ovp8xx/cmd/swupdate.go @@ -5,6 +5,7 @@ package cmd import ( "fmt" + "path/filepath" "time" "github.com/graugans/go-ovp8xx/pkg/swupdater" @@ -33,7 +34,12 @@ func swupdateCommand(cmd *cobra.Command, args []string) error { return fmt.Errorf("cannot get timeout: %w", err) } - fmt.Printf("Updating firmware on %s:%d with file %s (%v)\n", host, port, filename, timeout) + fmt.Printf("Updating firmware on %s:%d with file %s (%v)\n", + host, + port, + filepath.Base(filename), + timeout, + ) swu := swupdater.NewSWUpdater(host, port) diff --git a/go.mod b/go.mod index 5164633..2e7859e 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/technoweenie/multipartstreamer v1.0.1 // indirect golang.org/x/net v0.17.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index bad1fcc..491a028 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,8 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/technoweenie/multipartstreamer v1.0.1 h1:XRztA5MXiR1TIRHxH2uNxXxaIkKQDeX7m2XsSOlQEnM= +github.com/technoweenie/multipartstreamer v1.0.1/go.mod h1:jNVxdtShOxzAsukZwTSw6MDx5eUJoiEBsSvzDU9uzog= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/pkg/swupdater/swupdater.go b/pkg/swupdater/swupdater.go index 0c854c9..696ebd1 100644 --- a/pkg/swupdater/swupdater.go +++ b/pkg/swupdater/swupdater.go @@ -1,28 +1,27 @@ package swupdater import ( - "bytes" "encoding/json" "errors" "fmt" - "io" - "mime/multipart" "net/http" "os" - "path/filepath" - "strconv" + "strings" "time" "github.com/gorilla/websocket" + "github.com/technoweenie/multipartstreamer" ) +// SWUpdater represents a software updater. type SWUpdater struct { - hostName string - port uint16 - urlUpload string - urlStatus string + hostName string // The hostname of the updater. + port uint16 // The port number of the updater. + urlUpload string // The URL for uploading software updates. + urlStatus string // The URL for checking the status of software updates. } +// NewSWUpdater creates a new instance of SWUpdater with the specified host name and port. func NewSWUpdater(hostName string, port uint16) *SWUpdater { return &SWUpdater{ hostName: hostName, @@ -31,53 +30,61 @@ func NewSWUpdater(hostName string, port uint16) *SWUpdater { urlStatus: fmt.Sprintf("ws://%s:%d/ws", hostName, port), } } -func (s *SWUpdater) upload(filename string, timeout time.Duration) error { - image, err := os.Open(filename) - if err != nil { - return fmt.Errorf("cannot open file: %w", err) - } - fmt.Printf("Uploading software image to %s\n", s.urlUpload) - body := &bytes.Buffer{} - writer := multipart.NewWriter(body) - - part, err := writer.CreateFormFile("file", filepath.Base(filename)) - if err != nil { - return fmt.Errorf("cannot create form file: %w", err) - } +// Upload performs the upload of the specified file. +// The filename parameter specifies the name of the file to be uploaded. +// Returns an error if the upload fails. +func (s *SWUpdater) upload(filename string) error { + fmt.Printf("Uploading software image to %s\n", s.urlUpload) + const fieldname string = "file" - _, err = io.Copy(part, image) + file, err := os.Open(filename) if err != nil { - return fmt.Errorf("cannot write to form file: %w", err) + return fmt.Errorf("cannot open file: %w", err) } + defer file.Close() - err = writer.Close() + fileInfo, err := file.Stat() if err != nil { - return fmt.Errorf("cannot close multipart writer: %w", err) + return fmt.Errorf("cannot get file info: %w", err) } - req, err := http.NewRequest("POST", s.urlUpload, bytes.NewReader(body.Bytes())) - if err != nil { - return fmt.Errorf("cannot create request: %w", err) - } + ms := multipartstreamer.New() + ms.WriteReader(fieldname, filename, fileInfo.Size(), file) - req.Header.Set("Content-Type", writer.FormDataContentType()) - req.Header.Set("Content-Length", strconv.Itoa(body.Len())) + req, _ := http.NewRequest("POST", s.urlUpload, nil) + ms.SetupRequest(req) - client := &http.Client{Timeout: timeout} - resp, err := client.Do(req) + resp, err := http.DefaultClient.Do(req) if err != nil { - return fmt.Errorf("cannot upload software image: %w", err) + return fmt.Errorf("cannot send request: %w", err) } defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("cannot upload software image: status code %d", resp.StatusCode) - } - - return nil + return err } +// waitForFinished waits for the SWUpdater process to finish by listening to a WebSocket connection. +// It continuously reads messages from the WebSocket and checks for specific conditions to determine +// if the SWUpdater process has completed successfully or has failed. +// +// Parameters: +// - done: A channel used to signal the completion of the SWUpdater process. If the process finishes +// successfully, nil is sent to the channel. If the process fails, an error is sent to the channel. +// +// Returns: +// +// None +// +// Example usage: +// +// done := make(chan error) +// go s.waitForFinished(done) +// err := <-done +// if err != nil { +// // Handle error +// } else { +// // SWUpdater process completed successfully +// } func (s *SWUpdater) waitForFinished(done chan error) { c, _, err := websocket.DefaultDialer.Dial(s.urlStatus, nil) if err != nil { @@ -104,21 +111,24 @@ func (s *SWUpdater) waitForFinished(done chan error) { continue } - if data["text"] == "SWUPDATE successful" { + if strings.Contains(data["text"], "SWUPDATE successful") { done <- nil return } - if data["text"] == "Installation failed" { + if strings.Contains(data["text"], "Installation failed") { done <- errors.New("installation failed") return } } } +// Update uploads a software image and waits for the update process to finish. +// It takes a filename string and a timeout duration as parameters. +// It returns an error if the upload fails, or if the operation times out. func (s *SWUpdater) Update(filename string, timeout time.Duration) error { done := make(chan error) go s.waitForFinished(done) - err := s.upload(filename, timeout) + err := s.upload(filename) if err != nil { return fmt.Errorf("cannot upload software image: %w", err) } From 3f3b1c4e6b8d2af1fce552b3aa76d81d2d1d6678 Mon Sep 17 00:00:00 2001 From: Christian Ege Date: Fri, 14 Jun 2024 08:39:44 +0200 Subject: [PATCH 5/8] feat: Add connection timeout flag to swupdate command --- cmd/ovp8xx/cmd/swupdate.go | 31 ++++++++++++-- pkg/swupdater/swupdater.go | 86 ++++++++++++++++++++++++++++---------- 2 files changed, 91 insertions(+), 26 deletions(-) diff --git a/cmd/ovp8xx/cmd/swupdate.go b/cmd/ovp8xx/cmd/swupdate.go index 80a50ce..43abb0b 100644 --- a/cmd/ovp8xx/cmd/swupdate.go +++ b/cmd/ovp8xx/cmd/swupdate.go @@ -34,6 +34,11 @@ func swupdateCommand(cmd *cobra.Command, args []string) error { return fmt.Errorf("cannot get timeout: %w", err) } + connectionTimeout, err := cmd.Flags().GetDuration("online") + if err != nil { + return fmt.Errorf("cannot get timeout: %w", err) + } + fmt.Printf("Updating firmware on %s:%d with file %s (%v)\n", host, port, @@ -41,13 +46,30 @@ func swupdateCommand(cmd *cobra.Command, args []string) error { timeout, ) - swu := swupdater.NewSWUpdater(host, port) + // notifications is a channel used to receive SWUpdaterNotification events. + // It has a buffer size of 10 to allow for asynchronous processing. + notifications := make(chan swupdater.SWUpdaterNotification, 10) - err = swu.Update(filename, timeout) - if err != nil { + // Print the messages as they come + go func() { + for n := range notifications { + if value, ok := n["swupdater"]; ok { + fmt.Println(value) + } + if value, ok := n["text"]; ok && n["type"] == "message" { + fmt.Println(value) + } + } + }() + + // Create a new SWUpdater instance with the specified host, port, and notifications. + swu := swupdater.NewSWUpdater(host, port, notifications) + if err = swu.Update(filename, + connectionTimeout, + timeout, + ); err != nil { return fmt.Errorf("software update failed: %w", err) } - return nil } @@ -63,4 +85,5 @@ func init() { swupdateCmd.Flags().String("file", "", "A file conatining the firmware image") swupdateCmd.Flags().Uint16("port", 8080, "Port number for SWUpdate") swupdateCmd.Flags().Duration("timeout", 5*time.Minute, "The timeout for the upload") + swupdateCmd.Flags().Duration("online", 2*time.Minute, "The time to wait for the device to become available") } diff --git a/pkg/swupdater/swupdater.go b/pkg/swupdater/swupdater.go index 696ebd1..a0dafc1 100644 --- a/pkg/swupdater/swupdater.go +++ b/pkg/swupdater/swupdater.go @@ -15,19 +15,24 @@ import ( // SWUpdater represents a software updater. type SWUpdater struct { - hostName string // The hostname of the updater. - port uint16 // The port number of the updater. - urlUpload string // The URL for uploading software updates. - urlStatus string // The URL for checking the status of software updates. + hostName string // The hostname of the updater. + port uint16 // The port number of the updater. + urlUpload string // The URL for uploading software updates. + urlStatus string // The URL for checking the status of software updates. + notifications chan SWUpdaterNotification // A channel for receiving notifications. + ws *websocket.Conn } +type SWUpdaterNotification map[string]string + // NewSWUpdater creates a new instance of SWUpdater with the specified host name and port. -func NewSWUpdater(hostName string, port uint16) *SWUpdater { +func NewSWUpdater(hostName string, port uint16, notifications chan SWUpdaterNotification) *SWUpdater { return &SWUpdater{ - hostName: hostName, - port: port, - urlUpload: fmt.Sprintf("http://%s:%d/upload", hostName, port), - urlStatus: fmt.Sprintf("ws://%s:%d/ws", hostName, port), + hostName: hostName, + port: port, + urlUpload: fmt.Sprintf("http://%s:%d/upload", hostName, port), + urlStatus: fmt.Sprintf("ws://%s:%d/ws", hostName, port), + notifications: notifications, } } @@ -35,7 +40,7 @@ func NewSWUpdater(hostName string, port uint16) *SWUpdater { // The filename parameter specifies the name of the file to be uploaded. // Returns an error if the upload fails. func (s *SWUpdater) upload(filename string) error { - fmt.Printf("Uploading software image to %s\n", s.urlUpload) + s.statusUpdate(fmt.Sprintf("Uploading software image to %s\n", s.urlUpload)) const fieldname string = "file" file, err := os.Open(filename) @@ -86,31 +91,27 @@ func (s *SWUpdater) upload(filename string) error { // // SWUpdater process completed successfully // } func (s *SWUpdater) waitForFinished(done chan error) { - c, _, err := websocket.DefaultDialer.Dial(s.urlStatus, nil) - if err != nil { - done <- fmt.Errorf("cannot connect to websocket: %w", err) - return - } - defer c.Close() for { - _, message, err := c.ReadMessage() + _, message, err := s.ws.ReadMessage() if err != nil { done <- fmt.Errorf("cannot read message from websocket: %w", err) return } - data := make(map[string]string) + data := make(SWUpdaterNotification) err = json.Unmarshal(message, &data) if err != nil { done <- fmt.Errorf("cannot unmarshal message: %w", err) return } - fmt.Println("Raw JSON: ", data) + // Send notification to channel + if s.notifications != nil { + s.notifications <- data + } if data["type"] != "message" { continue } - if strings.Contains(data["text"], "SWUPDATE successful") { done <- nil return @@ -122,11 +123,52 @@ func (s *SWUpdater) waitForFinished(done chan error) { } } +func (s *SWUpdater) connect() error { + var err error + s.ws, _, err = websocket.DefaultDialer.Dial(s.urlStatus, nil) + if err != nil { + return fmt.Errorf("unable to connect to the status socket: %w", err) + } + return err +} + +func (s *SWUpdater) disconnect() { + s.ws.Close() +} + +// statusUpdate updates the status of the SWUpdater. +// It sends a notification to the channel with the provided status. +func (s *SWUpdater) statusUpdate(status string) { + notification := make(SWUpdaterNotification) + notification["swupdater"] = status + // Send notification to channel + if s.notifications != nil { + s.notifications <- notification + } +} + // Update uploads a software image and waits for the update process to finish. // It takes a filename string and a timeout duration as parameters. // It returns an error if the upload fails, or if the operation times out. -func (s *SWUpdater) Update(filename string, timeout time.Duration) error { +func (s *SWUpdater) Update(filename string, connectionTimeout, timeout time.Duration) error { done := make(chan error) + start := time.Now() + s.statusUpdate("Waiting for the Device to become ready...") + // Retry connection until successful or connectionTimeout occurs + for { + err := s.connect() + if err == nil { + s.statusUpdate("Device is ready now") + break + } + if time.Since(start) > connectionTimeout { + return fmt.Errorf("connection timeout: %w", err) + } + time.Sleep(3 * time.Second) // wait for a second before retrying + } + defer s.disconnect() + + s.statusUpdate("Starting the Software Update process...") go s.waitForFinished(done) err := s.upload(filename) if err != nil { @@ -140,6 +182,6 @@ func (s *SWUpdater) Update(filename string, timeout time.Duration) error { } return nil case <-time.After(timeout): - return errors.New("timeout") + return errors.New("a timeout occurred while waiting for the update to finish") } } From 086ee1dadc338060436d00502dbaf84ef245302b Mon Sep 17 00:00:00 2001 From: Christian Ege Date: Fri, 14 Jun 2024 08:55:39 +0200 Subject: [PATCH 6/8] ci: fix a go-cilinter finding about not handled error --- pkg/swupdater/swupdater.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pkg/swupdater/swupdater.go b/pkg/swupdater/swupdater.go index a0dafc1..0e228a2 100644 --- a/pkg/swupdater/swupdater.go +++ b/pkg/swupdater/swupdater.go @@ -55,7 +55,10 @@ func (s *SWUpdater) upload(filename string) error { } ms := multipartstreamer.New() - ms.WriteReader(fieldname, filename, fileInfo.Size(), file) + err = ms.WriteReader(fieldname, filename, fileInfo.Size(), file) + if err != nil { + return fmt.Errorf("cannot write reader: %w", err) + } req, _ := http.NewRequest("POST", s.urlUpload, nil) ms.SetupRequest(req) From bc34ba604a74ed5efbe14caf0b8e7a2e9c6e5e87 Mon Sep 17 00:00:00 2001 From: Christian Ege Date: Fri, 14 Jun 2024 10:04:09 +0200 Subject: [PATCH 7/8] fix: the handling of the notifications --- cmd/ovp8xx/cmd/swupdate.go | 25 +++++++++++++++++++------ pkg/swupdater/swupdater.go | 1 + 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/cmd/ovp8xx/cmd/swupdate.go b/cmd/ovp8xx/cmd/swupdate.go index 43abb0b..95cba6a 100644 --- a/cmd/ovp8xx/cmd/swupdate.go +++ b/cmd/ovp8xx/cmd/swupdate.go @@ -6,6 +6,7 @@ package cmd import ( "fmt" "path/filepath" + "sync" "time" "github.com/graugans/go-ovp8xx/pkg/swupdater" @@ -24,10 +25,11 @@ func swupdateCommand(cmd *cobra.Command, args []string) error { return fmt.Errorf("cannot get port: %w", err) } - filename, err := cmd.Flags().GetString("file") - if err != nil { - return fmt.Errorf("cannot get filename: %w", err) + // Check if filename is provided as a positional argument + if len(args) < 1 { + return fmt.Errorf("no filename provided") } + filename := args[0] timeout, err := cmd.Flags().GetDuration("timeout") if err != nil { @@ -50,6 +52,9 @@ func swupdateCommand(cmd *cobra.Command, args []string) error { // It has a buffer size of 10 to allow for asynchronous processing. notifications := make(chan swupdater.SWUpdaterNotification, 10) + var wg sync.WaitGroup + wg.Add(1) + // Print the messages as they come go func() { for n := range notifications { @@ -59,7 +64,9 @@ func swupdateCommand(cmd *cobra.Command, args []string) error { if value, ok := n["text"]; ok && n["type"] == "message" { fmt.Println(value) } + } + wg.Done() // Decrease counter when goroutine completes }() // Create a new SWUpdater instance with the specified host, port, and notifications. @@ -70,19 +77,25 @@ func swupdateCommand(cmd *cobra.Command, args []string) error { ); err != nil { return fmt.Errorf("software update failed: %w", err) } + + wg.Wait() // Wait for all goroutines to finish return nil } // swupdateCmd represents the swupdate command var swupdateCmd = &cobra.Command{ - Use: "swupdate", + Use: "swupdate [filename]", Short: "Update the firmware on the device", - RunE: swupdateCommand, + Long: `The swupdate command is used to update the firmware on the device. + +It takes a filename as a positional argument, which is the path to the firmware file to be uploaded. + +The command establishes a connection to the device, uploads the firmware file, and waits for the update process to complete.`, + RunE: swupdateCommand, } func init() { rootCmd.AddCommand(swupdateCmd) - swupdateCmd.Flags().String("file", "", "A file conatining the firmware image") swupdateCmd.Flags().Uint16("port", 8080, "Port number for SWUpdate") swupdateCmd.Flags().Duration("timeout", 5*time.Minute, "The timeout for the upload") swupdateCmd.Flags().Duration("online", 2*time.Minute, "The time to wait for the device to become available") diff --git a/pkg/swupdater/swupdater.go b/pkg/swupdater/swupdater.go index 0e228a2..6538533 100644 --- a/pkg/swupdater/swupdater.go +++ b/pkg/swupdater/swupdater.go @@ -180,6 +180,7 @@ func (s *SWUpdater) Update(filename string, connectionTimeout, timeout time.Dura select { case err := <-done: + close(s.notifications) // Close the channel to signal the end of notifications if err != nil { return fmt.Errorf("update failed: %w", err) } From d45a2f2cbe6238dad07e830bd69cc0af722995f3 Mon Sep 17 00:00:00 2001 From: Christian Ege Date: Fri, 14 Jun 2024 10:07:30 +0200 Subject: [PATCH 8/8] build: fix the import module for swupdater --- cmd/ovp8xx/cmd/swupdate.go | 2 +- go.mod | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/ovp8xx/cmd/swupdate.go b/cmd/ovp8xx/cmd/swupdate.go index 95cba6a..0f1548f 100644 --- a/cmd/ovp8xx/cmd/swupdate.go +++ b/cmd/ovp8xx/cmd/swupdate.go @@ -9,7 +9,7 @@ import ( "sync" "time" - "github.com/graugans/go-ovp8xx/pkg/swupdater" + "github.com/graugans/go-ovp8xx/v2/pkg/swupdater" "github.com/spf13/cobra" ) diff --git a/go.mod b/go.mod index 444eb3b..55f8953 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/gorilla/websocket v1.5.1 github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.8.4 + github.com/technoweenie/multipartstreamer v1.0.1 ) require ( @@ -14,7 +15,6 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/technoweenie/multipartstreamer v1.0.1 // indirect golang.org/x/net v0.17.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect )