Skip to content

Commit

Permalink
refactor: unify http client creation
Browse files Browse the repository at this point in the history
Signed-off-by: Suraj Shirvankar <[email protected]>
  • Loading branch information
h0lyalg0rithm committed Aug 5, 2024
1 parent 4f7a99e commit 15002b5
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 50 deletions.
3 changes: 2 additions & 1 deletion sztp-agent/cmd/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ func Daemon() *cobra.Command {
return fmt.Errorf("must not be folder: %q", filePath)
}
}
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandDaemon()
},
}
Expand Down
3 changes: 2 additions & 1 deletion sztp-agent/cmd/disable.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ func Disable() *cobra.Command {
Use: "disable",
Short: "Run the disable command",
RunE: func(_ *cobra.Command, _ []string) error {
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandDisable()
},
}
Expand Down
3 changes: 2 additions & 1 deletion sztp-agent/cmd/enable.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ func Enable() *cobra.Command {
Use: "enable",
Short: "Run the enable command",
RunE: func(_ *cobra.Command, _ []string) error {
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandEnable()
},
}
Expand Down
3 changes: 2 additions & 1 deletion sztp-agent/cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ func Run() *cobra.Command {
return fmt.Errorf("must not be folder: %q", filePath)
}
}
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommand()
},
}
Expand Down
3 changes: 2 additions & 1 deletion sztp-agent/cmd/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ func Status() *cobra.Command {
Use: "status",
Short: "Run the status command",
RunE: func(_ *cobra.Command, _ []string) error {
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandStatus()
},
}
Expand Down
39 changes: 37 additions & 2 deletions sztp-agent/pkg/secureagent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ Copyright (C) 2022 Red Hat.
// Package secureagent implements the secure agent
package secureagent

import (
"crypto/tls"
"crypto/x509"
"net/http"
"os"
)

const (
CONTENT_TYPE_YANG = "application/yang-data+json"
OS_RELEASE_FILE = "/etc/os-release"
Expand Down Expand Up @@ -68,6 +75,11 @@ type BootstrapServerErrorOutput struct {
} `json:"ietf-restconf:errors"`
}

type HttpClient interface {
Get(uri string) (*http.Response, error)
Do(req *http.Request) (*http.Response, error)
}

// Agent is the basic structure to define an agent instance
type Agent struct {
InputBootstrapURL string // Bootstrap complete URL given by USER
Expand All @@ -83,10 +95,10 @@ type Agent struct {
ProgressJSON ProgressJSON // ProgressJson structure
BootstrapServerOnboardingInfo BootstrapServerOnboardingInfo // BootstrapServerOnboardingInfo structure
BootstrapServerRedirectInfo BootstrapServerRedirectInfo // BootstrapServerRedirectInfo structure

HttpClient HttpClient
}

func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string) *Agent {
func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string, httpClient HttpClient) *Agent {
return &Agent{
InputBootstrapURL: bootstrapURL,
BootstrapURL: "",
Expand All @@ -101,6 +113,7 @@ func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, deviceP
ProgressJSON: ProgressJSON{},
BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{},
BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{},
HttpClient: httpClient,
}
}

Expand Down Expand Up @@ -171,3 +184,25 @@ func (a *Agent) SetContentTypeReq(ct string) {
func (a *Agent) SetProgressJSON(p ProgressJSON) {
a.ProgressJSON = p
}

func NewHttpClient(bootstrapTrustAnchorCert string, deviceEndEntityCert string, devicePrivateKey string) http.Client {
caCert, _ := os.ReadFile(bootstrapTrustAnchorCert)
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(deviceEndEntityCert, devicePrivateKey)
client := http.Client{
CheckRedirect: func(r *http.Request, _ []*http.Request) error {
r.URL.Opaque = r.URL.Path
return nil
},
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true, // TODO: remove skip verify

Check failure

Code scanning / CodeQL

Disabled TLS certificate check High

InsecureSkipVerify should not be used in production code.
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
}
return client
}
5 changes: 4 additions & 1 deletion sztp-agent/pkg/secureagent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Copyright (C) 2022 Red Hat.
package secureagent

import (
"net/http"
"reflect"
"testing"
)
Expand Down Expand Up @@ -829,6 +830,7 @@ func TestNewAgent(t *testing.T) {
deviceEndEntityCert string
bootstrapTrustAnchorCert string
}
client := http.Client{}
tests := []struct {
name string
args args
Expand Down Expand Up @@ -856,12 +858,13 @@ func TestNewAgent(t *testing.T) {
ContentTypeReq: "application/yang-data+json",
InputJSONContent: generateInputJSONContent(),
DhcpLeaseFile: "TestDhcpLeaseFile",
HttpClient: &client,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert); !reflect.DeepEqual(got, tt.want) {
if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert, &client); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewAgent() = %v, want %v", got, tt.want)
}
})
Expand Down
25 changes: 1 addition & 24 deletions sztp-agent/pkg/secureagent/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,13 @@ package secureagent

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"os/exec"
Expand Down Expand Up @@ -196,27 +193,7 @@ func (a *Agent) downloadAndValidateImage() error {
return err
}

caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert())
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey())

check := http.Client{
CheckRedirect: func(r *http.Request, _ []*http.Request) error {
r.URL.Opaque = r.URL.Path
return nil
},
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true, // TODO: remove skip verify
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
}

response, err := check.Get(item)
response, err := a.HttpClient.Get(item)
if err != nil {
return err
}
Expand Down
6 changes: 5 additions & 1 deletion sztp-agent/pkg/secureagent/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ func TestAgent_doReqBootstrap(t *testing.T) {
ContentTypeReq: tt.fields.ContentTypeReq,
InputJSONContent: tt.fields.InputJSONContent,
DhcpLeaseFile: tt.fields.DhcpLeaseFile,
HttpClient: &http.Client{},
}
if err := a.doRequestBootstrapServerOnboardingInfo(); (err != nil) != tt.wantErr {
t.Errorf("doRequestBootstrapServer() error = %v, wantErr %v", err, tt.wantErr)
Expand All @@ -359,7 +360,6 @@ func TestAgent_downloadAndValidateImage(t *testing.T) {
}
}))
defer svr.Close()

type fields struct {
BootstrapURL string
SerialNumber string
Expand Down Expand Up @@ -638,6 +638,7 @@ func TestAgent_downloadAndValidateImage(t *testing.T) {
},
}
for _, tt := range tests {
deleteTempTestFile(ARTIFACTS_PATH + "/imageOK")
t.Run(tt.name, func(t *testing.T) {
a := &Agent{
BootstrapURL: tt.fields.BootstrapURL,
Expand All @@ -652,6 +653,7 @@ func TestAgent_downloadAndValidateImage(t *testing.T) {
ProgressJSON: tt.fields.ProgressJSON,
BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo,
BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo,
HttpClient: svr.Client(),
}
if err := a.downloadAndValidateImage(); (err != nil) != tt.wantErr {
t.Errorf("downloadAndValidateImage() error = %v, wantErr %v", err, tt.wantErr)
Expand Down Expand Up @@ -807,6 +809,7 @@ func TestAgent_copyConfigurationFile(t *testing.T) {
ProgressJSON: tt.fields.ProgressJSON,
BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo,
BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo,
HttpClient: &http.Client{},
}
if err := a.copyConfigurationFile(); (err != nil) != tt.wantErr {
t.Errorf("copyConfigurationFile() error = %v, wantErr %v", err, tt.wantErr)
Expand Down Expand Up @@ -1024,6 +1027,7 @@ func TestAgent_launchScriptsConfiguration(t *testing.T) {
ProgressJSON: tt.fields.ProgressJSON,
BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo,
BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo,
HttpClient: &http.Client{},
}
if err := a.launchScriptsConfiguration(tt.args.typeOf); (err != nil) != tt.wantErr {
t.Errorf("launchScriptsConfiguration() error = %v, wantErr %v", err, tt.wantErr)
Expand Down
1 change: 1 addition & 0 deletions sztp-agent/pkg/secureagent/progress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ func TestAgent_doReportProgress(t *testing.T) {
InputJSONContent: tt.fields.InputJSONContent,
DhcpLeaseFile: tt.fields.DhcpLeaseFile,
ProgressJSON: tt.fields.ProgressJSON,
HttpClient: &http.Client{},
}
if err := a.doReportProgress(ProgressTypeBootstrapInitiated, "Bootstrap Initiated"); (err != nil) != tt.wantErr {
t.Errorf("doReportProgress() error = %v, wantErr %v", err, tt.wantErr)
Expand Down
18 changes: 1 addition & 17 deletions sztp-agent/pkg/secureagent/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,11 @@ package secureagent

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"io"
"log"
"net/http"
"os"
"strconv"
"strings"
)
Expand All @@ -38,20 +35,7 @@ func (a *Agent) doTLSRequest(input string, url string, empty bool) (*BootstrapSe
r.SetBasicAuth(a.GetSerialNumber(), a.GetDevicePassword())
r.Header.Add("Content-Type", a.GetContentTypeReq())

caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert())
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey())

client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{ //nolint:gosec
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
}
res, err := client.Do(r)
res, err := a.HttpClient.Do(r)
if err != nil {
log.Println("Error doing the request", err.Error())
return nil, err
Expand Down

0 comments on commit 15002b5

Please sign in to comment.