diff --git a/.github/workflows/ci-go-cover.yml b/.github/workflows/ci-go-cover.yml index 94b48854..9dce92ca 100644 --- a/.github/workflows/ci-go-cover.yml +++ b/.github/workflows/ci-go-cover.yml @@ -26,7 +26,7 @@ jobs: steps: - uses: actions/setup-go@v2 with: - go-version: "1.18" + go-version: "1.19" - name: Checkout code uses: actions/checkout@v2 - name: Install mockgen diff --git a/.github/workflows/linters.yml b/.github/workflows/linters.yml index 665e7048..92368c61 100644 --- a/.github/workflows/linters.yml +++ b/.github/workflows/linters.yml @@ -10,7 +10,7 @@ jobs: steps: - uses: actions/setup-go@v2 with: - go-version: "1.18" + go-version: "1.19" - name: Checkout code uses: actions/checkout@v2 - name: Install golangci-lint diff --git a/.golangci.yml b/.golangci.yml index 44b32208..51cc9e70 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -22,9 +22,9 @@ linters-settings: - octalLiteral - paramTypeCombine - whyNoLint - - wrapperFunc + - wrapperFunc gofmt: - simplify: false + simplify: false goimports: golint: min-confidence: 0 @@ -40,24 +40,20 @@ linters-settings: linters: disable-all: true enable: - - deadcode - errcheck - goconst - gocyclo - gofmt - goimports - - golint - gosec - govet - ineffassign - - maligned - misspell + - revive - staticcheck - - structcheck - typecheck - unconvert - unused - - varcheck issues: @@ -72,7 +68,7 @@ issues: - goconst - dupl - gomnd - - lll + - lll - path: doc\.go linters: - goimports diff --git a/Makefile b/Makefile index de492694..f5e8bbdb 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,8 @@ GOPKG := github.com/veraison/corim/corim GOPKG += github.com/veraison/corim/comid GOPKG += github.com/veraison/corim/cots GOPKG += github.com/veraison/corim/cocli/cmd +GOPKG += github.com/veraison/corim/encoding +GOPKG += github.com/veraison/corim/extensions MOCKGEN := $(shell go env GOPATH)/bin/mockgen INTERFACES := cocli/cmd/isubmitter.go diff --git a/README.md b/README.md index 7dc87de8..8dd87998 100644 --- a/README.md +++ b/README.md @@ -19,3 +19,16 @@ Before requesting a PR (and routinely during the dev/test cycle), you are encour make presubmit ``` and check its output to make sure your code coverage figures are in line with the set target and that there are no newly introduced lint problems. + +## Extending CoRIM/CoMID + +The CoRIM specification provides a mechanism for adding extensions to the base +CoRIM schema. The `corim` and `comid` structs which can be extended, embed an +`Extensions` object that allows registering a wrapper structure defining +extension fields. For field types that can be extended, i.e. `type choice`, +extensions can be implemented by calling an appropriate registration function +and giving it a new type or a value (for enums). + +Please see [extensions documentation](extensions/README.md) for details. + + diff --git a/cocli/cmd/corimCreate_test.go b/cocli/cmd/corimCreate_test.go index 33f112b8..b2dfb770 100644 --- a/cocli/cmd/corimCreate_test.go +++ b/cocli/cmd/corimCreate_test.go @@ -117,7 +117,7 @@ func Test_CorimCreateCmd_with_a_bad_comid(t *testing.T) { cmd.SetArgs(args) err = cmd.Execute() - assert.EqualError(t, err, `error loading CoMID from bad-comid.cbor: cbor: unexpected "break" code`) + assert.EqualError(t, err, `error loading CoMID from bad-comid.cbor: expected map (CBOR Major Type 5), found Major Type 7`) } func Test_CorimCreateCmd_with_an_invalid_comid(t *testing.T) { @@ -138,7 +138,7 @@ func Test_CorimCreateCmd_with_an_invalid_comid(t *testing.T) { cmd.SetArgs(args) err = cmd.Execute() - assert.EqualError(t, err, `error adding CoMID from invalid-comid.cbor (check its validity using the "comid validate" sub-command)`) + assert.EqualError(t, err, `error loading CoMID from invalid-comid.cbor: missing mandatory field "Triples" (4)`) } func Test_CorimCreateCmd_with_a_bad_coswid(t *testing.T) { diff --git a/cocli/cmd/corimDisplay_test.go b/cocli/cmd/corimDisplay_test.go index 7f87cf25..5e1ccb47 100644 --- a/cocli/cmd/corimDisplay_test.go +++ b/cocli/cmd/corimDisplay_test.go @@ -76,7 +76,7 @@ func Test_CorimDisplayCmd_invalid_signed_corim(t *testing.T) { require.NoError(t, err) err = cmd.Execute() - assert.EqualError(t, err, "error decoding signed CoRIM from invalid.cbor: failed validation of unsigned CoRIM: empty id") + assert.EqualError(t, err, `error decoding signed CoRIM from invalid.cbor: failed CBOR decoding of unsigned CoRIM: unexpected EOF`) } func Test_CorimDisplayCmd_ok_top_level_view(t *testing.T) { diff --git a/cocli/cmd/corimExtract_test.go b/cocli/cmd/corimExtract_test.go index 2ec748d1..8b476c98 100644 --- a/cocli/cmd/corimExtract_test.go +++ b/cocli/cmd/corimExtract_test.go @@ -76,7 +76,7 @@ func Test_CorimExtractCmd_invalid_signed_corim(t *testing.T) { require.NoError(t, err) err = cmd.Execute() - assert.EqualError(t, err, "error decoding signed CoRIM from invalid.cbor: failed validation of unsigned CoRIM: empty id") + assert.EqualError(t, err, `error decoding signed CoRIM from invalid.cbor: failed CBOR decoding of unsigned CoRIM: unexpected EOF`) } func Test_CorimExtractCmd_ok_save_to_default_dir(t *testing.T) { diff --git a/cocli/cmd/corimSign_test.go b/cocli/cmd/corimSign_test.go index 27956a4f..bc460376 100644 --- a/cocli/cmd/corimSign_test.go +++ b/cocli/cmd/corimSign_test.go @@ -91,7 +91,7 @@ func Test_CorimSignCmd_bad_unsigned_corim(t *testing.T) { require.NoError(t, err) err = cmd.Execute() - assert.EqualError(t, err, "error decoding unsigned CoRIM from bad.txt: unexpected EOF") + assert.EqualError(t, err, "error decoding unsigned CoRIM from bad.txt: expected map (CBOR Major Type 5), found Major Type 3") } func Test_CorimSignCmd_invalid_unsigned_corim(t *testing.T) { @@ -109,7 +109,7 @@ func Test_CorimSignCmd_invalid_unsigned_corim(t *testing.T) { require.NoError(t, err) err = cmd.Execute() - assert.EqualError(t, err, "error validating CoRIM: tags validation failed: no tags") + assert.EqualError(t, err, `error decoding unsigned CoRIM from invalid.cbor: missing mandatory field "Tags" (1)`) } func Test_CorimSignCmd_non_existent_meta_file(t *testing.T) { diff --git a/comid/attestverifkey_test.go b/comid/attestverifkey_test.go index 659a0abc..20c791d9 100644 --- a/comid/attestverifkey_test.go +++ b/comid/attestverifkey_test.go @@ -23,16 +23,17 @@ func TestAttestVerifKey_Valid_empty(t *testing.T) { testerr: "environment validation failed: environment must not be empty", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewUEIDInstance(TestUEID)}, verifkey: CryptoKeys{}, testerr: "verification keys validation failed: no keys to validate", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewUEIDInstance(TestUEID)}, verifkey: CryptoKeys{&invalidKey}, testerr: "verification keys validation failed: invalid key at index 0: key value not set", }, } + for _, tv := range tvs { av := AttestVerifKey{Environment: tv.env, VerifKeys: tv.verifkey} err := av.Valid() diff --git a/comid/cbor.go b/comid/cbor.go index c44d518d..ab8a9233 100644 --- a/comid/cbor.go +++ b/comid/cbor.go @@ -4,6 +4,7 @@ package comid import ( + "fmt" "reflect" cbor "github.com/fxamacker/cbor/v2" @@ -12,16 +13,14 @@ import ( var ( em, emError = initCBOREncMode() dm, dmError = initCBORDecMode() -) -func comidTags() cbor.TagSet { - comidTagsMap := map[uint64]interface{}{ + comidTagsMap = map[uint64]interface{}{ 32: TaggedURI(""), 37: TaggedUUID{}, 111: TaggedOID{}, // CoMID tags 550: TaggedUEID{}, - //551: To Do see: https://github.com/veraison/corim/issues/32 + 551: TaggedInt(0), 552: TaggedSVN(0), 553: TaggedMinSVN(0), 554: TaggedPKIXBase64Key(""), @@ -37,7 +36,9 @@ func comidTags() cbor.TagSet { 601: TaggedPSARefValID{}, 602: TaggedCCAPlatformConfigID(""), } +) +func comidTags() cbor.TagSet { opts := cbor.TagOptions{ EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired, @@ -69,6 +70,28 @@ func initCBORDecMode() (dm cbor.DecMode, err error) { return decOpt.DecModeWithTags(comidTags()) } +func registerCOMIDTag(tag uint64, t interface{}) error { + if _, exists := comidTagsMap[tag]; exists { + return fmt.Errorf("tag %d is already registered", tag) + } + + comidTagsMap[tag] = t + + var err error + + em, err = initCBOREncMode() + if err != nil { + return err + } + + dm, err = initCBORDecMode() + if err != nil { + return err + } + + return nil +} + func init() { if emError != nil { panic(emError) diff --git a/comid/ccaplatformconfigid.go b/comid/ccaplatformconfigid.go index e485083f..a55271ff 100644 --- a/comid/ccaplatformconfigid.go +++ b/comid/ccaplatformconfigid.go @@ -3,11 +3,16 @@ package comid -import "fmt" +import ( + "encoding/json" + "errors" + "fmt" + "unicode/utf8" +) -type CCAPlatformConfigID string +var CCAPlatformConfigIDType = "cca.platform-config-id" -type TaggedCCAPlatformConfigID CCAPlatformConfigID +type CCAPlatformConfigID string func (o CCAPlatformConfigID) Empty() bool { return o == "" @@ -27,3 +32,67 @@ func (o CCAPlatformConfigID) Get() (CCAPlatformConfigID, error) { } return o, nil } + +type TaggedCCAPlatformConfigID CCAPlatformConfigID + +func NewTaggedCCAPlatformConfigID(val any) (*TaggedCCAPlatformConfigID, error) { + var ret TaggedCCAPlatformConfigID + + if val == nil { + return &ret, nil + } + + switch t := val.(type) { + case TaggedCCAPlatformConfigID: + ret = t + case *TaggedCCAPlatformConfigID: + ret = *t + case CCAPlatformConfigID: + ret = TaggedCCAPlatformConfigID(t) + case *CCAPlatformConfigID: + ret = TaggedCCAPlatformConfigID(*t) + case string: + ret = TaggedCCAPlatformConfigID(t) + case []byte: + if !utf8.Valid(t) { + return nil, errors.New("bytes do not form a valid UTF-8 string") + } + ret = TaggedCCAPlatformConfigID(t) + default: + return nil, fmt.Errorf("unexpected type for CCA platform-config-id: %T", t) + } + + return &ret, nil +} + +func (o TaggedCCAPlatformConfigID) Valid() error { + if o == "" { + return errors.New("empty value") + } + + return nil +} + +func (o TaggedCCAPlatformConfigID) String() string { + return string(o) +} + +func (o TaggedCCAPlatformConfigID) Type() string { + return CCAPlatformConfigIDType +} + +func (o TaggedCCAPlatformConfigID) IsZero() bool { + return len(o) == 0 +} + +func (o *TaggedCCAPlatformConfigID) UnmarshalJSON(data []byte) error { + var temp string + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + *o = TaggedCCAPlatformConfigID(temp) + + return nil +} diff --git a/comid/ccaplatformconfigid_test.go b/comid/ccaplatformconfigid_test.go index b0a0a37c..543390fd 100644 --- a/comid/ccaplatformconfigid_test.go +++ b/comid/ccaplatformconfigid_test.go @@ -29,3 +29,67 @@ func TestCCAPlatformConfigID_Get_nok(t *testing.T) { _, err := cca.Get() assert.EqualError(t, err, expectedErr) } + +func TestNewTaggedCCAPlatformConfigID(t *testing.T) { + testID := TaggedCCAPlatformConfigID("test") + untagged := CCAPlatformConfigID("test") + + for _, tv := range []struct { + Name string + Input any + Err string + Expected TaggedCCAPlatformConfigID + }{ + { + Name: "TaggedCCAPlatformConfigID ok", + Input: testID, + Expected: testID, + }, + { + Name: "*TaggedCCAPlatformConfigID ok", + Input: &testID, + Expected: testID, + }, + { + Name: "CCAPlatformConfigID ok", + Input: untagged, + Expected: testID, + }, + { + Name: "*CCAPlatformConfigID ok", + Input: &untagged, + Expected: testID, + }, + { + Name: "string ok", + Input: "test", + Expected: testID, + }, + { + Name: "[]byte ok", + Input: []byte{0x74, 0x65, 0x73, 0x74}, + Expected: testID, + }, + { + Name: "[]byte not ok", + Input: []byte{0x80, 0x65, 0x73, 0x74}, + Err: "bytes do not form a valid UTF-8 string", + }, + { + Name: "bad type", + Input: 7, + Err: "unexpected type for CCA platform-config-id: int", + }, + } { + t.Run(tv.Name, func(t *testing.T) { + out, err := NewTaggedCCAPlatformConfigID(tv.Input) + + if tv.Err != "" { + assert.Nil(t, out) + assert.EqualError(t, err, tv.Err) + } else { + assert.Equal(t, tv.Expected, *out) + } + }) + } +} diff --git a/comid/class.go b/comid/class.go index 085e703b..ab414740 100644 --- a/comid/class.go +++ b/comid/class.go @@ -23,39 +23,33 @@ type Class struct { // NewClassUUID instantiates a new Class object with the specified UUID as // identifier func NewClassUUID(uuid UUID) *Class { - c := Class{ - ClassID: &ClassID{}, - } - - if c.ClassID.SetUUID(uuid) == nil { + classID, err := NewUUIDClassID(uuid) + if err != nil { return nil } - return &c + + return &Class{ClassID: classID} } // NewClassImplID instantiates a new Class object that identifies the specified PSA // Implementation ID func NewClassImplID(implID ImplID) *Class { - c := Class{ - ClassID: &ClassID{}, - } - - if c.ClassID.SetImplID(implID) == nil { + classID, err := NewImplIDClassID(implID) + if err != nil { return nil } - return &c + + return &Class{ClassID: classID} } // NewClassOID instantiates a new Class object that identifies the OID func NewClassOID(oid string) *Class { - c := Class{ - ClassID: &ClassID{}, - } - - if c.ClassID.SetOID(oid) == nil { + classID, err := NewOIDClassID(oid) + if err != nil { return nil } - return &c + + return &Class{ClassID: classID} } // SetVendor sets the vendor metadata to the supplied string @@ -131,7 +125,7 @@ func (o *Class) SetIndex(index uint64) *Class { // Valid checks the non-empty<> constraint on the map func (o Class) Valid() error { // check non-empty<{ ... }> - if (o.ClassID == nil || o.ClassID.Unset()) && + if (o.ClassID == nil || !o.ClassID.IsSet()) && o.Vendor == nil && o.Model == nil && o.Layer == nil && o.Index == nil { return fmt.Errorf("class must not be empty") } diff --git a/comid/classid.go b/comid/classid.go index 760f8e5d..acad47d5 100644 --- a/comid/classid.go +++ b/comid/classid.go @@ -5,245 +5,383 @@ package comid import ( "encoding/base64" + "encoding/binary" "encoding/json" + "errors" "fmt" + "strconv" + + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" ) -// ClassID represents a $class-id-type-choice, which can be one of TaggedUUID, -// TaggedOID, or TaggedImplID (PSA-specific extension) +// ClassID identifies the environment via a well-known identifier. This can be +// an OID, a UUID, or a profile-defined extension type. type ClassID struct { - val interface{} + Value IClassIDValue } -type ClassIDType uint16 +// NewClassID creates a new ClassID of the specified type using the specified value. +func NewClassID(val any, typ string) (*ClassID, error) { + factory, ok := classIDValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unknown class id type: %s", typ) + } -const ( - ClassIDTypeUUID = ClassIDType(iota) - ClassIDTypeImplID - ClassIDTypeOID + return factory(val) +} - ClassIDTypeUnknown = ^ClassIDType(0) -) +// Valid returns nil if the ClassID is valid, or an error describing the +// problem, if it is not. +func (o ClassID) Valid() error { + if o.Value == nil { + return errors.New("nil value") + } -// SetUUID sets the value of the targed ClassID to the supplied UUID -func (o *ClassID) SetUUID(uuid UUID) *ClassID { - if o != nil { - o.val = TaggedUUID(uuid) + return o.Value.Valid() +} + +// Type returns the type of the ClassID +func (o ClassID) Type() string { + if o.Value == nil { + return "" } - return o + + return o.Value.Type() } -type ImplID [32]byte -type TaggedImplID ImplID +// Bytes returns a []byte containing the raw bytes of the class id value +func (o ClassID) Bytes() []byte { + if o.Value == nil { + return []byte{} + } + return o.Value.Bytes() +} -func (o ImplID) MarshalJSON() ([]byte, error) { - return json.Marshal(o[:]) +// IsSet returns true iff the underlying class id value has been set (is not nil) +func (o ClassID) IsSet() bool { + return o.Value != nil } -func (o *ImplID) UnmarshalJSON(data []byte) error { - var b []byte +// MarshalCBOR serializes the target ClassID to CBOR +func (o ClassID) MarshalCBOR() ([]byte, error) { + return em.Marshal(o.Value) +} - if err := json.Unmarshal(data, &b); err != nil { - return fmt.Errorf("bad ImplID: %w", err) +// UnmarshalCBOR deserializes the supplied CBOR buffer into the target ClassID. +// It is undefined behavior to try and inspect the target ClassID in case this +// method returns an error. +func (o *ClassID) UnmarshalCBOR(data []byte) error { + return dm.Unmarshal(data, &o.Value) +} + +// UnmarshalJSON deserializes the supplied JSON object into the target ClassID +// The class id object must have the following shape: +// +// { +// "type": "", +// "value": +// } +// +// where must be one of the known IClassIDValue implementation +// type names (available in this implementation: "uuid", "oid", +// "psa.impl-id", "int"), and is the JSON encoding of the underlying +// class id value. The exact encoding is dependent. For the base +// implementation types it is +// +// oid: dot-separated integers, e.g. "1.2.3.4" +// psa.impl-id: base64-encoded bytes, e.g. "YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=" +// uuid: standard UUID string representation, e.g. "550e8400-e29b-41d4-a716-446655440000" +// int: an integer value, e.g. 7 +func (o *ClassID) UnmarshalJSON(data []byte) error { + var tnv encoding.TypeAndValue + + if err := json.Unmarshal(data, &tnv); err != nil { + return fmt.Errorf("class id decoding failure: %w", err) } - if nb := len(b); nb != 32 { - return fmt.Errorf("bad ImplID format: got %d bytes, want 32", nb) + decoded, err := NewClassID(nil, tnv.Type) + if err != nil { + return err + } + + if err := json.Unmarshal(tnv.Value, &decoded.Value); err != nil { + return fmt.Errorf( + "cannot unmarshal class id: %w", + err, + ) + } + + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) } - copy(o[:], b) + o.Value = decoded.Value return nil } -type TaggedOID OID +// MarshalJSON serializes the target ClassID to JSON +func (o ClassID) MarshalJSON() ([]byte, error) { + return extensions.TypeChoiceValueMarshalJSON(o.Value) +} -// SetImplID sets the value of the targed ClassID to the supplied PSA -// Implementation ID (see Section 3.2.2 of draft-tschofenig-rats-psa-token) -func (o *ClassID) SetImplID(implID ImplID) *ClassID { - if o != nil { - o.val = TaggedImplID(implID) +// String returns a printable string of the ClassID value. UUIDs use the +// canonical 8-4-4-4-12 format, PSA Implementation IDs are base64 encoded. +// OIDs are output in dotted-decimal notation. +func (o ClassID) String() string { + if o.Value == nil { + return "" } - return o + + return o.Value.String() } -func (o ClassID) GetImplID() (ImplID, error) { - switch t := o.val.(type) { +type IClassIDValue interface { + extensions.ITypeChoiceValue + + Bytes() []byte +} + +const ImplIDType = "psa.impl-id" + +type ImplID [32]byte + +func (o ImplID) String() string { + return base64.StdEncoding.EncodeToString(o[:]) +} + +func (o ImplID) Valid() error { + return nil +} + +type TaggedImplID ImplID + +func NewImplIDClassID(val any) (*ClassID, error) { + var ret TaggedImplID + + if val == nil { + return &ClassID{&TaggedImplID{}}, nil + } + + switch t := val.(type) { + case []byte: + if nb := len(t); nb != 32 { + return nil, fmt.Errorf("bad psa.impl-id: got %d bytes, want 32", nb) + } + + copy(ret[:], t) + case string: + v, err := base64.StdEncoding.DecodeString(t) + if err != nil { + return nil, fmt.Errorf("bad psa.impl-id: %w", err) + } + + if nb := len(v); nb != 32 { + return nil, fmt.Errorf("bad psa.impl-id: decoded %d bytes, want 32", nb) + } + + copy(ret[:], v) case TaggedImplID: - return ImplID(t), nil + copy(ret[:], t[:]) + case *TaggedImplID: + copy(ret[:], (*t)[:]) + case ImplID: + copy(ret[:], t[:]) + case *ImplID: + copy(ret[:], (*t)[:]) default: - return ImplID{}, fmt.Errorf("class-id type is: %T", t) + return nil, fmt.Errorf("unexpected type for psa.impl-id: %T", t) } + + return &ClassID{&ret}, nil } -// SetOID sets the value of the targed ClassID to the supplied OID. -// The OID is a string in dotted-decimal notation -func (o *ClassID) SetOID(s string) *ClassID { - if o != nil { - var berOID OID - if berOID.FromString(s) != nil { - return nil - } - o.val = TaggedOID(berOID) +func MustNewImplIDClassID(val any) *ClassID { + ret, err := NewImplIDClassID(val) + if err != nil { + panic(err) } - return o + + return ret } -// MarshalCBOR serializes the target ClassID to CBOR -func (o ClassID) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) +func (o TaggedImplID) Valid() error { + return ImplID(o).Valid() } -// UnmarshalCBOR deserializes the supplied CBOR buffer into the target ClassID. -// It is undefined behavior to try and inspect the target ClassID in case this -// method returns an error. -func (o *ClassID) UnmarshalCBOR(data []byte) error { - var implID TaggedImplID +func (o TaggedImplID) String() string { + return ImplID(o).String() +} - if dm.Unmarshal(data, &implID) == nil { - o.val = implID - return nil +func (o TaggedImplID) Type() string { + return ImplIDType +} + +func (o TaggedImplID) Bytes() []byte { + return o[:] +} + +func (o *TaggedImplID) MarshalJSON() ([]byte, error) { + return json.Marshal((*o)[:]) +} + +func (o *TaggedImplID) UnmarshalJSON(data []byte) error { + var out []byte + if err := json.Unmarshal(data, &out); err != nil { + return err } - var uuid TaggedUUID + if len(out) != 32 { + return fmt.Errorf("bad psa.impl-id: decoded %d bytes, want 32", len(out)) + } - if dm.Unmarshal(data, &uuid) == nil { - o.val = uuid - return nil + copy((*o)[:], out) + + return nil +} + +func NewOIDClassID(val any) (*ClassID, error) { + ret, err := NewTaggedOID(val) + if err != nil { + return nil, err } - var oid TaggedOID + return &ClassID{ret}, nil +} - if dm.Unmarshal(data, &oid) == nil { - o.val = oid - return nil +func MustNewOIDClassID(val any) *ClassID { + ret, err := NewOIDClassID(val) + if err != nil { + panic(err) } - return fmt.Errorf("unknown class id (CBOR: %x)", data) + return ret } -// UnmarshalJSON deserializes the supplied JSON object into the target ClassID -// The class id object must have one of the following shapes: -// -// UUID: -// -// { -// "type": "uuid", -// "value": "69E027B2-7157-4758-BCB4-D9F167FE49EA" -// } -// -// OID: -// -// { -// "type": "oid", -// "value": "2.16.840.1.113741.1.15.4.2" -// } -// -// PSA Implementation ID: -// -// { -// "type": "psa.impl-id", -// "value": "YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=" -// } -func (o *ClassID) UnmarshalJSON(data []byte) error { - var v tnv +func NewUUIDClassID(val any) (*ClassID, error) { + if val == nil { + return &ClassID{&TaggedUUID{}}, nil + } - if err := json.Unmarshal(data, &v); err != nil { - return err + ret, err := NewTaggedUUID(val) + if err != nil { + return nil, err } - switch v.Type { - case "uuid": // nolint: goconst - var x UUID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedUUID(x) - case "oid": - var x OID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedOID(x) - case "psa.impl-id": - var x ImplID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedImplID(x) - default: - return fmt.Errorf("unknown type '%s' for class id", v.Type) + return &ClassID{ret}, nil +} + +func MustNewUUIDClassID(val any) *ClassID { + ret, err := NewUUIDClassID(val) + if err != nil { + panic(err) } - return nil + return ret } -// MarshalJSON serializes the target ClassID to JSON -func (o ClassID) MarshalJSON() ([]byte, error) { - var ( - v tnv - b []byte - err error - ) - - switch t := o.val.(type) { - case TaggedUUID: - b, err = UUID(t).MarshalJSON() - if err != nil { - return nil, err - } - v = tnv{Type: "uuid", Value: b} - case TaggedOID: - b, err = OID(t).MarshalJSON() +const IntType = "int" + +type TaggedInt int + +func NewIntClassID(val any) (*ClassID, error) { + if val == nil { + zeroVal := TaggedInt(0) + return &ClassID{&zeroVal}, nil + } + + var ret TaggedInt + + switch t := val.(type) { + case string: + i, err := strconv.Atoi(t) if err != nil { - return nil, err + return nil, fmt.Errorf("bad int: %w", err) } - v = tnv{Type: "oid", Value: b} - case TaggedImplID: - b, err = ImplID(t).MarshalJSON() - if err != nil { - return nil, err + ret = TaggedInt(i) + case []byte: + if len(t) != 8 { + return nil, fmt.Errorf("bad int: want 8 bytes, got %d bytes", len(t)) } - v = tnv{Type: "psa.impl-id", Value: b} + ret = TaggedInt(binary.BigEndian.Uint64(t)) + case int: + ret = TaggedInt(t) + case *int: + ret = TaggedInt(*t) + case int64: + ret = TaggedInt(t) + case *int64: + ret = TaggedInt(*t) + case uint64: + ret = TaggedInt(t) + case *uint64: + ret = TaggedInt(*t) default: - return nil, fmt.Errorf("unknown type %T for class-id", t) + return nil, fmt.Errorf("unexpected type for int: %T", t) } - return json.Marshal(v) + if err := ret.Valid(); err != nil { + return nil, err + } + + return &ClassID{&ret}, nil } -// Type returns the type of the target ClassID, i.e., one of UUID, OID or PSA -// Implementation ID -func (o ClassID) Type() ClassIDType { - switch o.val.(type) { - case TaggedUUID: - return ClassIDTypeUUID - case TaggedImplID: - return ClassIDTypeImplID - case TaggedOID: - return ClassIDTypeOID - } - return ClassIDTypeUnknown +func (o TaggedInt) String() string { + return fmt.Sprint(int(o)) } -// String returns a printable string of the ClassID value. UUIDs use the -// canonical 8-4-4-4-12 format, PSA Implementation IDs are base64 encoded. -// OIDs are output in dotted-decimal notation. -func (o ClassID) String() string { - switch t := o.val.(type) { - case TaggedUUID: - return UUID(t).String() - case TaggedImplID: - b := [32]byte(t) - return base64.StdEncoding.EncodeToString(b[:]) - case TaggedOID: - return OID(t).String() - default: - return "" - } +func (o TaggedInt) Valid() error { + return nil +} + +func (o TaggedInt) Type() string { + return "int" } -// Unset tests whether the target ClassID has been initialized -func (o ClassID) Unset() bool { - return o.val == nil || o.Type() == ClassIDTypeUnknown +func (o TaggedInt) Bytes() []byte { + var ret [8]byte + binary.BigEndian.PutUint64(ret[:], uint64(o)) + return ret[:] +} + +// IClassIDFactory defines the signature for the factory functions that may be +// registred using RegisterClassIDType to provide a new implementation of the +// corresponding type choice. The factory function should create a new *ClassID +// with the underlying value created based on the provided input. The range of +// valid inputs is up to the specific type choice implementation, however it +// _must_ accept nil as one of the inputs, and return the Zero value for +// implemented type. +// See also https://go.dev/ref/spec#The_zero_value +type IClassIDFactory func(any) (*ClassID, error) + +var classIDValueRegister = map[string]IClassIDFactory{ + OIDType: NewOIDClassID, + UUIDType: NewUUIDClassID, + IntType: NewIntClassID, + + ImplIDType: NewImplIDClassID, +} + +// RegisterClassIDType registers a new IClassIDValue implementation (created +// by the provided IClassIDFactory) under the specified CBOR tag. +func RegisterClassIDType(tag uint64, factory IClassIDFactory) error { + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Type() + if _, exists := classIDValueRegister[typ]; exists { + return fmt.Errorf("class ID type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + classIDValueRegister[typ] = factory + + return nil } diff --git a/comid/classid_test.go b/comid/classid_test.go index 51b07391..aaeeb393 100644 --- a/comid/classid_test.go +++ b/comid/classid_test.go @@ -4,6 +4,8 @@ package comid import ( + "encoding/binary" + "encoding/json" "fmt" "testing" @@ -12,9 +14,7 @@ import ( ) func TestClassID_MarshalCBOR_UUID(t *testing.T) { - var tv ClassID - - require.NotNil(t, tv.SetUUID(TestUUID)) + tv := MustNewUUIDClassID(TestUUID) // 37(h'31FB5ABF023E4992AA4E95F9C1503BFA') // tag(37): d8 25 @@ -29,9 +29,7 @@ func TestClassID_MarshalCBOR_UUID(t *testing.T) { } func TestClassID_MarshalCBOR_ImplID(t *testing.T) { - var tv ClassID - - require.NotNil(t, tv.SetImplID(TestImplID)) + tv := MustNewImplIDClassID(TestImplID) // 600 (h'61636D652D696D706C656D656E746174696F6E2D69642D303030303030303031') // tag(600): d9 0258 @@ -66,7 +64,7 @@ func TestClassID_UnmarshalCBOR_UUID_OK(t *testing.T) { err := actual.UnmarshalCBOR(tv) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeUUID, actual.Type()) + assert.Equal(t, "uuid", actual.Type()) assert.Equal(t, TestUUIDString, actual.String()) } @@ -79,7 +77,7 @@ func TestClassID_UnmarshalCBOR_ImplID_OK(t *testing.T) { err := actual.UnmarshalCBOR(tv) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeImplID, actual.Type()) + assert.Equal(t, "psa.impl-id", actual.Type()) assert.Equal(t, expected, actual.String()) } @@ -88,12 +86,10 @@ func TestClassID_UnmarshalCBOR_badInput(t *testing.T) { hex := "582061636d652d696d706c656d656e746174696f6e2d69642d303030303030303031" tv := MustHexDecode(t, hex) - expectedError := fmt.Sprintf("unknown class id (CBOR: %s)", hex) - var actual ClassID err := actual.UnmarshalCBOR(tv) - assert.EqualError(t, err, expectedError) + assert.EqualError(t, err, "cbor: cannot unmarshal byte string into Go value of type comid.IClassIDValue") } func TestClassID_UnmarshalJSON_UUID(t *testing.T) { @@ -105,7 +101,7 @@ func TestClassID_UnmarshalJSON_UUID(t *testing.T) { err := actual.UnmarshalJSON([]byte(tv)) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeUUID, actual.Type()) + assert.Equal(t, "uuid", actual.Type()) assert.Equal(t, TestUUIDString, actual.String()) } @@ -122,7 +118,7 @@ func TestClassID_UnmarshalJSON_ImplID(t *testing.T) { err := actual.UnmarshalJSON([]byte(tv)) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeImplID, actual.Type()) + assert.Equal(t, "psa.impl-id", actual.Type()) // the returned string is the base64 encoding of the stored binary assert.Equal(t, expected, actual.String()) } @@ -133,8 +129,8 @@ func TestClassID_UnmarshalJSON_badInput_unknown_type(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "unknown type 'FOOBAR' for class id") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "unknown class id type: FOOBAR") + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_missing_value(t *testing.T) { @@ -143,8 +139,8 @@ func TestClassID_UnmarshalJSON_badInput_missing_value(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "bad ImplID: unexpected end of JSON input") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "class id decoding failure: no value provided for psa.impl-id") + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_empty_value(t *testing.T) { @@ -153,8 +149,8 @@ func TestClassID_UnmarshalJSON_badInput_empty_value(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "bad ImplID format: got 0 bytes, want 32") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "cannot unmarshal class id: bad psa.impl-id: decoded 0 bytes, want 32") + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_badly_encoded_ImplID_value(t *testing.T) { @@ -163,8 +159,8 @@ func TestClassID_UnmarshalJSON_badInput_badly_encoded_ImplID_value(t *testing.T) var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "bad ImplID: illegal base64 data at input byte 0") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "cannot unmarshal class id: illegal base64 data at input byte 0") + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_badly_encoded_UUID_value(t *testing.T) { @@ -173,8 +169,8 @@ func TestClassID_UnmarshalJSON_badInput_badly_encoded_UUID_value(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "bad UUID: invalid UUID length: 9") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "cannot unmarshal class id: bad UUID: invalid UUID length: 9") + assert.Equal(t, "", actual.Type()) } func TestClassID_SetOID_ok(t *testing.T) { @@ -200,8 +196,7 @@ func TestClassID_SetOID_ok(t *testing.T) { } for _, tv := range tvs { - c := ClassID{} - assert.NotNil(t, c.SetOID(tv)) + c := MustNewOIDClassID(tv) assert.Equal(t, tv, c.String()) } } @@ -219,7 +214,230 @@ func TestClassID_SetOID_bad(t *testing.T) { } for _, tv := range tvs { - c := ClassID{} - assert.Nil(t, c.SetOID(tv)) + c, err := NewOIDClassID(tv) + assert.NotNil(t, err) + assert.Nil(t, c) + } +} + +func Test_NewImplIDClassID(t *testing.T) { + classID, err := NewImplIDClassID(nil) + expected := [32]byte{} + require.NoError(t, err) + assert.Equal(t, expected[:], classID.Bytes()) + + taggedImplID := TaggedImplID(TestImplID) + + for _, v := range []any{ + TestImplID, + &TestImplID, + taggedImplID, + &taggedImplID, + taggedImplID.Bytes(), + } { + classID, err = NewImplIDClassID(v) + require.NoError(t, err) + assert.Equal(t, taggedImplID.Bytes(), classID.Bytes()) + } + + expected = [32]byte{ + 0x61, 0x63, 0x6d, 0x65, 0x2d, 0x69, 0x6d, 0x70, + 0x6c, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x2d, 0x69, 0x64, 0x2d, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x31, + } + classID, err = NewImplIDClassID("YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=") + require.NoError(t, err) + assert.Equal(t, expected[:], classID.Bytes()) + + _, err = NewImplIDClassID(7) + assert.EqualError(t, err, "unexpected type for psa.impl-id: int") +} + +func Test_NewUUIDClassID(t *testing.T) { + classID, err := NewUUIDClassID(nil) + + expected := [16]byte{} + require.NoError(t, err) + assert.Equal(t, expected[:], classID.Bytes()) + + taggedUUID := TaggedUUID(TestUUID) + + for _, v := range []any{ + TestUUID, + &TestUUID, + taggedUUID, + &taggedUUID, + taggedUUID.Bytes(), + } { + classID, err = NewUUIDClassID(v) + require.NoError(t, err) + assert.Equal(t, taggedUUID.Bytes(), classID.Bytes()) + } + + classID, err = NewUUIDClassID(taggedUUID.String()) + require.NoError(t, err) + assert.Equal(t, taggedUUID.Bytes(), classID.Bytes()) +} + +func Test_NewOIDClassID(t *testing.T) { + classID, err := NewOIDClassID(nil) + + expected := []byte{} + require.NoError(t, err) + assert.Equal(t, expected, classID.Bytes()) + + var oid OID + require.NoError(t, oid.FromString(TestOID)) + taggedOID := TaggedOID(oid) + + for _, v := range []any{ + TestOID, + oid, + &oid, + taggedOID, + &taggedOID, + taggedOID.Bytes(), + } { + classID, err = NewOIDClassID(v) + require.NoError(t, err) + expected := taggedOID.Bytes() + got := classID.Bytes() + assert.Equal(t, expected, got) + } + + classID, err = NewOIDClassID(taggedOID.String()) + require.NoError(t, err) + assert.Equal(t, taggedOID.Bytes(), classID.Bytes()) +} + +func Test_NewIntClassID(t *testing.T) { + classID, err := NewIntClassID(nil) + require.NoError(t, err) + assert.Equal(t, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, classID.Bytes()) + + testInt := 7 + testInt64 := int64(7) + testUint64 := uint64(7) + + var testBytes [8]byte + binary.BigEndian.PutUint64(testBytes[:], testUint64) + + for _, v := range []any{ + testInt, + &testInt, + testInt64, + &testInt64, + testUint64, + &testUint64, + "7", + testBytes[:], + } { + classID, err = NewIntClassID(v) + require.NoError(t, err) + got := classID.Bytes() + assert.Equal(t, testBytes[:], got) + } +} + +func Test_TaggedInt(t *testing.T) { + val := TaggedInt(7) + assert.Equal(t, "7", val.String()) + assert.Equal(t, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07}, val.Bytes()) + assert.Equal(t, "int", val.Type()) + assert.NoError(t, val.Valid()) + + classID := ClassID{&val} + + bytes, err := em.Marshal(classID) + require.NoError(t, err) + assert.Equal(t, []byte{ + 0xd9, 0x02, 0x27, // tag 551 + 0x07, // int 7 + }, bytes) + + var out ClassID + err = dm.Unmarshal(bytes, &out) + require.NoError(t, err) + assert.Equal(t, classID, out) + + jsonBytes, err := json.Marshal(classID) + require.NoError(t, err) + assert.Equal(t, `{"type":"int","value":7}`, string(jsonBytes)) + + out = ClassID{} + err = json.Unmarshal(jsonBytes, &out) + require.NoError(t, err) + assert.Equal(t, classID, out) +} + +type testClassID [4]byte + +func newTestClassID(val any) (*ClassID, error) { + return &ClassID{&testClassID{0x74, 0x65, 0x73, 0x74}}, nil +} + +func (o testClassID) Bytes() []byte { + return o[:] +} + +func (o testClassID) Type() string { + return "test-class-id" +} + +func (o testClassID) String() string { + return "test" +} + +func (o testClassID) Valid() error { + return nil +} + +func (o testClassID) MarshalJSON() ([]byte, error) { + return json.Marshal(o.String()) +} + +func (o *testClassID) UnmarshalJSON(data []byte) error { + var out string + if err := json.Unmarshal(data, &out); err != nil { + return err } + + if len(out) != 4 { + return fmt.Errorf("bad testClassID: decoded %d bytes, want 4", len(out)) + } + + copy((*o)[:], []byte(out)) + + return nil +} + +func Test_RegisterClassIDType(t *testing.T) { + err := RegisterClassIDType(99999, newTestClassID) + require.NoError(t, err) + + classID, err := newTestClassID(nil) + require.NoError(t, err) + + data, err := json.Marshal(classID) + require.NoError(t, err) + assert.Equal(t, string(data), `{"type":"test-class-id","value":"test"}`) + + var out ClassID + err = json.Unmarshal(data, &out) + require.NoError(t, err) + assert.Equal(t, classID.Bytes(), out.Bytes()) + + data, err = em.Marshal(classID) + require.NoError(t, err) + assert.Equal(t, data, []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9f, // tag 99999 + 0x44, // bstr(4) + 0x74, 0x65, 0x73, 0x74, // "test" + }) + + var out2 ClassID + err = dm.Unmarshal(data, &out2) + require.NoError(t, err) + assert.Equal(t, classID.Bytes(), out2.Bytes()) } diff --git a/comid/comid.go b/comid/comid.go index ea9779c7..5052cc0c 100644 --- a/comid/comid.go +++ b/comid/comid.go @@ -8,6 +8,8 @@ import ( "fmt" "net/url" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" "github.com/veraison/swid" ) @@ -19,6 +21,8 @@ type Comid struct { Entities *Entities `cbor:"2,keyasint,omitempty" json:"entities,omitempty"` LinkedTags *LinkedTags `cbor:"3,keyasint,omitempty" json:"linked-tags,omitempty"` Triples Triples `cbor:"4,keyasint" json:"triples"` + + Extensions } // NewComid instantiates an empty Comid @@ -26,6 +30,16 @@ func NewComid() *Comid { return &Comid{} } +// RegisterExtensions registers a struct as a collections of extensions +func (o *Comid) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *Comid) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + // SetLanguage sets the language used in the target Comid to the supplied // language tag. See also: BCP 47 and the IANA Language subtag registry. func (o *Comid) SetLanguage(language string) *Comid { @@ -106,7 +120,7 @@ func (o *Comid) AddEntity(name string, regID *string, roles ...Role) *Comid { } e := Entity{ - EntityName: name, + EntityName: MustNewStringEntityName(name), RegID: uri, Roles: rs, } @@ -229,7 +243,7 @@ func (o Comid) Valid() error { return fmt.Errorf("triples validation failed: %w", err) } - return nil + return o.Extensions.validComid(&o) } // ToCBOR serializes the target Comid to CBOR @@ -238,17 +252,17 @@ func (o Comid) ToCBOR() ([]byte, error) { return nil, err } - return em.Marshal(&o) + return encoding.SerializeStructToCBOR(em, &o) } // FromCBOR deserializes a CBOR-encoded CoMID into the target Comid func (o *Comid) FromCBOR(data []byte) error { - return dm.Unmarshal(data, o) + return encoding.PopulateStructFromCBOR(dm, data, o) } // FromJSON deserializes a JSON-encoded CoMID into the target Comid func (o *Comid) FromJSON(data []byte) error { - return json.Unmarshal(data, o) + return encoding.PopulateStructFromJSON(data, o) } // ToJSON serializes the target Comid to JSON @@ -257,7 +271,7 @@ func (o Comid) ToJSON() ([]byte, error) { return nil, err } - return json.Marshal(&o) + return encoding.SerializeStructToJSON(&o) } func (o Comid) ToJSONPretty(indent string) ([]byte, error) { diff --git a/comid/comid_test.go b/comid/comid_test.go new file mode 100644 index 00000000..a8cdf198 --- /dev/null +++ b/comid/comid_test.go @@ -0,0 +1,51 @@ +package comid + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/veraison/swid" +) + +func Test_Comid_Extensions(t *testing.T) { + c := NewComid() + assert.Nil(t, c.GetExtensions()) + assert.Equal(t, "", c.MustGetString("field-one")) + + err := c.Set("field-one", "foo") + assert.EqualError(t, err, "extension not found: field-one") + + type ComidExt struct { + FieldOne string `cbor:"-1,keyasint" json:"field-one"` + } + + c.RegisterExtensions(&ComidExt{}) + + err = c.Set("field-one", "foo") + assert.NoError(t, err) + assert.Equal(t, "foo", c.MustGetString("-1")) +} + +func Test_Comid_ToJSONPretty(t *testing.T) { + c := NewComid() + + _, err := c.ToJSONPretty(" ") + assert.EqualError(t, err, "tag-identity validation failed: empty tag-id") + + c.TagIdentity = TagIdentity{TagID: *swid.NewTagID("test"), TagVersion: 1} + c.Triples = Triples{ReferenceValues: &[]ReferenceValue{}} + + expected := `{ + "tag-identity": { + "id": "test", + "version": 1 + }, + "triples": { + "reference-values": [] + } +}` + v, err := c.ToJSONPretty(" ") + require.NoError(t, err) + assert.Equal(t, expected, string(v)) +} diff --git a/comid/cryptokey.go b/comid/cryptokey.go index 53d99df8..3133b95e 100644 --- a/comid/cryptokey.go +++ b/comid/cryptokey.go @@ -14,6 +14,8 @@ import ( "fmt" "github.com/fxamacker/cbor/v2" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" "github.com/veraison/go-cose" "github.com/veraison/swid" ) @@ -58,50 +60,12 @@ type CryptoKey struct { // specified crypto key type. For PKIX types, k must be a string. For COSE_Key, // k must be a []byte. For thumbprint types, k must be a swid.HashEntry. func NewCryptoKey(k any, typ string) (*CryptoKey, error) { - switch typ { - case PKIXBase64KeyType: - v, ok := k.(string) - if !ok { - return nil, fmt.Errorf("value must be a string; found %T", k) - } - return NewPKIXBase64Key(v) - case PKIXBase64CertType: - v, ok := k.(string) - if !ok { - return nil, fmt.Errorf("value must be a string; found %T", k) - } - return NewPKIXBase64Cert(v) - case PKIXBase64CertPathType: - v, ok := k.(string) - if !ok { - return nil, fmt.Errorf("value must be a string; found %T", k) - } - return NewPKIXBase64CertPath(v) - case COSEKeyType: - v, ok := k.([]byte) - if !ok { - return nil, fmt.Errorf("value must be a []byte; found %T", k) - } - return NewCOSEKey(v) - case ThumbprintType, CertThumbprintType, CertPathThumbprintType: - v, ok := k.(swid.HashEntry) - if !ok { - return nil, fmt.Errorf("value must be a swid.HashEntry; found %T", k) - } - switch typ { - case ThumbprintType: - return NewThumbprint(v) - case CertThumbprintType: - return NewCertThumbprint(v) - case CertPathThumbprintType: - return NewCertPathThumbprint(v) - default: - // Should never here because of the the outer case clause - panic(fmt.Sprintf("unexpected thumbprint type: %s", typ)) - } - default: + factory, ok := cryptoKeyValueRegister[typ] + if !ok { return nil, fmt.Errorf("unexpected CryptoKey type: %s", typ) } + + return factory(k) } // MustNewCryptoKey is the same as NewCryptoKey, but does not return an error, @@ -126,6 +90,11 @@ func (o CryptoKey) Valid() error { return o.Value.Valid() } +// Type returns the type of the CryptoKey value +func (o CryptoKey) Type() string { + return o.Value.Type() +} + // PublicKey returns a crypto.PublicKey constructed from the CryptoKey's // underlying value. This returns an error if the CryptoKey is one of the // thumbprint types. @@ -136,30 +105,14 @@ func (o CryptoKey) PublicKey() (crypto.PublicKey, error) { // MarshalJSON returns a []byte containing the JSON representation of the // CryptoKey. func (o CryptoKey) MarshalJSON() ([]byte, error) { - value := struct { - Type string `json:"type"` - Value string `json:"value"` - }{ - Value: o.Value.String(), - } - - switch o.Value.(type) { - case TaggedPKIXBase64Key: - value.Type = PKIXBase64KeyType - case TaggedPKIXBase64Cert: - value.Type = PKIXBase64CertType - case TaggedPKIXBase64CertPath: - value.Type = PKIXBase64CertPathType - case TaggedCOSEKey: - value.Type = COSEKeyType - case TaggedThumbprint: - value.Type = ThumbprintType - case TaggedCertThumbprint: - value.Type = CertThumbprintType - case TaggedCertPathThumbprint: - value.Type = CertPathThumbprintType - default: - return nil, fmt.Errorf("unexpected ICryptoKeyValue type: %T", o.Value) + valueBytes, err := json.Marshal(o.Value.String()) + if err != nil { + return nil, err + } + + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: valueBytes, } return json.Marshal(value) @@ -168,10 +121,7 @@ func (o CryptoKey) MarshalJSON() ([]byte, error) { // UnmarshalJSON populates the CryptoKey from the JSON representation inside // the provided []byte. func (o *CryptoKey) UnmarshalJSON(b []byte) error { - var value struct { - Type string `json:"type"` - Value string `json:"value"` - } + var value encoding.TypeAndValue if err := json.Unmarshal(b, &value); err != nil { return err @@ -181,36 +131,23 @@ func (o *CryptoKey) UnmarshalJSON(b []byte) error { return errors.New("key type not set") } - switch value.Type { - case PKIXBase64KeyType: - o.Value = TaggedPKIXBase64Key(value.Value) - case PKIXBase64CertType: - o.Value = TaggedPKIXBase64Cert(value.Value) - case PKIXBase64CertPathType: - o.Value = TaggedPKIXBase64CertPath(value.Value) - case COSEKeyType: - data, err := base64.StdEncoding.DecodeString(value.Value) - if err != nil { - return fmt.Errorf("base64 decode error: %w", err) - } - o.Value = TaggedCOSEKey(data) - case ThumbprintType, CertThumbprintType, CertPathThumbprintType: - he, err := swid.ParseHashEntry(value.Value) - if err != nil { - return fmt.Errorf("swid.HashEntry decode error: %w", err) - } - switch value.Type { - case ThumbprintType: - o.Value = TaggedThumbprint{digest{he}} - case CertThumbprintType: - o.Value = TaggedCertThumbprint{digest{he}} - case CertPathThumbprintType: - o.Value = TaggedCertPathThumbprint{digest{he}} - } - default: + factory, ok := cryptoKeyValueRegister[value.Type] + if !ok { return fmt.Errorf("unexpected ICryptoKeyValue type: %q", value.Type) } + var valueString string + if err := json.Unmarshal(value.Value, &valueString); err != nil { + return err + } + + k, err := factory(valueString) + if err != nil { + return err + } + + o.Value = k.Value + return o.Valid() } @@ -229,11 +166,8 @@ func (o *CryptoKey) UnmarshalCBOR(b []byte) error { // ICryptoKeyValue is the interface implemented by the concrete CryptoKey value // types. type ICryptoKeyValue interface { - // String returns the string representation of the ICryptoKeyValue. - String() string - // Valid returns an error if validation of the ICryptoKeyValue fails, - // or nil if it succeeds. - Valid() error + extensions.ITypeChoiceValue + // PublicKey returns a crypto.PublicKey constructed from the // ICryptoKeyValue's underlying value. This returns an error if the // ICryptoKeyValue is one of the thumbprint types. @@ -244,7 +178,12 @@ type ICryptoKeyValue interface { // https://www.rfc-editor.org/rfc/rfc7468#section-13 type TaggedPKIXBase64Key string -func NewPKIXBase64Key(s string) (*CryptoKey, error) { +func NewPKIXBase64Key(k any) (*CryptoKey, error) { + s, ok := k.(string) + if !ok { + return nil, fmt.Errorf("value must be a string; found %T", k) + } + key := TaggedPKIXBase64Key(s) if err := key.Valid(); err != nil { return nil, err @@ -252,8 +191,8 @@ func NewPKIXBase64Key(s string) (*CryptoKey, error) { return &CryptoKey{key}, nil } -func MustNewPKIXBase64Key(s string) *CryptoKey { - key, err := NewPKIXBase64Key(s) +func MustNewPKIXBase64Key(k any) *CryptoKey { + key, err := NewPKIXBase64Key(k) if err != nil { panic(err) } @@ -269,6 +208,10 @@ func (o TaggedPKIXBase64Key) Valid() error { return err } +func (o TaggedPKIXBase64Key) Type() string { + return PKIXBase64KeyType +} + func (o TaggedPKIXBase64Key) PublicKey() (crypto.PublicKey, error) { if string(o) == "" { return nil, errors.New("key value not set") @@ -302,7 +245,12 @@ func (o TaggedPKIXBase64Key) PublicKey() (crypto.PublicKey, error) { // certificate. See https://www.rfc-editor.org/rfc/rfc7468#section-5 type TaggedPKIXBase64Cert string -func NewPKIXBase64Cert(s string) (*CryptoKey, error) { +func NewPKIXBase64Cert(k any) (*CryptoKey, error) { + s, ok := k.(string) + if !ok { + return nil, fmt.Errorf("value must be a string; found %T", k) + } + cert := TaggedPKIXBase64Cert(s) if err := cert.Valid(); err != nil { return nil, err @@ -310,8 +258,8 @@ func NewPKIXBase64Cert(s string) (*CryptoKey, error) { return &CryptoKey{cert}, nil } -func MustNewPKIXBase64Cert(s string) *CryptoKey { - cert, err := NewPKIXBase64Cert(s) +func MustNewPKIXBase64Cert(k any) *CryptoKey { + cert, err := NewPKIXBase64Cert(k) if err != nil { panic(err) } @@ -327,6 +275,10 @@ func (o TaggedPKIXBase64Cert) Valid() error { return err } +func (o TaggedPKIXBase64Cert) Type() string { + return PKIXBase64CertType +} + func (o TaggedPKIXBase64Cert) PublicKey() (crypto.PublicKey, error) { cert, err := o.cert() if err != nil { @@ -375,7 +327,11 @@ func (o TaggedPKIXBase64Cert) cert() (*x509.Certificate, error) { // directly certifies the one preceding. type TaggedPKIXBase64CertPath string -func NewPKIXBase64CertPath(s string) (*CryptoKey, error) { +func NewPKIXBase64CertPath(k any) (*CryptoKey, error) { + s, ok := k.(string) + if !ok { + return nil, fmt.Errorf("value must be a string; found %T", k) + } cert := TaggedPKIXBase64CertPath(s) if err := cert.Valid(); err != nil { @@ -385,8 +341,8 @@ func NewPKIXBase64CertPath(s string) (*CryptoKey, error) { return &CryptoKey{cert}, nil } -func MustNewPKIXBase64CertPath(s string) *CryptoKey { - cert, err := NewPKIXBase64CertPath(s) +func MustNewPKIXBase64CertPath(k any) *CryptoKey { + cert, err := NewPKIXBase64CertPath(k) if err != nil { panic(err) @@ -404,6 +360,10 @@ func (o TaggedPKIXBase64CertPath) Valid() error { return err } +func (o TaggedPKIXBase64CertPath) Type() string { + return PKIXBase64CertPathType +} + func (o TaggedPKIXBase64CertPath) PublicKey() (crypto.PublicKey, error) { certs, err := o.certPath() if err != nil { @@ -468,7 +428,22 @@ func (o TaggedPKIXBase64CertPath) certPath() ([]*x509.Certificate, error) { // https://www.rfc-editor.org/rfc/rfc9052#section-7 type TaggedCOSEKey []byte -func NewCOSEKey(b []byte) (*CryptoKey, error) { +func NewCOSEKey(k any) (*CryptoKey, error) { + var b []byte + var err error + + switch t := k.(type) { + case []byte: + b = t + case string: + b, err = base64.StdEncoding.DecodeString(t) + if err != nil { + return nil, fmt.Errorf("base64 decode error: %w", err) + } + default: + return nil, fmt.Errorf("value must be a []byte or a string; found %T", k) + } + key := TaggedCOSEKey(b) if err := key.Valid(); err != nil { @@ -478,8 +453,8 @@ func NewCOSEKey(b []byte) (*CryptoKey, error) { return &CryptoKey{key}, nil } -func MustNewCOSEKey(b []byte) *CryptoKey { - key, err := NewCOSEKey(b) +func MustNewCOSEKey(k any) *CryptoKey { + key, err := NewCOSEKey(k) if err != nil { panic(err) @@ -509,6 +484,10 @@ func (o TaggedCOSEKey) Valid() error { return err } +func (o TaggedCOSEKey) Type() string { + return COSEKeyType +} + func (o TaggedCOSEKey) PublicKey() (crypto.PublicKey, error) { if len(o) == 0 { return nil, errors.New("empty COSE_Key value") @@ -608,7 +587,22 @@ type TaggedThumbprint struct { digest } -func NewThumbprint(he swid.HashEntry) (*CryptoKey, error) { +func NewThumbprint(k any) (*CryptoKey, error) { + var he swid.HashEntry + var err error + + switch t := k.(type) { + case string: + he, err = swid.ParseHashEntry(t) + if err != nil { + return nil, fmt.Errorf("swid.HashEntry decode error: %w", err) + } + case swid.HashEntry: + he = t + default: + return nil, fmt.Errorf("value must be a swid.HashEntry or a string; found %T", k) + } + key := &CryptoKey{TaggedThumbprint{digest{he}}} if err := key.Valid(); err != nil { @@ -618,8 +612,8 @@ func NewThumbprint(he swid.HashEntry) (*CryptoKey, error) { return key, nil } -func MustNewThumbprint(he swid.HashEntry) *CryptoKey { - key, err := NewThumbprint(he) +func MustNewThumbprint(k any) *CryptoKey { + key, err := NewThumbprint(k) if err != nil { panic(err) @@ -628,13 +622,32 @@ func MustNewThumbprint(he swid.HashEntry) *CryptoKey { return key } +func (o TaggedThumbprint) Type() string { + return ThumbprintType +} + // TaggedCertThumbprint represents a digest of a certificate. The digest value // may be used to find the certificate if contained in a lookup table. type TaggedCertThumbprint struct { digest } -func NewCertThumbprint(he swid.HashEntry) (*CryptoKey, error) { +func NewCertThumbprint(k any) (*CryptoKey, error) { + var he swid.HashEntry + var err error + + switch t := k.(type) { + case string: + he, err = swid.ParseHashEntry(t) + if err != nil { + return nil, fmt.Errorf("swid.HashEntry decode error: %w", err) + } + case swid.HashEntry: + he = t + default: + return nil, fmt.Errorf("value must be a swid.HashEntry or a string; found %T", k) + } + key := &CryptoKey{TaggedCertThumbprint{digest{he}}} if err := key.Valid(); err != nil { @@ -644,8 +657,8 @@ func NewCertThumbprint(he swid.HashEntry) (*CryptoKey, error) { return key, nil } -func MustNewCertThumbprint(he swid.HashEntry) *CryptoKey { - key, err := NewCertThumbprint(he) +func MustNewCertThumbprint(k any) *CryptoKey { + key, err := NewCertThumbprint(k) if err != nil { panic(err) @@ -654,6 +667,10 @@ func MustNewCertThumbprint(he swid.HashEntry) *CryptoKey { return key } +func (o TaggedCertThumbprint) Type() string { + return CertThumbprintType +} + // TaggedCertPathThumbprint represents a digest of a certification path. The // digest value may be used to find the certificate path if contained in a // lookup table. @@ -661,7 +678,22 @@ type TaggedCertPathThumbprint struct { digest } -func NewCertPathThumbprint(he swid.HashEntry) (*CryptoKey, error) { +func NewCertPathThumbprint(k any) (*CryptoKey, error) { + var he swid.HashEntry + var err error + + switch t := k.(type) { + case string: + he, err = swid.ParseHashEntry(t) + if err != nil { + return nil, fmt.Errorf("swid.HashEntry decode error: %w", err) + } + case swid.HashEntry: + he = t + default: + return nil, fmt.Errorf("value must be a swid.HashEntry or a string; found %T", k) + } + key := &CryptoKey{TaggedCertPathThumbprint{digest{he}}} if err := key.Valid(); err != nil { @@ -671,8 +703,8 @@ func NewCertPathThumbprint(he swid.HashEntry) (*CryptoKey, error) { return key, nil } -func MustNewCertPathThumbprint(he swid.HashEntry) *CryptoKey { - key, err := NewCertPathThumbprint(he) +func MustNewCertPathThumbprint(k any) *CryptoKey { + key, err := NewCertPathThumbprint(k) if err != nil { panic(err) @@ -680,3 +712,52 @@ func MustNewCertPathThumbprint(he swid.HashEntry) *CryptoKey { return key } + +func (o TaggedCertPathThumbprint) Type() string { + return CertPathThumbprintType +} + +// ICryptoKeyFactory defines the signature for the factory functions that may be +// registred using RegisterCryptoKeyType to provide a new implementation of the +// corresponding type choice. The factory function should create a new *CryptoKey +// with the underlying value created based on the provided input. The range of +// valid inputs is up to the specific type choice implementation, however it +// _must_ accept nil as one of the inputs, and return the Zero value for +// implemented type. +// See also https://go.dev/ref/spec#The_zero_value +type ICryptoKeyFactory func(any) (*CryptoKey, error) + +var cryptoKeyValueRegister = map[string]ICryptoKeyFactory{ + // types defined by the core spec + PKIXBase64KeyType: NewPKIXBase64Key, + PKIXBase64CertType: NewPKIXBase64Cert, + PKIXBase64CertPathType: NewPKIXBase64CertPath, + COSEKeyType: NewCOSEKey, + ThumbprintType: NewThumbprint, + CertThumbprintType: NewCertThumbprint, + CertPathThumbprintType: NewCertPathThumbprint, +} + +// RegisterCryptoKeyType registers a new ICryptoKeyValue implementation +// (created by the provided ICryptoKeyFactory) under the specified type name +// and CBOR tag. +func RegisterCryptoKeyType(tag uint64, factory ICryptoKeyFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Type() + if _, exists := cryptoKeyValueRegister[typ]; exists { + return fmt.Errorf("crypto key type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + cryptoKeyValueRegister[typ] = factory + + return nil +} diff --git a/comid/cryptokey_test.go b/comid/cryptokey_test.go index e1a3b4f8..8a53edeb 100644 --- a/comid/cryptokey_test.go +++ b/comid/cryptokey_test.go @@ -4,6 +4,7 @@ package comid import ( + "crypto" "encoding/base64" "encoding/json" "fmt" @@ -119,7 +120,7 @@ func Test_CryptoKey_NewCOSEKey(t *testing.T) { assert.EqualError(t, err, "empty COSE_Key bytes") _, err = NewCOSEKey([]byte("DEADBEEF")) - assert.Contains(t, err.Error(), "cbor: cannot unmarshal") + assert.Contains(t, err.Error(), "cbor: 3 bytes of extraneous data starting at index 5") badKey := []byte{ // taken from go-cose unit tests 0xa2, // map(2) @@ -149,7 +150,7 @@ func Test_CryptoKey_NewCOSEKey(t *testing.T) { } func Test_CryptoKey_NewThumbprint(t *testing.T) { - type newKeyFunc func(swid.HashEntry) (*CryptoKey, error) + type newKeyFunc func(any) (*CryptoKey, error) for _, newFunc := range []newKeyFunc{ NewThumbprint, @@ -177,7 +178,7 @@ func Test_CryptoKey_NewThumbprint(t *testing.T) { assert.Contains(t, err.Error(), "length mismatch for hash algorithm") } - type mustNewKeyFunc func(swid.HashEntry) *CryptoKey + type mustNewKeyFunc func(any) *CryptoKey for _, mustNewFunc := range []mustNewKeyFunc{ MustNewThumbprint, @@ -271,7 +272,7 @@ func Test_CryptoKey_UnmarshalJSON_negative(t *testing.T) { }, { Val: `{"value":"deadbeef"}`, - ErrMsg: "key type not set", + ErrMsg: "type not set", }, { Val: `{"type": "cose-key", "value":";;;"}`, @@ -371,22 +372,22 @@ func Test_NewCryptoKey_negative(t *testing.T) { { Type: COSEKeyType, In: 7, - ErrMsg: "value must be a []byte; found int", + ErrMsg: "value must be a []byte or a string; found int", }, { Type: ThumbprintType, In: 7, - ErrMsg: "value must be a swid.HashEntry; found int", + ErrMsg: "value must be a swid.HashEntry or a string; found int", }, { Type: CertThumbprintType, In: 7, - ErrMsg: "value must be a swid.HashEntry; found int", + ErrMsg: "value must be a swid.HashEntry or a string; found int", }, { Type: CertPathThumbprintType, In: 7, - ErrMsg: "value must be a swid.HashEntry; found int", + ErrMsg: "value must be a swid.HashEntry or a string; found int", }, { Type: "random-key", @@ -399,3 +400,55 @@ func Test_NewCryptoKey_negative(t *testing.T) { assert.ErrorContains(t, err, tv.ErrMsg) } } + +type testCryptoKey [4]byte + +func newTestCryptoKey(val any) (*CryptoKey, error) { + return &CryptoKey{&testCryptoKey{0x74, 0x64, 0x73, 0x74}}, nil +} + +func (o testCryptoKey) PublicKey() (crypto.PublicKey, error) { + return crypto.PublicKey(o[:]), nil +} + +func (o testCryptoKey) Type() string { + return "test-crypto-key" +} + +func (o testCryptoKey) String() string { + return "test" +} + +func (o testCryptoKey) Valid() error { + return nil +} + +func Test_RegisterCryptoKey(t *testing.T) { + err := RegisterCryptoKeyType(99998, newTestCryptoKey) + require.NoError(t, err) + + key, err := newTestCryptoKey(nil) + require.NoError(t, err) + + data, err := json.Marshal(key) + require.NoError(t, err) + assert.Equal(t, string(data), `{"type":"test-crypto-key","value":"test"}`) + + var out CryptoKey + err = json.Unmarshal(data, &out) + require.NoError(t, err) + assert.EqualValues(t, key, &out) + + data, err = em.Marshal(key) + require.NoError(t, err) + assert.Equal(t, data, []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9e, // tag 99998 + 0x44, // bstr(4) + 0x74, 0x64, 0x73, 0x74, // "test" + }) + + var out2 CryptoKey + err = dm.Unmarshal(data, &out2) + require.NoError(t, err) + assert.Equal(t, key, &out2) +} diff --git a/comid/devidentitykey_test.go b/comid/devidentitykey_test.go index 14f2f6e1..b12f6a41 100644 --- a/comid/devidentitykey_test.go +++ b/comid/devidentitykey_test.go @@ -23,12 +23,12 @@ func TestDevIdentityKey_Valid_empty(t *testing.T) { testerr: "environment validation failed: environment must not be empty", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewUEIDInstance(TestUEID)}, verifkey: CryptoKeys{}, testerr: "verification keys validation failed: no keys to validate", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewUEIDInstance(TestUEID)}, verifkey: CryptoKeys{&invalidKey}, testerr: "verification keys validation failed: invalid key at index 0: key value not set", }, diff --git a/comid/entity.go b/comid/entity.go index 694fef6c..911c2111 100644 --- a/comid/entity.go +++ b/comid/entity.go @@ -3,31 +3,47 @@ package comid -import "fmt" +import ( + "encoding/json" + "errors" + "fmt" + "unicode/utf8" -type TaggedURI string - -func (o TaggedURI) Empty() bool { - return o == "" -} + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" +) // Entity stores an entity-map capable of CBOR and JSON serializations. type Entity struct { - EntityName string `cbor:"0,keyasint" json:"name"` - RegID *TaggedURI `cbor:"1,keyasint,omitempty" json:"regid,omitempty"` - Roles Roles `cbor:"2,keyasint" json:"roles"` + EntityName *EntityName `cbor:"0,keyasint" json:"name"` + RegID *TaggedURI `cbor:"1,keyasint,omitempty" json:"regid,omitempty"` + Roles Roles `cbor:"2,keyasint" json:"roles"` + + Extensions +} + +// RegisterExtensions registers a struct as a collections of extensions +func (o *Entity) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) } +// GetExtensions returns pervisouosly registered extension +func (o *Entity) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + +// SetEntityName is used to set the EntityName field of Entity using supplied name func (o *Entity) SetEntityName(name string) *Entity { if o != nil { if name == "" { return nil } - o.EntityName = name + o.EntityName = MustNewStringEntityName(name) } return o } +// SetRegID is used to set the RegID field of Entity using supplied uri func (o *Entity) SetRegID(uri string) *Entity { if o != nil { if uri == "" { @@ -39,6 +55,7 @@ func (o *Entity) SetRegID(uri string) *Entity { return o } +// SetRoles appends the supplied roles to the target entity. func (o *Entity) SetRoles(roles ...Role) *Entity { if o != nil { o.Roles.Add(roles...) @@ -46,11 +63,16 @@ func (o *Entity) SetRoles(roles ...Role) *Entity { return o } +// Valid checks for validity of the fields within each Entity func (o Entity) Valid() error { - if o.EntityName == "" { + if o.EntityName == nil { return fmt.Errorf("invalid entity: empty entity-name") } + if err := o.EntityName.Valid(); err != nil { + return fmt.Errorf("invalid entity: %w", err) + } + if o.RegID != nil && o.RegID.Empty() { return fmt.Errorf("invalid entity: empty reg-id") } @@ -59,7 +81,27 @@ func (o Entity) Valid() error { return fmt.Errorf("invalid entity: %w", err) } - return nil + return o.Extensions.validEntity(&o) +} + +// UnmarshalCBOR deserializes from CBOR +func (o *Entity) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *Entity) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *Entity) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *Entity) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) } // Entities is an array of entity-map's @@ -78,6 +120,7 @@ func (o *Entities) AddEntity(e Entity) *Entities { return o } +// Valid iterates over the range of individual entities to check for validity func (o Entities) Valid() error { for i, m := range o { if err := m.Valid(); err != nil { @@ -86,3 +129,219 @@ func (o Entities) Valid() error { } return nil } + +// EntityName encapsulates the name of the associated Entity. The CoRIM +// specification only allows for text (string) name, but this may be extended +// by other specifications. +type EntityName struct { + Value IEntityNameValue +} + +// NewEntityName creates a new EntityName of the specified type using the +// provided value. +func NewEntityName(val any, typ string) (*EntityName, error) { + factory, ok := entityNameValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unexpected entity name type: %s", typ) + } + + return factory(val) +} + +// MustNewEntityName is like NewEntityName, except it doesn't return an error, +// assuming that the provided value is valid. It panics if that isn't the case. +func MustNewEntityName(val any, typ string) *EntityName { + ret, err := NewEntityName(val, typ) + if err != nil { + panic(err) + } + + return ret +} + +func (o EntityName) String() string { + return o.Value.String() +} + +func (o EntityName) Valid() error { + if o.Value == nil { + return errors.New("empty entity name") + } + + return o.Value.Valid() +} + +func (o EntityName) MarshalCBOR() ([]byte, error) { + if err := o.Valid(); err != nil { + return nil, err + } + + return em.Marshal(o.Value) +} + +func (o *EntityName) UnmarshalCBOR(data []byte) error { + if len(data) == 0 { + return errors.New("empty") + } + + majorType := (data[0] & 0xe0) >> 5 + if majorType == 3 { // text string + var text string + + if err := dm.Unmarshal(data, &text); err != nil { + return err + } + + name := StringEntityName(text) + o.Value = &name + + return nil + } + + return dm.Unmarshal(data, &o.Value) +} + +func (o EntityName) MarshalJSON() ([]byte, error) { + if err := o.Valid(); err != nil { + return nil, err + } + + if o.Value.Type() == extensions.StringType { + return json.Marshal(o.Value.String()) + } + + return extensions.TypeChoiceValueMarshalJSON(o.Value) +} + +func (o *EntityName) UnmarshalJSON(data []byte) error { + var text string + if err := json.Unmarshal(data, &text); err == nil { + *o = *MustNewStringEntityName(text) + return nil + } + + var tnv encoding.TypeAndValue + + if err := json.Unmarshal(data, &tnv); err != nil { + return fmt.Errorf("entity name decoding failure: %w", err) + } + + decoded, err := NewEntityName(nil, tnv.Type) + if err != nil { + return err + } + + if err := json.Unmarshal(tnv.Value, &decoded.Value); err != nil { + return fmt.Errorf( + "cannot unmarshal entity name: %w", + err, + ) + } + + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) + } + + o.Value = decoded.Value + + return nil +} + +type IEntityNameValue interface { + extensions.ITypeChoiceValue +} + +type StringEntityName string + +func NewStringEntityName(val any) (*EntityName, error) { + var ret StringEntityName + + if val == nil { + ret = StringEntityName("") + return &EntityName{&ret}, nil + } + + switch t := val.(type) { + case string: + ret = StringEntityName(t) + case []byte: + if !utf8.Valid(t) { + return nil, errors.New("bytes do not form a valid UTF-8 string") + } + + ret = StringEntityName(t) + default: + return nil, fmt.Errorf("unexpected type for string entity name: %T", t) + } + + return &EntityName{&ret}, nil +} + +func MustNewStringEntityName(val any) *EntityName { + ret, err := NewStringEntityName(val) + if err != nil { + panic(err) + } + + return ret +} + +func (o StringEntityName) String() string { + return string(o) +} + +func (o StringEntityName) Type() string { + return extensions.StringType +} + +func (o StringEntityName) Valid() error { + if o == "" { + return errors.New("empty entity-name") + } + + return nil +} + +// IEntityNameFactory defines the signature for the factory functions that may +// be registred using RegisterEntityNameType to provide a new implementation of +// the corresponding type choice. The factory function should create a new +// *EntityName with the underlying value created based on the provided input. +// The range of valid inputs is up to the specific type choice implementation, +// however it _must_ accept nil as one of the inputs, and return the Zero value +// for implemented type. +// See also https://go.dev/ref/spec#The_zero_value +type IEntityNameFactory func(any) (*EntityName, error) + +var entityNameValueRegister = map[string]IEntityNameFactory{ + extensions.StringType: NewStringEntityName, +} + +// RegisterEntityNameType registers a new IEntityNameValue implementation +// (created by the provided IEntityNameFactory) under the specified type name +// and CBOR tag. +func RegisterEntityNameType(tag uint64, factory IEntityNameFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Value.Type() + if _, exists := entityNameValueRegister[typ]; exists { + return fmt.Errorf("entity name type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + entityNameValueRegister[typ] = factory + + return nil +} + +type TaggedURI string + +func (o TaggedURI) Empty() bool { + return o == "" +} diff --git a/comid/entity_test.go b/comid/entity_test.go index 61c0d1b3..d634a2f4 100644 --- a/comid/entity_test.go +++ b/comid/entity_test.go @@ -4,6 +4,8 @@ package comid import ( + "errors" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -83,3 +85,185 @@ func TestEntity_SetRegID_empty(t *testing.T) { assert.Nil(t, e.SetRegID("")) } + +type testEntityName uint64 + +func newTestEntityName(val any) (*EntityName, error) { + if val == nil { + v := testEntityName(0) + return &EntityName{&v}, nil + } + + u, ok := val.(uint64) + if !ok { + return nil, errors.New("must be uint64") + } + + v := testEntityName(u) + return &EntityName{&v}, nil +} + +func (o testEntityName) Type() string { + return "test" +} + +func (o testEntityName) String() string { + return fmt.Sprint(uint64(o)) +} + +func (o testEntityName) Valid() error { + return nil +} + +type testEntityNameBadType struct { + testEntityName +} + +func newTestEntityNameBadType(val any) (*EntityName, error) { + v := testEntityNameBadType{testEntityName(7)} + return &EntityName{&v}, nil +} + +func (o testEntityNameBadType) Type() string { + return "string" +} + +func Test_RegisterEntityNameType(t *testing.T) { + err := RegisterEntityNameType(32, newTestEntityName) + assert.EqualError(t, err, "tag 32 is already registered") + + err = RegisterEntityNameType(99994, newTestEntityNameBadType) + assert.EqualError(t, err, `entity name type with name "string" already exists`) + + registerTestEntityNameType(t) +} + +// Since there only one, untagged, entity name type in the core spec, we use +// the test type define above in order to test the marshaling code works +// properly. Since global environment is not reset when running multiple tests, +// we cannot simply call RegisterEntityNameType() inside each test that relies +// on the test type, as that will cause the "tag already registered" error. On +// the other hand, we do not want to create inter-test dependencies by relying +// that the test registering the type is run before the others that rely on it. +// To get around this, use this global flag to only register the test type if a +// previous test hasn't already done so. +var testEntityNameTypeRegistered = false + +func registerTestEntityNameType(t *testing.T) { + if !testEntityNameTypeRegistered { + err := RegisterEntityNameType(99994, newTestEntityName) + require.NoError(t, err) + + testEntityNameTypeRegistered = true + } +} + +func TestEntityName_CBOR(t *testing.T) { + registerTestEntityNameType(t) + + for _, tv := range []struct { + Value any + Type string + ExpectedBytes []byte + ExpectedString string + }{ + { + Value: "test", + Type: "string", + ExpectedBytes: []byte{ + 0x64, // tstr(4) + 0x74, 0x65, 0x73, 0x74, // "test" + }, + ExpectedString: "test", + }, + { + Value: uint64(7), + Type: "test", + ExpectedBytes: []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9a, // tag 99994 + 0x07, // unsigned int(7) + }, + ExpectedString: "7", + }, + } { + t.Run(tv.Type, func(t *testing.T) { + en, err := NewEntityName(tv.Value, tv.Type) + require.NoError(t, err) + + data, err := en.MarshalCBOR() + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedBytes, data) + + var out EntityName + + err = out.UnmarshalCBOR(data) + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedString, out.String()) + }) + } +} + +func TestEntityName_JSON(t *testing.T) { + registerTestEntityNameType(t) + + for _, tv := range []struct { + Value any + Type string + ExpectedBytes []byte + ExpectedString string + }{ + { + Value: "test", + Type: "string", + ExpectedBytes: []byte(`"test"`), + ExpectedString: "test", + }, + { + Value: uint64(7), + Type: "test", + ExpectedBytes: []byte(`{"type":"test","value":7}`), + ExpectedString: "7", + }, + } { + t.Run(tv.Type, func(t *testing.T) { + en, err := NewEntityName(tv.Value, tv.Type) + require.NoError(t, err) + + data, err := en.MarshalJSON() + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedBytes, data) + + var out EntityName + + err = out.UnmarshalJSON(data) + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedString, out.String()) + }) + } +} + +func Test_NewStringEntityName(t *testing.T) { + out, err := NewStringEntityName(nil) + require.NoError(t, err) + assert.EqualError(t, out.Valid(), "empty entity-name") + + out, err = NewStringEntityName([]byte("test")) + require.NoError(t, err) + assert.Equal(t, "test", out.String()) + + _, err = NewStringEntityName(7) + assert.EqualError(t, err, "unexpected type for string entity name: int") +} + +func Test_MustNewEntityName(t *testing.T) { + out := MustNewEntityName("test", "string") + assert.Equal(t, "test", out.String()) + + assert.Panics(t, func() { + MustNewEntityName(7, "int") + }) +} diff --git a/comid/environment_test.go b/comid/environment_test.go index c078ce90..f33d1ebf 100644 --- a/comid/environment_test.go +++ b/comid/environment_test.go @@ -46,7 +46,7 @@ func TestEnvironment_Valid_empty_group(t *testing.T) { err := tv.Valid() - assert.EqualError(t, err, "group validation failed: invalid group id") + assert.EqualError(t, err, "group validation failed: no value set") } func TestEnvironment_Valid_ok_with_class(t *testing.T) { tv := Environment{ @@ -78,7 +78,7 @@ func TestEnvironment_ToCBOR_class_only(t *testing.T) { func TestEnvironment_ToCBOR_class_and_instance(t *testing.T) { tv := Environment{ Class: NewClassUUID(TestUUID), - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewUEIDInstance(TestUEID), } require.NotNil(t, tv.Class) require.NotNil(t, tv.Instance) @@ -96,7 +96,7 @@ func TestEnvironment_ToCBOR_class_and_instance(t *testing.T) { func TestEnvironment_ToCBOR_instance_only(t *testing.T) { tv := Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewUEIDInstance(TestUEID), } require.NotNil(t, tv.Instance) @@ -113,7 +113,7 @@ func TestEnvironment_ToCBOR_instance_only(t *testing.T) { func TestEnvironment_ToCBOR_group_only(t *testing.T) { tv := Environment{ - Group: NewGroupUUID(TestUUID), + Group: MustNewUUIDGroup(TestUUID), } require.NotNil(t, tv.Group) @@ -180,7 +180,7 @@ func TestEnvironment_FromCBOR_class_and_instance(t *testing.T) { assert.NotNil(t, actual.Class) assert.Equal(t, TestUUIDString, actual.Class.ClassID.String()) assert.NotNil(t, actual.Instance) - assert.Equal(t, TestUEIDString, actual.Instance.String()) + assert.Equal(t, []byte(TestUEID), actual.Instance.Bytes()) assert.Nil(t, actual.Group) } @@ -197,3 +197,25 @@ func TestEnvironment_FromCBOR_group_only(t *testing.T) { assert.NotNil(t, actual.Group) assert.Equal(t, TestUUIDString, actual.Group.String()) } + +func TestEnviroment_JSON(t *testing.T) { + testEnv := Environment{ + Class: NewClassUUID(TestUUID), + } + + out, err := testEnv.ToJSON() + require.NoError(t, err) + assert.Equal(t, `{"class":{"id":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}}`, string(out)) + + var outEnv Environment + + err = outEnv.FromJSON(out) + require.NoError(t, err) + assert.Equal(t, testEnv, outEnv) + + _, err = Environment{}.ToJSON() + assert.EqualError(t, err, "environment must not be empty") + + err = outEnv.FromJSON([]byte(`{"class": 7}`)) + assert.EqualError(t, err, "json: cannot unmarshal number into Go struct field Environment.class of type comid.Class") +} diff --git a/comid/example_cca_refval_test.go b/comid/example_cca_refval_test.go index a9d1102b..9b52304c 100644 --- a/comid/example_cca_refval_test.go +++ b/comid/example_cca_refval_test.go @@ -66,21 +66,23 @@ func extractCCARefVal(rv ReferenceValue) error { if !m.Key.IsSet() { return fmt.Errorf("mKey not set at index %d", i) } - if m.Key.IsPSARefValID() { + + switch t := m.Key.Value.(type) { + case *TaggedPSARefValID: if err := extractSwMeasurement(m); err != nil { return fmt.Errorf("extracting measurement at index %d: %w", i, err) } - } - if m.Key.IsCCAPlatformConfigID() { + case *TaggedCCAPlatformConfigID: if err := extractCCARefValID(m.Key); err != nil { return fmt.Errorf("extracting cca-refval-id: %w", err) } if err := extractRawValue(m.Val.RawValue); err != nil { return fmt.Errorf("extracting raw vlue: %w", err) } - - return nil + default: + return fmt.Errorf("unexpected Mkey type: %T", t) } + } return nil @@ -105,9 +107,9 @@ func extractCCARefValID(k *Mkey) error { return fmt.Errorf("no measurement key") } - id, err := k.GetCCAPlatformConfigID() - if err != nil { - return fmt.Errorf("getting CCA platform config id: %w", err) + id, ok := k.Value.(*TaggedCCAPlatformConfigID) + if !ok { + return fmt.Errorf("expected CCA platform config id, found: %T", k.Value) } fmt.Printf("Label: %s\n", id) return nil diff --git a/comid/example_psa_keys_test.go b/comid/example_psa_keys_test.go index edf95989..3780cbbd 100644 --- a/comid/example_psa_keys_test.go +++ b/comid/example_psa_keys_test.go @@ -70,12 +70,7 @@ func extractInstanceID(i *Instance) error { return fmt.Errorf("no instance") } - instID, err := i.GetUEID() - if err != nil { - return fmt.Errorf("extracting implemenetation-id: %w", err) - } - - fmt.Printf("InstanceID: %x\n", instID) + fmt.Printf("InstanceID: %x\n", i.Bytes()) return nil } diff --git a/comid/example_psa_refval_test.go b/comid/example_psa_refval_test.go index 230209d0..18817d91 100644 --- a/comid/example_psa_refval_test.go +++ b/comid/example_psa_refval_test.go @@ -111,9 +111,10 @@ func extractPSARefValID(k *Mkey) error { return fmt.Errorf("no measurement key") } - id, err := k.GetPSARefValID() - if err != nil { - return fmt.Errorf("getting PSA refval id: %w", err) + id, ok := k.Value.(*TaggedPSARefValID) + + if !ok { + return fmt.Errorf("expected PSA refval id, found: %T", k.Value) } fmt.Printf("SignerID: %x\n", id.SignerID) @@ -142,12 +143,11 @@ func extractImplementationID(c *Class) error { return fmt.Errorf("no class-id") } - implID, err := classID.GetImplID() - if err != nil { - return fmt.Errorf("extracting implemenetation-id: %w", err) + if classID.Type() != ImplIDType { + return fmt.Errorf("class id is not a psa.impl-id") } - fmt.Printf("ImplementationID: %x\n", implID) + fmt.Printf("ImplementationID: %x\n", classID.Bytes()) return nil } diff --git a/comid/example_test.go b/comid/example_test.go index 7f159d44..9d3985c6 100644 --- a/comid/example_test.go +++ b/comid/example_test.go @@ -26,18 +26,18 @@ func Example_encode() { SetModel("RoadRunner"). SetLayer(0). SetIndex(1), - Instance: NewInstanceUEID(TestUEID), - Group: NewGroupUUID(TestUUID), + Instance: MustNewUEIDInstance(TestUEID), + Group: MustNewUUIDGroup(TestUUID), }, Measurements: *NewMeasurements(). AddMeasurement( - NewMeasurement(). - SetKeyUUID(TestUUID). + MustNewUUIDMeasurement(TestUUID). SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0xff, 0xff, 0xff, 0xff}). SetSVN(2). AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}). AddDigest(swid.Sha256_32, []byte{0xff, 0xff, 0xff, 0xff}). - SetOpFlags(OpFlagNotSecure, OpFlagDebug). + SetFlagsTrue(FlagIsDebug). + SetFlagsFalse(FlagIsSecure). SetSerialNumber("C02X70VHJHD5"). SetUEID(TestUEID). SetUUID(TestUUID). @@ -54,18 +54,18 @@ func Example_encode() { SetModel("RoadRunner"). SetLayer(0). SetIndex(1), - Instance: NewInstanceUEID(TestUEID), - Group: NewGroupUUID(TestUUID), + Instance: MustNewUEIDInstance(TestUEID), + Group: MustNewUUIDGroup(TestUUID), }, Measurements: *NewMeasurements(). AddMeasurement( - NewMeasurement(). - SetKeyUUID(TestUUID). + MustNewUUIDMeasurement(TestUUID). SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0xff, 0xff, 0xff, 0xff}). SetMinSVN(2). AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}). AddDigest(swid.Sha256_32, []byte{0xff, 0xff, 0xff, 0xff}). - SetOpFlags(OpFlagNotSecure, OpFlagDebug, OpFlagNotConfigured). + SetFlagsTrue(FlagIsDebug). + SetFlagsFalse(FlagIsSecure, FlagIsConfigured). SetSerialNumber("C02X70VHJHD5"). SetUEID(TestUEID). SetUUID(TestUUID). @@ -77,7 +77,7 @@ func Example_encode() { AddAttestVerifKey( AttestVerifKey{ Environment: Environment{ - Instance: NewInstanceUUID(uuid.UUID(TestUUID)), + Instance: MustNewUUIDInstance(uuid.UUID(TestUUID)), }, VerifKeys: *NewCryptoKeys(). Add( @@ -87,7 +87,7 @@ func Example_encode() { ).AddDevIdentityKey( DevIdentityKey{ Environment: Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewUEIDInstance(TestUEID), }, VerifKeys: *NewCryptoKeys(). Add( @@ -107,8 +107,8 @@ func Example_encode() { } // Output: - //a50065656e2d474201a10078206d792d6e733a61636d652d726f616472756e6e65722d737570706c656d656e740282a3006941434d45204c74642e01d8207468747470733a2f2f61636d652e6578616d706c6502820100a20069454d4341204c74642e0281020382a200781a6d792d6e733a61636d652d726f616472756e6e65722d626173650100a20078196d792d6e733a61636d652d726f616472756e6e65722d6f6c64010104a4008182a300a500d86f445502c000016941434d45204c74642e026a526f616452756e6e65720300040101d902264702deadbeefdead02d8255031fb5abf023e4992aa4e95f9c1503bfa81a200d8255031fb5abf023e4992aa4e95f9c1503bfa01aa01d90228020282820644abcdef00820644ffffffff030a04d9023044010203040544ffffffff064802005e1000000001075020010db8000000000000000000000068086c43303258373056484a484435094702deadbeefdead0a5031fb5abf023e4992aa4e95f9c1503bfa018182a300a500d8255031fb5abf023e4992aa4e95f9c1503bfa016941434d45204c74642e026a526f616452756e6e65720300040101d902264702deadbeefdead02d8255031fb5abf023e4992aa4e95f9c1503bfa81a200d8255031fb5abf023e4992aa4e95f9c1503bfa01aa01d90229020282820644abcdef00820644ffffffff030b04d9023044010203040544ffffffff064802005e1000000001075020010db8000000000000000000000068086c43303258373056484a484435094702deadbeefdead0a5031fb5abf023e4992aa4e95f9c1503bfa028182a101d8255031fb5abf023e4992aa4e95f9c1503bfa81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d038182a101d902264702deadbeefdead81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d - //{"lang":"en-GB","tag-identity":{"id":"my-ns:acme-roadrunner-supplement"},"entities":[{"name":"ACME Ltd.","regid":"https://acme.example","roles":["creator","tagCreator"]},{"name":"EMCA Ltd.","roles":["maintainer"]}],"linked-tags":[{"target":"my-ns:acme-roadrunner-base","rel":"supplements"},{"target":"my-ns:acme-roadrunner-old","rel":"replaces"}],"triples":{"reference-values":[{"environment":{"class":{"id":{"type":"oid","value":"2.5.2.8192"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"ueid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"exact-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"op-flags":["notSecure","debug"],"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"endorsed-values":[{"environment":{"class":{"id":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"ueid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"min-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"op-flags":["notConfigured","notSecure","debug"],"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"attester-verification-keys":[{"environment":{"instance":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}],"dev-identity-keys":[{"environment":{"instance":{"type":"ueid","value":"At6tvu/erQ=="}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}]}} + // a50065656e2d474201a10078206d792d6e733a61636d652d726f616472756e6e65722d737570706c656d656e740282a3006941434d45204c74642e01d8207468747470733a2f2f61636d652e6578616d706c6502820100a20069454d4341204c74642e0281020382a200781a6d792d6e733a61636d652d726f616472756e6e65722d626173650100a20078196d792d6e733a61636d652d726f616472756e6e65722d6f6c64010104a4008182a300a500d86f445502c000016941434d45204c74642e026a526f616452756e6e65720300040101d902264702deadbeefdead02d8255031fb5abf023e4992aa4e95f9c1503bfa81a200d8255031fb5abf023e4992aa4e95f9c1503bfa01aa01d90228020282820644abcdef00820644ffffffff03a201f403f504d9023044010203040544ffffffff064802005e1000000001075020010db8000000000000000000000068086c43303258373056484a484435094702deadbeefdead0a5031fb5abf023e4992aa4e95f9c1503bfa018182a300a500d8255031fb5abf023e4992aa4e95f9c1503bfa016941434d45204c74642e026a526f616452756e6e65720300040101d902264702deadbeefdead02d8255031fb5abf023e4992aa4e95f9c1503bfa81a200d8255031fb5abf023e4992aa4e95f9c1503bfa01aa01d90229020282820644abcdef00820644ffffffff03a300f401f403f504d9023044010203040544ffffffff064802005e1000000001075020010db8000000000000000000000068086c43303258373056484a484435094702deadbeefdead0a5031fb5abf023e4992aa4e95f9c1503bfa028182a101d8255031fb5abf023e4992aa4e95f9c1503bfa81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d038182a101d902264702deadbeefdead81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d + // {"lang":"en-GB","tag-identity":{"id":"my-ns:acme-roadrunner-supplement"},"entities":[{"name":"ACME Ltd.","regid":"https://acme.example","roles":["creator","tagCreator"]},{"name":"EMCA Ltd.","roles":["maintainer"]}],"linked-tags":[{"target":"my-ns:acme-roadrunner-base","rel":"supplements"},{"target":"my-ns:acme-roadrunner-old","rel":"replaces"}],"triples":{"reference-values":[{"environment":{"class":{"id":{"type":"oid","value":"2.5.2.8192"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"exact-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"flags":{"is-secure":false,"is-debug":true},"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"endorsed-values":[{"environment":{"class":{"id":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"min-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"flags":{"is-configured":false,"is-secure":false,"is-debug":true},"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"attester-verification-keys":[{"environment":{"instance":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}],"dev-identity-keys":[{"environment":{"instance":{"type":"ueid","value":"At6tvu/erQ=="}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}]}} } func Example_encode_PSA() { @@ -124,25 +124,23 @@ func Example_encode_PSA() { }, Measurements: *NewMeasurements(). AddMeasurement( - NewPSAMeasurement( - *NewPSARefValID(TestSignerID). - SetLabel("BL"). - SetVersion("5.0.5"), - ).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), + MustNewPSAMeasurement( + MustCreatePSARefValID( + TestSignerID, "BL", "5.0.5", + )).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), ). AddMeasurement( - NewPSAMeasurement( - *NewPSARefValID(TestSignerID). - SetLabel("PRoT"). - SetVersion("1.3.5"), - ).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), + MustNewPSAMeasurement( + MustCreatePSARefValID( + TestSignerID, "PRoT", "1.3.5", + )).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), ), }, ). AddAttestVerifKey( AttestVerifKey{ Environment: Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewUEIDInstance(TestUEID), }, VerifKeys: *NewCryptoKeys(). Add( @@ -162,8 +160,8 @@ func Example_encode_PSA() { } // Output: - //a301a10078206d792d6e733a61636d652d726f616472756e6e65722d737570706c656d656e740281a3006941434d45204c74642e01d8207468747470733a2f2f61636d652e6578616d706c65028301000204a2008182a100a300d90258582061636d652d696d706c656d656e746174696f6e2d69642d303030303030303031016941434d45204c74642e026e526f616452756e6e657220322e3082a200d90259a30162424c0465352e302e35055820acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b01a10281820644abcdef00a200d90259a3016450526f540465312e332e35055820acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b01a10281820644abcdef00028182a101d902264702deadbeefdead81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d - //{"tag-identity":{"id":"my-ns:acme-roadrunner-supplement"},"entities":[{"name":"ACME Ltd.","regid":"https://acme.example","roles":["creator","tagCreator","maintainer"]}],"triples":{"reference-values":[{"environment":{"class":{"id":{"type":"psa.impl-id","value":"YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE="},"vendor":"ACME Ltd.","model":"RoadRunner 2.0"}},"measurements":[{"key":{"type":"psa.refval-id","value":{"label":"BL","version":"5.0.5","signer-id":"rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}},"value":{"digests":["sha-256-32;q83vAA=="]}},{"key":{"type":"psa.refval-id","value":{"label":"PRoT","version":"1.3.5","signer-id":"rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}},"value":{"digests":["sha-256-32;q83vAA=="]}}]}],"attester-verification-keys":[{"environment":{"instance":{"type":"ueid","value":"At6tvu/erQ=="}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}]}} + // a301a10078206d792d6e733a61636d652d726f616472756e6e65722d737570706c656d656e740281a3006941434d45204c74642e01d8207468747470733a2f2f61636d652e6578616d706c65028301000204a2008182a100a300d90258582061636d652d696d706c656d656e746174696f6e2d69642d303030303030303031016941434d45204c74642e026e526f616452756e6e657220322e3082a200d90259a30162424c0465352e302e35055820acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b01a10281820644abcdef00a200d90259a3016450526f540465312e332e35055820acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b01a10281820644abcdef00028182a101d902264702deadbeefdead81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d + // {"tag-identity":{"id":"my-ns:acme-roadrunner-supplement"},"entities":[{"name":"ACME Ltd.","regid":"https://acme.example","roles":["creator","tagCreator","maintainer"]}],"triples":{"reference-values":[{"environment":{"class":{"id":{"type":"psa.impl-id","value":"YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE="},"vendor":"ACME Ltd.","model":"RoadRunner 2.0"}},"measurements":[{"key":{"type":"psa.refval-id","value":{"label":"BL","version":"5.0.5","signer-id":"rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}},"value":{"digests":["sha-256-32;q83vAA=="]}},{"key":{"type":"psa.refval-id","value":{"label":"PRoT","version":"1.3.5","signer-id":"rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}},"value":{"digests":["sha-256-32;q83vAA=="]}}]}],"attester-verification-keys":[{"environment":{"instance":{"type":"ueid","value":"At6tvu/erQ=="}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}]}} } func Example_encode_PSA_attestation_verification() { @@ -173,7 +171,7 @@ func Example_encode_PSA_attestation_verification() { AddAttestVerifKey( AttestVerifKey{ Environment: Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewUEIDInstance(TestUEID), }, VerifKeys: *NewCryptoKeys(). Add( diff --git a/comid/extensions.go b/comid/extensions.go new file mode 100644 index 00000000..df9d0d2f --- /dev/null +++ b/comid/extensions.go @@ -0,0 +1,173 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package comid + +import ( + "github.com/veraison/corim/extensions" +) + +type IComidConstrainer interface { + ConstrainComid(*Comid) error +} + +type ITriplesConstrainer interface { + ValidTriples(*Triples) error +} + +type IMvalConstrainer interface { + ConstrainMval(*Mval) error +} + +type IEntityConstrainer interface { + ConstrainEntity(*Entity) error +} + +type IFlagsMapConstrainer interface { + ConstrainFlagsMap(*FlagsMap) error +} + +type IFlagSetter interface { + AnySet() bool + SetTrue(Flag) + SetFalse(Flag) + Clear(Flag) + Get(Flag) *bool +} + +type Extensions struct { + extensions.Extensions +} + +func (o *Extensions) validComid(comid *Comid) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IComidConstrainer) + if ok { + if err := ev.ConstrainComid(comid); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) validTriples(triples *Triples) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(ITriplesConstrainer) + if ok { + if err := ev.ValidTriples(triples); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) validMval(triples *Mval) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IMvalConstrainer) + if ok { + if err := ev.ConstrainMval(triples); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) validEntity(triples *Entity) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IEntityConstrainer) + if ok { + if err := ev.ConstrainEntity(triples); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) validFlagsMap(triples *FlagsMap) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IFlagsMapConstrainer) + if ok { + if err := ev.ConstrainFlagsMap(triples); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) setTrue(flag Flag) { + if !o.HaveExtensions() { + return + } + + ev, ok := o.IExtensionsValue.(IFlagSetter) + if ok { + ev.SetTrue(flag) + } +} + +func (o *Extensions) setFalse(flag Flag) { + if !o.HaveExtensions() { + return + } + + ev, ok := o.IExtensionsValue.(IFlagSetter) + if ok { + ev.SetFalse(flag) + } +} + +func (o *Extensions) clear(flag Flag) { + if !o.HaveExtensions() { + return + } + + ev, ok := o.IExtensionsValue.(IFlagSetter) + if ok { + ev.Clear(flag) + } +} + +func (o *Extensions) get(flag Flag) *bool { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IFlagSetter) + if ok { + return ev.Get(flag) + } + + return nil +} + +func (o *Extensions) anySet() bool { + if !o.HaveExtensions() { + return false + } + + ev, ok := o.IExtensionsValue.(IFlagSetter) + if ok { + return ev.AnySet() + } + + return false +} diff --git a/comid/extensions_test.go b/comid/extensions_test.go new file mode 100644 index 00000000..2d96f709 --- /dev/null +++ b/comid/extensions_test.go @@ -0,0 +1,94 @@ +package comid + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +var FlagTestFlag = Flag(-1) + +type TestExtension struct { + TestFlag *bool +} + +func (o *TestExtension) ConstrainComid(v *Comid) error { + return errors.New("invalid") +} + +func (o *TestExtension) ValidTriples(v *Triples) error { + return errors.New("invalid") +} + +func (o *TestExtension) ConstrainMval(v *Mval) error { + return errors.New("invalid") +} + +func (o *TestExtension) ConstrainFlagsMap(v *FlagsMap) error { + return errors.New("invalid") +} + +func (o *TestExtension) ConstrainEntity(v *Entity) error { + return errors.New("invalid") +} + +func (o *TestExtension) SetTrue(flag Flag) { + if flag == FlagTestFlag { + o.TestFlag = &True + } +} +func (o *TestExtension) SetFalse(flag Flag) { + if flag == FlagTestFlag { + o.TestFlag = &False + } +} + +func (o *TestExtension) Clear(flag Flag) { + if flag == FlagTestFlag { + o.TestFlag = nil + } +} + +func (o *TestExtension) Get(flag Flag) *bool { + if flag == FlagTestFlag { + return o.TestFlag + } + + return nil +} + +func (o *TestExtension) AnySet() bool { + return o.TestFlag != nil +} + +func Test_Extensions(t *testing.T) { + exts := Extensions{} + exts.Register(&TestExtension{}) + + err := exts.validComid(nil) + assert.EqualError(t, err, "invalid") + + err = exts.validTriples(nil) + assert.EqualError(t, err, "invalid") + + err = exts.validMval(nil) + assert.EqualError(t, err, "invalid") + + err = exts.validEntity(nil) + assert.EqualError(t, err, "invalid") + + err = exts.validFlagsMap(nil) + assert.EqualError(t, err, "invalid") + + assert.False(t, exts.anySet()) + + exts.setTrue(FlagTestFlag) + + exts.setFalse(FlagTestFlag) + assert.False(t, *exts.get(FlagTestFlag)) + + exts.clear(FlagTestFlag) + assert.Nil(t, exts.get(FlagTestFlag)) + assert.False(t, exts.anySet()) +} diff --git a/comid/flagsmap.go b/comid/flagsmap.go new file mode 100644 index 00000000..94ed3e35 --- /dev/null +++ b/comid/flagsmap.go @@ -0,0 +1,217 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 + +package comid + +import ( + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" +) + +var True = true +var False = false + +// Flag indicates whether a particular operational mode is active within the +// measured environment. +type Flag int + +const ( + FlagIsConfigured Flag = iota + FlagIsSecure + FlagIsRecovery + FlagIsDebug + FlagIsReplayProtected + FlagIsIntegrityProtected + FlagIsRuntimeMeasured + FlagIsImmutable + FlagIsTcb +) + +// FlagsMap describes a number of boolean operational modes. If a value is nil, +// then the operational mode is unknown. +type FlagsMap struct { + // IsConfigured indicates whether the measured environment is fully + // configured for normal operation. + IsConfigured *bool `cbor:"0,keyasint,omitempty" json:"is-configured,omitempty"` + // IsSecure indicates whether the measured environment's configurable + // security settings are fully enabled. + IsSecure *bool `cbor:"1,keyasint,omitempty" json:"is-secure,omitempty"` + // IsRecovery indicates whether the measured environment is in recovery + // mode. + IsRecovery *bool `cbor:"2,keyasint,omitempty" json:"is-recovery,omitempty"` + // IsDebug indicates whether the measured environment is in a debug + // enabled mode. + IsDebug *bool `cbor:"3,keyasint,omitempty" json:"is-debug,omitempty"` + // IsReplayProtected indicates whether the measured environment is + // protected from replay by a previous image that differs from the + // current image. + IsReplayProtected *bool `cbor:"4,keyasint,omitempty" json:"is-replay-protected,omitempty"` + // IsIntegrityProtected indicates whether the measured environment is + // protected from unauthorized update. + IsIntegrityProtected *bool `cbor:"5,keyasint,omitempty" json:"is-integrity-protected,omitempty"` + // IsRuntimeMeasured indicates whether the measured environment is + // measured after being loaded into memory. + IsRuntimeMeasured *bool `cbor:"6,keyasint,omitempty" json:"is-runtime-meas,omitempty"` + // IsImmutable indicates whether the measured environment is immutable. + IsImmutable *bool `cbor:"7,keyasint,omitempty" json:"is-immutable,omitempty"` + // IsTcb indicates whether the measured environment is a trusted + // computing base. + IsTcb *bool `cbor:"8,keyasint,omitempty" json:"is-tcb,omitempty"` + + Extensions +} + +func NewFlagsMap() *FlagsMap { + return &FlagsMap{} +} + +func (o *FlagsMap) AnySet() bool { + if o.IsConfigured != nil || o.IsSecure != nil || o.IsRecovery != nil || o.IsDebug != nil || + o.IsReplayProtected != nil || o.IsIntegrityProtected != nil || + o.IsRuntimeMeasured != nil || o.IsImmutable != nil || o.IsTcb != nil { + return true + } + + return o.Extensions.anySet() +} + +func (o *FlagsMap) SetTrue(flags ...Flag) { + for _, flag := range flags { + switch flag { + case FlagIsConfigured: + o.IsConfigured = &True + case FlagIsSecure: + o.IsSecure = &True + case FlagIsRecovery: + o.IsRecovery = &True + case FlagIsDebug: + o.IsDebug = &True + case FlagIsReplayProtected: + o.IsReplayProtected = &True + case FlagIsIntegrityProtected: + o.IsIntegrityProtected = &True + case FlagIsRuntimeMeasured: + o.IsRuntimeMeasured = &True + case FlagIsImmutable: + o.IsImmutable = &True + case FlagIsTcb: + o.IsTcb = &True + default: + o.Extensions.setTrue(flag) + } + } +} + +func (o *FlagsMap) SetFalse(flags ...Flag) { + for _, flag := range flags { + switch flag { + case FlagIsConfigured: + o.IsConfigured = &False + case FlagIsSecure: + o.IsSecure = &False + case FlagIsRecovery: + o.IsRecovery = &False + case FlagIsDebug: + o.IsDebug = &False + case FlagIsReplayProtected: + o.IsReplayProtected = &False + case FlagIsIntegrityProtected: + o.IsIntegrityProtected = &False + case FlagIsRuntimeMeasured: + o.IsRuntimeMeasured = &False + case FlagIsImmutable: + o.IsImmutable = &False + case FlagIsTcb: + o.IsTcb = &False + default: + o.Extensions.setFalse(flag) + } + } +} + +func (o *FlagsMap) Clear(flags ...Flag) { + for _, flag := range flags { + switch flag { + case FlagIsConfigured: + o.IsConfigured = nil + case FlagIsSecure: + o.IsSecure = nil + case FlagIsRecovery: + o.IsRecovery = nil + case FlagIsDebug: + o.IsDebug = nil + case FlagIsReplayProtected: + o.IsReplayProtected = nil + case FlagIsIntegrityProtected: + o.IsIntegrityProtected = nil + case FlagIsRuntimeMeasured: + o.IsRuntimeMeasured = nil + case FlagIsImmutable: + o.IsImmutable = nil + case FlagIsTcb: + o.IsTcb = nil + default: + o.Extensions.clear(flag) + } + } +} + +func (o *FlagsMap) Get(flag Flag) *bool { + switch flag { + case FlagIsConfigured: + return o.IsConfigured + case FlagIsSecure: + return o.IsSecure + case FlagIsRecovery: + return o.IsRecovery + case FlagIsDebug: + return o.IsDebug + case FlagIsReplayProtected: + return o.IsReplayProtected + case FlagIsIntegrityProtected: + return o.IsIntegrityProtected + case FlagIsRuntimeMeasured: + return o.IsRuntimeMeasured + case FlagIsImmutable: + return o.IsImmutable + case FlagIsTcb: + return o.IsTcb + default: + return o.Extensions.get(flag) + } +} + +// RegisterExtensions registers a struct as a collections of extensions +func (o *FlagsMap) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *FlagsMap) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + +// UnmarshalCBOR deserializes from CBOR +func (o *FlagsMap) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *FlagsMap) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *FlagsMap) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *FlagsMap) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) +} + +// Valid returns an error if the FlagsMap is invalid. +func (o FlagsMap) Valid() error { + return o.Extensions.validFlagsMap(&o) +} diff --git a/comid/flagsmap_test.go b/comid/flagsmap_test.go new file mode 100644 index 00000000..e9bc477c --- /dev/null +++ b/comid/flagsmap_test.go @@ -0,0 +1,41 @@ +package comid + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_FlagsMap(t *testing.T) { + fm := NewFlagsMap() + assert.False(t, fm.AnySet()) + + for _, flag := range []Flag{ + FlagIsConfigured, + FlagIsSecure, + FlagIsRecovery, + FlagIsDebug, + FlagIsReplayProtected, + FlagIsIntegrityProtected, + FlagIsRuntimeMeasured, + FlagIsImmutable, + FlagIsTcb, + } { + fm.SetTrue(flag) + assert.True(t, fm.AnySet()) + assert.Equal(t, true, *fm.Get(flag)) + + fm.SetFalse(flag) + assert.True(t, fm.AnySet()) + assert.Equal(t, false, *fm.Get(flag)) + + fm.Clear(flag) + assert.False(t, fm.AnySet()) + assert.Equal(t, (*bool)(nil), fm.Get(flag)) + } + + fm.SetTrue(Flag(-1)) + fm.SetFalse(Flag(-1)) + assert.False(t, fm.AnySet()) + assert.Equal(t, (*bool)(nil), fm.Get(Flag(-1))) +} diff --git a/comid/group.go b/comid/group.go index adb9cdb4..789ed197 100644 --- a/comid/group.go +++ b/comid/group.go @@ -5,66 +5,51 @@ package comid import ( "encoding/json" + "errors" "fmt" + + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" ) // Group stores a group identity. The supported format is UUID. type Group struct { - val interface{} + Value IGroupValue } // NewGroup instantiates an empty group -func NewGroup() *Group { - return &Group{} -} - -// SetUUID sets the identity of the target group to the supplied UUID -func (o *Group) SetUUID(val UUID) *Group { - if o != nil { - o.val = TaggedUUID(val) +func NewGroup(val any, typ string) (*Group, error) { + factory, ok := groupValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unknown group type: %s", typ) } - return o -} -// NewGroupUUID instantiates a new group with the supplied UUID identity -func NewGroupUUID(val UUID) *Group { - return NewGroup().SetUUID(val) + return factory(val) } // Valid checks for the validity of given group func (o Group) Valid() error { - if o.String() == "" { - return fmt.Errorf("invalid group id") + if o.Value == nil { + return errors.New("no value set") } - return nil + + return o.Value.Valid() } // String returns a printable string of the Group value. UUIDs use the // canonical 8-4-4-4-12 format, UEIDs are hex encoded. func (o Group) String() string { - switch t := o.val.(type) { - case TaggedUUID: - return UUID(t).String() - default: - return "" - } + return o.Value.String() } // MarshalCBOR serializes the target group to CBOR func (o Group) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) + return em.Marshal(o.Value) } // UnmarshalCBOR deserializes the supplied CBOR into the target group func (o *Group) UnmarshalCBOR(data []byte) error { - var uuid TaggedUUID - - if dm.Unmarshal(data, &uuid) == nil { - o.val = uuid - return nil - } - - return fmt.Errorf("unknown group type (CBOR: %x)", data) + return dm.Unmarshal(data, &o.Value) } // UnmarshalJSON deserializes the supplied JSON type/value object into the Group @@ -75,43 +60,97 @@ func (o *Group) UnmarshalCBOR(data []byte) error { // "value": "69E027B2-7157-4758-BCB4-D9F167FE49EA" // } func (o *Group) UnmarshalJSON(data []byte) error { - var v tnv + var tnv encoding.TypeAndValue - if err := json.Unmarshal(data, &v); err != nil { + if err := json.Unmarshal(data, &tnv); err != nil { + return fmt.Errorf("group decoding failure: %w", err) + } + + decoded, err := NewGroup(nil, tnv.Type) + if err != nil { return err } - switch v.Type { - case "uuid": - var x UUID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedUUID(x) - default: - return fmt.Errorf("unknown type %s for group", v.Type) + if err := json.Unmarshal(tnv.Value, &decoded.Value); err != nil { + return fmt.Errorf( + "cannot unmarshal group: %w", + err, + ) } + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) + } + + o.Value = decoded.Value + return nil } func (o Group) MarshalJSON() ([]byte, error) { - var ( - v tnv - b []byte - err error - ) - - switch t := o.val.(type) { - case TaggedUUID: - b, err = UUID(t).MarshalJSON() - if err != nil { - return nil, err - } - v = tnv{Type: "ueid", Value: b} - default: - return nil, fmt.Errorf("unknown type %T for group", t) + return extensions.TypeChoiceValueMarshalJSON(o.Value) +} + +type IGroupValue interface { + extensions.ITypeChoiceValue +} + +func NewUUIDGroup(val any) (*Group, error) { + if val == nil { + return &Group{&TaggedUUID{}}, nil + } + + u, err := NewTaggedUUID(val) + if err != nil { + return nil, err + } + + return &Group{u}, nil +} + +func MustNewUUIDGroup(val any) *Group { + ret, err := NewUUIDGroup(val) + if err != nil { + panic(err) + } + + return ret +} + +// IGroupFactory defines the signature for the factory functions that may be +// registred using RegisterGroupType to provide a new implementation of the +// corresponding type choice. The factory function should create a new *Group +// with the underlying value created based on the provided input. The range of +// valid inputs is up to the specific type choice implementation, however it +// _must_ accept nil as one of the inputs, and return the Zero value for +// implemented type. +// See also https://go.dev/ref/spec#The_zero_value +type IGroupFactory func(any) (*Group, error) + +var groupValueRegister = map[string]IGroupFactory{ + UUIDType: NewUUIDGroup, +} + +// RegisterGroupType registers a new IGroupValue implementation +// (created by the provided IGroupFactory) under the specified type name +// and CBOR tag. +func RegisterGroupType(tag uint64, factory IGroupFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Value.Type() + if _, exists := groupValueRegister[typ]; exists { + return fmt.Errorf("Group type with name %q already exists", typ) } - return json.Marshal(v) + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + groupValueRegister[typ] = factory + + return nil } diff --git a/comid/group_test.go b/comid/group_test.go new file mode 100644 index 00000000..4363f3b2 --- /dev/null +++ b/comid/group_test.go @@ -0,0 +1,62 @@ +package comid + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testGroup uint64 + +func newTestGroup(val any) (*Group, error) { + v := testGroup(7) + return &Group{&v}, nil +} + +func (o testGroup) Type() string { + return "test-value" +} + +func (o testGroup) String() string { + return "test" +} + +func (o testGroup) Valid() error { + return nil +} + +type testGroupBadType struct { + testGroup +} + +func newTestGroupBadType(val any) (*Group, error) { + v := testGroupBadType{testGroup(7)} + return &Group{&v}, nil +} + +func (o testGroupBadType) Type() string { + return "uuid" +} + +func Test_RegisterGroupType(t *testing.T) { + err := RegisterGroupType(32, newTestGroup) + assert.EqualError(t, err, "tag 32 is already registered") + + err = RegisterGroupType(99993, newTestGroupBadType) + assert.EqualError(t, err, `Group type with name "uuid" already exists`) + + err = RegisterGroupType(99993, newTestGroup) + require.NoError(t, err) + +} + +func TestGroup_UmarshalJSON(t *testing.T) { + var group Group + + err := group.UnmarshalJSON([]byte(`{`)) + assert.EqualError(t, err, "group decoding failure: unexpected end of JSON input") + + err = group.UnmarshalJSON([]byte(`{"type":"uuid","value":"aaaa"}`)) + assert.EqualError(t, err, "cannot unmarshal group: bad UUID: invalid UUID length: 4") +} diff --git a/comid/instance.go b/comid/instance.go index 8201145d..83631b32 100644 --- a/comid/instance.go +++ b/comid/instance.go @@ -1,51 +1,27 @@ package comid import ( - "encoding/hex" "encoding/json" "fmt" - "github.com/google/uuid" - "github.com/veraison/eat" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" ) // Instance stores an instance identity. The supported formats are UUID and UEID. type Instance struct { - val interface{} + Value IInstanceValue } -// NewInstance instantiates an empty instance -func NewInstance() *Instance { - return &Instance{} -} - -// SetUEID sets the identity of the target instance to the supplied UEID -func (o *Instance) SetUEID(val eat.UEID) *Instance { - if o != nil { - if val.Validate() != nil { - return nil - } - o.val = TaggedUEID(val) +// NewInstance creates a new instance with the value of the specified type +// populated using the provided value. +func NewInstance(val any, typ string) (*Instance, error) { + factory, ok := instanceValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unknown instance type: %s", typ) } - return o -} -// SetUUID sets the identity of the target instance to the supplied UUID -func (o *Instance) SetUUID(val uuid.UUID) *Instance { - if o != nil { - o.val = TaggedUUID(val) - } - return o -} - -// NewInstanceUEID instantiates a new instance with the supplied UEID identity -func NewInstanceUEID(val eat.UEID) *Instance { - return NewInstance().SetUEID(val) -} - -// NewInstanceUUID instantiates a new instance with the supplied UUID identity -func NewInstanceUUID(val uuid.UUID) *Instance { - return NewInstance().SetUUID(val) + return factory(val) } // Valid checks for the validity of given instance @@ -56,124 +32,186 @@ func (o Instance) Valid() error { return nil } -func (o Instance) GetUEID() (eat.UEID, error) { - switch t := o.val.(type) { - case TaggedUEID: - return eat.UEID(t), nil - default: - return eat.UEID{}, fmt.Errorf("instance-id type is: %T", t) - } -} - -func (o Instance) GetUUID() (UUID, error) { - switch t := o.val.(type) { - case TaggedUUID: - return UUID(t), nil - default: - return UUID{}, fmt.Errorf("instance-id type is: %T", t) - } -} - // String returns a printable string of the Instance value. UUIDs use the // canonical 8-4-4-4-12 format, UEIDs are hex encoded. func (o Instance) String() string { - switch t := o.val.(type) { - case TaggedUUID: - return UUID(t).String() - case TaggedUEID: - return hex.EncodeToString(t) - default: + if o.Value == nil { return "" } + + return o.Value.String() +} + +// Type returns a string naming the type of the underlying Instance value. +func (o Instance) Type() string { + return o.Value.Type() +} + +// Bytes returns a []byte containing the bytes of the underlying Instance +// value. +func (o Instance) Bytes() []byte { + return o.Value.Bytes() } // MarshalCBOR serializes the target instance to CBOR func (o Instance) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) + return em.Marshal(o.Value) } func (o *Instance) UnmarshalCBOR(data []byte) error { - var ueid TaggedUEID - - if dm.Unmarshal(data, &ueid) == nil { - o.val = ueid - return nil - } - - var u TaggedUUID - - if dm.Unmarshal(data, &u) == nil { - o.val = u - return nil - } - - return fmt.Errorf("unknown instance type (CBOR: %x)", data) + return dm.Unmarshal(data, &o.Value) } -// UnmarshalJSON deserializes the supplied JSON type/value object into the Group -// target. The supported formats are UUID, e.g.: +// UnmarshalJSON deserializes the supplied JSON object into the target Instance +// The instance object must have the following shape: // // { -// "type": "uuid", -// "value": "69E027B2-7157-4758-BCB4-D9F167FE49EA" +// "type": "", +// "value": // } // -// and UEID: +// where must be one of the known IInstanceValue implementation +// type names (available in the base implementation: "ueid" and "uuid"), and +// is the JSON encoding of the instance value. The exact +// encoding is dependent. For the base implmentation types it is // -// { -// "type": "ueid", -// "value": "Ad6tvu/erb7v3q2+796tvu8=" -// } +// ueid: base64-encoded bytes, e.g. "YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=" +// uuid: standard UUID string representation, e.g. "550e8400-e29b-41d4-a716-446655440000" func (o *Instance) UnmarshalJSON(data []byte) error { - var v tnv + var tnv encoding.TypeAndValue - if err := json.Unmarshal(data, &v); err != nil { + if err := json.Unmarshal(data, &tnv); err != nil { + return fmt.Errorf("instance decoding failure: %w", err) + } + + decoded, err := NewInstance(nil, tnv.Type) + if err != nil { return err } - switch v.Type { - case "uuid": - var x UUID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedUUID(x) - case "ueid": - var x UEID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedUEID(x) - default: - return fmt.Errorf("unknown type %s for instance", v.Type) + if err := json.Unmarshal(tnv.Value, &decoded.Value); err != nil { + return fmt.Errorf( + "cannot unmarshal instance: %w", + err, + ) } + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) + } + + o.Value = decoded.Value + return nil } +// MarshalJSON serializes the Instance into a JSON object. func (o Instance) MarshalJSON() ([]byte, error) { - var ( - v tnv - b []byte - err error - ) - - switch t := o.val.(type) { - case TaggedUUID: - b, err = UUID(t).MarshalJSON() - if err != nil { - return nil, err - } - v = tnv{Type: "uuid", Value: b} - case TaggedUEID: - b, err = UEID(t).MarshalJSON() - if err != nil { - return nil, err - } - v = tnv{Type: "ueid", Value: b} - default: - return nil, fmt.Errorf("unknown type %T for instance", t) - } - - return json.Marshal(v) + valueBytes, err := json.Marshal(o.Value) + if err != nil { + return nil, err + } + + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: valueBytes, + } + + return json.Marshal(value) +} + +// IInstanceValue is the interface implemented by all Instance value +// implementations. +type IInstanceValue interface { + extensions.ITypeChoiceValue + + Bytes() []byte +} + +// NewUEIDInstance instantiates a new instance with the supplied UEID identity. +func NewUEIDInstance(val any) (*Instance, error) { + if val == nil { + return &Instance{&TaggedUEID{}}, nil + } + + ret, err := NewTaggedUEID(val) + if err != nil { + return nil, err + } + return &Instance{ret}, nil +} + +// MustNewUEIDInstance is like NewUEIDInstance execept it does not return an +// error, assuming that the provided value is valid. It panics if that isn't +// the case. +func MustNewUEIDInstance(val any) *Instance { + ret, err := NewUEIDInstance(val) + if err != nil { + panic(err) + } + + return ret +} + +// NewUUIDInstance instantiates a new instance with the supplied UUID identity +func NewUUIDInstance(val any) (*Instance, error) { + if val == nil { + return &Instance{&TaggedUUID{}}, nil + } + + ret, err := NewTaggedUUID(val) + if err != nil { + return nil, err + } + + return &Instance{ret}, nil +} + +// MustNewUUIDInstance is like NewUUIDInstance execept it does not return an +// error, assuming that the provided value is valid. It panics if that isn't +// the case. +func MustNewUUIDInstance(val any) *Instance { + ret, err := NewUUIDInstance(val) + if err != nil { + panic(err) + } + + return ret +} + +// IInstanceFactory defines the signature for the factory functions that may be +// registred using RegisterInstanceType to provide a new implementation of the +// corresponding type choice. The factory function should create a new *Instance +// with the underlying value created based on the provided input. The range of +// valid inputs is up to the specific type choice implementation, however it +// _must_ accept nil as one of the inputs, and return the Zero value for +// implemented type. +// See also https://go.dev/ref/spec#The_zero_value +type IInstanceFactory func(any) (*Instance, error) + +var instanceValueRegister = map[string]IInstanceFactory{ + UEIDType: NewUEIDInstance, + UUIDType: NewUUIDInstance, +} + +// RegisterInstanceType registers a new IInstanceValue implementation (created +// by the provided IInstanceFactory) under the specified CBOR tag. +func RegisterInstanceType(tag uint64, factory IInstanceFactory) error { + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Type() + if _, exists := instanceValueRegister[typ]; exists { + return fmt.Errorf("class ID type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + instanceValueRegister[typ] = factory + + return nil } diff --git a/comid/instance_test.go b/comid/instance_test.go index 34e671e7..6ea22fd4 100644 --- a/comid/instance_test.go +++ b/comid/instance_test.go @@ -1,24 +1,69 @@ package comid import ( + "encoding/json" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestInstance_GetUUID_OK(t *testing.T) { - inst := NewInstanceUUID(uuid.UUID(TestUUID)) - require.NotNil(t, inst) - u, err := inst.GetUUID() - assert.Nil(t, err) - assert.Equal(t, u, TestUUID) + inst := MustNewUUIDInstance(TestUUID) + u, ok := inst.Value.(*TaggedUUID) + assert.True(t, ok) + assert.EqualValues(t, TestUUID, *u) } -func TestInstance_GetUUID_NOK(t *testing.T) { - inst := &Instance{} - expectedErr := "instance-id type is: " - _, err := inst.GetUUID() - assert.EqualError(t, err, expectedErr) +type testInstance string + +func newTestInstance(val any) (*Instance, error) { + ret := testInstance("test") + return &Instance{&ret}, nil +} + +func (o testInstance) Bytes() []byte { + return []byte(o) +} + +func (o testInstance) Type() string { + return "test-instance" +} + +func (o testInstance) String() string { + return string(o) +} + +func (o testInstance) Valid() error { + return nil +} + +func Test_RegisterInstanceType(t *testing.T) { + err := RegisterInstanceType(99997, newTestInstance) + require.NoError(t, err) + + instance, err := newTestInstance(nil) + require.NoError(t, err) + + data, err := json.Marshal(instance) + require.NoError(t, err) + assert.Equal(t, string(data), `{"type":"test-instance","value":"test"}`) + + var out Instance + err = json.Unmarshal(data, &out) + require.NoError(t, err) + assert.Equal(t, instance.Bytes(), out.Bytes()) + + data, err = em.Marshal(instance) + require.NoError(t, err) + assert.Equal(t, data, []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9d, // tag 99997 + 0x64, // tstr(4) + 0x74, 0x65, 0x73, 0x74, // "test" + }) + + var out2 Instance + err = dm.Unmarshal(data, &out2) + require.NoError(t, err) + assert.Equal(t, instance.Bytes(), out2.Bytes()) } diff --git a/comid/measurement.go b/comid/measurement.go index 4a560b67..f942855a 100644 --- a/comid/measurement.go +++ b/comid/measurement.go @@ -5,206 +5,296 @@ package comid import ( "encoding/json" + "errors" "fmt" "net" + "strconv" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" "github.com/veraison/eat" "github.com/veraison/swid" ) const MaxUint64 = ^uint64(0) -// Measurement stores a measurement-map with CBOR and JSON serializations. -type Measurement struct { - Key *Mkey `cbor:"0,keyasint,omitempty" json:"key,omitempty"` - Val Mval `cbor:"1,keyasint" json:"value"` -} - // Mkey stores a $measured-element-type-choice. // The supported types are UUID, PSA refval-id, CCA platform-config-id and unsigned integer // TO DO Add tagged OID: see https://github.com/veraison/corim/issues/35 type Mkey struct { - val interface{} + Value IMKeyValue +} + +// NewMkey creates a new Mkey of the specfied type using the provided value. +func NewMkey(val any, typ string) (*Mkey, error) { + factory, ok := mkeyValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unexpected measurement key type: %q", typ) + } + + return factory(nil) } +// MustNewMkey is like NewMkey, execept it does not return an error, assuming +// that the provided value is valid. It panics if that is not the case. +func MustNewMkey(val any, typ string) *Mkey { + ret, err := NewMkey(val, typ) + if err != nil { + panic(err) + } + + return ret +} + +// IsSet returns true if the value of the Mkey is set. func (o Mkey) IsSet() bool { - return o.val != nil + return o.Value != nil } +// Valid returns nil if the Mkey is valid or an error describing the problem, +// if it is not. func (o Mkey) Valid() error { - switch t := o.val.(type) { - case TaggedUUID: - if UUID(t).Empty() { - return fmt.Errorf("empty UUID") - } - return nil - case TaggedPSARefValID: - return PSARefValID(t).Valid() - case TaggedCCAPlatformConfigID: - if CCAPlatformConfigID(t).Empty() { - return fmt.Errorf("empty CCAPlatformConfigID") - } - case uint64: - if o.val == nil { - return fmt.Errorf("empty uint Mkey") - } - return nil - default: - return fmt.Errorf("unknown measurement key type: %T", t) + if o.Value == nil { + return errors.New("Mkey value not set") } + + if err := o.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", o.Value.Type(), err) + } + return nil } -func (o Mkey) IsPSARefValID() bool { - _, ok := o.val.(TaggedPSARefValID) - return ok -} +// UnmarshalJSON deserializes the supplied JSON object into the target MKey +// The key object must have the following shape: +// +// { +// "type": "", +// "value": +// } +// +// where must be one of the known IMKeyValue implementation +// type names (available in the base implementation: "uuid", "oid", +// "psa.impl-id"), and is the class id value serialized to +// JSON. The exact serialization is depenent. For the base +// implementation types it is +// +// oid: dot-seprated integers, e.g. "1.2.3.4" +// uuid: standard UUID string representation, e.g. "550e8400-e29b-41d4-a716-446655440000" +// psa.refval-id: JSON representation of the PSA refval-id +func (o *Mkey) UnmarshalJSON(data []byte) error { + var tnv encoding.TypeAndValue + + if err := json.Unmarshal(data, &tnv); err != nil { + return err + } + + decoded, err := NewMkey(nil, tnv.Type) + if err != nil { + return err + } -func (o Mkey) IsCCAPlatformConfigID() bool { - _, ok := o.val.(TaggedCCAPlatformConfigID) - return ok + if err := json.Unmarshal(tnv.Value, decoded.Value); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) + } + + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) + } + + o.Value = decoded.Value + + return nil } -func (o Mkey) GetPSARefValID() (PSARefValID, error) { - switch t := o.val.(type) { - case TaggedPSARefValID: - return PSARefValID(t), nil - default: - return PSARefValID{}, fmt.Errorf("measurement-key type is: %T", t) +// MarshalJSON serializes the target Mkey into the type'n'value JSON object +func (o Mkey) MarshalJSON() ([]byte, error) { + valueBytes, err := json.Marshal(o.Value) + if err != nil { + return nil, err + } + + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: valueBytes, } + + return json.Marshal(value) } -func (o Mkey) GetCCAPlatformConfigID() (CCAPlatformConfigID, error) { - switch t := o.val.(type) { - case TaggedCCAPlatformConfigID: - return CCAPlatformConfigID(t), nil - default: - return CCAPlatformConfigID(""), fmt.Errorf("measurement-key type is: %T", t) +// MarshalCBOR serializes the taret mkey into CBOR-encoded bytes. +func (o Mkey) MarshalCBOR() ([]byte, error) { + return em.Marshal(o.Value) +} + +// UnmarshalCBOR deserializes the Mkey from the provided CBOR bytes. +func (o *Mkey) UnmarshalCBOR(data []byte) error { + if len(data) == 0 { + return errors.New("empty input") + } + + majorType := (data[0] & 0xe0) >> 5 + if majorType == 6 { // tag + return dm.Unmarshal(data, &o.Value) + } + + // untagged value must be a uint + + var val UintMkey + if err := dm.Unmarshal(data, &val); err != nil { + return err } + + o.Value = &val + return nil } -func (o Mkey) GetKeyUint() (uint64, error) { - switch t := o.val.(type) { +// IMKeyValue is the interface implemented by all Mkey value implementations. +type IMKeyValue interface { + extensions.ITypeChoiceValue +} + +const UintType = "uint" + +type UintMkey uint64 + +func NewUintMkey(val any) (*UintMkey, error) { + var ret UintMkey + + if val == nil { + return &ret, nil + } + + switch t := val.(type) { + case UintMkey: + ret = t + case *UintMkey: + ret = *t + case string: + u, err := strconv.ParseUint(t, 10, 64) + if err != nil { + return nil, err + } + ret = UintMkey(u) case uint64: - return t, nil + ret = UintMkey(t) + case uint: + ret = UintMkey(t) default: - return MaxUint64, fmt.Errorf("measurement-key type is: %T", t) + return nil, fmt.Errorf("unexpected type for UintMkey: %T", t) } + + return &ret, nil } -// UnmarshalJSON deserializes the type'n'value JSON object into the target Mkey -func (o *Mkey) UnmarshalJSON(data []byte) error { - var v tnv +func (o UintMkey) Valid() error { + return nil +} + +func (o UintMkey) String() string { + return fmt.Sprint(uint64(o)) +} + +func (o UintMkey) Type() string { + return UintType +} + +func (o *UintMkey) UnmarshalJSON(data []byte) error { + var tmp uint64 - if err := json.Unmarshal(data, &v); err != nil { + if err := json.Unmarshal(data, &tmp); err != nil { return err } - switch v.Type { - case "uuid": - var x UUID - if err := x.UnmarshalJSON(v.Value); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type UUID: %w", - err, - ) - } - o.val = TaggedUUID(x) - case "psa.refval-id": - var x PSARefValID - if err := json.Unmarshal(v.Value, &x); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type PSARefValID: %w", - err, - ) - } - if err := x.Valid(); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type PSARefValID: %w", - err, - ) - } - o.val = TaggedPSARefValID(x) - case "cca.platform-config-id": - var x CCAPlatformConfigID - if err := json.Unmarshal(v.Value, &x); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type CCAPlatformConfigID: %w", - err, - ) - } - if x.Empty() { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type CCAPlatformConfigID: empty label", - ) - } - o.val = TaggedCCAPlatformConfigID(x) - case "uint": - var x uint64 - if err := json.Unmarshal(v.Value, &x); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type uint: %w", - err, - ) - } - o.val = x - default: - return fmt.Errorf("unknown type %s for $measured-element-type-choice", v.Type) - } + *o = UintMkey(tmp) return nil } -// MarshalJSON serializes the target Mkey into the type'n'value JSON object -// Supported types are: uuid, psa.refval-id and unsigned integer -func (o Mkey) MarshalJSON() ([]byte, error) { - var ( - v tnv - b []byte - err error - ) - - switch t := o.val.(type) { - case TaggedUUID: - uuidString := UUID(t).String() - b, err = json.Marshal(uuidString) - if err != nil { - return nil, err - } - v = tnv{Type: "uuid", Value: b} - case TaggedPSARefValID: - b, err = json.Marshal(t) - if err != nil { - return nil, err - } - v = tnv{Type: "psa.refval-id", Value: b} - case TaggedCCAPlatformConfigID: - b, err = json.Marshal(t) - if err != nil { - return nil, err - } - v = tnv{Type: "cca.platform-config-id", Value: b} +func NewMkeyOID(val any) (*Mkey, error) { + ret, err := NewTaggedOID(val) + if err != nil { + return nil, err + } - case uint64: - b, err = json.Marshal(t) - if err != nil { - return nil, err - } - v = tnv{Type: "uint", Value: b} + return &Mkey{ret}, nil +} - default: - return nil, fmt.Errorf("unknown type %T for mkey", t) +func NewMkeyUUID(val any) (*Mkey, error) { + ret, err := NewTaggedUUID(val) + if err != nil { + return nil, err } - return json.Marshal(v) + return &Mkey{ret}, nil } -func (o Mkey) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) +func NewMkeyUint(val any) (*Mkey, error) { + ret, err := NewUintMkey(val) + if err != nil { + return nil, err + } + + return &Mkey{ret}, nil } -func (o *Mkey) UnmarshalCBOR(data []byte) error { - return dm.Unmarshal(data, &o.val) +func NewMkeyPSARefvalID(val any) (*Mkey, error) { + ret, err := NewTaggedPSARefValID(val) + if err != nil { + return nil, err + } + + return &Mkey{ret}, nil +} + +func NewMkeyCCAPlatformConfigID(val any) (*Mkey, error) { + ret, err := NewTaggedCCAPlatformConfigID(val) + if err != nil { + return nil, err + } + + return &Mkey{ret}, nil +} + +// IMkeyFactory defines the signature for the factory functions that may be +// registred using RegisterMkeyType to provide a new implementation of the +// corresponding type choice. The factory function should create a new *Mkey +// with the underlying value created based on the provided input. The range of +// valid inputs is up to the specific type choice implementation, however it +// _must_ accept nil as one of the inputs, and return the Zero value for +// implemented type. +// See also https://go.dev/ref/spec#The_zero_value +type IMkeyFactory = func(val any) (*Mkey, error) + +var mkeyValueRegister = map[string]IMkeyFactory{ + OIDType: NewMkeyOID, + UUIDType: NewMkeyUUID, + UintType: NewMkeyUint, + PSARefValIDType: NewMkeyPSARefvalID, + CCAPlatformConfigIDType: NewMkeyCCAPlatformConfigID, +} + +// RegisterMkeyType registers a new IMKeyValue implementation +// (created by the provided IMKeyFactory) under the specified CBOR tag. +func RegisterMkeyType(tag uint64, factory IMkeyFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Value.Type() + if _, exists := mkeyValueRegister[typ]; exists { + return fmt.Errorf("mesurement key type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + mkeyValueRegister[typ] = factory + + return nil } // Mval stores a measurement-values-map with JSON and CBOR serializations. @@ -212,7 +302,7 @@ type Mval struct { Ver *Version `cbor:"0,keyasint,omitempty" json:"version,omitempty"` SVN *SVN `cbor:"1,keyasint,omitempty" json:"svn,omitempty"` Digests *Digests `cbor:"2,keyasint,omitempty" json:"digests,omitempty"` - OpFlags *OpFlags `cbor:"3,keyasint,omitempty" json:"op-flags,omitempty"` + Flags *FlagsMap `cbor:"3,keyasint,omitempty" json:"flags,omitempty"` RawValue *RawValue `cbor:"4,keyasint,omitempty" json:"raw-value,omitempty"` RawValueMask *[]byte `cbor:"5,keyasint,omitempty" json:"raw-value-mask,omitempty"` MACAddr *MACaddr `cbor:"6,keyasint,omitempty" json:"mac-addr,omitempty"` @@ -220,13 +310,45 @@ type Mval struct { SerialNumber *string `cbor:"8,keyasint,omitempty" json:"serial-number,omitempty"` UEID *eat.UEID `cbor:"9,keyasint,omitempty" json:"ueid,omitempty"` UUID *UUID `cbor:"10,keyasint,omitempty" json:"uuid,omitempty"` + + Extensions +} + +// RegisterExtensions registers a struct as a collections of extensions +func (o *Mval) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *Mval) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + +// UnmarshalCBOR deserializes from CBOR +func (o *Mval) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *Mval) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *Mval) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *Mval) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) } func (o Mval) Valid() error { if o.Ver == nil && o.SVN == nil && o.Digests == nil && - o.OpFlags == nil && + o.Flags == nil && o.RawValue == nil && o.RawValueMask == nil && o.MACAddr == nil && @@ -249,8 +371,8 @@ func (o Mval) Valid() error { } } - if o.OpFlags != nil { - if err := o.OpFlags.Valid(); err != nil { + if o.Flags != nil { + if err := o.Flags.Valid(); err != nil { return err } } @@ -259,7 +381,7 @@ func (o Mval) Valid() error { // TODO(tho) MAC addr & friends (see https://github.com/veraison/corim/issues/18) - return nil + return o.Extensions.validMval(&o) } // Version stores a version-map with JSON and CBOR serializations. @@ -295,95 +417,112 @@ func (o Version) Valid() error { return nil } -// NewMeasurement instantiates an empty measurement -func NewMeasurement() *Measurement { - return &Measurement{} +// Measurement stores a measurement-map with CBOR and JSON serializations. +type Measurement struct { + Key *Mkey `cbor:"0,keyasint,omitempty" json:"key,omitempty"` + Val Mval `cbor:"1,keyasint" json:"value"` + AuthorizedBy *CryptoKey `cbor:"2,keyasint,omitempty" json:"authorized-by,omitempty"` } -// SetKeyPSARefValID sets the key of the target measurement-map to the supplied -// PSA refval-id -func (o *Measurement) SetKeyPSARefValID(psaRefValID PSARefValID) *Measurement { - if o != nil { - if psaRefValID.Valid() != nil { - return nil - } - o.Key = &Mkey{ - val: TaggedPSARefValID(psaRefValID), - } +func NewMeasurement(val any, typ string) (*Measurement, error) { + keyFactory, ok := mkeyValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unknown Mkey type: %s", typ) } - return o -} -// SetKeyCCAPlatformConfigID sets the key of the target measurement-map to the supplied -// CCA platform-config-id -func (o *Measurement) SetKeyCCAPlatformConfigID(ccaPlatformConfigID CCAPlatformConfigID) *Measurement { - if o != nil { - if ccaPlatformConfigID.Empty() { - return nil - } - o.Key = &Mkey{ - val: TaggedCCAPlatformConfigID(ccaPlatformConfigID), - } + key, err := keyFactory(val) + if err != nil { + return nil, fmt.Errorf("invalid key: %w", err) } - return o -} -// SetKeyKeyUUID sets the key of the target measurement-map to the supplied -// UUID -func (o *Measurement) SetKeyUUID(u UUID) *Measurement { - if o != nil { - if u.Empty() { - return nil - } + if err = key.Valid(); err != nil { + return nil, fmt.Errorf("invalid key: %w", err) + } - if u.Valid() != nil { - return nil - } + var ret Measurement + ret.Key = key - o.Key = &Mkey{ - val: TaggedUUID(u), - } - } - return o + return &ret, nil } -// SetKeyUint sets the key of the target measurement-map to the supplied -// unsigned integer -func (o *Measurement) SetKeyUint(u uint64) *Measurement { - if o != nil { - o.Key = &Mkey{ - val: u, - } +func MustNewMeasurement(val any, typ string) *Measurement { + ret, err := NewMeasurement(val, typ) + + if err != nil { + panic(err) } - return o + + return ret } // NewPSAMeasurement instantiates a new measurement-map with the key set to the // supplied PSA refval-id -func NewPSAMeasurement(psaRefValID PSARefValID) *Measurement { - m := &Measurement{} - return m.SetKeyPSARefValID(psaRefValID) +func NewPSAMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, PSARefValIDType) +} + +func MustNewPSAMeasurement(key any) *Measurement { + ret, err := NewPSAMeasurement(key) + + if err != nil { + panic(err) + } + + return ret } // NewCCAPlatCfgMeasurement instantiates a new measurement-map with the key set to the // supplied CCA platform-config-id -func NewCCAPlatCfgMeasurement(ccaPlatformConfigID CCAPlatformConfigID) *Measurement { - m := &Measurement{} - return m.SetKeyCCAPlatformConfigID(ccaPlatformConfigID) +func NewCCAPlatCfgMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, CCAPlatformConfigIDType) +} + +func MustNewCCAPlatCfgMeasurement(key any) *Measurement { + ret, err := NewCCAPlatCfgMeasurement(key) + + if err != nil { + panic(err) + } + + return ret } // NewUUIDMeasurement instantiates a new measurement-map with the key set to the // supplied UUID -func NewUUIDMeasurement(uuid UUID) *Measurement { - m := &Measurement{} - return m.SetKeyUUID(uuid) +func NewUUIDMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, UUIDType) +} + +func MustNewUUIDMeasurement(key any) *Measurement { + ret, err := NewUUIDMeasurement(key) + + if err != nil { + panic(err) + } + + return ret } // NewUintMeasurement instantiates a new measurement-map with the key set to the // supplied Uint -func NewUintMeasurement(mkey uint64) *Measurement { - m := &Measurement{} - return m.SetKeyUint(mkey) +func NewUintMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, UintType) +} + +func MustNewUintMeasurement(key any) *Measurement { + ret, err := NewUintMeasurement(key) + + if err != nil { + panic(err) + } + + return ret +} + +// NewOIDMeasurement instantiates a new measurement-map with the key set to the +// supplied OID +func NewOIDMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, OIDType) } func (o *Measurement) SetVersion(ver string, scheme int64) *Measurement { @@ -413,26 +552,14 @@ func (o *Measurement) SetRawValueBytes(rawValue, rawValueMask []byte) *Measureme // SetSVN sets the supplied svn in the measurement-values-map of the target // measurement func (o *Measurement) SetSVN(svn uint64) *Measurement { - if o != nil { - s := SVN{} - if s.SetSVN(svn) == nil { - return nil - } - o.Val.SVN = &s - } + o.Val.SVN = MustNewTaggedSVN(svn) return o } // SetMinSVN sets the supplied min-svn in the measurement-values-map of the // target measurement func (o *Measurement) SetMinSVN(svn uint64) *Measurement { - if o != nil { - s := SVN{} - if s.SetMinSVN(svn) == nil { - return nil - } - o.Val.SVN = &s - } + o.Val.SVN = MustNewTaggedMinSVN(svn) return o } @@ -450,16 +577,51 @@ func (o *Measurement) AddDigest(algID uint64, digest []byte) *Measurement { } o.Val.Digests = ds } + return o } -// SetOpFlags sets the supplied operational flags in the measurement-values-map -// of the target measurement -func (o *Measurement) SetOpFlags(flags ...OpFlags) *Measurement { +// SetFlagsTrue sets the supplied operational flags to true in the +// measurement-values-map of the target measurement +func (o *Measurement) SetFlagsTrue(flags ...Flag) *Measurement { if o != nil { - o.Val.OpFlags = NewOpFlags() - o.Val.OpFlags.SetOpFlags(flags...) + if o.Val.Flags == nil { + o.Val.Flags = NewFlagsMap() + } + o.Val.Flags.SetTrue(flags...) } + + return o +} + +// SetFlagsFalse sets the supplied operational flags to true in the +// measurement-values-map of the target measurement +func (o *Measurement) SetFlagsFalse(flags ...Flag) *Measurement { + if o != nil { + if o.Val.Flags == nil { + o.Val.Flags = NewFlagsMap() + } + o.Val.Flags.SetFalse(flags...) + } + + return o +} + +// ClearFlags clears the supplied operational flags in the +// measurement-values-map of the target measurement +func (o *Measurement) ClearFlags(flags ...Flag) *Measurement { + if o != nil { + if o.Val.Flags == nil { + return o + } + + o.Val.Flags.Clear(flags...) + + if !o.Val.Flags.AnySet() { + o.Val.Flags = nil + } + } + return o } diff --git a/comid/measurement_test.go b/comid/measurement_test.go index e6ea5a2b..c4f95e11 100644 --- a/comid/measurement_test.go +++ b/comid/measurement_test.go @@ -4,6 +4,7 @@ package comid import ( + "crypto" "fmt" "testing" @@ -14,86 +15,86 @@ import ( ) func TestMeasurement_NewUUIDMeasurement_good_uuid(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - assert.NotNil(t, tv) + _, err := NewUUIDMeasurement(TestUUID) + assert.NoError(t, err) } func TestMeasurement_NewUUIDMeasurement_empty_uuid(t *testing.T) { emptyUUID := UUID{} - tv := NewUUIDMeasurement(emptyUUID) + _, err := NewUUIDMeasurement(emptyUUID) - assert.Nil(t, tv) + assert.EqualError(t, err, + "invalid key: expecting RFC4122 UUID, got Reserved instead") } func TestMeasurement_NewUIntMeasurement(t *testing.T) { var TestUint uint64 = 35 - tv := NewUintMeasurement(TestUint) + _, err := NewUintMeasurement(TestUint) - assert.NotNil(t, tv) + assert.NoError(t, err) } func TestMeasurement_NewPSAMeasurement_empty(t *testing.T) { emptyPSARefValID := PSARefValID{} - tv := NewPSAMeasurement(emptyPSARefValID) - assert.Nil(t, tv) + _, err := NewPSAMeasurement(emptyPSARefValID) + + assert.EqualError(t, err, "invalid key: invalid psa.refval-id: missing mandatory signer ID") } func TestMeasurement_NewPSAMeasurement_no_values(t *testing.T) { - psaRefValID := - NewPSARefValID(TestSignerID). - SetLabel("PRoT"). - SetVersion("1.2.3") + psaRefValID, err := NewPSARefValID(TestSignerID) + require.NoError(t, err) + psaRefValID.SetLabel("PRoT") + psaRefValID.SetVersion("1.2.3") require.NotNil(t, psaRefValID) - tv := NewPSAMeasurement(*psaRefValID) - assert.NotNil(t, tv) + tv, err := NewPSAMeasurement(*psaRefValID) + assert.NoError(t, err) - err := tv.Valid() + err = tv.Valid() assert.EqualError(t, err, "no measurement value set") } func TestMeasurement_NewCCAPlatCfgMeasurement_no_values(t *testing.T) { ccaplatID := CCAPlatformConfigID(TestCCALabel) - tv := NewCCAPlatCfgMeasurement(ccaplatID) - assert.NotNil(t, tv) + tv, err := NewCCAPlatCfgMeasurement(ccaplatID) + assert.NoError(t, err) - err := tv.Valid() + err = tv.Valid() assert.EqualError(t, err, "no measurement value set") } func TestMeasurement_NewCCAPlatCfgMeasurement_valid_meas(t *testing.T) { ccaplatID := CCAPlatformConfigID(TestCCALabel) - tv := NewCCAPlatCfgMeasurement(ccaplatID).SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{}) - assert.NotNil(t, tv) + tv, err := NewCCAPlatCfgMeasurement(ccaplatID) + assert.NoError(t, err) - err := tv.Valid() - assert.Nil(t, err) + tv.SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{}) + + err = tv.Valid() + assert.NoError(t, err) } func TestMeasurement_NewPSAMeasurement_one_value(t *testing.T) { - psaRefValID := - NewPSARefValID(TestSignerID). - SetLabel("PRoT"). - SetVersion("1.2.3") - require.NotNil(t, psaRefValID) + tv, err := NewPSAMeasurement(MustCreatePSARefValID(TestSignerID, "PRoT", "1.2.3")) + require.NoError(t, err) - tv := NewPSAMeasurement(*psaRefValID).SetIPaddr(TestIPaddr) - assert.NotNil(t, tv) + tv.SetIPaddr(TestIPaddr) - err := tv.Valid() + err = tv.Valid() assert.Nil(t, err) } func TestMeasurement_NewUUIDMeasurement_no_values(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) - err := tv.Valid() + err = tv.Valid() assert.EqualError(t, err, "no measurement value set") } @@ -101,26 +102,27 @@ func TestMeasurement_NewUUIDMeasurement_some_value(t *testing.T) { var vs swid.VersionScheme require.NoError(t, vs.SetCode(swid.VersionSchemeSemVer)) - tv := NewUUIDMeasurement(TestUUID). - SetMinSVN(2). - SetOpFlags(OpFlagDebug). + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) + + tv.SetMinSVN(2). + SetFlagsTrue(FlagIsDebug). SetVersion("1.2.3", swid.VersionSchemeSemVer) - require.NotNil(t, tv) - err := tv.Valid() + err = tv.Valid() assert.Nil(t, err) } func TestMeasurement_NewUUIDMeasurement_bad_digest(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) assert.Nil(t, tv.AddDigest(swid.Sha256, []byte{0xff})) } func TestMeasurement_NewUUIDMeasurement_bad_ueid(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) badUEID := eat.UEID{ 0xFF, // Invalid @@ -131,8 +133,8 @@ func TestMeasurement_NewUUIDMeasurement_bad_ueid(t *testing.T) { } func TestMeasurement_NewUUIDMeasurement_bad_uuid(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) nonRFC4122UUID, err := ParseUUID("f47ac10b-58cc-4372-c567-0e02b2c3d479") require.Nil(t, err) @@ -141,13 +143,13 @@ func TestMeasurement_NewUUIDMeasurement_bad_uuid(t *testing.T) { } var ( - testMKeyUintMin uint64 = 0 - testMKeyUintMax uint64 = ^uint64(0) + testMKeyUintMin uint64 + testMKeyUintMax = ^uint64(0) ) func TestMkey_Valid_no_value(t *testing.T) { mkey := &Mkey{} - expectedErr := "unknown measurement key type: " + expectedErr := "Mkey value not set" err := mkey.Valid() assert.EqualError(t, err, expectedErr) } @@ -183,19 +185,19 @@ func TestMKey_MarshalCBOR_CCAPlatformConfigID_ok(t *testing.T) { func TestMKey_UnmarshalCBOR_CCAPlatformConfigID_ok(t *testing.T) { tvs := []struct { input []byte - expected CCAPlatformConfigID + expected TaggedCCAPlatformConfigID }{ { input: MustHexDecode(t, "d9025a736363612d706c6174666f726d2d636f6e666967"), - expected: CCAPlatformConfigID(TestCCALabel), + expected: TaggedCCAPlatformConfigID(TestCCALabel), }, { input: MustHexDecode(t, "d9025a716d7974657374706c6174666f726d666967"), - expected: CCAPlatformConfigID("mytestplatformfig"), + expected: TaggedCCAPlatformConfigID("mytestplatformfig"), }, { input: MustHexDecode(t, "d9025a6c6d79746573746c6162656c32"), - expected: CCAPlatformConfigID("mytestlabel2"), + expected: TaggedCCAPlatformConfigID("mytestlabel2"), }, } @@ -203,9 +205,9 @@ func TestMKey_UnmarshalCBOR_CCAPlatformConfigID_ok(t *testing.T) { mkey := &Mkey{} err := mkey.UnmarshalCBOR(tv.input) assert.Nil(t, err) - actual, err := mkey.GetCCAPlatformConfigID() - assert.Nil(t, err) - assert.Equal(t, tv.expected, actual) + actual, ok := mkey.Value.(*TaggedCCAPlatformConfigID) + assert.True(t, ok) + assert.Equal(t, tv.expected, *actual) fmt.Printf("CBOR: %x\n", actual) } } @@ -230,36 +232,13 @@ func TestMKey_MarshalCBOR_uint_ok(t *testing.T) { } for _, tv := range tvs { - mkey := &Mkey{tv.mkey} + mkey := &Mkey{UintMkey(tv.mkey)} actual, err := mkey.MarshalCBOR() assert.Nil(t, err) assert.Equal(t, tv.expected, actual) fmt.Printf("CBOR: %x\n", actual) } } -func TestMkey_MarshalCBOR_uint_not_ok(t *testing.T) { - tvs := []struct { - input interface{} - expected string - }{ - { - input: 123.456, - expected: "unknown measurement key type: float64", - }, - { - input: "sample", - expected: "unknown measurement key type: string", - }, - } - - for _, tv := range tvs { - mkey := &Mkey{tv.input} - _, err := mkey.MarshalCBOR() - assert.Nil(t, err) - err = mkey.Valid() - assert.EqualError(t, err, tv.expected) - } -} func TestMkey_UnmarshalCBOR_uint_ok(t *testing.T) { tvs := []struct { @@ -284,10 +263,10 @@ func TestMkey_UnmarshalCBOR_uint_ok(t *testing.T) { mKey := &Mkey{} err := mKey.UnmarshalCBOR(tv.mkey) - assert.Nil(t, err) - actual, err := mKey.GetKeyUint() - assert.Nil(t, err) - assert.Equal(t, tv.expected, actual) + require.NoError(t, err) + actual, ok := mKey.Value.(*UintMkey) + require.True(t, ok) + assert.Equal(t, tv.expected, uint64(*actual)) } } @@ -315,38 +294,9 @@ func TestMkey_UnmarshalCBOR_not_ok(t *testing.T) { } } -func TestMkey_UnmarshalCBOR_uint_not_ok(t *testing.T) { - tvs := []struct { - input []byte - expected string - }{ - { - input: []byte{0xd8, 0x25, 0x50, 0x31, 0xfb, 0x5a, 0xbf, 0x02, - 0x3e, 0x49, 0x92, 0xaa, 0x4e, 0x95, 0xf9, 0xc1, - 0x50, 0x3b, 0xfa}, - expected: "measurement-key type is: comid.TaggedUUID", - }, - { - input: []byte{0xd8, 0x21, 0x50, 0x31, 0xfb, 0x5a, 0xff, 0x12, - 0xFF, 0xFF, 0x92, 0xaa, 0x4e, 0x95, 0xf9, 0xc1, - 0x50, 0x3b, 0xfa}, - expected: "measurement-key type is: cbor.Tag", - }, - } - - for _, tv := range tvs { - mKey := &Mkey{} - - err := mKey.UnmarshalCBOR(tv.input) - assert.Nil(t, err) - _, err = mKey.GetKeyUint() - assert.EqualError(t, err, tv.expected) - } -} - func TestMKey_MarshalJSON_CCAPlatformConfigID_ok(t *testing.T) { refval := TestCCALabel - mkey := &Mkey{val: TaggedCCAPlatformConfigID(refval)} + mkey := &Mkey{Value: TaggedCCAPlatformConfigID(refval)} expected := `{"type":"cca.platform-config-id","value":"cca-platform-config"}` @@ -359,30 +309,29 @@ func TestMKey_MarshalJSON_CCAPlatformConfigID_ok(t *testing.T) { func TestMKey_UnMarshalJSON_CCAPlatformConfigID_ok(t *testing.T) { input := []byte(`{"type":"cca.platform-config-id","value":"cca-platform-config"}`) - expected := CCAPlatformConfigID(TestCCALabel) + expected := TaggedCCAPlatformConfigID(TestCCALabel) mKey := &Mkey{} err := mKey.UnmarshalJSON(input) assert.Nil(t, err) - actual, err := mKey.GetCCAPlatformConfigID() - assert.Nil(t, err) - assert.Equal(t, expected, actual) + actual, ok := mKey.Value.(*TaggedCCAPlatformConfigID) + assert.True(t, ok) + assert.Equal(t, expected, *actual) } func TestMKey_UnMarshalJSON_CCAPlatformConfigID_not_ok(t *testing.T) { input := []byte(`{"type":"cca.platform-config-id","value":""}`) - expected := "cannot unmarshal $measured-element-type-choice of type CCAPlatformConfigID: empty label" + expected := "invalid cca.platform-config-id: empty value" mKey := &Mkey{} err := mKey.UnmarshalJSON(input) - assert.NotNil(t, err) - assert.Equal(t, expected, err.Error()) - + assert.EqualError(t, err, expected) } + func TestMkey_MarshalJSON_uint_ok(t *testing.T) { tvs := []struct { mkey uint64 @@ -404,7 +353,7 @@ func TestMkey_MarshalJSON_uint_ok(t *testing.T) { for _, tv := range tvs { - mkey := &Mkey{tv.mkey} + mkey := &Mkey{UintMkey(tv.mkey)} actual, err := mkey.MarshalJSON() assert.Nil(t, err) @@ -414,31 +363,6 @@ func TestMkey_MarshalJSON_uint_ok(t *testing.T) { } } -func TestMkey_MarshalJSON_uint_not_ok(t *testing.T) { - tvs := []struct { - input interface{} - expected string - }{ - { - input: 123.456, - expected: "unknown type float64 for mkey", - }, - { - input: "sample", - expected: "unknown type string for mkey", - }, - } - - for _, tv := range tvs { - - mkey := &Mkey{tv.input} - - _, err := mkey.MarshalJSON() - - assert.EqualError(t, err, tv.expected) - } -} - func TestMkey_UnmarshalJSON_uint_ok(t *testing.T) { tvs := []struct { input []byte @@ -463,9 +387,9 @@ func TestMkey_UnmarshalJSON_uint_ok(t *testing.T) { err := mKey.UnmarshalJSON(tv.input) assert.Nil(t, err) - actual, err := mKey.GetKeyUint() - assert.Nil(t, err) - assert.Equal(t, tv.expected, actual) + actual, ok := mKey.Value.(*UintMkey) + assert.True(t, ok) + assert.Equal(t, tv.expected, uint64(*actual)) } } @@ -476,11 +400,11 @@ func TestMkey_UnmarshalJSON_notok(t *testing.T) { }{ { input: []byte(`{"type":"uint","value":"abcdefg"}`), - expected: "cannot unmarshal $measured-element-type-choice of type uint: json: cannot unmarshal string into Go value of type uint64", + expected: `invalid uint: json: cannot unmarshal string into Go value of type uint64`, }, { input: []byte(`{"type":"uint","value":123.456}`), - expected: "cannot unmarshal $measured-element-type-choice of type uint: json: cannot unmarshal number 123.456 into Go value of type uint64", + expected: "invalid uint: json: cannot unmarshal number 123.456 into Go value of type uint64", }, } @@ -493,27 +417,102 @@ func TestMkey_UnmarshalJSON_notok(t *testing.T) { } } -func TestMkey_UnmarshalJSON_uint_notok(t *testing.T) { +func TestNewUintMkey(t *testing.T) { + testVal := UintMkey(7) + tvs := []struct { - input []byte - expected string + input any + expected UintMkey + err string }{ { - input: []byte(`{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}`), - expected: "measurement-key type is: comid.TaggedUUID", + input: testVal, + expected: testVal, + }, + { + input: &testVal, + expected: testVal, }, { - input: []byte(`{"type":"psa.refval-id","value":{"label": "BL","version": "2.1.0","signer-id": "rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}}`), - expected: "measurement-key type is: comid.TaggedPSARefValID", + input: uint(7), + expected: testVal, + }, + { + input: uint64(7), + expected: testVal, + }, + { + input: "7", + expected: testVal, + }, + { + input: true, + err: "unexpected type for UintMkey: bool", }, } for _, tv := range tvs { - mKey := &Mkey{} - - err := mKey.UnmarshalJSON(tv.input) - assert.Nil(t, err) - _, err = mKey.GetKeyUint() - assert.EqualError(t, err, tv.expected) + out, err := NewUintMkey(tv.input) + if tv.err != "" { + assert.Nil(t, out) + assert.EqualError(t, err, tv.err) + } else { + assert.Equal(t, tv.expected, *out) + } } } + +func TestNewMkeyOID(t *testing.T) { + var expectedOID OID + require.NoError(t, expectedOID.FromString(TestOID)) + expected := TaggedOID(expectedOID) + + out, err := NewMkeyOID(TestOID) + require.NoError(t, err) + assert.Equal(t, &expected, out.Value) +} + +type testMkey [4]byte + +func newTestMkey(val any) (*Mkey, error) { + return &Mkey{&testMkey{0x74, 0x64, 0x73, 0x74}}, nil +} + +func (o testMkey) PublicKey() (crypto.PublicKey, error) { + return crypto.PublicKey(o[:]), nil +} + +func (o testMkey) Type() string { + return "test-mkey" +} + +func (o testMkey) String() string { + return "test" +} + +func (o testMkey) Valid() error { + return nil +} + +type badMkey struct { + testMkey +} + +func (o badMkey) Type() string { + return "uuid" +} + +func newBadMkey(val any) (*Mkey, error) { + return &Mkey{&badMkey{testMkey{0x74, 0x64, 0x73, 0x74}}}, nil +} + +func TestRegisterMkeyType(t *testing.T) { + err := RegisterMkeyType(32, newTestMkey) + assert.EqualError(t, err, "tag 32 is already registered") + + err = RegisterMkeyType(99996, newBadMkey) + assert.EqualError(t, err, `mesurement key type with name "uuid" already exists`) + + err = RegisterMkeyType(99996, newTestMkey) + assert.NoError(t, err) +} diff --git a/comid/oid.go b/comid/oid.go index 822d7deb..24d366e6 100644 --- a/comid/oid.go +++ b/comid/oid.go @@ -11,6 +11,8 @@ import ( "strings" ) +const OIDType = "oid" + // BER-encoded absolute OID type OID []byte @@ -152,3 +154,78 @@ func (o *OID) UnmarshalJSON(data []byte) error { func (o OID) MarshalJSON() ([]byte, error) { return json.Marshal(o.String()) } + +type TaggedOID OID + +func NewTaggedOID(val any) (*TaggedOID, error) { + ret := TaggedOID{} + + if val == nil { + return &ret, nil + } + + switch t := val.(type) { + case string: + var berOID OID + if err := berOID.FromString(t); err != nil { + return nil, err + } + + ret = TaggedOID(berOID) + case []byte: + ret = make([]byte, len(t)) + copy(ret, t) + case TaggedOID: + ret = make([]byte, len(t)) + copy(ret, t) + case OID: + ret = make([]byte, len(t)) + copy(ret, t) + case *TaggedOID: + ret = make([]byte, len(*t)) + copy(ret, (*t)) + case *OID: + ret = make([]byte, len(*t)) + copy(ret, (*t)) + } + + return &ret, nil +} + +func (o TaggedOID) Type() string { + return OIDType +} + +func (o TaggedOID) String() string { + return OID(o).String() +} + +func (o TaggedOID) Valid() error { + return nil +} + +func (o TaggedOID) Bytes() []byte { + return o +} + +func (o *TaggedOID) FromString(s string) error { + return (*OID)(o).FromString(s) +} + +func (o *TaggedOID) UnmarshalJSON(data []byte) error { + var s string + + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + if err := o.FromString(s); err != nil { + return err + } + + return nil +} + +func (o TaggedOID) MarshalJSON() ([]byte, error) { + return json.Marshal(o.String()) +} diff --git a/comid/opflag.go b/comid/opflag.go deleted file mode 100644 index ce0c8d8b..00000000 --- a/comid/opflag.go +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2021 Contributors to the Veraison project. -// SPDX-License-Identifier: Apache-2.0 - -package comid - -import ( - "encoding/json" - "fmt" -) - -// OpFlags implements the flags-type, mapping to DiceTcbInfo.flags via the -// operational flags not-configured, not-secure, recovery and debug. -// If the flags field is omitted, all flags are assumed to be 0. -type OpFlags uint8 - -const ( - OpFlagNotConfigured OpFlags = 1 << iota - OpFlagNotSecure - OpFlagRecovery - OpFlagDebug -) - -func NewOpFlags() *OpFlags { - return new(OpFlags) -} - -func (o OpFlags) Strings() []string { - var a []string - - if o&OpFlagNotConfigured != 0 { - a = append(a, "notConfigured") - } - - if o&OpFlagNotSecure != 0 { - a = append(a, "notSecure") - } - - if o&OpFlagRecovery != 0 { - a = append(a, "recovery") - } - - if o&OpFlagDebug != 0 { - a = append(a, "debug") - } - - return a -} - -func (o OpFlags) Valid() error { - // While any combination in the lower half-byte is acceptable, the most - // significant nibble must be all zeroes. - if o&0xf0 != 0 { - return fmt.Errorf("op-flags has unknown bits asserted: %02x", o) - } - - return nil -} - -// SetFlags sets the target object as specified. As many flags as necessary can -// be specified in one call. -func (o *OpFlags) SetOpFlags(flags ...OpFlags) *OpFlags { - if o != nil { - for _, flag := range flags { - *o |= flag - } - } - return o -} - -func (o OpFlags) IsSet(flag OpFlags) bool { - return o&flag != 0 -} - -// UnmarshalJSON provides a custom deserializer for the OpFlags type that uses an -// array of identifiers rather than a bit set, e.g.: -// -// "op-flags": [ -// "notSecure", -// "debug" -// ] -func (o *OpFlags) UnmarshalJSON(data []byte) error { - var a []string - - if err := json.Unmarshal(data, &a); err != nil { - return err - } - - if len(a) == 0 { - *o = 0 - return nil - } - - for _, s := range a { - switch s { - case "notSecure": - *o |= OpFlagNotSecure - case "notConfigured": - *o |= OpFlagNotConfigured - case "recovery": - *o |= OpFlagRecovery - case "debug": - *o |= OpFlagDebug - default: - // ignore unknown opflags - continue - } - } - - return nil -} - -func (o OpFlags) MarshalJSON() ([]byte, error) { - return json.Marshal(o.Strings()) -} diff --git a/comid/opflag_test.go b/comid/opflag_test.go deleted file mode 100644 index 261dc963..00000000 --- a/comid/opflag_test.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2021 Contributors to the Veraison project. -// SPDX-License-Identifier: Apache-2.0 - -package comid - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestFlags_UnmarshalJSON_skip_unknown(t *testing.T) { - tv := []byte(`[ "notSecure", "mysteriousFlagWhichWillBeIgnored" ]`) - - flags := NewOpFlags().SetOpFlags(OpFlagNotSecure) - require.NotNil(t, flags) - expected := *flags - - var actual OpFlags - err := actual.UnmarshalJSON(tv) - - assert.Nil(t, err) - assert.Equal(t, expected, actual) - assert.True(t, actual.IsSet(OpFlagNotSecure)) -} - -func TestFlags_UnmarshalJSON_all_known(t *testing.T) { - tv := []byte(`[ "notSecure", "notConfigured", "recovery", "debug" ]`) - - flags := NewOpFlags(). - SetOpFlags(OpFlagNotSecure). - SetOpFlags(OpFlagNotConfigured). - SetOpFlags(OpFlagRecovery). - SetOpFlags(OpFlagDebug) - require.NotNil(t, flags) - expected := *flags - - var actual OpFlags - err := actual.UnmarshalJSON(tv) - - fmt.Printf("CBOR: %02x\n", actual) - - assert.Nil(t, err) - assert.Equal(t, expected, actual) - assert.True(t, actual.IsSet(OpFlagNotSecure)) - assert.True(t, actual.IsSet(OpFlagRecovery)) - assert.True(t, actual.IsSet(OpFlagDebug)) - assert.True(t, actual.IsSet(OpFlagNotConfigured)) -} - -func TestFlags_UnmarshalJSON_empty(t *testing.T) { - tv := []byte(`[ ]`) - - flags := NewOpFlags() - require.NotNil(t, flags) - expected := *flags - - var actual OpFlags - err := actual.UnmarshalJSON(tv) - - fmt.Printf("%02x\n", actual) - - assert.Nil(t, err) - assert.Equal(t, expected, actual) - assert.False(t, actual.IsSet(OpFlagNotSecure)) - assert.False(t, actual.IsSet(OpFlagRecovery)) - assert.False(t, actual.IsSet(OpFlagDebug)) - assert.False(t, actual.IsSet(OpFlagNotConfigured)) -} - -func TestFlags_Valid_ok(t *testing.T) { - // all valid flags combinations - for i := 1; i <= 15; i++ { - tv := OpFlags(i) - - assert.Nil(t, tv.Valid()) - } -} - -func TestFlags_Valid_bad_combos(t *testing.T) { - for i := 1; i <= 15; i++ { - for j := 1; j <= 15; j++ { - tv := OpFlags(i<<4 | j) - - expectedErr := fmt.Sprintf("op-flags has unknown bits asserted: %02x", tv) - - assert.EqualError(t, tv.Valid(), expectedErr) - } - } -} diff --git a/comid/psareferencevalue.go b/comid/psareferencevalue.go index cde3304b..0ffacf97 100644 --- a/comid/psareferencevalue.go +++ b/comid/psareferencevalue.go @@ -4,9 +4,12 @@ package comid import ( + "encoding/json" "fmt" ) +var PSARefValIDType = "psa.refval-id" + // PSARefValID stores a PSA refval-id with CBOR and JSON serializations // (See https://datatracker.ietf.org/doc/html/draft-xyz-rats-psa-endorsements) type PSARefValID struct { @@ -30,18 +33,56 @@ func (o PSARefValID) Valid() error { return nil } -type TaggedPSARefValID PSARefValID +func CreatePSARefValID(signerID []byte, label, version string) (*PSARefValID, error) { + ret, err := NewPSARefValID(signerID) + if err != nil { + return nil, err + } -func NewPSARefValID(signerID []byte) *PSARefValID { - switch len(signerID) { - case 32, 48, 64: - default: - return nil + ret.SetLabel(label) + ret.SetVersion(version) + + return ret, nil +} + +func MustCreatePSARefValID(signerID []byte, label, version string) *PSARefValID { + ret, err := CreatePSARefValID(signerID, label, version) + + if err != nil { + panic(err) } - return &PSARefValID{ - SignerID: signerID, + return ret +} + +func NewPSARefValID(val any) (*PSARefValID, error) { + var ret PSARefValID + + if val == nil { + return &ret, nil } + + switch t := val.(type) { + case PSARefValID: + ret = t + case *PSARefValID: + ret = *t + case string: + if err := json.Unmarshal([]byte(t), &ret); err != nil { + return nil, err + } + case []byte: + switch len(t) { + case 32, 48, 64: + ret.SignerID = t + default: + return nil, fmt.Errorf("invalid PSA RefVal ID length: %d", len(t)) + } + default: + return nil, fmt.Errorf("unexpected type for PSA RefVal ID: %T", t) + } + + return &ret, nil } func (o *PSARefValID) SetLabel(label string) *PSARefValID { @@ -57,3 +98,46 @@ func (o *PSARefValID) SetVersion(version string) *PSARefValID { } return o } + +type TaggedPSARefValID PSARefValID + +func NewTaggedPSARefValID(val any) (*TaggedPSARefValID, error) { + var ret TaggedPSARefValID + + switch t := val.(type) { + case TaggedPSARefValID: + ret = t + case *TaggedPSARefValID: + ret = *t + default: + refvalID, err := NewPSARefValID(val) + if err != nil { + return nil, err + } + ret = TaggedPSARefValID(*refvalID) + + } + + return &ret, nil +} + +func (o TaggedPSARefValID) Valid() error { + return PSARefValID(o).Valid() +} + +func (o TaggedPSARefValID) String() string { + ret, err := json.Marshal(o) + if err != nil { + return "" + } + + return string(ret) +} + +func (o TaggedPSARefValID) Type() string { + return PSARefValIDType +} + +func (o TaggedPSARefValID) IsZero() bool { + return len(o.SignerID) == 0 +} diff --git a/comid/psareferencevalue_test.go b/comid/psareferencevalue_test.go index 069c432e..72116cc2 100644 --- a/comid/psareferencevalue_test.go +++ b/comid/psareferencevalue_test.go @@ -4,9 +4,11 @@ package comid import ( + "fmt" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestPSARefValID_Valid_SignerID_range(t *testing.T) { @@ -15,13 +17,27 @@ func TestPSARefValID_Valid_SignerID_range(t *testing.T) { for i := 1; i <= 100; i++ { signerID = append(signerID, byte(0xff)) - tv := NewPSARefValID(signerID) + tv, err := NewPSARefValID(signerID) + switch i { case 32, 48, 64: assert.NotNil(t, tv) assert.Nil(t, tv.Valid()) default: assert.Nil(t, tv) + assert.EqualError( + t, + err, + fmt.Sprintf("invalid PSA RefVal ID length: %d", i), + ) } } } + +func TestPSARefValID_Streing(t *testing.T) { + signerID := MustHexDecode(t, "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") + refvalID, err := NewTaggedPSARefValID(signerID) + require.NoError(t, err) + + assert.Equal(t, `{"signer-id":"3q2+796tvu/erb7v3q2+796tvu/erb7v3q2+796tvu8="}`, refvalID.String()) +} diff --git a/comid/referencevalue_test.go b/comid/referencevalue_test.go new file mode 100644 index 00000000..f999c876 --- /dev/null +++ b/comid/referencevalue_test.go @@ -0,0 +1,21 @@ +package comid + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ReferenceValue(t *testing.T) { + rv := ReferenceValue{} + err := rv.Valid() + assert.EqualError(t, err, "environment validation failed: environment must not be empty") + + id, err := uuid.NewUUID() + require.NoError(t, err) + rv.Environment.Instance = MustNewUUIDInstance(id) + err = rv.Valid() + assert.EqualError(t, err, "measurements validation failed: no measurement entries") +} diff --git a/comid/rel.go b/comid/rel.go index ad35ed36..39e40d64 100644 --- a/comid/rel.go +++ b/comid/rel.go @@ -18,6 +18,37 @@ const ( RelUnset = ^Rel(0) ) +var ( + relToString = map[Rel]string{ + RelReplaces: "replaces", + RelSupplements: "supplements", + } + + stringToRel = map[string]Rel{ + "replaces": RelReplaces, + "supplements": RelSupplements, + } +) + +// RegisterRel creates a new Rel association between the provided value and +// name. An error is returned if either clashes with any of the existing roles. +func RegisterRel(val int64, name string) error { + rel := Rel(val) + + if _, ok := relToString[rel]; ok { + return fmt.Errorf("rel with value %d already exists", val) + } + + if _, ok := stringToRel[name]; ok { + return fmt.Errorf("rel with name %q already exists", name) + } + + relToString[rel] = name + stringToRel[name] = rel + + return nil +} + func NewRel() *Rel { r := RelUnset return &r @@ -43,14 +74,12 @@ func (o Rel) Valid() error { } func (o Rel) String() string { - switch o { - case RelReplaces: - return "replaces" - case RelSupplements: - return "supplements" - default: - return fmt.Sprintf("rel(%d)", o) + ret, ok := relToString[o] + if ok { + return ret } + + return fmt.Sprintf("rel(%d)", o) } func (o Rel) ToCBOR() ([]byte, error) { diff --git a/comid/rel_test.go b/comid/rel_test.go index 29de1629..86f6d748 100644 --- a/comid/rel_test.go +++ b/comid/rel_test.go @@ -163,3 +163,20 @@ func TestRel_ToCBOR_fail_unset(t *testing.T) { assert.EqualError(t, err, "rel is unset") } + +func Test_RegisterRel(t *testing.T) { + err := RegisterRel(1, "augments") + assert.EqualError(t, err, "rel with value 1 already exists") + + err = RegisterRel(3, "replaces") + assert.EqualError(t, err, `rel with name "replaces" already exists`) + + err = RegisterRel(3, "augments") + assert.NoError(t, err) + + rel := Rel(3) + + out, err := rel.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, `"augments"`, string(out)) +} diff --git a/comid/role.go b/comid/role.go index 1b08322b..3cd3d0a4 100644 --- a/comid/role.go +++ b/comid/role.go @@ -40,6 +40,35 @@ var ( } ) +// String returns the string representation of the Role. +func (o Role) String() string { + text, ok := roleToString[o] + if ok { + return text + } + + return fmt.Sprintf("Role(%d)", o) +} + +// RegisterRole creates a new Role association between the provided value and +// name. An error is returned if either clashes with any of the existing roles. +func RegisterRole(val int64, name string) error { + role := Role(val) + + if _, ok := roleToString[role]; ok { + return fmt.Errorf("role with value %d already exists", val) + } + + if _, ok := stringToRole[name]; ok { + return fmt.Errorf("role with name %q already exists", name) + } + + roleToString[role] = name + stringToRole[name] = role + + return nil +} + type Roles []Role func NewRoles() *Roles { diff --git a/comid/role_test.go b/comid/role_test.go index 4479960c..103f4085 100644 --- a/comid/role_test.go +++ b/comid/role_test.go @@ -210,3 +210,25 @@ func TestRoles_UnmarshalJSON_fail(t *testing.T) { assert.EqualError(t, err, tv.expectedErr) } } + +func Test_Role_String(t *testing.T) { + assert.Equal(t, "maintainer", RoleMaintainer.String()) + assert.Equal(t, "Role(9999)", Role(9999).String()) +} + +func Test_RegisterRole(t *testing.T) { + err := RegisterRole(1, "owner") + assert.EqualError(t, err, "role with value 1 already exists") + + err = RegisterRole(3, "maintainer") + assert.EqualError(t, err, `role with name "maintainer" already exists`) + + err = RegisterRole(3, "owner") + assert.NoError(t, err) + + roles := NewRoles().Add(Role(3)) + + out, err := roles.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, `["owner"]`, string(out)) +} diff --git a/comid/svn.go b/comid/svn.go index ed7a26bb..29fea097 100644 --- a/comid/svn.go +++ b/comid/svn.go @@ -6,102 +6,262 @@ package comid import ( "encoding/json" "fmt" -) + "strconv" -type TaggedSVN uint64 -type TaggedMinSVN uint64 + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" +) +// SVN is the Security Version Number. This typically changes only when a +// security relevant change is needed to the measured environment. type SVN struct { - val interface{} + Value ISVNValue } -func (o *SVN) SetSVN(val uint64) *SVN { - if o != nil { - o.val = TaggedSVN(val) +// NewSVN creates a new SVN of the specified and value. The type must be one of +// the strings defined by the spec ("exact-value", "min-value"), or has been +// registered with RegisterSVNType(). +func NewSVN(val any, typ string) (*SVN, error) { + factory, ok := svnValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unknown SVN type: %s", typ) } - return o + + return factory(val) } -func (o *SVN) SetMinSVN(val uint64) *SVN { - if o != nil { - o.val = TaggedMinSVN(val) +// MustNewSVN is like NewSVN but does not return an error, assuming that the +// provided value is valid. It panics if this is not the case. +func MustNewSVN(val any, typ string) *SVN { + ret, err := NewSVN(val, typ) + if err != nil { + panic(err) } - return o + + return ret } +// MarshalCBOR returns the CBOR encoding of the SVN. func (o SVN) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) + return em.Marshal(o.Value) } +// UnmarshalCBOR populates the SVN form the provided CBOR bytes. func (o *SVN) UnmarshalCBOR(data []byte) error { - var svn TaggedSVN + return dm.Unmarshal(data, &o.Value) +} + +// UnmarshalJSON deserializes the supplied JSON object into the target SVN +// The SVN object must have the following shape: +// +// { +// "type": "", +// "value": +// } +// +// where must be one of the known ISVNValue implementation +// type names (available in the base implementation: "exact-value", +// "min-value"), and is the JSON encoding of the underlying +// class id value. The exact encoding is dependent. For both base +// types, it is an integer (JSON number). +func (o *SVN) UnmarshalJSON(data []byte) error { + var tnv encoding.TypeAndValue - if dm.Unmarshal(data, &svn) == nil { - o.val = svn - return nil + if err := json.Unmarshal(data, &tnv); err != nil { + return fmt.Errorf("SVN decoding failure: %w", err) } - var minsvn TaggedMinSVN + decoded, err := NewSVN(nil, tnv.Type) + if err != nil { + return err + } - if dm.Unmarshal(data, &minsvn) == nil { - o.val = svn - return nil + if err := json.Unmarshal(tnv.Value, &decoded.Value); err != nil { + return fmt.Errorf("invalid SVN %s: %w", tnv.Type, err) } - return fmt.Errorf("unknown SVN (CBOR: %x)", data) + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid SVN %s: %w", tnv.Type, err) + } + + o.Value = decoded.Value + + return nil +} + +// MarshalJSON serializes the SVN int a JSON object +func (o SVN) MarshalJSON() ([]byte, error) { + return extensions.TypeChoiceValueMarshalJSON(o.Value) } -type svnJSONRepr tnv +// ISVNValue is the interface that must be implemented by all SVN values. +type ISVNValue interface { + extensions.ITypeChoiceValue +} -// Supported formats: -// { "type": "exact-value", "value": 123 } -> SVN -// { "type": "min-value", "value": 123 } -> MinSVN -func (o *SVN) UnmarshalJSON(data []byte) error { - var s svnJSONRepr +const ( + ExactValueType = "exact-value" + MinValueType = "min-value" +) - if err := json.Unmarshal(data, &s); err != nil { - return fmt.Errorf("SVN decoding failure: %w", err) - } +type TaggedSVN uint64 + +func NewTaggedSVN(val any) (*SVN, error) { + var ret TaggedSVN - var x uint64 - if err := json.Unmarshal(s.Value, &x); err != nil { - return fmt.Errorf( - "cannot unmarshal svn or min-svn: %w", - err, - ) + if val == nil { + return &SVN{&ret}, nil } - switch s.Type { - case "exact-value": - o.val = TaggedSVN(x) - case "min-value": - o.val = TaggedMinSVN(x) + switch t := val.(type) { + case string: + u, err := strconv.ParseUint(t, 10, 64) + if err != nil { + return nil, err + } + ret = TaggedSVN(u) + case TaggedSVN: + ret = t + case *TaggedSVN: + ret = *t + case uint64: + ret = TaggedSVN(t) + case uint: + ret = TaggedSVN(t) + case int: + if t < 0 { + return nil, fmt.Errorf("SVN cannot be negative: %d", t) + } + ret = TaggedSVN(t) + case int64: + if t < 0 { + return nil, fmt.Errorf("SVN cannot be negative: %d", t) + } + ret = TaggedSVN(t) default: - return fmt.Errorf("unknown comparison operator %s", s.Type) + return nil, fmt.Errorf("unexpected type for SVN exact-value: %T", t) } + return &SVN{&ret}, nil +} + +func MustNewTaggedSVN(val any) *SVN { + ret, err := NewTaggedSVN(val) + if err != nil { + panic(err) + } + + return ret +} + +func (o TaggedSVN) String() string { + return fmt.Sprint(uint64(o)) +} + +func (o TaggedSVN) Type() string { + return ExactValueType +} + +func (o TaggedSVN) Valid() error { return nil } -func (o SVN) MarshalJSON() ([]byte, error) { - var ( - v svnJSONRepr - b []byte - err error - ) +type TaggedMinSVN uint64 - b, err = json.Marshal(o.val) - if err != nil { - return nil, err +func NewTaggedMinSVN(val any) (*SVN, error) { + var ret TaggedMinSVN + + if val == nil { + return &SVN{&ret}, nil } - switch t := o.val.(type) { - case TaggedSVN: - v = svnJSONRepr{Type: "exact-value", Value: b} + + switch t := val.(type) { + case string: + u, err := strconv.ParseUint(t, 10, 64) + if err != nil { + return nil, err + } + ret = TaggedMinSVN(u) case TaggedMinSVN: - v = svnJSONRepr{Type: "min-value", Value: b} + ret = t + case *TaggedMinSVN: + ret = *t + case uint64: + ret = TaggedMinSVN(t) + case uint: + ret = TaggedMinSVN(t) + case int: + if t < 0 { + return nil, fmt.Errorf("SVN cannot be negative: %d", t) + } + ret = TaggedMinSVN(t) + case int64: + if t < 0 { + return nil, fmt.Errorf("SVN cannot be negative: %d", t) + } + ret = TaggedMinSVN(t) default: - return nil, fmt.Errorf("unknown SVN type: %T", t) + return nil, fmt.Errorf("unexpected type for SVN min-value: %T", t) + } + + return &SVN{&ret}, nil +} + +func MustNewTaggedMinSVN(val any) *SVN { + ret, err := NewTaggedMinSVN(val) + if err != nil { + panic(err) + } + + return ret +} + +func (o TaggedMinSVN) String() string { + return fmt.Sprint(uint64(o)) +} + +func (o TaggedMinSVN) Type() string { + return MinValueType +} + +func (o TaggedMinSVN) Valid() error { + return nil +} + +// ISVNFactory defines the signature for the factory functions that may be +// registred using RegisterSVNType to provide a new implementation of the +// corresponding type choice. The factory function should create a new *SVN +// with the underlying value created based on the provided input. The range of +// valid inputs is up to the specific type choice implementation, however it +// _must_ accept nil as one of the inputs, and return the Zero value for +// implemented type. +// See also https://go.dev/ref/spec#The_zero_value +type ISVNFactory func(any) (*SVN, error) + +var svnValueRegister = map[string]ISVNFactory{ + ExactValueType: NewTaggedSVN, + MinValueType: NewTaggedMinSVN, +} + +// RegisterSVNType registers a new ISVNValue implementation +// (created by the provided ISVNFactory) under the specified CBOR tag. +func RegisterSVNType(tag uint64, factory ISVNFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Value.Type() + if _, exists := svnValueRegister[typ]; exists { + return fmt.Errorf("SVN type with name %q already exists", typ) } - return json.Marshal(v) + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + svnValueRegister[typ] = factory + + return nil } diff --git a/comid/svn_test.go b/comid/svn_test.go new file mode 100644 index 00000000..0ffb8f28 --- /dev/null +++ b/comid/svn_test.go @@ -0,0 +1,182 @@ +package comid + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewSVN(t *testing.T) { + for _, tv := range []struct { + Name string + Input any + Expected uint64 + Err string + }{ + { + Name: "string ok", + Input: "7", + Expected: 7, + Err: "", + }, + { + Name: "string err", + Input: "test", + Expected: 0, + Err: `strconv.ParseUint: parsing "test": invalid syntax`, + }, + { + Name: "uint", + Input: uint(7), + Expected: 7, + Err: "", + }, + { + Name: "uint64", + Input: uint64(7), + Expected: 7, + Err: "", + }, + { + Name: "int ok", + Input: 7, + Expected: 7, + Err: "", + }, + { + Name: "int not ok", + Input: -7, + Expected: 0, + Err: "SVN cannot be negative: -7", + }, + { + Name: "int64 ok", + Input: int64(7), + Expected: 7, + Err: "", + }, + { + Name: "int64 not ok", + Input: int64(-7), + Expected: 0, + Err: "SVN cannot be negative: -7", + }, + { + Name: "nil", + Input: nil, + Expected: 0, + Err: "", + }, + } { + t.Run(tv.Name, func(t *testing.T) { + ret, err := NewSVN(tv.Input, "exact-value") + exact := TaggedSVN(tv.Expected) + expected := SVN{&exact} + + if tv.Err != "" { + assert.EqualError(t, err, tv.Err) + } else { + assert.Equal(t, &expected, ret) + } + + retMin, err := NewSVN(tv.Input, "min-value") + min := TaggedMinSVN(tv.Expected) + expected = SVN{&min} + + if tv.Err != "" { + assert.EqualError(t, err, tv.Err) + } else { + assert.Equal(t, &expected, retMin) + } + }) + } + + in := TaggedSVN(7) + + _, err := NewSVN(in, "exact-value") + assert.NoError(t, err) + + _, err = NewSVN(&in, "exact-value") + assert.NoError(t, err) + + _, err = NewSVN(true, "exact-value") + assert.EqualError(t, err, "unexpected type for SVN exact-value: bool") + + inMin := TaggedMinSVN(7) + + _, err = NewSVN(inMin, "min-value") + assert.NoError(t, err) + + _, err = NewSVN(&inMin, "min-value") + assert.NoError(t, err) + + _, err = NewSVN(true, "min-value") + assert.EqualError(t, err, "unexpected type for SVN min-value: bool") + + _, err = NewSVN(true, "test") + assert.EqualError(t, err, "unknown SVN type: test") + + ret := MustNewSVN(7, "exact-value") + assert.NotNil(t, ret) + + assert.Panics(t, func() { MustNewSVN(true, "exact-value") }) +} + +func TestSVN_JSON(t *testing.T) { + var v SVN + + err := v.UnmarshalJSON([]byte(`{"type":"exact-value","value":2.3}`)) + assert.EqualError(t, err, "invalid SVN exact-value: json: cannot unmarshal number 2.3 into Go value of type comid.TaggedSVN") + + err = v.UnmarshalJSON([]byte(`{"type":"test","value":7}`)) + assert.EqualError(t, err, "unknown SVN type: test") + + err = v.UnmarshalJSON([]byte(`@@@`)) + assert.EqualError(t, err, "SVN decoding failure: invalid character '@' looking for beginning of value") + +} + +type testSVN uint64 + +func newTestSVN(val any) (*SVN, error) { + v := testSVN(7) + return &SVN{&v}, nil +} + +func (o testSVN) Type() string { + return "test-value" +} + +func (o testSVN) String() string { + return "test" +} + +func (o testSVN) Valid() error { + return nil +} + +type testSVNBadType struct { + testSVN +} + +func newTestSVNBadType(val any) (*SVN, error) { + v := testSVNBadType{testSVN(7)} + return &SVN{&v}, nil +} + +func (o testSVNBadType) Type() string { + return "min-value" +} + +func Test_RegisterSVNType(t *testing.T) { + err := RegisterSVNType(32, newTestSVN) + assert.EqualError(t, err, "tag 32 is already registered") + + err = RegisterSVNType(99995, newTestSVNBadType) + assert.EqualError(t, err, `SVN type with name "min-value" already exists`) + + err = RegisterSVNType(99995, newTestSVN) + require.NoError(t, err) + +} diff --git a/comid/test_vars.go b/comid/test_vars.go index 5735189e..37b9be1a 100644 --- a/comid/test_vars.go +++ b/comid/test_vars.go @@ -201,7 +201,7 @@ func MustHexDecode(t *testing.T, s string) []byte { } func b64TestImplID() string { - var implID []byte = TestImplID[:] + var implID = TestImplID[:] return base64.StdEncoding.EncodeToString(implID) } diff --git a/comid/triples.go b/comid/triples.go index 28df6b75..e3e9ac90 100644 --- a/comid/triples.go +++ b/comid/triples.go @@ -3,13 +3,50 @@ package comid -import "fmt" +import ( + "fmt" + + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" +) type Triples struct { ReferenceValues *[]ReferenceValue `cbor:"0,keyasint,omitempty" json:"reference-values,omitempty"` EndorsedValues *[]EndorsedValue `cbor:"1,keyasint,omitempty" json:"endorsed-values,omitempty"` AttestVerifKeys *[]AttestVerifKey `cbor:"2,keyasint,omitempty" json:"attester-verification-keys,omitempty"` DevIdentityKeys *[]DevIdentityKey `cbor:"3,keyasint,omitempty" json:"dev-identity-keys,omitempty"` + + Extensions +} + +// RegisterExtensions registers a struct as a collections of extensions +func (o *Triples) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *Triples) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + +// UnmarshalCBOR deserializes from CBOR +func (o *Triples) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *Triples) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *Triples) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *Triples) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) } // Valid checks that the Triples is valid as per the specification @@ -52,7 +89,7 @@ func (o Triples) Valid() error { } } - return nil + return o.Extensions.validTriples(&o) } func (o *Triples) AddReferenceValue(val ReferenceValue) *Triples { diff --git a/comid/ueid.go b/comid/ueid.go index 765d1b86..cf64ffa2 100644 --- a/comid/ueid.go +++ b/comid/ueid.go @@ -4,18 +4,17 @@ package comid import ( - "encoding/json" + "encoding/base64" "fmt" "github.com/veraison/eat" ) +const UEIDType = "ueid" + // UEID is an Unique Entity Identifier type UEID eat.UEID -// TaggedUEID is an alias to allow automatic tagging of an UEID type -type TaggedUEID UEID - func (o UEID) Empty() bool { return len(o) == 0 } @@ -28,25 +27,65 @@ func (o UEID) Valid() error { return nil } -// UnmarshalJSON deserializes the supplied string into the UEID target -func (o *UEID) UnmarshalJSON(data []byte) error { - var b []byte +func (o UEID) String() string { + return base64.StdEncoding.EncodeToString(o) +} + +// TaggedUEID is an alias to allow automatic tagging of an UEID type +type TaggedUEID UEID + +func NewTaggedUEID(val any) (*TaggedUEID, error) { + var ret TaggedUEID - if err := json.Unmarshal(data, &b); err != nil { - return err + if val == nil { + return &ret, nil } - u := UEID(b) + switch t := val.(type) { + case string: + b, err := base64.StdEncoding.DecodeString(t) + if err != nil { + return nil, fmt.Errorf("bad UEID: %w", err) + } - if err := u.Valid(); err != nil { - return err + ret = TaggedUEID(b) + case []byte: + ret = TaggedUEID(t) + case TaggedUEID: + ret = append(ret, t...) + case *TaggedUEID: + ret = append(ret, *t...) + case UEID: + ret = append(ret, t...) + case *UEID: + ret = append(ret, *t...) + case eat.UEID: + ret = append(ret, t...) + case *eat.UEID: + ret = append(ret, *t...) + default: + return nil, fmt.Errorf("unexpeted type for UEID: %T", t) } - *o = u + if err := ret.Valid(); err != nil { + return nil, err + } - return nil + return &ret, nil +} + +func (o TaggedUEID) Valid() error { + return UEID(o).Valid() +} + +func (o TaggedUEID) String() string { + return UEID(o).String() +} + +func (o TaggedUEID) Type() string { + return "ueid" } -func (o UEID) MarshalJSON() ([]byte, error) { - return json.Marshal([]byte(o)) +func (o TaggedUEID) Bytes() []byte { + return []byte(o) } diff --git a/comid/ueid_test.go b/comid/ueid_test.go new file mode 100644 index 00000000..a119ad44 --- /dev/null +++ b/comid/ueid_test.go @@ -0,0 +1,30 @@ +package comid + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewTaggedUEID(t *testing.T) { + ueid := UEID(TestUEID) + tagged := TaggedUEID(TestUEID) + bytes := MustHexDecode(t, TestUEIDString) + + for _, v := range []any{ + TestUEID, + &TestUEID, + ueid, + &ueid, + tagged, + &tagged, + bytes, + base64.StdEncoding.EncodeToString(bytes), + } { + ret, err := NewTaggedUEID(v) + require.NoError(t, err) + assert.Equal(t, []byte(TestUEID), ret.Bytes()) + } +} diff --git a/comid/uuid.go b/comid/uuid.go index c5c78d61..68fc630b 100644 --- a/comid/uuid.go +++ b/comid/uuid.go @@ -10,12 +10,11 @@ import ( "github.com/google/uuid" ) +const UUIDType = "uuid" + // UUID represents an Universally Unique Identifier (UUID, see RFC4122) type UUID uuid.UUID -// TaggedUUID is an alias to allow automatic tagging of a UUID type -type TaggedUUID UUID - // ParseUUID parses the supplied string into a UUID func ParseUUID(s string) (UUID, error) { v, err := uuid.Parse(s) @@ -64,3 +63,89 @@ func (o *UUID) UnmarshalJSON(data []byte) error { func (o UUID) MarshalJSON() ([]byte, error) { return json.Marshal(o.String()) } + +// TaggedUUID is an alias to allow automatic tagging of a UUID type +type TaggedUUID UUID + +func NewTaggedUUID(val any) (*TaggedUUID, error) { + var ret TaggedUUID + + switch t := val.(type) { + case string: + u, err := ParseUUID(t) + if err != nil { + return nil, fmt.Errorf("bad UUID: %w", err) + } + ret = TaggedUUID(u) + case []byte: + if len(t) != 16 { + return nil, fmt.Errorf( + "unexpected size for UUID: expected 16 bytes, found %d", + len(t), + ) + } + + copy(ret[:], t) + case TaggedUUID: + copy(ret[:], t[:]) + case *TaggedUUID: + copy(ret[:], (*t)[:]) + case UUID: + copy(ret[:], t[:]) + case *UUID: + copy(ret[:], (*t)[:]) + case uuid.UUID: + copy(ret[:], t[:]) + case *uuid.UUID: + copy(ret[:], (*t)[:]) + default: + return nil, fmt.Errorf("unexpected type for UUID: %T", t) + } + + if err := ret.Valid(); err != nil { + return nil, err + } + + return &ret, nil +} + +// String returns a string representation of the binary UUID +func (o TaggedUUID) String() string { + return UUID(o).String() +} + +func (o TaggedUUID) Valid() error { + return UUID(o).Valid() +} + +// Type returns a string containing type name. This is part of the +// ITypeChoiceValue implementation. +func (o TaggedUUID) Type() string { + return UUIDType +} + +// Bytes returns a []byte containing the raw UUID bytes +func (o TaggedUUID) Bytes() []byte { + return o[:] +} + +func (o TaggedUUID) MarshalJSON() ([]byte, error) { + temp := o.String() + return json.Marshal(temp) +} + +func (o *TaggedUUID) UnmarshalJSON(data []byte) error { + var temp string + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + u, err := ParseUUID(temp) + if err != nil { + return fmt.Errorf("bad UUID: %w", err) + } + + *o = TaggedUUID(u) + + return nil +} diff --git a/comid/uuid_test.go b/comid/uuid_test.go new file mode 100644 index 00000000..5308854a --- /dev/null +++ b/comid/uuid_test.go @@ -0,0 +1,24 @@ +package comid + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUUID_JSON(t *testing.T) { + val := TaggedUUID(TestUUID) + expected := fmt.Sprintf(`"%s"`, val.String()) + + out, err := val.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, expected, string(out)) + + var outUUID TaggedUUID + + err = outUUID.UnmarshalJSON(out) + require.NoError(t, err) + assert.Equal(t, val, outUUID) +} diff --git a/corim/cbor.go b/corim/cbor.go index ec15e95f..17464d1d 100644 --- a/corim/cbor.go +++ b/corim/cbor.go @@ -4,6 +4,7 @@ package corim import ( + "fmt" "reflect" cbor "github.com/fxamacker/cbor/v2" @@ -18,13 +19,13 @@ var ( var ( CoswidTag = []byte{0xd9, 0x01, 0xf9} // 505() ComidTag = []byte{0xd9, 0x01, 0xfa} // 506() -) -func corimTags() cbor.TagSet { - corimTagsMap := map[uint64]interface{}{ + corimTagsMap = map[uint64]interface{}{ 32: comid.TaggedURI(""), } +) +func corimTags() cbor.TagSet { opts := cbor.TagOptions{ EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired, @@ -57,6 +58,28 @@ func initCBORDecMode() (dm cbor.DecMode, err error) { return decOpt.DecModeWithTags(corimTags()) } +func registerCORIMTag(tag uint64, t interface{}) error { + if _, exists := corimTagsMap[tag]; exists { + return fmt.Errorf("tag %d is already registered", tag) + } + + corimTagsMap[tag] = t + + var err error + + em, err = initCBOREncMode() + if err != nil { + return err + } + + dm, err = initCBORDecMode() + if err != nil { + return err + } + + return nil +} + func init() { if emError != nil { panic(emError) diff --git a/corim/entity.go b/corim/entity.go index 49aa1468..8a3169ac 100644 --- a/corim/entity.go +++ b/corim/entity.go @@ -4,29 +4,47 @@ package corim import ( + "encoding/json" + "errors" "fmt" + "unicode/utf8" "github.com/veraison/corim/comid" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" ) // Entity stores an entity-map capable of CBOR and JSON serializations. type Entity struct { - EntityName string `cbor:"0,keyasint" json:"name"` + EntityName *EntityName `cbor:"0,keyasint" json:"name"` RegID *comid.TaggedURI `cbor:"1,keyasint,omitempty" json:"regid,omitempty"` Roles Roles `cbor:"2,keyasint" json:"roles"` + + Extensions } func NewEntity() *Entity { return &Entity{} } +// RegisterExtensions registers a struct as a collections of extensions +func (o *Entity) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *Entity) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + // SetEntityName is used to set the EntityName field of Entity using supplied name -func (o *Entity) SetEntityName(name string) *Entity { +func (o *Entity) SetEntityName(name any) *Entity { + if o != nil { if name == "" { return nil } - o.EntityName = name + o.EntityName = MustNewStringEntityName(name) } return o } @@ -60,10 +78,14 @@ func (o *Entity) SetRoles(roles ...Role) *Entity { // Valid checks for validity of the fields within each Entity func (o Entity) Valid() error { - if o.EntityName == "" { + if o.EntityName == nil { return fmt.Errorf("invalid entity: empty entity-name") } + if err := o.EntityName.Valid(); err != nil { + return fmt.Errorf("invalid entity: %w", err) + } + if o.RegID != nil && o.RegID.Empty() { return fmt.Errorf("invalid entity: empty reg-id") } @@ -72,7 +94,27 @@ func (o Entity) Valid() error { return fmt.Errorf("invalid entity: %w", err) } - return nil + return o.Extensions.validEntity(&o) +} + +// UnmarshalCBOR deserializes from CBOR +func (o *Entity) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *Entity) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *Entity) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *Entity) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) } // Entities is an array of entity-map's @@ -100,3 +142,233 @@ func (o Entities) Valid() error { } return nil } + +// EntityName encapsulates the name of the associated Entity. The CoRIM +// specification only allows for text (string) name, but this may be extended +// by other specifications. +type EntityName struct { + Value IEntityNameValue +} + +// NewEntityName creates a new EntityName of the specified type using the +// provided value. +func NewEntityName(val any, typ string) (*EntityName, error) { + factory, ok := entityNameValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unexpected entity name type: %s", typ) + } + + return factory(val) +} + +// MustNewEntityName is like NewEntityName, except it doesn't return an error, +// assuming that the provided value is valid. It panics if that isn't the case. +func MustNewEntityName(val any, typ string) *EntityName { + ret, err := NewEntityName(val, typ) + if err != nil { + panic(err) + } + + return ret +} + +// String returns the string representation of the EntityName +func (o EntityName) String() string { + return o.Value.String() +} + +// Valid returns nil if the underlying EntityName value is valid, or an error +// describing the problem otherwise. +func (o EntityName) Valid() error { + if o.Value == nil { + return errors.New("empty entity name") + } + + return o.Value.Valid() +} + +// MarshalCBOR serializes the EntityName into CBOR-encoded bytes. +func (o EntityName) MarshalCBOR() ([]byte, error) { + if err := o.Valid(); err != nil { + return nil, err + } + + return em.Marshal(o.Value) +} + +// UnmarshalCBOR deserializes the EntityName from CBOR-encoded bytes. +func (o *EntityName) UnmarshalCBOR(data []byte) error { + if len(data) == 0 { + return errors.New("empty") + } + + majorType := (data[0] & 0xe0) >> 5 + if majorType == 3 { // text string + var text string + + if err := dm.Unmarshal(data, &text); err != nil { + return err + } + + name := StringEntityName(text) + o.Value = &name + + return nil + } + + return dm.Unmarshal(data, &o.Value) +} + +// MarshalJSON serializes the EntityName into a JSON object. +func (o EntityName) MarshalJSON() ([]byte, error) { + if err := o.Valid(); err != nil { + return nil, err + } + + if o.Value.Type() == extensions.StringType { + return json.Marshal(o.Value.String()) + } + + valueBytes, err := json.Marshal(o.Value) + if err != nil { + return nil, err + } + + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: valueBytes, + } + + return json.Marshal(value) +} + +// UnmarshalJSON deserializes EntityName from the provided JSON object. +func (o *EntityName) UnmarshalJSON(data []byte) error { + var text string + if err := json.Unmarshal(data, &text); err == nil { + *o = *MustNewStringEntityName(text) + return nil + } + + var tnv encoding.TypeAndValue + + if err := json.Unmarshal(data, &tnv); err != nil { + return fmt.Errorf("entity name decoding failure: %w", err) + } + + decoded, err := NewEntityName(nil, tnv.Type) + if err != nil { + return err + } + + if err := json.Unmarshal(tnv.Value, &decoded.Value); err != nil { + return fmt.Errorf( + "cannot unmarshal entity name: %w", + err, + ) + } + + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) + } + + o.Value = decoded.Value + + return nil +} + +// IEntityNameValue is the interface implemented by all EntityName value types. +type IEntityNameValue interface { + extensions.ITypeChoiceValue +} + +// StringEntityName is a text string EntityName with no other contraints. This +// is the only EntityName value type defined by the CoRIM specification itself. +type StringEntityName string + +func NewStringEntityName(val any) (*EntityName, error) { + var ret StringEntityName + + if val == nil { + ret = StringEntityName("") + return &EntityName{&ret}, nil + } + + switch t := val.(type) { + case string: + ret = StringEntityName(t) + case []byte: + if !utf8.Valid(t) { + return nil, errors.New("bytes do not form a valid UTF-8 string") + } + + ret = StringEntityName(t) + default: + return nil, fmt.Errorf("unexpected type for string entity name: %T", t) + } + + return &EntityName{&ret}, nil +} + +func MustNewStringEntityName(val any) *EntityName { + ret, err := NewStringEntityName(val) + if err != nil { + panic(err) + } + + return ret +} + +func (o StringEntityName) String() string { + return string(o) +} + +func (o StringEntityName) Type() string { + return extensions.StringType +} + +func (o StringEntityName) Valid() error { + if o == "" { + return errors.New("empty entity-name") + } + + return nil +} + +// IEntityNameFactory defines the signature for the factory functions that may +// be registred using RegisterEntityNameType to provide a new implementation of +// the corresponding type choice. The factory function should create a new +// *EntityName with the underlying value created based on the provided input. +// The range of valid inputs is up to the specific type choice implementation, +// however it _must_ accept nil as one of the inputs, and return the Zero value +// for implemented type. +// See also https://go.dev/ref/spec#The_zero_value +type IEntityNameFactory func(any) (*EntityName, error) + +var entityNameValueRegister = map[string]IEntityNameFactory{ + extensions.StringType: NewStringEntityName, +} + +// RegisterEntityNameType registers a new IEntityNameValue implementation +// (created by the provided IEntityNameFactory) under the specified type name +// and CBOR tag. +func RegisterEntityNameType(tag uint64, factory IEntityNameFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Value.Type() + if _, exists := entityNameValueRegister[typ]; exists { + return fmt.Errorf("entity name type with name %q already exists", typ) + } + + if err := registerCORIMTag(tag, nilVal.Value); err != nil { + return err + } + + entityNameValueRegister[typ] = factory + + return nil +} diff --git a/corim/entity_test.go b/corim/entity_test.go index 457b3770..aec1be92 100644 --- a/corim/entity_test.go +++ b/corim/entity_test.go @@ -4,6 +4,8 @@ package corim import ( + "errors" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -21,7 +23,7 @@ func TestEntity_Valid_uninitialized(t *testing.T) { func TestEntity_Valid_empty_name(t *testing.T) { tv := Entity{ - EntityName: "", + EntityName: MustNewStringEntityName(""), } err := tv.Valid() @@ -33,7 +35,7 @@ func TestEntity_Valid_non_nil_empty_URI(t *testing.T) { emptyRegID := comid.TaggedURI("") tv := Entity{ - EntityName: "ACME Ltd.", + EntityName: MustNewStringEntityName("ACME Ltd."), RegID: &emptyRegID, } @@ -46,7 +48,7 @@ func TestEntity_Valid_missing_roles(t *testing.T) { regID := comid.TaggedURI("http://acme.example") tv := Entity{ - EntityName: "ACME Ltd.", + EntityName: MustNewStringEntityName("ACME Ltd."), RegID: ®ID, } @@ -59,7 +61,7 @@ func TestEntity_Valid_unknown_role(t *testing.T) { regID := comid.TaggedURI("http://acme.example") tv := Entity{ - EntityName: "ACME Ltd.", + EntityName: MustNewStringEntityName("ACME Ltd."), RegID: ®ID, Roles: Roles{Role(666)}, } @@ -92,3 +94,185 @@ func TestEntities_Valid_empty(t *testing.T) { err := es.Valid() assert.EqualError(t, err, "entity at index 0: invalid entity: empty entity-name") } + +type testEntityName uint64 + +func newTestEntityName(val any) (*EntityName, error) { + if val == nil { + v := testEntityName(0) + return &EntityName{&v}, nil + } + + u, ok := val.(uint64) + if !ok { + return nil, errors.New("must be uint64") + } + + v := testEntityName(u) + return &EntityName{&v}, nil +} + +func (o testEntityName) Type() string { + return "test" +} + +func (o testEntityName) String() string { + return fmt.Sprint(uint64(o)) +} + +func (o testEntityName) Valid() error { + return nil +} + +type testEntityNameBadType struct { + testEntityName +} + +func newTestEntityNameBadType(val any) (*EntityName, error) { + v := testEntityNameBadType{testEntityName(7)} + return &EntityName{&v}, nil +} + +func (o testEntityNameBadType) Type() string { + return "string" +} + +func Test_RegisterEntityNameType(t *testing.T) { + err := RegisterEntityNameType(32, newTestEntityName) + assert.EqualError(t, err, "tag 32 is already registered") + + err = RegisterEntityNameType(99994, newTestEntityNameBadType) + assert.EqualError(t, err, `entity name type with name "string" already exists`) + + registerTestEntityNameType(t) +} + +// Since there only one, untagged, entity name type in the core spec, we use +// the test type define above in order to test the marshaling code works +// properly. Since global environment is not reset when running multiple tests, +// we cannot simply call RegisterEntityNameType() inside each test that relies +// on the test type, as that will cause the "tag already registered" error. On +// the other hand, we do not want to create inter-test dependencies by relying +// that the test registering the type is run before the others that rely on it. +// To get around this, use this global flag to only register the test type if a +// previous test hasn't already done so. +var testEntityNameTypeRegistered = false + +func registerTestEntityNameType(t *testing.T) { + if !testEntityNameTypeRegistered { + err := RegisterEntityNameType(99994, newTestEntityName) + require.NoError(t, err) + + testEntityNameTypeRegistered = true + } +} + +func TestEntityName_CBOR(t *testing.T) { + registerTestEntityNameType(t) + + for _, tv := range []struct { + Value any + Type string + ExpectedBytes []byte + ExpectedString string + }{ + { + Value: "test", + Type: "string", + ExpectedBytes: []byte{ + 0x64, // tstr(4) + 0x74, 0x65, 0x73, 0x74, // "test" + }, + ExpectedString: "test", + }, + { + Value: uint64(7), + Type: "test", + ExpectedBytes: []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9a, // tag 99994 + 0x07, // unsigned int(7) + }, + ExpectedString: "7", + }, + } { + t.Run(tv.Type, func(t *testing.T) { + en, err := NewEntityName(tv.Value, tv.Type) + require.NoError(t, err) + + data, err := en.MarshalCBOR() + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedBytes, data) + + var out EntityName + + err = out.UnmarshalCBOR(data) + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedString, out.String()) + }) + } +} + +func TestEntityName_JSON(t *testing.T) { + registerTestEntityNameType(t) + + for _, tv := range []struct { + Value any + Type string + ExpectedBytes []byte + ExpectedString string + }{ + { + Value: "test", + Type: "string", + ExpectedBytes: []byte(`"test"`), + ExpectedString: "test", + }, + { + Value: uint64(7), + Type: "test", + ExpectedBytes: []byte(`{"type":"test","value":7}`), + ExpectedString: "7", + }, + } { + t.Run(tv.Type, func(t *testing.T) { + en, err := NewEntityName(tv.Value, tv.Type) + require.NoError(t, err) + + data, err := en.MarshalJSON() + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedBytes, data) + + var out EntityName + + err = out.UnmarshalJSON(data) + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedString, out.String()) + }) + } +} + +func Test_NewStringEntityName(t *testing.T) { + out, err := NewStringEntityName(nil) + require.NoError(t, err) + assert.EqualError(t, out.Valid(), "empty entity-name") + + out, err = NewStringEntityName([]byte("test")) + require.NoError(t, err) + assert.Equal(t, "test", out.String()) + + _, err = NewStringEntityName(7) + assert.EqualError(t, err, "unexpected type for string entity name: int") +} + +func Test_MustNewEntityName(t *testing.T) { + out := MustNewEntityName("test", "string") + assert.Equal(t, "test", out.String()) + + assert.Panics(t, func() { + MustNewEntityName(7, "int") + }) +} diff --git a/corim/extensions.go b/corim/extensions.go new file mode 100644 index 00000000..0bd3502f --- /dev/null +++ b/corim/extensions.go @@ -0,0 +1,68 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package corim + +import ( + "github.com/veraison/corim/extensions" +) + +type IEntityConstrainer interface { + ConstrainEntity(*Entity) error +} + +type ICorimConstrainer interface { + ConstrainCorim(*UnsignedCorim) error +} + +type ISignerConstrainer interface { + ConstrainSigner(*Signer) error +} + +type Extensions struct { + extensions.Extensions +} + +func (o *Extensions) validEntity(entity *Entity) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IEntityConstrainer) + if ok { + if err := ev.ConstrainEntity(entity); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) validCorim(c *UnsignedCorim) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(ICorimConstrainer) + if ok { + if err := ev.ConstrainCorim(c); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) validSigner(signer *Signer) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(ISignerConstrainer) + if ok { + if err := ev.ConstrainSigner(signer); err != nil { + return err + } + } + + return nil +} diff --git a/corim/extensions_test.go b/corim/extensions_test.go new file mode 100644 index 00000000..043d9a52 --- /dev/null +++ b/corim/extensions_test.go @@ -0,0 +1,86 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package corim + +import ( + "errors" + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type TestExtensions struct { + Address string `cbor:"-1,keyasint,omitempty" json:"address,omitempty"` + Size int `cbor:"-2,keyasint,omitempty" json:"size,omitempty"` +} + +func (o TestExtensions) ConstrainEntity(ent *Entity) error { + if ent.EntityName.String() != "Futurama" { + return errors.New(`EntityName must be "Futurama"`) // nolint:golint + } + + return nil +} + +func (o TestExtensions) ConstrainCorim(c *UnsignedCorim) error { + return errors.New("invalid") +} + +func (o TestExtensions) ConstrainSigner(s *Signer) error { + return errors.New("invalid") +} + +func TestEntityExtensions_Valid(t *testing.T) { + ent := NewEntity() + ent.SetEntityName("The Simpsons") + ent.SetRoles(RoleManifestCreator) + + err := ent.Valid() + assert.NoError(t, err) + + ent.RegisterExtensions(&TestExtensions{}) + err = ent.Valid() + assert.EqualError(t, err, `EntityName must be "Futurama"`) + + ent.SetEntityName("Futurama") + err = ent.Valid() + assert.NoError(t, err) + + assert.EqualError(t, ent.Extensions.validCorim(nil), "invalid") + assert.EqualError(t, ent.Extensions.validSigner(nil), "invalid") +} + +func TestEntityExtensions_CBOR(t *testing.T) { + data := []byte{ + 0xa4, // map(4) + + 0x00, // key 0 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + + 0x02, // key 2 + 0x81, // array(1) + 0x01, // 1 + + 0x20, // key -1 + 0x63, // val tstr(3) + 0x66, 0x6f, 0x6f, // "foo" + + 0x21, // key -2 + 0x06, // val 6 + } + + ent := NewEntity() + ent.RegisterExtensions(&TestExtensions{}) + + err := cbor.Unmarshal(data, &ent) + assert.NoError(t, err) + + assert.Equal(t, ent.EntityName.String(), "acme") + + address, err := ent.Get("address") + require.NoError(t, err) + assert.Equal(t, address, "foo") +} diff --git a/corim/meta.go b/corim/meta.go index 45f1ea49..9ce834e5 100644 --- a/corim/meta.go +++ b/corim/meta.go @@ -10,17 +10,31 @@ import ( "time" "github.com/veraison/corim/comid" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" ) type Signer struct { Name string `cbor:"0,keyasint" json:"name"` URI *comid.TaggedURI `cbor:"1,keyasint,omitempty" json:"uri,omitempty"` + + Extensions } func NewSigner() *Signer { return &Signer{} } +// RegisterExtensions registers a struct as a collections of extensions +func (o *Signer) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns previously registered extension +func (o *Signer) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + // SetName sets the target Signer's name to the supplied value func (o *Signer) SetName(name string) *Signer { if o != nil { @@ -61,7 +75,27 @@ func (o Signer) Valid() error { } } - return nil + return o.Extensions.validSigner(&o) +} + +// UnmarshalCBOR deserializes from CBOR +func (o *Signer) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *Signer) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *Signer) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *Signer) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) } // Meta stores a corim-meta-map with JSON and CBOR serializations. It carries diff --git a/corim/meta_test.go b/corim/meta_test.go index a4dd34cd..811495cf 100644 --- a/corim/meta_test.go +++ b/corim/meta_test.go @@ -183,3 +183,15 @@ func TestMeta_FromCBOR_full(t *testing.T) { assert.Equal(t, notBefore.Unix(), actual.Validity.NotBefore.Unix()) assert.Equal(t, notAfter.Unix(), actual.Validity.NotAfter.Unix()) } + +func Test_Signer_Valid(t *testing.T) { + var signer Signer + + assert.EqualError(t, signer.Valid(), "empty name") + + signer.Name = "test-signer" + uri := comid.TaggedURI("@@@") + signer.URI = &uri + + assert.EqualError(t, signer.Valid(), `invalid URI: "@@@" is not an absolute URI`) +} diff --git a/corim/role.go b/corim/role.go index 9bfc1276..7ef276f7 100644 --- a/corim/role.go +++ b/corim/role.go @@ -24,6 +24,35 @@ var ( } ) +// String returns the string representation of the Role. +func (o Role) String() string { + text, ok := roleToString[o] + if ok { + return text + } + + return fmt.Sprintf("Role(%d)", o) +} + +// RegisterRole creates a new Role association between the provided value and +// name. An error is returned if either clashes with any of the existing roles. +func RegisterRole(val int64, name string) error { + role := Role(val) + + if _, ok := roleToString[role]; ok { + return fmt.Errorf("role with value %d already exists", val) + } + + if _, ok := stringToRole[name]; ok { + return fmt.Errorf("role with name %q already exists", name) + } + + roleToString[role] = name + stringToRole[name] = role + + return nil +} + type Roles []Role func NewRoles() *Roles { @@ -44,7 +73,8 @@ func (o *Roles) Add(roles ...Role) *Roles { } func isRole(r Role) bool { - return r == RoleManifestCreator + _, ok := roleToString[r] + return ok } // Valid iterates over the range of individual roles to check for validity diff --git a/corim/role_test.go b/corim/role_test.go index 66df92b0..54b4433c 100644 --- a/corim/role_test.go +++ b/corim/role_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRoles_ToJSON_ok(t *testing.T) { @@ -77,3 +78,25 @@ func TestRoles_FromJSON_fail(t *testing.T) { assert.EqualError(t, err, tv.expectedErr) } } + +func Test_Role_String(t *testing.T) { + assert.Equal(t, "manifestCreator", RoleManifestCreator.String()) + assert.Equal(t, "Role(9999)", Role(9999).String()) +} + +func Test_RegisterRole(t *testing.T) { + err := RegisterRole(1, "owner") + assert.EqualError(t, err, "role with value 1 already exists") + + err = RegisterRole(3, "manifestCreator") + assert.EqualError(t, err, `role with name "manifestCreator" already exists`) + + err = RegisterRole(3, "owner") + assert.NoError(t, err) + + roles := NewRoles().Add(Role(3)) + + out, err := roles.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, `["owner"]`, string(out)) +} diff --git a/corim/signedcorim_test.go b/corim/signedcorim_test.go index 9e030b3c..eae15d45 100644 --- a/corim/signedcorim_test.go +++ b/corim/signedcorim_test.go @@ -246,7 +246,7 @@ func TestSignedCorim_FromCOSE_fail_invalid_corim(t *testing.T) { var actual SignedCorim err := actual.FromCOSE(tv) - assert.EqualError(t, err, "failed validation of unsigned CoRIM: tags validation failed: no tags") + assert.EqualError(t, err, `failed CBOR decoding of unsigned CoRIM: missing mandatory field "Tags" (1)`) } func TestSignedCorim_FromCOSE_fail_no_content_type(t *testing.T) { diff --git a/corim/unsignedcorim.go b/corim/unsignedcorim.go index 68ca6786..1247527e 100644 --- a/corim/unsignedcorim.go +++ b/corim/unsignedcorim.go @@ -4,12 +4,13 @@ package corim import ( - "encoding/json" "errors" "fmt" "time" "github.com/veraison/corim/cots" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" "github.com/veraison/corim/comid" "github.com/veraison/eat" @@ -19,12 +20,22 @@ import ( // UnsignedCorim is the top-level representation of the unsigned-corim-map with // CBOR and JSON serialization. type UnsignedCorim struct { - ID swid.TagID `cbor:"0,keyasint" json:"corim-id"` - Tags []Tag `cbor:"1,keyasint" json:"tags"` + ID swid.TagID `cbor:"0,keyasint" json:"corim-id"` + // note: even though tags are mandatory for CoRIM, we allow omitting + // them in our JSON templates for cocli (the min template just has + // corim-id). Since we're never writing JSON (so far), this normally + // wouldn't matter, however the custom serialization code we use to + // handle embedded structs relies on the omitempty entry to determine + // if a fieled is optional, so we use it during unmarshaling as well as + // marshaling. Hence omitempty is present for the json tag, but not + // cbor. + Tags []Tag `cbor:"1,keyasint" json:"tags,omitempty"` DependentRims *[]Locator `cbor:"2,keyasint,omitempty" json:"dependent-rims,omitempty"` Profiles *[]eat.Profile `cbor:"3,keyasint,omitempty" json:"profiles,omitempty"` RimValidity *Validity `cbor:"4,keyasint,omitempty" json:"validity,omitempty"` Entities *Entities `cbor:"5,keyasint,omitempty" json:"entities,omitempty"` + + Extensions } // NewUnsignedCorim instantiates an empty UnsignedCorim @@ -32,6 +43,16 @@ func NewUnsignedCorim() *UnsignedCorim { return &UnsignedCorim{} } +// RegisterExtensions registers a struct as a collections of extensions +func (o *UnsignedCorim) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *UnsignedCorim) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + // SetID sets the corim-id in the unsigned-corim-map to the supplied value. The // corim-id can be passed as UUID in string or binary form (i.e., byte array), // or as a (non-empty) string @@ -239,22 +260,22 @@ func (o UnsignedCorim) Valid() error { } } - return nil + return o.Extensions.validCorim(&o) } // ToCBOR serializes the target unsigned CoRIM to CBOR -func (o UnsignedCorim) ToCBOR() ([]byte, error) { - return em.Marshal(&o) +func (o *UnsignedCorim) ToCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) } // FromCBOR deserializes a CBOR-encoded unsigned CoRIM into the target UnsignedCorim func (o *UnsignedCorim) FromCBOR(data []byte) error { - return dm.Unmarshal(data, o) + return encoding.PopulateStructFromCBOR(dm, data, o) } // FromJSON deserializes a JSON-encoded unsigned CoRIM into the target UnsignedCorim func (o *UnsignedCorim) FromJSON(data []byte) error { - return json.Unmarshal(data, o) + return encoding.PopulateStructFromJSON(data, o) } // Tag is either a CBOR-encoded CoMID, CoSWID or CoTS diff --git a/corim/unsignedcorim_test.go b/corim/unsignedcorim_test.go index d9b4d105..918dde0f 100644 --- a/corim/unsignedcorim_test.go +++ b/corim/unsignedcorim_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/veraison/corim/comid" @@ -182,7 +181,7 @@ func TestUnsignedCorim_Valid_ok(t *testing.T) { AddAttestVerifKey( comid.AttestVerifKey{ Environment: comid.Environment{ - Instance: comid.NewInstanceUUID(uuid.UUID(comid.TestUUID)), + Instance: comid.MustNewUUIDInstance(comid.TestUUID), }, VerifKeys: *comid.NewCryptoKeys(). Add( @@ -264,7 +263,7 @@ func TestUnsignedCorim_AddEntity_full(t *testing.T) { expected := UnsignedCorim{ Entities: &Entities{ Entity{ - EntityName: name, + EntityName: MustNewStringEntityName(name), Roles: Roles{role}, RegID: &taggedRegID, }, @@ -299,3 +298,11 @@ func TestUnsignedCorim_AddEntity_non_nil_empty_URI(t *testing.T) { assert.Nil(t, tv) } + +func TestUnsignedCorim_FromJSON(t *testing.T) { + data := []byte(`{"corim-id": "5c57e8f4-46cd-421b-91c9-08cf93e13cfc"}`) + + err := NewUnsignedCorim().FromJSON(data) + + assert.NoError(t, err) +} diff --git a/encoding/cbor.go b/encoding/cbor.go new file mode 100644 index 00000000..5d7a3533 --- /dev/null +++ b/encoding/cbor.go @@ -0,0 +1,409 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 + +package encoding + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + "reflect" + "strconv" + "strings" + + cbor "github.com/fxamacker/cbor/v2" +) + +func SerializeStructToCBOR(em cbor.EncMode, source any) ([]byte, error) { + rawMap := newStructFieldsCBOR() + + structType := reflect.TypeOf(source) + structVal := reflect.ValueOf(source) + + if err := doSerializeStructToCBOR(em, rawMap, structType, structVal); err != nil { + return nil, err + } + + return rawMap.ToCBOR(em) +} + +func doSerializeStructToCBOR( + em cbor.EncMode, + rawMap *structFieldsCBOR, + structType reflect.Type, + structVal reflect.Value, +) error { + if structType.Kind() == reflect.Pointer { + structType = structType.Elem() + structVal = structVal.Elem() + } + + var embeds []embedded + + for i := 0; i < structVal.NumField(); i++ { + typeField := structType.Field(i) + valField := structVal.Field(i) + + if collectEmbedded(&typeField, valField, &embeds) { + continue + } + + tag, ok := typeField.Tag.Lookup("cbor") + if !ok { + continue + } + + parts := strings.Split(tag, ",") + keyString := parts[0] + + isOmitEmpty := false + if len(parts) > 1 { + for _, option := range parts[1:] { + if option == omitempty { + isOmitEmpty = true + break + } + } + } + + // do not serialize zero values if the corresponding field is + // omitempty + if isOmitEmpty && valField.IsZero() { + continue + } + + keyInt, err := strconv.Atoi(keyString) + if err != nil { + return fmt.Errorf("non-integer cbor key: %s", keyString) + } + + data, err := em.Marshal(valField.Interface()) + if err != nil { + return fmt.Errorf("error marshaling field %q: %w", + typeField.Name, + err, + ) + } + + if err := rawMap.Add(keyInt, cbor.RawMessage(data)); err != nil { + return err + } + } + + for _, emb := range embeds { + if err := doSerializeStructToCBOR(em, rawMap, emb.Type, emb.Value); err != nil { + return err + } + } + + return nil +} + +func PopulateStructFromCBOR(dm cbor.DecMode, data []byte, dest any) error { + rawMap := newStructFieldsCBOR() + + if err := rawMap.FromCBOR(dm, data); err != nil { + return err + } + + structType := reflect.TypeOf(dest) + structVal := reflect.ValueOf(dest) + + return doPopulateStructFromCBOR(dm, rawMap, structType, structVal) +} + +func doPopulateStructFromCBOR( + dm cbor.DecMode, + rawMap *structFieldsCBOR, + structType reflect.Type, + structVal reflect.Value, +) error { + if structType.Kind() == reflect.Pointer { + structType = structType.Elem() + structVal = structVal.Elem() + } + + var embeds []embedded + + for i := 0; i < structVal.NumField(); i++ { + typeField := structType.Field(i) + valField := structVal.Field(i) + + if collectEmbedded(&typeField, valField, &embeds) { + continue + } + + tag, ok := typeField.Tag.Lookup("cbor") + if !ok { + continue + } + + parts := strings.Split(tag, ",") + keyString := parts[0] + + isOmitEmpty := false + if len(parts) > 1 { + for _, option := range parts[1:] { + if option == omitempty { + isOmitEmpty = true + break + } + } + } + + keyInt, err := strconv.Atoi(keyString) + if err != nil { + return fmt.Errorf("non-integer cbor key %s", keyString) + } + + rawVal, ok := rawMap.Get(keyInt) + if !ok { + if isOmitEmpty { + continue + } + + return fmt.Errorf("missing mandatory field %q (%d)", + typeField.Name, keyInt) + } + + fieldPtr := valField.Addr().Interface() + if err := dm.Unmarshal(rawVal, fieldPtr); err != nil { + return fmt.Errorf("error unmarshalling field %q: %w", + typeField.Name, + err, + ) + } + + rawMap.Delete(keyInt) + } + + for _, emb := range embeds { + if err := doPopulateStructFromCBOR(dm, rawMap, emb.Type, emb.Value); err != nil { + return err + } + } + + return nil +} + +// structFieldsCBOR is a specialized implementation of "OrderedMap", where the +// order of the keys is kept track of, and used when serializing the map to +// CBOR. While CBOR maps do not mandate any particular ordering, and so this +// isn't strictly necessary, it is useful to have a _stable_ serialization +// order for map keys to be compatible with regular Go struct serialization +// behavior. This is also useful for tests/examples that compare encoded +// []byte's. +type structFieldsCBOR struct { + Fields map[int]cbor.RawMessage + Keys []int +} + +func newStructFieldsCBOR() *structFieldsCBOR { + return &structFieldsCBOR{ + Fields: make(map[int]cbor.RawMessage), + } +} + +func (o structFieldsCBOR) Has(key int) bool { + _, ok := o.Fields[key] + return ok +} + +func (o *structFieldsCBOR) Add(key int, val cbor.RawMessage) error { + if o.Has(key) { + return fmt.Errorf("duplicate cbor key: %d", key) + } + + o.Fields[key] = val + o.Keys = append(o.Keys, key) + + return nil +} + +func (o *structFieldsCBOR) Get(key int) (cbor.RawMessage, bool) { + val, ok := o.Fields[key] + return val, ok +} + +func (o *structFieldsCBOR) Delete(key int) { + delete(o.Fields, key) + + for i, existing := range o.Keys { + if existing == key { + o.Keys = append(o.Keys[:i], o.Keys[i+1:]...) + } + } +} + +func (o *structFieldsCBOR) ToCBOR(em cbor.EncMode) ([]byte, error) { + var out []byte + + header := byte(0xa0) // 0b101_00000 -- Major Type 5 == map + mapLen := len(o.Keys) + if mapLen == 0 { + return []byte{header}, nil + } else if mapLen < 24 { + header = header | byte(mapLen) + out = append(out, header) + } else if mapLen <= math.MaxUint8 { + header = header | byte(24) + out = append(out, header, uint8(mapLen)) + } else if mapLen <= math.MaxUint16 { + header = header | byte(25) + out = append(out, header) + out = binary.BigEndian.AppendUint16(out, uint16(mapLen)) + } else { + header = header | byte(26) + out = append(out, header) + out = binary.BigEndian.AppendUint32(out, uint32(mapLen)) + } + // Since len() returns an int, the value cannot exceed MaxUint32, so + // the 8-byte length variant cannot occur. + + for _, key := range o.Keys { + marshalledKey, err := em.Marshal(key) + if err != nil { + return nil, fmt.Errorf("problem marshaling key %d: %w", key, err) + } + + out = append(out, marshalledKey...) + out = append(out, o.Fields[key]...) + } + + return out, nil +} + +func (o *structFieldsCBOR) FromCBOR(dm cbor.DecMode, data []byte) error { + if len(data) == 0 { + return errors.New("empty input") + } + + header := data[0] + rest := data[1:] + additionalInfo := 0x1f & header + + var err error + + majorType := (0xe0 & header) >> 5 + if majorType == 6 { // tag + _, rest, err = processAdditionalInfo(additionalInfo, rest) + if err != nil { + return err + } + + header = rest[0] + rest = rest[1:] + majorType = (0xe0 & header) >> 5 + additionalInfo = 0x1f & header + } + + if majorType != 5 { + return fmt.Errorf("expected map (CBOR Major Type 5), found Major Type %d", majorType) + } + + var mapLen int + + mapLen, rest, err = processAdditionalInfo(additionalInfo, rest) + if err != nil { + return err + } + + if mapLen != 0 { + o.Fields = make(map[int]cbor.RawMessage, mapLen) + + for i := 0; i < mapLen; i++ { + rest, err = o.unmarshalKeyValue(dm, rest) + if err != nil { + return fmt.Errorf("map item %d: %w", i, err) + } + } + } else { // mapLen == 0 --> indefinite encoding + o.Fields = make(map[int]cbor.RawMessage) + + i := 0 + done := false + for len(rest) > 0 { + if rest[0] == 0xFF { + done = true + break + } + + rest, err = o.unmarshalKeyValue(dm, rest) + if err != nil { + return fmt.Errorf("map item %d: %w", i, err) + } + + i++ + } + + if !done { + return errors.New("unexpected EOF") + } + } + + return nil +} + +func (o *structFieldsCBOR) unmarshalKeyValue(dm cbor.DecMode, rest []byte) ([]byte, error) { + var key int + var val cbor.RawMessage + var err error + + rest, err = dm.UnmarshalFirst(rest, &key) + if err != nil { + return rest, fmt.Errorf("could not unmarshal key: %w", err) + } + + rest, err = dm.UnmarshalFirst(rest, &val) + if err != nil { + return rest, fmt.Errorf("could not unmarshal value: %w", err) + } + + if err = o.Add(key, val); err != nil { + return rest, err + } + + return rest, nil +} + +func processAdditionalInfo( + additionalInfo byte, + data []byte, +) (int, []byte, error) { + var val int + rest := data + + if additionalInfo < 24 { + val = int(additionalInfo) + } else if additionalInfo < 28 { + switch additionalInfo - 23 { + case 1: + if len(data) < 1 { + return 0, nil, errors.New("unexpected EOF") + } + val = int(data[0]) + rest = data[1:] + case 2: + if len(data) < 2 { + return 0, nil, errors.New("unexpected EOF") + } + val = int(binary.BigEndian.Uint16(data[:2])) + rest = data[2:] + case 3: + if len(data) < 4 { + return 0, nil, errors.New("unexpected EOF") + } + val = int(binary.BigEndian.Uint32(data[:4])) + rest = data[4:] + default: + return 0, nil, errors.New("cbor: cannot decode length value of 8 bytes") + } + } else if additionalInfo == 31 { + val = 0 // indefinite encoding + } else { + return 0, nil, fmt.Errorf("cbor: unexpected additional information value %d", additionalInfo) + } + + return val, rest, nil +} diff --git a/encoding/cbor_test.go b/encoding/cbor_test.go new file mode 100644 index 00000000..1dda9146 --- /dev/null +++ b/encoding/cbor_test.go @@ -0,0 +1,279 @@ +// Copyright 2021 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package encoding + +import ( + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_PopulateStructFromCBOR_simple(t *testing.T) { + type SimpleStruct struct { + FieldOne string `cbor:"0,keyasint,omitempty"` + FieldTwo int `cbor:"1,keyasint"` + } + + var v SimpleStruct + + data := []byte{ + 0xa2, // map(2) + + 0x00, // key 0 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + + 0x01, // key 1 + 0x06, // val 6 + } + + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + err = PopulateStructFromCBOR(dm, data, &v) + require.NoError(t, err) + assert.Equal(t, "acme", v.FieldOne) + assert.Equal(t, 6, v.FieldTwo) + + data = []byte{ + 0xa1, // map(1) + + 0x01, // key 1 + 0x06, // val 6 + } + v = SimpleStruct{} + + err = PopulateStructFromCBOR(dm, data, &v) + require.NoError(t, err) + assert.Equal(t, "", v.FieldOne) + assert.Equal(t, 6, v.FieldTwo) + + data = []byte{ + 0xa1, // map(1) + + 0x02, // key 2 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + } + v = SimpleStruct{} + + err = PopulateStructFromCBOR(dm, data, &v) + assert.EqualError(t, err, `missing mandatory field "FieldTwo" (1)`) + + err = PopulateStructFromCBOR(dm, []byte{0x01}, &v) + assert.EqualError(t, err, `expected map (CBOR Major Type 5), found Major Type 0`) + + type CompositeStruct struct { + FieldThree string `cbor:"2,keyasint"` + SimpleStruct + } + + var c CompositeStruct + + data = []byte{ + 0xa3, // map(3) + + 0x00, // key 0 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + + 0x01, // key 1 + 0x06, // val 6 + + 0x02, // key 2 + 0x63, // val tstr(3) + 0x66, 0x6f, 0x6f, // "foo" + } + + err = PopulateStructFromCBOR(dm, data, &c) + require.NoError(t, err) + assert.Equal(t, "acme", c.FieldOne) + assert.Equal(t, 6, c.FieldTwo) + assert.Equal(t, "foo", c.FieldThree) + + em, err := cbor.EncOptions{}.EncMode() + require.NoError(t, err) + + res, err := SerializeStructToCBOR(em, &c) + require.NoError(t, err) + + var c2 CompositeStruct + err = PopulateStructFromCBOR(dm, res, &c2) + require.NoError(t, err) + assert.EqualValues(t, c, c2) + +} + +func Test_structFieldsCBOR_CRUD(t *testing.T) { + sf := newStructFieldsCBOR() + + err := sf.Add(2, cbor.RawMessage{0x02}) + assert.NoError(t, err) + + err = sf.Add(1, cbor.RawMessage{0x01}) + assert.NoError(t, err) + + err = sf.Add(3, cbor.RawMessage{0x03}) + assert.NoError(t, err) + + assert.Equal(t, []int{2, 1, 3}, sf.Keys) + assert.True(t, sf.Has(3)) + assert.False(t, sf.Has(4)) + + val, ok := sf.Get(2) + assert.True(t, ok) + assert.Equal(t, cbor.RawMessage{0x2}, val) + + _, ok = sf.Get(4) + assert.False(t, ok) + + sf.Delete(2) + _, ok = sf.Get(2) + assert.False(t, ok) + + err = sf.Add(1, cbor.RawMessage{0x11}) + assert.EqualError(t, err, "duplicate cbor key: 1") +} + +func Test_structFieldsCBOR_CBOR_roundtrip(t *testing.T) { + em, err := cbor.EncOptions{}.EncMode() + require.NoError(t, err) + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + sf := newStructFieldsCBOR() + + data, err := sf.ToCBOR(em) + require.NoError(t, err) + assert.Equal(t, data, []byte{0xa0}) // empty map + + for i := 0; i < 5; i++ { + err = sf.Add(i, cbor.RawMessage{0x00}) + require.NoError(t, err) + } + + data, err = sf.ToCBOR(em) + require.NoError(t, err) + assert.Equal(t, data, []byte{ + 0xa5, // map 5 + 0x00, 0x00, + 0x01, 0x00, + 0x02, 0x00, + 0x03, 0x00, + 0x04, 0x00, + }) + + sfOut := newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, data) + require.NoError(t, err) + assert.Equal(t, sf, sfOut) + + for i := 5; i < 200; i++ { + err = sf.Add(i, cbor.RawMessage{0x00}) + require.NoError(t, err) + } + + data, err = sf.ToCBOR(em) + require.NoError(t, err) + assert.Equal(t, data[:2], []byte{ + 0xb8, 0xc8, // map 200 + }) + + sfOut = newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, data) + require.NoError(t, err) + assert.Equal(t, sf, sfOut) + + for i := 200; i < 2048; i++ { + err = sf.Add(i, cbor.RawMessage{0x00}) + require.NoError(t, err) + } + + data, err = sf.ToCBOR(em) + require.NoError(t, err) + assert.Equal(t, data[:3], []byte{ + 0xb9, 0x08, 0x00, // map 2048 + }) + + sfOut = newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, data) + require.NoError(t, err) + assert.Equal(t, sf, sfOut) +} + +func Test_structFieldsCBOR_CBOR_decode_tagged(t *testing.T) { + data := []byte{ + 0xc1, // tag 1 + 0xa5, // map 5 + 0x00, 0x00, + 0x01, 0x00, + 0x02, 0x00, + 0x03, 0x00, + 0x04, 0x00, + } + + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + sfOut := newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, data) + require.NoError(t, err) + assert.Equal(t, []int{0, 1, 2, 3, 4}, sfOut.Keys) +} + +func Test_structFieldsCBOR_CBOR_decode_indefinite(t *testing.T) { + data := []byte{ + 0xbf, // indefinite map + 0x00, 0x00, + 0x01, 0x00, + 0x02, 0x00, + 0x03, 0x00, + 0x04, 0x00, + 0xff, // break + } + + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + sfOut := newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, data) + require.NoError(t, err) + assert.Equal(t, []int{0, 1, 2, 3, 4}, sfOut.Keys) +} + +func Test_structFieldsCBOR_CBOR_decode_negative(t *testing.T) { + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + sfOut := newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, []byte{0xa1, 0xff, 0x00}) + assert.EqualError(t, err, `map item 0: could not unmarshal key: cbor: unexpected "break" code`) + err = sfOut.FromCBOR(dm, []byte{0xbf, 0x00, 0x00}) + assert.EqualError(t, err, `unexpected EOF`) + err = sfOut.FromCBOR(dm, []byte{0xa1, 0x00, 0xff}) + assert.EqualError(t, err, `map item 0: could not unmarshal value: cbor: unexpected "break" code`) + + err = sfOut.FromCBOR(dm, []byte{0x00}) + assert.EqualError(t, err, `expected map (CBOR Major Type 5), found Major Type 0`) +} + +func Test_processAdditionalInfo(t *testing.T) { + addInfo := byte(26) + data := []byte{0x00, 0x00, 0x00, 0x01} + + val, rest, err := processAdditionalInfo(addInfo, data) + require.NoError(t, err) + assert.Equal(t, 1, val) + assert.Equal(t, []byte{}, rest) + + _, _, err = processAdditionalInfo(byte(27), data) + assert.EqualError(t, err, "cbor: cannot decode length value of 8 bytes") + + _, _, err = processAdditionalInfo(byte(28), data) + assert.EqualError(t, err, "cbor: unexpected additional information value 28") + + _, _, err = processAdditionalInfo(addInfo, []byte{}) + assert.EqualError(t, err, "unexpected EOF") +} diff --git a/encoding/embedded.go b/encoding/embedded.go new file mode 100644 index 00000000..1bfeb999 --- /dev/null +++ b/encoding/embedded.go @@ -0,0 +1,49 @@ +package encoding + +import "reflect" + +const omitempty = "omitempty" + +type embedded struct { + Type reflect.Type + Value reflect.Value +} + +// collectEmbedded returns true if the Field is embedded (regardless of +// whether or not it was collected). +func collectEmbedded( + typeField *reflect.StructField, + valField reflect.Value, + embeds *[]embedded, +) bool { + // embedded fields are alway anonymous:w + if !typeField.Anonymous { + return false + } + + if typeField.Name == typeField.Type.Name() && + (typeField.Type.Kind() == reflect.Struct || + typeField.Type.Kind() == reflect.Interface) { + + var fieldType reflect.Type + var fieldValue reflect.Value + + if typeField.Type.Kind() == reflect.Interface { + fieldValue = valField.Elem() + if fieldValue.Kind() == reflect.Invalid { + // no value underlying the interface + return true + } + // use the interface's underlying value's real type + fieldType = valField.Elem().Type() + } else { + fieldType = typeField.Type + fieldValue = valField + } + + *embeds = append(*embeds, embedded{Type: fieldType, Value: fieldValue}) + return true + } + + return false +} diff --git a/encoding/json.go b/encoding/json.go new file mode 100644 index 00000000..f9b24f89 --- /dev/null +++ b/encoding/json.go @@ -0,0 +1,351 @@ +package encoding + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" +) + +func SerializeStructToJSON(source any) ([]byte, error) { + rawMap := newStructFieldsJSON() + + structType := reflect.TypeOf(source) + structVal := reflect.ValueOf(source) + + if err := doSerializeStructToJSON(rawMap, structType, structVal); err != nil { + return nil, err + } + + return rawMap.ToJSON() +} + +func doSerializeStructToJSON( + rawMap *structFieldsJSON, + structType reflect.Type, + structVal reflect.Value, +) error { + if structType.Kind() == reflect.Pointer { + structType = structType.Elem() + structVal = structVal.Elem() + } + + var embeds []embedded + + for i := 0; i < structVal.NumField(); i++ { + typeField := structType.Field(i) + valField := structVal.Field(i) + + if collectEmbedded(&typeField, valField, &embeds) { + continue + } + + tag, ok := typeField.Tag.Lookup("json") + if !ok { + continue + } + + parts := strings.Split(tag, ",") + key := parts[0] + + isOmitEmpty := false + if len(parts) > 1 { + for _, option := range parts[1:] { + if option == omitempty { + isOmitEmpty = true + break + } + } + } + + // do not serialize zero values if the corresponding field is + // omitempty + if isOmitEmpty && valField.IsZero() { + continue + } + + data, err := json.Marshal(valField.Interface()) + if err != nil { + return fmt.Errorf("error marshaling field %q: %w", + typeField.Name, + err, + ) + } + + if err := rawMap.Add(key, json.RawMessage(data)); err != nil { + return err + } + } + + for _, emb := range embeds { + if err := doSerializeStructToJSON(rawMap, emb.Type, emb.Value); err != nil { + return err + } + } + + return nil +} + +func PopulateStructFromJSON(data []byte, dest any) error { + rawMap := newStructFieldsJSON() + + if err := rawMap.FromJSON(data); err != nil { + return err + } + + structType := reflect.TypeOf(dest) + structVal := reflect.ValueOf(dest) + + return doPopulateStructFromJSON(rawMap, structType, structVal) +} + +func doPopulateStructFromJSON( + rawMap *structFieldsJSON, + structType reflect.Type, + structVal reflect.Value, +) error { + if structType.Kind() == reflect.Pointer { + structType = structType.Elem() + structVal = structVal.Elem() + } + + var embeds []embedded + + for i := 0; i < structVal.NumField(); i++ { + typeField := structType.Field(i) + valField := structVal.Field(i) + + if collectEmbedded(&typeField, valField, &embeds) { + continue + } + + tag, ok := typeField.Tag.Lookup("json") + if !ok { + continue + } + + parts := strings.Split(tag, ",") + key := parts[0] + + isOmitEmpty := false + if len(parts) > 1 { + for _, option := range parts[1:] { + if option == omitempty { + isOmitEmpty = true + break + } + } + } + + rawVal, ok := rawMap.Get(key) + if !ok { + if isOmitEmpty { + continue + } + + return fmt.Errorf("missing mandatory field %q (%q)", + typeField.Name, key) + } + + fieldPtr := valField.Addr().Interface() + if err := json.Unmarshal(rawVal, fieldPtr); err != nil { + return fmt.Errorf("error unmarshalling field %q: %w", + typeField.Name, + err, + ) + } + + rawMap.Delete(key) + } + + for _, emb := range embeds { + if err := doPopulateStructFromJSON(rawMap, emb.Type, emb.Value); err != nil { + return err + } + } + + return nil +} + +// structFieldsJSON is a specialized implementation of "OrderedMap", where the +// order of the keys is kept track of, and used when serializing the map to +// JSON. While JSON maps do not mandate any particular ordering, and so this +// isn't strictly necessary, it is useful to have a _stable_ serialization +// order for map keys to be compatible with regular Go struct serialization +// behavior. This is also useful for tests/examples that compare encoded +// []byte's. +type structFieldsJSON struct { + Fields map[string]json.RawMessage + Keys []string +} + +func newStructFieldsJSON() *structFieldsJSON { + return &structFieldsJSON{ + Fields: make(map[string]json.RawMessage), + } +} + +func (o structFieldsJSON) Has(key string) bool { + _, ok := o.Fields[key] + return ok +} + +func (o *structFieldsJSON) Add(key string, val json.RawMessage) error { + if o.Has(key) { + return fmt.Errorf("duplicate JSON key: %q", key) + } + + o.Fields[key] = val + o.Keys = append(o.Keys, key) + + return nil +} + +func (o *structFieldsJSON) Get(key string) (json.RawMessage, bool) { + val, ok := o.Fields[key] + return val, ok +} + +func (o *structFieldsJSON) Delete(key string) { + delete(o.Fields, key) + + for i, existing := range o.Keys { + if existing == key { + o.Keys = append(o.Keys[:i], o.Keys[i+1:]...) + } + } +} + +func (o *structFieldsJSON) ToJSON() ([]byte, error) { + var out bytes.Buffer + + out.Write([]byte("{")) + + first := true + for _, key := range o.Keys { + if first { + first = false + } else { + out.Write([]byte(",")) + } + marshaledKey, err := json.Marshal(key) + if err != nil { + return nil, fmt.Errorf("problem marshaling key %s: %w", key, err) + } + out.Write(marshaledKey) + out.Write([]byte(":")) + out.Write(o.Fields[key]) + } + + out.Write([]byte("}")) + + return out.Bytes(), nil +} + +func (o *structFieldsJSON) FromJSON(data []byte) error { + if err := json.Unmarshal(data, &o.Fields); err != nil { + return err + } + + return o.unmarshalKeys(data) +} + +func (o *structFieldsJSON) unmarshalKeys(data []byte) error { + + decoder := json.NewDecoder(bytes.NewReader(data)) + + token, err := decoder.Token() + if err != nil { + return err + } + + if token != json.Delim('{') { + return errors.New("expected start of object") + } + + var keys []string + + for { + token, err = decoder.Token() + if err != nil { + return err + } + + if token == json.Delim('}') { + break + } + + key, ok := token.(string) + if !ok { + return fmt.Errorf("expected string, found %T", token) + } + + keys = append(keys, key) + + if err := skipValue(decoder); err != nil { + return err + } + } + + o.Keys = keys + + return nil +} + +var errEndOfStream = errors.New("invalid end of array or object") + +func skipValue(decoder *json.Decoder) error { + + token, err := decoder.Token() + if err != nil { + return err + } + switch token { + case json.Delim('['), json.Delim('{'): + for { + if err := skipValue(decoder); err != nil { + if err == errEndOfStream { + break + } + return err + } + } + case json.Delim(']'), json.Delim('}'): + return errEndOfStream + } + return nil +} + +// TypeAndValue stores a JSON object with two attributes: a string "type" +// and a generic "value" (string) defined by type. This type is used in +// a few places to implement the choice types that CBOR handles using tags. +type TypeAndValue struct { + Type string `json:"type"` + Value json.RawMessage `json:"value"` +} + +func (o *TypeAndValue) UnmarshalJSON(data []byte) error { + var temp struct { + Type string `json:"type"` + Value json.RawMessage `json:"value"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + if temp.Type == "" { + return errors.New("type not set") + } + + if len(temp.Value) == 0 { + return fmt.Errorf("no value provided for %s", temp.Type) + } + + o.Type = temp.Type + o.Value = temp.Value + + return nil +} diff --git a/encoding/json_test.go b/encoding/json_test.go new file mode 100644 index 00000000..85eca9ff --- /dev/null +++ b/encoding/json_test.go @@ -0,0 +1,152 @@ +package encoding + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_PopulateStructFromJSON(t *testing.T) { + type SimpleStruct struct { + FieldOne string `json:"field-one,omitempty"` + FieldTwo int `json:"field-two"` + } + + var v SimpleStruct + + data := []byte(`{"field-one": "acme", "field-two": 6}`) + + err := PopulateStructFromJSON(data, &v) + require.NoError(t, err) + assert.Equal(t, "acme", v.FieldOne) + assert.Equal(t, 6, v.FieldTwo) + + data = []byte(`{"field-two": 6}`) + v = SimpleStruct{} + + err = PopulateStructFromJSON(data, &v) + require.NoError(t, err) + assert.Equal(t, "", v.FieldOne) + assert.Equal(t, 6, v.FieldTwo) + + data = []byte(`{"field-one": "acme"}`) + v = SimpleStruct{} + + err = PopulateStructFromJSON(data, &v) + assert.EqualError(t, err, `missing mandatory field "FieldTwo" ("field-two")`) + + err = PopulateStructFromJSON([]byte("7"), &v) + assert.EqualError(t, err, `json: cannot unmarshal number into Go value of type map[string]json.RawMessage`) + + type CompositeStruct struct { + FieldThree string `json:"field-three"` + SimpleStruct + } + + var c CompositeStruct + + data = []byte(`{"field-one": "acme", "field-two": 6, "field-three": "foo"}`) + + err = PopulateStructFromJSON(data, &c) + require.NoError(t, err) + assert.Equal(t, "acme", c.FieldOne) + assert.Equal(t, 6, c.FieldTwo) + assert.Equal(t, "foo", c.FieldThree) + + res, err := SerializeStructToJSON(&c) + require.NoError(t, err) + + var c2 CompositeStruct + err = PopulateStructFromJSON(res, &c2) + require.NoError(t, err) + assert.EqualValues(t, c, c2) +} + +func Test_structFieldsJSON_CRUD(t *testing.T) { + sf := newStructFieldsJSON() + + err := sf.Add("two", json.RawMessage("2")) + assert.NoError(t, err) + + err = sf.Add("one", json.RawMessage("1")) + assert.NoError(t, err) + + err = sf.Add("three", json.RawMessage("3")) + assert.NoError(t, err) + + assert.Equal(t, []string{"two", "one", "three"}, sf.Keys) + assert.True(t, sf.Has("three")) + assert.False(t, sf.Has("four")) + + val, ok := sf.Get("two") + assert.True(t, ok) + assert.Equal(t, json.RawMessage("2"), val) + + _, ok = sf.Get("four") + assert.False(t, ok) + + sf.Delete("two") + _, ok = sf.Get("two") + assert.False(t, ok) + + err = sf.Add("one", json.RawMessage("4")) + assert.EqualError(t, err, `duplicate JSON key: "one"`) +} + +func Test_skipValue(t *testing.T) { + text := "" + decoder := json.NewDecoder(strings.NewReader(text)) + err := skipValue(decoder) + assert.EqualError(t, err, "EOF") + + text = "[]" + decoder = json.NewDecoder(strings.NewReader(text)) + _, _ = decoder.Token() // skip the '[' + err = skipValue(decoder) + assert.EqualError(t, err, "invalid end of array or object") + + text = `{"embed": {"one": 1, "two": [1,2,3]}, "other": 1}` + decoder = json.NewDecoder(strings.NewReader(text)) + _, _ = decoder.Token() // skip the '{' + _, _ = decoder.Token() // skip the '"embed"' + err = skipValue(decoder) + assert.NoError(t, err) + + token, err := decoder.Token() + assert.NoError(t, err) + assert.Equal(t, "other", token) +} + +func TestTypeAndValue_UnmarshalJSON(t *testing.T) { + for _, tv := range []struct { + Input string + Expected TypeAndValue + Err string + }{ + { + Input: `{"type": "test", "value": "test"}`, + Expected: TypeAndValue{Type: "test", Value: []byte(`"test"`)}, + }, + { + Input: `{"type": "test"}`, + Err: "no value provided for test", + }, + { + Input: `{"value": "test"}`, + Err: "type not set", + }, + } { + var out TypeAndValue + err := out.UnmarshalJSON([]byte(tv.Input)) + + if tv.Err != "" { + assert.EqualError(t, err, tv.Err) + } else { + assert.NoError(t, err) + assert.Equal(t, tv.Expected, out) + } + } +} diff --git a/extensions/README.md b/extensions/README.md new file mode 100644 index 00000000..aeffe0f7 --- /dev/null +++ b/extensions/README.md @@ -0,0 +1,392 @@ +[CoRIM +specification](https://datatracker.ietf.org/doc/draft-ietf-rats-corim/02) +may be extended by CoRIM Profiles documented in other specifications at well +defined points identified in the base CoRIM spec using CDDL extension sockets. +CoRIM profiles may + +1. Introduce new data fields to certain objects. +2. Further constrain allowed values for existing fields. +3. Introduce new type choices for fields whose values can be one of several + types. + +This implementation, likewise, allows dependent code to register extension +types. This is done via three distinct extension mechanisms: + +1. Structures that allow extensions embed an `Extensions` object, which allows + registering a user-provided struct. The user-provided struct can extend + their containing structures in two ways: + + - the fields of the user-provided struct become additional fields in + the containing structure. + - the user-provided struct may define additional constraints on the + containing structure by defining an appropriate validation method for it. + + This corresponds to the map-extension sockets in the spec. +2. Some fields may have values chosen from a set of pre-determined types (e.g., an + instance ID may be either a UUID or UEID). These (mostly) correspond to + type-choice sockets in the CoRIM spec. The set of allowed types for a field + may be extended by registering a factory function for a new type, using the + registration function associated with the type choice. +3. A couple of type-choice sockets (`$tag-rel-type-choice`, + `$corim-role-type-choice` and `$comid-role-type-choice`) define what, in + effect, are extensible enums. They allow providing additional values, rather + than types. This implementation provides registration functions for new + values for those types. + +> [!NOTE] +> CoRIM also "imports" CDDL from the CoSWID spec. Some of +> these CoSWID CDDL definitions also feature extension sockets. +> However, as they are defined in a different spec and are implemented +> in the [`veraison/swid`](https://github.com/veraison/swid) library, they cannot be extended +> using the extension feature provided in the CoRIM library. The extension +> support in CoRIM library is applicable ONLY to CoRIM and CoMID maps and +> type choices + + +## Map Extensions + +Map extensions allow extending CoRIM maps with additional keys, effectively +defining new fields for the corresponding structures. In the code base, these +can be identified by the embedded `Extensions` struct. These are + +- `comid.Comid` +- `comid.Entity` +- `comid.FlagsMap` +- `comid.Mval` +- `comid.Triples` +- `corim.Entity` +- `corim.Signer` +- `corim.UnsignedCorim` + +To extend the above types, you need to define a struct containing your +extensions and pass a pointer to an instance of that struct to the +`RegisterExtensions()` method of the corresponding instance of the type that is +being extended. This should be done as early as possible, before any marshaling +is performed. + +These types can be extended in two ways: by adding additional fields, and by +introducing additional constraints over existing fields. + +### Adding new fields + +To add new fields, simply add them to your extensions struct, ensuring that +the `cbor` and `json` tags on those fields are set correctly. As CoRIM +mandates integer keys, you must use the `keyasint` option for the `cbor` tag. + +To access the values of those fields, you can call the extended type instance's +`Extensions.Get()` passing in the name of the field you want to access. The +name can be either the Go struct's field name, the name specified in the `json` +tag, or (a string containing) the integer specified in the `cbor` tag. + +`Get()` returns an `interface{}`. There are equivalent `GetInt()`, +`GetString()`, etc. methods that perform the required conversions, and return +the value of the indicated type, along with possible errors. ("Must" versions +of these also exist, e.g. `MustGetString()`, that do not return an error and +instead call `panic()`). + +You can also get the pointer to your extension's instance itself by calling +the extended type instance's `GetExtensions()`. This returns an `interface{}`, so +you will need to type assert to be able to access the fields directly. + +### Introducing additional constraints + +To introduce new constraints, add a method called `Constrain(v *)` +to your extensions struct, where `` is the name of the type being +extended (one of the ones listed above) -- e.g. +`ConstrainComid(v *comid.Comid)` when extending `comid.Comid`. This method, if +it exists, will be invoked inside the extended type instance's `Valid()` +method, passing itself as the parameter. + +You do not need to define this method unless you actually want to enforce some +constraints (i.e., if you just want to define additional fields). + +### Example + +The following example illustrates how to implement a map extension by extending +`comid.Entity` with the following features: + +1. an optional "email" field +2. additional constraint on the existing "name" field that it contains a + valid UUID (note: since `NameEntry` is a type choice extensible, this can + also be done by defining a new value type for `NameEntry` -- see the + following section). + +```go +package main + +import ( + "encoding/json" + "fmt" + "log" + + "github.com/google/uuid" + "github.com/veraison/corim/comid" +) + +// the struct containing the extensions +type EntityExtensions struct { + // a string field extension + Email string `cbor:"-1,keyasint,omitempty" json:"email,omitempty"` +} + +// custom constraints for Entity +func (o EntityExtensions) ConstrainEntity(val *comid.Entity) error { + _, err := uuid.Parse(val.EntityName.String()) + if err != nil { + return fmt.Errorf("invalid UUID: %w", err) + } + + return nil +} + +var sampleText = ` +{ + "name": "31fb5abf-023e-4992-aa4e-95f9c1503bfa", + "regid": "https://acme.example", + "email": "info@acme.com", + "roles": [ + "tagCreator", + "creator", + "maintainer" + ] +} +` + +func main() { + var entity comid.Entity + entity.RegisterExtensions(&EntityExtensions{}) + + if err := json.Unmarshal([]byte(sampleText), &entity); err != nil { + log.Fatalf("ERROR: %s", err.Error()) + } + + if err := entity.Valid(); err != nil { + log.Fatalf("failed to validate: %s", err.Error()) + } else { + fmt.Println("validation succeeded") + } + + // obtain the extension field value via a generic getter + email := entity.Extensions.MustGetString("email") + fmt.Printf("entity email: %s\n", email) + + // retrive the extensions struct and get value via its field. + exts := entity.GetExtensions().(*EntityExtensions) + fmt.Printf("also entity email: %s\n", exts.Email) +} +``` + + +## Type Choice Extensions + +Type Choice extensions allow specifying alternative types for existing CoRIM +fields by defining a type that implements an appropriate interface and +registering it with a CBOR tag. + +A type choice struct contains a single field, `Value`, that contains the actual +object represented by the type choice. The `Value` implements an interface +that is specific to the type choice and is derived from `ITypeChoiceValue`: + +```go +type ITypeChoiceValue interface { + // String returns the string representation of the ITypeChoiceValue. + String() string + // Valid returns an error if validation of the ITypeChoiceValue fails, + // or nil if it succeeds. + Valid() error + // Type returns the type name of this ITypeChoiceValue implementation. + Type() string +} +``` + +The following is the full list of type choice structs: + +- `comid.ClassID` +- `comid.CryptoKey` +- `comid.EntityName` +- `comid.Group` +- `comid.Instance` +- `comid.Mkey` +- `comid.SVN` +- `corim.EntityName` + +To provide a new value type, the following is required: + +1. Define a type that implements the value interface for the type choice you + want to extend. This interface is called `IValue`, where `` is + the name of the type choice type(e.g. `IClassIDValue`). These interfaces + always embed `ITypeChoiceValue` and possibly define additional methods. +2. Create a factory function for your type, with the signature `func (any) + (*, error)`, where `` is the name of the type choice type that + will contain your value. (Note that the function must return a pointer to + the container type choice struct, _not_ to the value type you define.) This + function should create an instance of your value type from the provided + input and return a new type choice struct instance containing it. The range + of valid inputs is up to you, however it _must_ handle `null`, returning the + [zero-value](https://go.dev/ref/spec#The_zero_value) for your type in that + case. +3. Register your factory function with the CBOR tag for your new type by + passing it to the registration function corresponding to the type choice + struct. It will have the name `RegisterType`, where `` is the + name of the type choice struct that will contain your value (e.g. + `RegisterClassIDType`). + +### Example + +The following example illustrates how to add a new type choice value +implementation by extending the `CryptoKey` type to support DER values. + +```go +package main + +import ( + "crypto" + "crypto/x509" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "log" + + "github.com/veraison/corim/comid" +) + +// the CBOR tag to be used for the new type +var DerKeyTag = uint64(9999) + +// new implementation of ICryptoKeyValue type +type TaggedDerKey []byte + +// The factory function for the new type +func NewTaggedDerKey(k any) (*comid.CryptoKey, error) { + var b []byte + var err error + + if k == nil { + k = *new([]byte) + } + + switch t := k.(type) { + case []byte: + b = t + case string: + b, err = base64.StdEncoding.DecodeString(t) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("value must be a []byte; found %T", k) + } + + key := TaggedDerKey(b) + + return &comid.CryptoKey{Value: key}, nil +} + +func (o TaggedDerKey) String() string { + return base64.StdEncoding.EncodeToString(o) +} + +func (o TaggedDerKey) Valid() error { + _, err := o.PublicKey() + return err +} + +func (o TaggedDerKey) Type() string { + return "pkix-der-key" +} + +func (o TaggedDerKey) PublicKey() (crypto.PublicKey, error) { + if len(o) == 0 { + return nil, errors.New("key value not set") + } + + key, err := x509.ParsePKIXPublicKey(o) + if err != nil { + return nil, fmt.Errorf("unable to parse public key: %w", err) + } + + return key, nil +} + +var testKeyJSON = ` +{ + "type": "pkix-der-key", + "value": "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8BlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==" +} +` + +func main() { + // register the factory function under the CBOR tag. + if err := comid.RegisterCryptoKeyType(DerKeyTag, NewTaggedDerKey); err != nil { + log.Fatal(err) + } + + var key comid.CryptoKey + + if err := json.Unmarshal([]byte(testKeyJSON), &key); err != nil { + log.Fatal(err) + } + + fmt.Printf("Decoded DER key: %x\n", key) +} +``` + + +## Enum extensions + +The following enum types may be extended with additional values: + +- `comid.Rel` +- `comid.Role` +- `corim.Role` + +This can be done by calling `RegisterRel` or `RegisterRole`, as appropriate, +and providing a new `uint64` value and corresponding `string` name. + +### Example + +```go +package main + +import ( + "encoding/json" + "fmt" + "log" + + "github.com/veraison/corim/comid" +) + +var sampleText = ` +{ + "name": "Acme Ltd.", + "regid": "https://acme.example", + "roles": [ + "tagCreator", + "owner" + ] +} +` + +func main() { + // associate role value 4 with the name "owner" + comid.RegisterRole(4, "owner") + + var entity comid.Entity + + if err := json.Unmarshal([]byte(sampleText), &entity); err != nil { + log.Fatalf("ERROR: %s", err.Error()) + } + + if err := entity.Valid(); err != nil { + log.Fatalf("failed to validate: %s", err.Error()) + } else { + fmt.Println("validation succeeded") + } + + fmt.Println("roles:") + for _, role := range entity.Roles { + fmt.Printf("\t%s\n", role.String()) + } +} +``` diff --git a/extensions/extensions.go b/extensions/extensions.go new file mode 100644 index 00000000..4be60fd1 --- /dev/null +++ b/extensions/extensions.go @@ -0,0 +1,379 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package extensions + +import ( + "errors" + "fmt" + "reflect" + "strings" + + "github.com/spf13/cast" +) + +var ErrExtensionNotFound = errors.New("extension not found") + +type IExtensionsValue any + +type Extensions struct { + IExtensionsValue `json:"extensions,omitempty"` +} + +func (o *Extensions) Register(exts IExtensionsValue) { + if reflect.TypeOf(exts).Kind() != reflect.Pointer { + panic("attempting to register a non-pointer IExtensionsValue") + } + + o.IExtensionsValue = exts +} + +func (o *Extensions) HaveExtensions() bool { + return o.IExtensionsValue != nil +} + +func (o *Extensions) Get(name string) (any, error) { + if o.IExtensionsValue == nil { + return nil, fmt.Errorf("%w: %s", ErrExtensionNotFound, name) + } + + extType := reflect.TypeOf(o.IExtensionsValue) + extVal := reflect.ValueOf(o.IExtensionsValue) + if extType.Kind() == reflect.Pointer { + extType = extType.Elem() + extVal = extVal.Elem() + } + + var fieldName, fieldJSONTag, fieldCBORTag string + for i := 0; i < extVal.NumField(); i++ { + typeField := extType.Field(i) + fieldName = typeField.Name + + tag, ok := typeField.Tag.Lookup("json") + if ok { + fieldJSONTag = strings.Split(tag, ",")[0] + } + + tag, ok = typeField.Tag.Lookup("cbor") + if ok { + fieldCBORTag = strings.Split(tag, ",")[0] + } + + if fieldName == name || fieldJSONTag == name || fieldCBORTag == name { + return extVal.Field(i).Interface(), nil + } + } + + return nil, fmt.Errorf("%w: %s", ErrExtensionNotFound, name) +} + +func (o *Extensions) MustGetString(name string) string { + v, _ := o.GetString(name) + return v +} + +func (o *Extensions) GetString(name string) (string, error) { + v, err := o.Get(name) + if err != nil { + return "", err + } + + return cast.ToStringE(v) +} + +func (o *Extensions) MustGetInt(name string) int { + v, _ := o.GetInt(name) + return v +} + +func (o *Extensions) GetInt(name string) (int, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToIntE(v) +} + +func (o *Extensions) MustGetInt64(name string) int64 { + v, _ := o.GetInt64(name) + return v +} + +func (o *Extensions) GetInt64(name string) (int64, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToInt64E(v) +} + +func (o *Extensions) MustGetInt32(name string) int32 { + v, _ := o.GetInt32(name) + return v +} + +func (o *Extensions) GetInt32(name string) (int32, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToInt32E(v) +} + +func (o *Extensions) MustGetInt16(name string) int16 { + v, _ := o.GetInt16(name) + return v +} + +func (o *Extensions) GetInt16(name string) (int16, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToInt16E(v) +} + +func (o *Extensions) MustGetInt8(name string) int8 { + v, _ := o.GetInt8(name) + return v +} + +func (o *Extensions) GetInt8(name string) (int8, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToInt8E(v) +} + +func (o *Extensions) MustGetUint(name string) uint { + v, _ := o.GetUint(name) + return v +} + +func (o *Extensions) GetUint(name string) (uint, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToUintE(v) +} + +func (o *Extensions) MustGetUint64(name string) uint64 { + v, _ := o.GetUint64(name) + return v +} + +func (o *Extensions) GetUint64(name string) (uint64, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToUint64E(v) +} + +func (o *Extensions) MustGetUint32(name string) uint32 { + v, _ := o.GetUint32(name) + return v +} + +func (o *Extensions) GetUint32(name string) (uint32, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToUint32E(v) +} + +func (o *Extensions) MustGetUint16(name string) uint16 { + v, _ := o.GetUint16(name) + return v +} + +func (o *Extensions) GetUint16(name string) (uint16, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToUint16E(v) +} + +func (o *Extensions) MustGetUint8(name string) uint8 { + v, _ := o.GetUint8(name) + return v +} + +func (o *Extensions) GetUint8(name string) (uint8, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToUint8E(v) +} + +func (o *Extensions) MustGetFloat32(name string) float32 { + v, _ := o.GetFloat32(name) + return v +} + +func (o *Extensions) GetFloat32(name string) (float32, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToFloat32E(v) +} + +func (o *Extensions) MustGetFloat64(name string) float64 { + v, _ := o.GetFloat64(name) + return v +} + +func (o *Extensions) GetFloat64(name string) (float64, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToFloat64E(v) +} + +func (o *Extensions) MustGetBool(name string) bool { + v, _ := o.GetBool(name) + return v +} + +func (o *Extensions) GetBool(name string) (bool, error) { + v, err := o.Get(name) + if err != nil { + return false, err + } + + return cast.ToBoolE(v) +} + +func (o *Extensions) MustGetSlice(name string) []any { + v, _ := o.GetSlice(name) + return v +} + +func (o *Extensions) GetSlice(name string) ([]any, error) { + v, err := o.Get(name) + if err != nil { + return []any{}, err + } + + return cast.ToSliceE(v) +} + +func (o *Extensions) MustGetIntSlice(name string) []int { + v, _ := o.GetIntSlice(name) + return v +} + +func (o *Extensions) GetIntSlice(name string) ([]int, error) { + v, err := o.Get(name) + if err != nil { + return []int{}, err + } + + return cast.ToIntSliceE(v) +} + +func (o *Extensions) MustGetStringSlice(name string) []string { + v, _ := o.GetStringSlice(name) + return v +} + +func (o *Extensions) GetStringSlice(name string) ([]string, error) { + v, err := o.Get(name) + if err != nil { + return []string{}, err + } + + return cast.ToStringSliceE(v) +} + +func (o *Extensions) MustGetStringMap(name string) map[string]any { + v, _ := o.GetStringMap(name) + return v +} + +func (o *Extensions) GetStringMap(name string) (map[string]any, error) { + v, err := o.Get(name) + if err != nil { + return map[string]any{}, err + } + + return cast.ToStringMapE(v) +} + +func (o *Extensions) MustGetStringMapString(name string) map[string]string { + v, _ := o.GetStringMapString(name) + return v +} + +func (o *Extensions) GetStringMapString(name string) (map[string]string, error) { + v, err := o.Get(name) + if err != nil { + return map[string]string{}, err + } + + return cast.ToStringMapStringE(v) +} + +func (o *Extensions) Set(name string, value any) error { + if o.IExtensionsValue == nil { + return fmt.Errorf("%w: %s", ErrExtensionNotFound, name) + } + + extType := reflect.TypeOf(o.IExtensionsValue) + extVal := reflect.ValueOf(o.IExtensionsValue) + if extType.Kind() == reflect.Pointer { + extType = extType.Elem() + extVal = extVal.Elem() + } + + var fieldName, fieldJSONTag, fieldCBORTag string + for i := 0; i < extVal.NumField(); i++ { + typeField := extType.Field(i) + valField := extVal.Field(i) + fieldName = typeField.Name + + tag, ok := typeField.Tag.Lookup("json") + if ok { + fieldJSONTag = strings.Split(tag, ",")[0] + } + + tag, ok = typeField.Tag.Lookup("cbor") + if ok { + fieldCBORTag = strings.Split(tag, ",")[0] + } + + if fieldName == name || fieldJSONTag == name || fieldCBORTag == name { + newVal := reflect.ValueOf(value) + if newVal.CanConvert(valField.Type()) { + valField.Set(newVal.Convert(valField.Type())) + return nil + } + + return fmt.Errorf( + "cannot set field %q (of type %s) to %v (%T)", + name, typeField.Type.Name(), + value, value, + ) + } + } + + return fmt.Errorf("%w: %s", ErrExtensionNotFound, name) +} diff --git a/extensions/extensions_test.go b/extensions/extensions_test.go new file mode 100644 index 00000000..74d44f52 --- /dev/null +++ b/extensions/extensions_test.go @@ -0,0 +1,122 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package extensions + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type Entity struct { + EntityName string + Roles []int64 + + Extensions +} + +type TestExtensions struct { + Address string `cbor:"-1,keyasint,omitempty" json:"address,omitempty"` + Size int `cbor:"-2,keyasint,omitempty" json:"size,omitempty"` + YearsOnAir float32 `cbor:"-3,keyasint,omitempty" json:"years-on-air,omitempty"` + StillAiring bool `cbor:"-4,keyasint,omitempty" json:"still-airing,omitempty"` + Ages []int `cbor:"-5,keyasint,omitempty" json:"ages,omitempty"` + Jobs map[string]string `cbor:"-6,keyasint,omitempty" json:"jobs,omitempty"` +} + +func TestExtensions_Register(t *testing.T) { + exts := Extensions{} + assert.False(t, exts.HaveExtensions()) + + exts.Register(&TestExtensions{}) + assert.True(t, exts.HaveExtensions()) + + badRegister := func() { + exts.Register(TestExtensions{}) + } + + assert.Panics(t, badRegister) +} + +func TestExtensions_GetSet(t *testing.T) { + extsVal := TestExtensions{ + Address: "742 Evergreen Terrace", + Size: 6, + YearsOnAir: 33.8, + StillAiring: true, + Ages: []int{2, 7, 8, 10, 37, 38}, + Jobs: map[string]string{ + "Homer": "safety inspector", + "Marge": "housewife", + "Bart": "elementary school student", + "Lisa": "elementary school student", + }, + } + exts := Extensions{IExtensionsValue: &extsVal} + + v, err := exts.GetInt("size") + assert.NoError(t, err) + assert.Equal(t, 6, v) + + assert.Equal(t, 6, exts.MustGetInt("size")) + assert.Equal(t, int64(6), exts.MustGetInt64("size")) + assert.Equal(t, int32(6), exts.MustGetInt32("size")) + assert.Equal(t, int16(6), exts.MustGetInt16("size")) + assert.Equal(t, int8(6), exts.MustGetInt8("size")) + + assert.Equal(t, uint(6), exts.MustGetUint("size")) + assert.Equal(t, uint64(6), exts.MustGetUint64("size")) + assert.Equal(t, uint32(6), exts.MustGetUint32("size")) + assert.Equal(t, uint16(6), exts.MustGetUint16("size")) + assert.Equal(t, uint8(6), exts.MustGetUint8("size")) + + assert.InEpsilon(t, float32(33.8), exts.MustGetFloat32("years-on-air"), 0.000001) + assert.InEpsilon(t, float64(33.8), exts.MustGetFloat64("-3"), 0.000001) + + assert.Equal(t, true, exts.MustGetBool("StillAiring")) + + _, err = exts.GetSlice("ages") + assert.EqualError(t, err, + `unable to cast []int{2, 7, 8, 10, 37, 38} of type []int to []interface{}`) + assert.Nil(t, exts.MustGetSlice("ages")) + + assert.EqualValues(t, []int{2, 7, 8, 10, 37, 38}, exts.MustGetIntSlice("ages")) + assert.EqualValues(t, []string{"2", "7", "8", "10", "37", "38"}, + exts.MustGetStringSlice("ages")) + + assert.EqualValues(t, map[string]string{ + "Homer": "safety inspector", + "Marge": "housewife", + "Bart": "elementary school student", + "Lisa": "elementary school student", + }, exts.MustGetStringMapString("jobs")) + + _, err = exts.GetStringMap("jobs") + assert.EqualError(t, err, + `unable to cast map[string]string{"Bart":"elementary school student", "Homer":"safety inspector", "Lisa":"elementary school student", "Marge":"housewife"} of type map[string]string to map[string]interface{}`) + m := exts.MustGetStringMap("jobs") + assert.Equal(t, map[string]any{}, m) + + s, err := exts.GetString("address") + assert.NoError(t, err) + assert.Equal(t, "742 Evergreen Terrace", s) + + _, err = exts.GetInt("address") + assert.EqualError(t, err, `unable to cast "742 Evergreen Terrace" of type string to int`) + + _, err = exts.GetInt("foo") + assert.EqualError(t, err, "extension not found: foo") + + err = exts.Set("-1", "123 Fake Street") + assert.NoError(t, err) + + s, err = exts.GetString("address") + assert.NoError(t, err) + assert.Equal(t, "123 Fake Street", s) + + err = exts.Set("Size", "foo") + assert.EqualError(t, err, `cannot set field "Size" (of type int) to foo (string)`) + + assert.Equal(t, "", exts.MustGetString("does-not-exist")) + assert.Equal(t, 0, exts.MustGetInt("does-not-exist")) +} diff --git a/extensions/typechoice.go b/extensions/typechoice.go new file mode 100644 index 00000000..5510c42a --- /dev/null +++ b/extensions/typechoice.go @@ -0,0 +1,36 @@ +package extensions + +import ( + "encoding/json" + + "github.com/veraison/corim/encoding" +) + +var StringType = "string" + +// ITypeChoiceValue is the interface that is implemented by all concrete type +// choice value types. Specific type choices define their own value interfaces +// that embed this one (and possibly include additional methods). +type ITypeChoiceValue interface { + // String returns the string representation of the ITypeChoiceValue. + String() string + // Valid returns an error if validation of the ITypeChoiceValue fails, + // or nil if it succeeds. + Valid() error + // Type returns the type name of this ITypeChoiceValue implementation. + Type() string +} + +func TypeChoiceValueMarshalJSON(v ITypeChoiceValue) ([]byte, error) { + valueBytes, err := json.Marshal(v) + if err != nil { + return nil, err + } + + value := encoding.TypeAndValue{ + Type: v.Type(), + Value: valueBytes, + } + + return json.Marshal(value) +} diff --git a/go.mod b/go.mod index 9d95f047..96a5aa49 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,12 @@ module github.com/veraison/corim go 1.18 require ( - github.com/fxamacker/cbor/v2 v2.4.0 + github.com/fxamacker/cbor/v2 v2.5.0 github.com/golang/mock v1.6.0 github.com/google/uuid v1.3.0 github.com/lestrrat-go/jwx/v2 v2.0.8 github.com/spf13/afero v1.9.2 + github.com/spf13/cast v1.4.1 github.com/spf13/cobra v1.2.1 github.com/spf13/viper v1.9.0 github.com/stretchr/testify v1.8.2 @@ -35,7 +36,6 @@ require ( github.com/moogar0880/problems v0.1.1 // indirect github.com/pelletier/go-toml v1.9.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/spf13/cast v1.4.1 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.2.0 // indirect diff --git a/go.sum b/go.sum index c6347523..e1ae4e55 100644 --- a/go.sum +++ b/go.sum @@ -90,8 +90,8 @@ github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWp github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU= github.com/fxamacker/cbor/v2 v2.2.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/fxamacker/cbor/v2 v2.3.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= -github.com/fxamacker/cbor/v2 v2.4.0 h1:ri0ArlOR+5XunOP8CRUowT0pSJOwhW098ZCUyskZD88= -github.com/fxamacker/cbor/v2 v2.4.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= +github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADivE= +github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=