Skip to content

Commit

Permalink
fix(sztp): send ssh key when onboarding completed
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Glimcher <[email protected]>
  • Loading branch information
glimchb committed Jun 7, 2024
1 parent 2c61991 commit 8f455c5
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 13 deletions.
2 changes: 1 addition & 1 deletion sztp-agent/pkg/secureagent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ func TestAgent_SetProgressJson(t *testing.T) {
ProgressJSON: tt.fields.ProgressJSON,
}
a.SetProgressJSON(tt.args.p)
if ! reflect.DeepEqual(a.GetProgressJSON(), tt.args.p) {
if !reflect.DeepEqual(a.GetProgressJSON(), tt.args.p) {
t.Errorf("SetProgressJson = %v, want %v", a.GetProgressJSON(), tt.args.p)
}
})
Expand Down
43 changes: 32 additions & 11 deletions sztp-agent/pkg/secureagent/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (a *Agent) RunCommandDaemon() error {
if err != nil {
return err
}
// _ = a.doReportProgress(ProgressTypeBootstrapComplete)
_ = a.doReportProgress(ProgressTypeBootstrapComplete, true)
return nil
}

Expand All @@ -88,10 +88,10 @@ func (a *Agent) getBootstrapURL() error {
return nil
}

func (a *Agent) doReportProgress(s ProgressType) error {
func (a *Agent) doReportProgress(s ProgressType, needssh bool) error {
log.Println("[INFO] Starting the Report Progress request.")
url := strings.ReplaceAll(a.GetBootstrapURL(), "get-bootstrapping-data", "report-progress")
a.SetProgressJSON(ProgressJSON{
p := ProgressJSON{
IetfSztpBootstrapServerInput: struct {
ProgressType string `json:"progress-type"`
Message string `json:"message"`
Expand All @@ -105,7 +105,28 @@ func (a *Agent) doReportProgress(s ProgressType) error {
ProgressType: s.String(),
Message: "message sent via JSON",
},
})
}
if needssh {
// TODO: generate real key here
encodedKey := base64.StdEncoding.EncodeToString([]byte("mysshpass"))
p.IetfSztpBootstrapServerInput.SSHHostKeys = struct {
SSHHostKey []struct {
Algorithm string `json:"algorithm"`
KeyData string `json:"key-data"`
} `json:"ssh-host-key,omitempty"`
}{
SSHHostKey: []struct {
Algorithm string `json:"algorithm"`
KeyData string `json:"key-data"`
}{
{
Algorithm: "ssh-rsa",
KeyData: encodedKey,
},
},
}
}
a.SetProgressJSON(p)
inputJSON, _ := json.Marshal(a.GetProgressJSON())
res, err := a.doTLSRequest(string(inputJSON), url, true)
if err != nil {
Expand Down Expand Up @@ -150,7 +171,7 @@ func (a *Agent) doRequestBootstrapServerOnboardingInfo() error {
return err
}
log.Println("[INFO] Response retrieved successfully")
_ = a.doReportProgress(ProgressTypeBootstrapInitiated)
_ = a.doReportProgress(ProgressTypeBootstrapInitiated, false)
crypto := res.IetfSztpBootstrapServerOutput.ConveyedInformation
newVal, err := base64.StdEncoding.DecodeString(crypto)
if err != nil {
Expand Down Expand Up @@ -190,7 +211,7 @@ func (a *Agent) doRequestBootstrapServerOnboardingInfo() error {
//nolint:funlen
func (a *Agent) downloadAndValidateImage() error {
log.Printf("[INFO] Starting the Download Image: %v", a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI)
_ = a.doReportProgress(ProgressTypeBootImageInitiated)
_ = a.doReportProgress(ProgressTypeBootImageInitiated, false)
// Download the image from DownloadURI and save it to a file
a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference = fmt.Sprintf("%8d", time.Now().Unix())
for i, item := range a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI {
Expand Down Expand Up @@ -251,7 +272,7 @@ func (a *Agent) downloadAndValidateImage() error {
return errors.New("Checksum mismatch")
}
log.Println("[INFO] Checksum verified successfully")
_ = a.doReportProgress(ProgressTypeBootImageComplete)
_ = a.doReportProgress(ProgressTypeBootImageComplete, false)
return nil
default:
return errors.New("Unsupported hash algorithm")
Expand All @@ -262,7 +283,7 @@ func (a *Agent) downloadAndValidateImage() error {

func (a *Agent) copyConfigurationFile() error {
log.Println("[INFO] Starting the Copy Configuration.")
_ = a.doReportProgress(ProgressTypeConfigInitiated)
_ = a.doReportProgress(ProgressTypeConfigInitiated, false)
// Copy the configuration file to the device
file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + "-config")
if err != nil {
Expand All @@ -283,7 +304,7 @@ func (a *Agent) copyConfigurationFile() error {
return err
}
log.Println("[INFO] Configuration file copied successfully")
_ = a.doReportProgress(ProgressTypeConfigComplete)
_ = a.doReportProgress(ProgressTypeConfigComplete, false)
return nil
}

Expand All @@ -303,7 +324,7 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error {
reportEnd = ProgressTypePreScriptComplete
}
log.Println("[INFO] Starting the " + scriptName + "-configuration.")
_ = a.doReportProgress(reportStart)
_ = a.doReportProgress(reportStart, false)
file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + scriptName + "configuration.sh")

Check failure on line 328 in sztp-agent/pkg/secureagent/daemon.go

View workflow job for this annotation

GitHub Actions / golangci

G304: Potential file inclusion via variable (gosec)
if err != nil {
log.Println("[ERROR] creating the "+scriptName+"-configuration script", err.Error())
Expand All @@ -330,7 +351,7 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error {
return err
}
log.Println(string(out)) // remove it
_ = a.doReportProgress(reportEnd)
_ = a.doReportProgress(reportEnd, false)
log.Println("[INFO] " + scriptName + "-Configuration script executed successfully")
return nil
}
2 changes: 1 addition & 1 deletion sztp-agent/pkg/secureagent/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ func TestAgent_doReportProgress(t *testing.T) {
DhcpLeaseFile: tt.fields.DhcpLeaseFile,
ProgressJSON: tt.fields.ProgressJSON,
}
if err := a.doReportProgress(ProgressTypeBootstrapInitiated); (err != nil) != tt.wantErr {
if err := a.doReportProgress(ProgressTypeBootstrapInitiated, false); (err != nil) != tt.wantErr {
t.Errorf("doReportProgress() error = %v, wantErr %v", err, tt.wantErr)
}
})
Expand Down

0 comments on commit 8f455c5

Please sign in to comment.