Skip to content

Commit

Permalink
fix upstream test
Browse files Browse the repository at this point in the history
Signed-off-by: Sammy Oina <[email protected]>
  • Loading branch information
SammyOina committed Dec 2, 2024
1 parent 51ddcd4 commit 6b0d623
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 75 deletions.
6 changes: 3 additions & 3 deletions e2e/upstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ import (
"strconv"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/virtee/sev-snp-measure-go/guest"
"github.com/virtee/sev-snp-measure-go/ovmf"
"github.com/virtee/sev-snp-measure-go/vmmtypes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var (
Expand All @@ -47,7 +47,7 @@ func TestCompatibility(t *testing.T) {
err = json.Unmarshal(values, &expectedValues)
require.NoError(err, "unmarshalling values file: %s", err)

ovmfObj, err := ovmf.New(*binaryPath)
ovmfObj, err := ovmf.New(*binaryPath, 0)
require.NoError(err, "creating OVMF object from: %s", err)

for _, entry := range expectedValues {
Expand Down
154 changes: 85 additions & 69 deletions sevsnpmeasure/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ SPDX-License-Identifier: Apache-2.0
package cmd

import (
"encoding/base64"
"encoding/hex"
"fmt"
"os"
Expand Down Expand Up @@ -38,86 +39,21 @@ var (
snpOvmfHash string
dumpVmsa bool
svsmFile string
outputFmt string
)

// Execute executes the root command.
func Execute() error {
return NewRootCmd().Execute()
return newRootCmd().Execute()
}

// newRootCmd creates the root command.
func NewRootCmd() *cobra.Command {
func newRootCmd() *cobra.Command {
rootCmd := &cobra.Command{
Use: "sevsnpmeasure",
Short: "Calculate AMD SEV/SEV-ES/SEV-SNP guest launch measurement",
Long: "Calculate AMD SEV/SEV-ES/SEV-SNP guest launch measurement.",
Run: func(cmd *cobra.Command, args []string) {
if mode == "snp:ovmf-hash" {
hash, err := guest.CalcSnpOvmfHash(ovmfFile)
if err != nil {
fmt.Printf("Error: %v\n", err)
return
}
fmt.Printf("%s\n", hex.EncodeToString(hash))
}

if initrdFile != "" && kernelFile == "" {
fmt.Println("kernel required when initrd is provided")
return
}

if append != "" && kernelFile == "" {
fmt.Println("kernel required when append is provided")
return
}

if mode != guest.SEV.String() && vcpus == 0 {
fmt.Println("vcpus required")
return
}

vmmType := vmmtypes.VMMTypeFromString(vmmtype)
if vmmType == -1 {
fmt.Println("invalid vmm-type")
return
}

vcpuSig, err := vCPUSIG()
if err != nil {
fmt.Printf("Error: %v\n", err)
return
}

sevMode := guest.SevModeFromString(mode)
if sevMode == -1 {
fmt.Println("invalid mode")
return
}

if sevMode == guest.SEV_SNP_SVSM {
if varsFile != "" {
varsInfo, err := os.Stat(varsFile)
if err != nil {
fmt.Printf("Error: %v\n", err)
return
}
varsSize = varsInfo.Size()
}

if varsSize == 0 {
fmt.Println("SNP:SVSM mode requires vars-size")
return
}
}

ld, err := guest.CalcLaunchDigest(sevMode, vcpus, uint64(vcpuSig), ovmfFile, kernelFile, initrdFile, append, guestFeatures, snpOvmfHash, vmmType, dumpVmsa, svsmFile, int(varsSize))
if err != nil {
fmt.Printf("Error: %v\n", err)
return
}

fmt.Printf("%s\n", hex.EncodeToString(ld))
},
RunE: launchMeasurement,
}

rootCmd.Flags().StringVarP(&mode, "mode", "m", "", "Guest mode, either 'snp', 'seves', 'sev', 'snp:ovmf-hash' or 'snp:svsm'.")
Expand All @@ -140,13 +76,66 @@ func NewRootCmd() *cobra.Command {
rootCmd.Flags().StringVar(&snpOvmfHash, "snp-ovmf-hash", "", "Precalculated hash of the OVMF binary (hex string).")
rootCmd.Flags().BoolVar(&dumpVmsa, "dump-vmsa", false, "Write measured VMSAs to vmsa<N>.bin (seves, snp, and snp:svsm modes only).")
rootCmd.Flags().StringVar(&svsmFile, "svsm", "", "Path to the SVSM binary.")
rootCmd.Flags().StringVar(&outputFmt, "output-format", "hex", "Output format, either 'hex' or 'base64'.")
rootCmd.MarkFlagsMutuallyExclusive("svsm", "ovmf")

rootCmd.AddCommand(NewParseCmd())

return rootCmd
}

func launchMeasurement(cmd *cobra.Command, args []string) error {
if mode == "snp:ovmf-hash" {
hash, err := guest.CalcSnpOvmfHash(ovmfFile)
if err != nil {
return err
}
outputMeasurement(hash)
}

if err := validateFlags(); err != nil {
return err
}

vmmType := vmmtypes.VMMTypeFromString(vmmtype)
if vmmType == -1 {

return fmt.Errorf("invalid vmm-type")
}

vcpuSig, err := vCPUSIG()
if err != nil {
return err
}

sevMode := guest.SevModeFromString(mode)
if sevMode == -1 {
return fmt.Errorf("invalid mode")
}

if sevMode == guest.SEV_SNP_SVSM {
if varsFile != "" {
varsInfo, err := os.Stat(varsFile)
if err != nil {
return err
}
varsSize = varsInfo.Size()
}

if varsSize == 0 {
return fmt.Errorf("SNP:SVSM mode requires vars-size")
}
}

ld, err := guest.CalcLaunchDigest(sevMode, vcpus, uint64(vcpuSig), ovmfFile, kernelFile, initrdFile, append, guestFeatures, snpOvmfHash, vmmType, dumpVmsa, svsmFile, int(varsSize))
if err != nil {
return err
}

outputMeasurement(ld)
return nil
}

func vCPUSIG() (int, error) {
if mode == guest.SEV.String() {
return 0, nil
Expand All @@ -160,3 +149,30 @@ func vCPUSIG() (int, error) {
return -1, fmt.Errorf("missing vcpu-type or vcpu-sig or vcpu-family in guest mode %s", mode)
}
}

func validateFlags() error {
if initrdFile != "" && kernelFile == "" {
return fmt.Errorf("kernel required when initrd is provided")
}

if append != "" && kernelFile == "" {
return fmt.Errorf("kernel required when append is provided")
}

if mode != guest.SEV.String() && vcpus == 0 {
return fmt.Errorf("vcpus required")
}

return nil
}

func outputMeasurement(ld []byte) {
switch outputFmt {
case "hex":
fmt.Printf("%s\n", hex.EncodeToString(ld))
case "base64":
fmt.Printf("%s\n", base64.StdEncoding.EncodeToString(ld))
default:
fmt.Printf("Invalid output format: %s\n", outputFmt)
}
}
10 changes: 7 additions & 3 deletions vmsa/vmsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,13 @@ func New(apEip uint32, guestFeatures uint64, vcpuSig uint64, vmmType vmmtypes.VM
if err != nil {
return VMSA{}, err
}
apSaveArea, err := BuildSaveArea(apEip, guestFeatures, vcpuSig, vmmType)
if err != nil {
return VMSA{}, err
var apSaveArea SevEsSaveArea
if apEip != 0 {
apSaveArea, err = BuildSaveArea(apEip, guestFeatures, vcpuSig, vmmType)

if err != nil {
return VMSA{}, err
}
}
return VMSA{BspSaveArea: bspSaveArea, ApSaveArea: apSaveArea}, nil
}
Expand Down

0 comments on commit 6b0d623

Please sign in to comment.