Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature: skip image download if it exists
Browse files Browse the repository at this point in the history
Signed-off-by: Suraj Shirvankar <[email protected]>
h0lyalg0rithm committed Aug 5, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 4f7a99e commit 2257429
Showing 11 changed files with 245 additions and 96 deletions.
3 changes: 2 additions & 1 deletion sztp-agent/cmd/daemon.go
Original file line number Diff line number Diff line change
@@ -59,7 +59,8 @@ func Daemon() *cobra.Command {
return fmt.Errorf("must not be folder: %q", filePath)
}
}
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandDaemon()
},
}
3 changes: 2 additions & 1 deletion sztp-agent/cmd/disable.go
Original file line number Diff line number Diff line change
@@ -34,7 +34,8 @@ func Disable() *cobra.Command {
Use: "disable",
Short: "Run the disable command",
RunE: func(_ *cobra.Command, _ []string) error {
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandDisable()
},
}
3 changes: 2 additions & 1 deletion sztp-agent/cmd/enable.go
Original file line number Diff line number Diff line change
@@ -34,7 +34,8 @@ func Enable() *cobra.Command {
Use: "enable",
Short: "Run the enable command",
RunE: func(_ *cobra.Command, _ []string) error {
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandEnable()
},
}
3 changes: 2 additions & 1 deletion sztp-agent/cmd/run.go
Original file line number Diff line number Diff line change
@@ -59,7 +59,8 @@ func Run() *cobra.Command {
return fmt.Errorf("must not be folder: %q", filePath)
}
}
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommand()
},
}
3 changes: 2 additions & 1 deletion sztp-agent/cmd/status.go
Original file line number Diff line number Diff line change
@@ -34,7 +34,8 @@ func Status() *cobra.Command {
Use: "status",
Short: "Run the status command",
RunE: func(_ *cobra.Command, _ []string) error {
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandStatus()
},
}
39 changes: 37 additions & 2 deletions sztp-agent/pkg/secureagent/agent.go
Original file line number Diff line number Diff line change
@@ -9,6 +9,13 @@ Copyright (C) 2022 Red Hat.
// Package secureagent implements the secure agent
package secureagent

import (
"crypto/tls"
"crypto/x509"
"net/http"
"os"
)

const (
CONTENT_TYPE_YANG = "application/yang-data+json"
OS_RELEASE_FILE = "/etc/os-release"
@@ -68,6 +75,11 @@ type BootstrapServerErrorOutput struct {
} `json:"ietf-restconf:errors"`
}

type HttpGetter interface {
Get(uri string) (*http.Response, error)
Do(req *http.Request) (*http.Response, error)
}

// Agent is the basic structure to define an agent instance
type Agent struct {
InputBootstrapURL string // Bootstrap complete URL given by USER
@@ -83,10 +95,10 @@ type Agent struct {
ProgressJSON ProgressJSON // ProgressJson structure
BootstrapServerOnboardingInfo BootstrapServerOnboardingInfo // BootstrapServerOnboardingInfo structure
BootstrapServerRedirectInfo BootstrapServerRedirectInfo // BootstrapServerRedirectInfo structure

HttpClient HttpGetter
}

func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string) *Agent {
func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string, httpClient HttpGetter) *Agent {
return &Agent{
InputBootstrapURL: bootstrapURL,
BootstrapURL: "",
@@ -101,6 +113,7 @@ func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, deviceP
ProgressJSON: ProgressJSON{},
BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{},
BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{},
HttpClient: httpClient,
}
}

@@ -171,3 +184,25 @@ func (a *Agent) SetContentTypeReq(ct string) {
func (a *Agent) SetProgressJSON(p ProgressJSON) {
a.ProgressJSON = p
}

func NewHttpClient(bootstrapTrustAnchorCert string, deviceEndEntityCert string, devicePrivateKey string) http.Client {
caCert, _ := os.ReadFile(bootstrapTrustAnchorCert)
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(deviceEndEntityCert, devicePrivateKey)
client := http.Client{
CheckRedirect: func(r *http.Request, _ []*http.Request) error {
r.URL.Opaque = r.URL.Path
return nil
},
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true, // TODO: remove skip verify

Check failure

Code scanning / CodeQL

Disabled TLS certificate check High

InsecureSkipVerify should not be used in production code.
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
}
return client
}
5 changes: 4 additions & 1 deletion sztp-agent/pkg/secureagent/agent_test.go
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ Copyright (C) 2022 Red Hat.
package secureagent

import (
"net/http"
"reflect"
"testing"
)
@@ -829,6 +830,7 @@ func TestNewAgent(t *testing.T) {
deviceEndEntityCert string
bootstrapTrustAnchorCert string
}
client := http.Client{}
tests := []struct {
name string
args args
@@ -856,12 +858,13 @@ func TestNewAgent(t *testing.T) {
ContentTypeReq: "application/yang-data+json",
InputJSONContent: generateInputJSONContent(),
DhcpLeaseFile: "TestDhcpLeaseFile",
HttpClient: &client,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert); !reflect.DeepEqual(got, tt.want) {
if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert, &client); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewAgent() = %v, want %v", got, tt.want)
}
})
157 changes: 87 additions & 70 deletions sztp-agent/pkg/secureagent/daemon.go
Original file line number Diff line number Diff line change
@@ -10,16 +10,13 @@ package secureagent

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"os/exec"
@@ -181,88 +178,108 @@ func (a *Agent) doRequestBootstrapServerOnboardingInfo() error {
return errri
}

//nolint:funlen
func (a *Agent) downloadAndValidateImage() error {
log.Printf("[INFO] Starting the Download Image: %v", a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI)
_ = a.doReportProgress(ProgressTypeBootImageInitiated, "BootImage Initiated")
// Download the image from DownloadURI and save it to a file
a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference = fmt.Sprintf("%8d", time.Now().Unix())
for i, item := range a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI {
// TODO: maybe need to file download to a function in util.go
log.Printf("[INFO] Downloading Image %v", item)
// Create a empty file
file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + filepath.Base(item))
if err != nil {
return err
}
func (a *Agent) downloadArtifact(uri string) (*os.File, error) {
file, err := os.CreateTemp("", filepath.Base(uri))
if err != nil {
return nil, err
}

caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert())
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey())
response, err := a.HttpClient.Get(uri)
if err != nil {
return nil, err
}
sizeorigin, _ := strconv.Atoi(response.Header.Get("Content-Length"))
downloadSize := int64(sizeorigin)
log.Printf("[INFO] Downloading the image with size: %v", downloadSize)

check := http.Client{
CheckRedirect: func(r *http.Request, _ []*http.Request) error {
r.URL.Opaque = r.URL.Path
return nil
},
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true, // TODO: remove skip verify
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
if response.StatusCode != 200 {
return nil, errors.New("received non 200 response code")
}
size, err := io.Copy(file, response.Body)
if err != nil {
return nil, err
}
defer func() {
if err := file.Close(); err != nil {
log.Println("[ERROR] Error when closing:", err)
}
}()
defer func() {
if err := response.Body.Close(); err != nil {
log.Println("[ERROR] Error when closing:", err)
}
}()
log.Printf("[INFO] Downloaded file: %s with size: %d", file.Name(), size)
return file, nil
}

func (a *Agent) validateImage(filePath string, algorithm string, expected string) error {
switch algorithm {
case "ietf-sztp-conveyed-info:sha-256":
checksum, err := calculateSHA256File(filePath)

response, err := check.Get(item)
if err != nil {
return err
log.Println("[ERROR] Could not calculate checksum", err)
}
log.Println("calculated: " + checksum)
log.Println("expected : " + expected)

sizeorigin, _ := strconv.Atoi(response.Header.Get("Content-Length"))
downloadSize := int64(sizeorigin)
log.Printf("[INFO] Downloading the image with size: %v", downloadSize)

if response.StatusCode != 200 {
return errors.New("received non 200 response code")
if checksum != expected {
return errors.New("checksum mismatch")
}
size, err := io.Copy(file, response.Body)
if err != nil {
return err
log.Println("[INFO] Checksum verified successfully")
return nil
default:
return errors.New("unsupported hash algorithm")
}
}

func (a *Agent) artifactExists(item string, algorithm string, expected string) bool {
filePath := ARTIFACTS_PATH + filepath.Base(item)
_, err := os.Stat(filePath)
if err != nil {
return false
}
err = a.validateImage(filePath, algorithm, expected)
return err == nil
}

//nolint:funlen
func (a *Agent) downloadAndValidateImage() error {
bootImage := a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage
log.Printf("[INFO] Starting the Download Image: %v", bootImage.DownloadURI)
_ = a.doReportProgress(ProgressTypeBootImageInitiated, "BootImage Initiated")
// Download the image from DownloadURI and save it to a file
for i, item := range bootImage.DownloadURI {

Check failure on line 253 in sztp-agent/pkg/secureagent/daemon.go

GitHub Actions / golangci

unnecessary leading newline (whitespace)

if len(bootImage.ImageVerification) <= i {
return errors.New("Invalid verification")
}
defer func() {
if err := file.Close(); err != nil {
log.Println("[ERROR] Error when closing:", err)
}
}()
defer func() {
if err := response.Body.Close(); err != nil {
log.Println("[ERROR] Error when closing:", err)
}
}()

log.Printf("[INFO] Downloaded file: %s with size: %d", ARTIFACTS_PATH+a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference+filepath.Base(item), size)
log.Println("[INFO] Verify the file checksum: ", ARTIFACTS_PATH+a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference+filepath.Base(item))
switch a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.ImageVerification[i].HashAlgorithm {
case "ietf-sztp-conveyed-info:sha-256":
filePath := ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + filepath.Base(item)
checksum, err := calculateSHA256File(filePath)
original := strings.ReplaceAll(a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.ImageVerification[i].HashValue, ":", "")
imageVerification := bootImage.ImageVerification[i]
expected := strings.ReplaceAll(imageVerification.HashValue, ":", "")
algorithm := imageVerification.HashAlgorithm

if a.artifactExists(item, algorithm, expected) {
log.Printf("[INFO] Image %v already exists", item)
} else {
log.Printf("[INFO] Downloading Image %v", item)
file, err := a.downloadArtifact(item)
if err != nil {
log.Println("[ERROR] Could not calculate checksum", err)
return err
}
log.Println("calculated: " + checksum)
log.Println("expected : " + original)
if checksum != original {
return errors.New("checksum mismatch")
log.Println("[INFO] Verify the file checksum: ", file.Name())
err = a.validateImage(file.Name(), algorithm, expected)
if err != nil {
return err
}
log.Println("[INFO] Checksum verified successfully")

log.Printf("[INFO] Moving file %s to %s", file.Name(), ARTIFACTS_PATH+filepath.Base(item))
os.Rename(file.Name(), ARTIFACTS_PATH+filepath.Base(item))

Check failure on line 278 in sztp-agent/pkg/secureagent/daemon.go

GitHub Actions / golangci

Error return value of `os.Rename` is not checked (errcheck)
_ = a.doReportProgress(ProgressTypeBootImageComplete, "BootImage Complete")

return nil
default:
return errors.New("unsupported hash algorithm")

}

Check failure on line 283 in sztp-agent/pkg/secureagent/daemon.go

GitHub Actions / golangci

unnecessary trailing newline (whitespace)
}
return nil
106 changes: 105 additions & 1 deletion sztp-agent/pkg/secureagent/daemon_test.go
Original file line number Diff line number Diff line change
@@ -7,10 +7,12 @@
import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
)

@@ -341,6 +343,7 @@
ContentTypeReq: tt.fields.ContentTypeReq,
InputJSONContent: tt.fields.InputJSONContent,
DhcpLeaseFile: tt.fields.DhcpLeaseFile,
HttpClient: &http.Client{},
}
if err := a.doRequestBootstrapServerOnboardingInfo(); (err != nil) != tt.wantErr {
t.Errorf("doRequestBootstrapServer() error = %v, wantErr %v", err, tt.wantErr)
@@ -349,6 +352,19 @@
}
}

type MockClient struct {
GetFunc func(uri string) (*http.Response, error)
DoFunc func(req *http.Request) (*http.Response, error)
}

func (m *MockClient) Do(req *http.Request) (*http.Response, error) {
return m.DoFunc(req)
}

func (m *MockClient) Get(uri string) (*http.Response, error) {
return m.GetFunc(uri)
}

//nolint:funlen
func TestAgent_downloadAndValidateImage(t *testing.T) {
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -359,7 +375,6 @@
}
}))
defer svr.Close()

type fields struct {
BootstrapURL string
SerialNumber string
@@ -638,6 +653,7 @@
},
}
for _, tt := range tests {
os.Remove(ARTIFACTS_PATH + "/imageOK")

Check failure on line 656 in sztp-agent/pkg/secureagent/daemon_test.go

GitHub Actions / golangci

Error return value of `os.Remove` is not checked (errcheck)
t.Run(tt.name, func(t *testing.T) {
a := &Agent{
BootstrapURL: tt.fields.BootstrapURL,
@@ -652,12 +668,98 @@
ProgressJSON: tt.fields.ProgressJSON,
BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo,
BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo,
HttpClient: svr.Client(),
}
if err := a.downloadAndValidateImage(); (err != nil) != tt.wantErr {
t.Errorf("downloadAndValidateImage() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
calls := 0
httpClient := MockClient{
GetFunc: func(uri string) (*http.Response, error) {

Check failure on line 680 in sztp-agent/pkg/secureagent/daemon_test.go

GitHub Actions / golangci

unused-parameter: parameter 'uri' seems to be unused, consider removing or renaming it as _ (revive)
calls = calls + 1

Check failure on line 681 in sztp-agent/pkg/secureagent/daemon_test.go

GitHub Actions / golangci

assignOp: replace `calls = calls + 1` with `calls++` (gocritic)
return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil
},
DoFunc: func(req *http.Request) (*http.Response, error) {

Check failure on line 684 in sztp-agent/pkg/secureagent/daemon_test.go

GitHub Actions / golangci

unused-parameter: parameter 'req' seems to be unused, consider removing or renaming it as _ (revive)
return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil
},
}
a := &Agent{
BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{
IetfSztpConveyedInfoOnboardingInformation: struct {
InfoTimestampReference string
BootImage struct {
DownloadURI []string `json:"download-uri"`
ImageVerification []struct {
HashAlgorithm string `json:"hash-algorithm"`
HashValue string `json:"hash-value"`
} `json:"image-verification"`
} `json:"boot-image"`
PreConfigurationScript string `json:"pre-configuration-script"`
ConfigurationHandling string `json:"configuration-handling"`
Configuration string `json:"configuration"`
PostConfigurationScript string `json:"post-configuration-script"`
}{
InfoTimestampReference: "TIMESTAMP",
BootImage: struct {
DownloadURI []string `json:"download-uri"`
ImageVerification []struct {
HashAlgorithm string `json:"hash-algorithm"`
HashValue string `json:"hash-value"`
} `json:"image-verification"`
}{
DownloadURI: []string{svr.URL + "/imageOK"},
ImageVerification: []struct {
HashAlgorithm string `json:"hash-algorithm"`
HashValue string `json:"hash-value"`
}{{
HashAlgorithm: "ietf-sztp-conveyed-info:sha-256",
HashValue: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
}},
},
PreConfigurationScript: "",
ConfigurationHandling: "",
Configuration: "",
PostConfigurationScript: "",
},
},
BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{},
HttpClient: &httpClient,
}
t.Run("OK case with cached file", func(t *testing.T) {
calls = 0
os.Remove(ARTIFACTS_PATH + "/imageOK")

Check failure on line 732 in sztp-agent/pkg/secureagent/daemon_test.go

GitHub Actions / golangci

Error return value of `os.Remove` is not checked (errcheck)
// Initiate cache download
err := a.downloadAndValidateImage()
if err != nil {
t.Errorf("downloadAndValidateImage() error = %v", err)
}
err = a.downloadAndValidateImage()
if err != nil {
t.Errorf("downloadAndValidateImage() error = %v", err)
}
if calls != 1 {
t.Errorf("downloadAndValidateImage() called httpclient more than 1 times, Called %d", calls)
}
})
t.Run("OK case with cached file with different signature", func(t *testing.T) {
calls = 0
os.Remove(ARTIFACTS_PATH + "/imageOK")

Check failure on line 748 in sztp-agent/pkg/secureagent/daemon_test.go

GitHub Actions / golangci

Error return value of `os.Remove` is not checked (errcheck)
// Initiate cache download
err := a.downloadAndValidateImage()
if err != nil {
t.Errorf("downloadAndValidateImage() error = %v", err)
}
_ = os.WriteFile(ARTIFACTS_PATH+"/imageOK", []byte("test"), 0666)

Check failure on line 754 in sztp-agent/pkg/secureagent/daemon_test.go

GitHub Actions / golangci

G306: Expect WriteFile permissions to be 0600 or less (gosec)
err = a.downloadAndValidateImage()
if err != nil {
t.Errorf("downloadAndValidateImage() error = %v", err)
}
if calls != 2 {
t.Errorf("downloadAndValidateImage() should call httpclient two time, Called %d", calls)
}
})
}

// nolint:funlen
@@ -807,6 +909,7 @@
ProgressJSON: tt.fields.ProgressJSON,
BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo,
BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo,
HttpClient: &http.Client{},
}
if err := a.copyConfigurationFile(); (err != nil) != tt.wantErr {
t.Errorf("copyConfigurationFile() error = %v, wantErr %v", err, tt.wantErr)
@@ -1024,6 +1127,7 @@
ProgressJSON: tt.fields.ProgressJSON,
BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo,
BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo,
HttpClient: &http.Client{},
}
if err := a.launchScriptsConfiguration(tt.args.typeOf); (err != nil) != tt.wantErr {
t.Errorf("launchScriptsConfiguration() error = %v, wantErr %v", err, tt.wantErr)
1 change: 1 addition & 0 deletions sztp-agent/pkg/secureagent/progress_test.go
Original file line number Diff line number Diff line change
@@ -158,6 +158,7 @@ func TestAgent_doReportProgress(t *testing.T) {
InputJSONContent: tt.fields.InputJSONContent,
DhcpLeaseFile: tt.fields.DhcpLeaseFile,
ProgressJSON: tt.fields.ProgressJSON,
HttpClient: &http.Client{},
}
if err := a.doReportProgress(ProgressTypeBootstrapInitiated, "Bootstrap Initiated"); (err != nil) != tt.wantErr {
t.Errorf("doReportProgress() error = %v, wantErr %v", err, tt.wantErr)
18 changes: 1 addition & 17 deletions sztp-agent/pkg/secureagent/tls.go
Original file line number Diff line number Diff line change
@@ -10,14 +10,11 @@ package secureagent

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"io"
"log"
"net/http"
"os"
"strconv"
"strings"
)
@@ -38,20 +35,7 @@ func (a *Agent) doTLSRequest(input string, url string, empty bool) (*BootstrapSe
r.SetBasicAuth(a.GetSerialNumber(), a.GetDevicePassword())
r.Header.Add("Content-Type", a.GetContentTypeReq())

caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert())
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey())

client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{ //nolint:gosec
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
}
res, err := client.Do(r)
res, err := a.HttpClient.Do(r)
if err != nil {
log.Println("Error doing the request", err.Error())
return nil, err

0 comments on commit 2257429

Please sign in to comment.