Skip to content

Commit

Permalink
refactor: implement functional options pattern consistently
Browse files Browse the repository at this point in the history
Signed-off-by: Jonathan Howard <[email protected]>
  • Loading branch information
jhoward-lm committed Jan 16, 2025
1 parent 5ce1b78 commit 32da4ad
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 180 deletions.
184 changes: 93 additions & 91 deletions backends/ent/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@ import (
)

type (
contactOwnerIDKey struct{}
documentIDKey struct{}
metadataIDKey struct{}
nodeIDKey struct{}
nodeListIDKey struct{}
nodeNativeIDMappingKey struct{}

TxFunc func(*ent.Tx) error
Expand Down Expand Up @@ -117,7 +113,7 @@ func (backend *Backend) saveAnnotations(annotations ...*ent.Annotation) TxFunc {
}
}

func (backend *Backend) saveDocumentTypes(docTypes []*sbom.DocumentType) TxFunc {
func (backend *Backend) saveDocumentTypes(docTypes []*sbom.DocumentType, opts ...func(*ent.DocumentTypeCreate)) TxFunc {
return func(tx *ent.Tx) error {
for _, docType := range docTypes {
typeName := documenttype.Type(docType.GetType().String())
Expand All @@ -128,8 +124,9 @@ func (backend *Backend) saveDocumentTypes(docTypes []*sbom.DocumentType) TxFunc
SetNillableName(docType.Name). //nolint:protogetter
SetNillableDescription(docType.Description) //nolint:protogetter

setDocumentID(backend.ctx, newDocType)
setMetadataID(backend.ctx, newDocType)
for _, fn := range opts {
fn(newDocType)
}

if err := newDocType.OnConflict().Ignore().Exec(backend.ctx); err != nil && !ent.IsConstraintError(err) {
return fmt.Errorf("saving document type: %w", err)
Expand All @@ -140,7 +137,7 @@ func (backend *Backend) saveDocumentTypes(docTypes []*sbom.DocumentType) TxFunc
}
}

func (backend *Backend) saveEdges(edges []*sbom.Edge) TxFunc { //nolint:gocognit
func (backend *Backend) saveEdges(edges []*sbom.Edge, opts ...func(*ent.EdgeTypeCreate)) TxFunc { //nolint:gocognit
return func(tx *ent.Tx) error {
nativeIDMap, ok := backend.ctx.Value(nodeNativeIDMappingKey{}).(map[string]uuid.UUID)
if !ok {
Expand All @@ -155,8 +152,9 @@ func (backend *Backend) saveEdges(edges []*sbom.Edge) TxFunc { //nolint:gocognit
SetFromID(nativeIDMap[edge.GetFrom()]).
SetToID(nativeIDMap[toID])

setDocumentID(backend.ctx, newEdgeType)
addNodeListIDs(backend.ctx, newEdgeType)
for _, fn := range opts {
fn(newEdgeType)
}

if err := newEdgeType.
OnConflict().
Expand Down Expand Up @@ -189,17 +187,16 @@ func (backend *Backend) saveExternalReferences(refs []*sbom.ExternalReference, o
SetAuthority(ref.GetAuthority()).
SetType(externalreference.Type(ref.GetType().String()))

setDocumentID(backend.ctx, newRef)

for _, fn := range opts {
fn(newRef)
}

builders = append(builders, newRef)

fns = append(fns, backend.saveHashes(ref.GetHashes(),
func(hec *ent.HashesEntryCreate) { hec.AddExternalReferenceIDs(extRefID) },
))
fns = append(fns, backend.saveHashes(ref.GetHashes(), func(hec *ent.HashesEntryCreate) {
hec.AddExternalReferenceIDs(extRefID)
setDocumentID(backend.ctx, hec)
}))
}

err := tx.ExternalReference.CreateBulk(builders...).
Expand Down Expand Up @@ -231,8 +228,6 @@ func (backend *Backend) saveHashes(hashes map[int32]string, opts ...func(*ent.Ha
SetHashAlgorithm(hashesentry.HashAlgorithm(alg.String())).
SetHashData(value)

setDocumentID(backend.ctx, hashesEntry)

for _, fn := range opts {
fn(hashesEntry)
}
Expand Down Expand Up @@ -262,8 +257,6 @@ func (backend *Backend) saveIdentifiers(idents map[int32]string, opts ...func(*e
SetType(identifiersentry.Type(identType.String())).
SetValue(value)

setDocumentID(backend.ctx, identEntry)

for _, fn := range opts {
fn(identEntry)
}
Expand Down Expand Up @@ -303,13 +296,23 @@ func (backend *Backend) saveMetadata(metadata *sbom.Metadata) TxFunc {
return fmt.Errorf("saving metadata: %w", err)
}

backend.ctx = context.WithValue(backend.ctx, metadataIDKey{}, id)

for _, fn := range []TxFunc{
backend.savePersons(metadata.GetAuthors()),
backend.saveDocumentTypes(metadata.GetDocumentTypes()),
backend.saveSourceData(metadata.GetSourceData()),
backend.saveTools(metadata.GetTools()),
backend.savePersons(metadata.GetAuthors(), func(pc *ent.PersonCreate) {
pc.SetMetadataID(id)
setDocumentID(backend.ctx, pc)
}),
backend.saveDocumentTypes(metadata.GetDocumentTypes(), func(dtc *ent.DocumentTypeCreate) {
dtc.SetMetadataID(id)
setDocumentID(backend.ctx, dtc)
}),
backend.saveSourceData(metadata.GetSourceData(), func(sdc *ent.SourceDataCreate) {
sdc.SetMetadataID(id)
setDocumentID(backend.ctx, sdc)
}),
backend.saveTools(metadata.GetTools(), func(tc *ent.ToolCreate) {
tc.SetMetadataID(id)
setDocumentID(backend.ctx, tc)
}),
} {
if err := fn(tx); err != nil {
return err
Expand Down Expand Up @@ -337,11 +340,15 @@ func (backend *Backend) saveNodeList(nodeList *sbom.NodeList) TxFunc {
return fmt.Errorf("saving node list: %w", err)
}

backend.ctx = context.WithValue(backend.ctx, nodeListIDKey{}, id)

for _, fn := range []TxFunc{
backend.saveNodes(nodeList.GetNodes()),
backend.saveEdges(nodeList.GetEdges()),
backend.saveNodes(nodeList.GetNodes(), func(nc *ent.NodeCreate) {
nc.AddNodeListIDs(id)
setDocumentID(backend.ctx, nc)
}),
backend.saveEdges(nodeList.GetEdges(), func(etc *ent.EdgeTypeCreate) {
etc.AddNodeListIDs(id)
setDocumentID(backend.ctx, etc)
}),
} {
if err := fn(tx); err != nil {
return err
Expand All @@ -352,7 +359,7 @@ func (backend *Backend) saveNodeList(nodeList *sbom.NodeList) TxFunc {
}
}

func (backend *Backend) saveNodes(nodes []*sbom.Node) TxFunc { //nolint:funlen,gocognit
func (backend *Backend) saveNodes(nodes []*sbom.Node, opts ...func(*ent.NodeCreate)) TxFunc { //nolint:funlen,gocognit
return func(tx *ent.Tx) error {
builders := []*ent.NodeCreate{}
fns := []TxFunc{}
Expand All @@ -366,10 +373,9 @@ func (backend *Backend) saveNodes(nodes []*sbom.Node) TxFunc { //nolint:funlen,g

nativeIDMap[srcNode.GetId()] = nodeID

backend.ctx = context.WithValue(backend.ctx, nodeIDKey{}, nodeID)
newNode := tx.Node.Create().
SetNativeID(srcNode.GetId()).
SetProtoMessage(srcNode).
SetNativeID(srcNode.GetId()).
SetAttribution(srcNode.GetAttribution()).
SetBuildDate(srcNode.GetBuildDate().AsTime()).
SetComment(srcNode.GetComment()).
Expand All @@ -390,25 +396,41 @@ func (backend *Backend) saveNodes(nodes []*sbom.Node) TxFunc { //nolint:funlen,g
SetValidUntilDate(srcNode.GetValidUntilDate().AsTime()).
SetVersion(srcNode.GetVersion())

addNodeListIDs(backend.ctx, newNode)
setDocumentID(backend.ctx, newNode)
for _, fn := range opts {
fn(newNode)
}

builders = append(builders, newNode)

fns = append(fns,
backend.saveExternalReferences(srcNode.GetExternalReferences(),
func(erc *ent.ExternalReferenceCreate) { erc.AddNodeIDs(nodeID) },
),
backend.saveHashes(srcNode.GetHashes(),
func(hec *ent.HashesEntryCreate) { hec.AddNodeIDs(nodeID) },
),
backend.saveIdentifiers(srcNode.GetIdentifiers(),
func(iec *ent.IdentifiersEntryCreate) { iec.AddNodeIDs(nodeID) },
),
backend.savePersons(srcNode.GetOriginators()),
backend.savePersons(srcNode.GetSuppliers()),
backend.saveProperties(srcNode.GetProperties(), nodeID),
backend.savePurposes(srcNode.GetPrimaryPurpose(), nodeID),
backend.saveExternalReferences(srcNode.GetExternalReferences(), func(erc *ent.ExternalReferenceCreate) {
erc.AddNodeIDs(nodeID)
setDocumentID(backend.ctx, erc)
}),
backend.saveHashes(srcNode.GetHashes(), func(hec *ent.HashesEntryCreate) {
hec.AddNodeIDs(nodeID)
setDocumentID(backend.ctx, hec)
}),
backend.saveIdentifiers(srcNode.GetIdentifiers(), func(iec *ent.IdentifiersEntryCreate) {
iec.AddNodeIDs(nodeID)
setDocumentID(backend.ctx, iec)
}),
backend.savePersons(srcNode.GetOriginators(), func(pc *ent.PersonCreate) {
pc.SetNodeID(nodeID)
setDocumentID(backend.ctx, pc)
}),
backend.savePersons(srcNode.GetSuppliers(), func(pc *ent.PersonCreate) {
pc.SetNodeID(nodeID)
setDocumentID(backend.ctx, pc)
}),
backend.saveProperties(srcNode.GetProperties(), func(pc *ent.PropertyCreate) {
pc.SetNodeID(nodeID)
setDocumentID(backend.ctx, pc)
}),
backend.savePurposes(srcNode.GetPrimaryPurpose(), func(pc *ent.PurposeCreate) {
pc.SetNodeID(nodeID)
setDocumentID(backend.ctx, pc)
}),
)
}

Expand All @@ -432,7 +454,7 @@ func (backend *Backend) saveNodes(nodes []*sbom.Node) TxFunc { //nolint:funlen,g
}
}

func (backend *Backend) savePersons(persons []*sbom.Person) TxFunc { //nolint:gocognit
func (backend *Backend) savePersons(persons []*sbom.Person, opts ...func(*ent.PersonCreate)) TxFunc { //nolint:gocognit
return func(tx *ent.Tx) error {
builders := []*ent.PersonCreate{}

Expand All @@ -450,15 +472,16 @@ func (backend *Backend) savePersons(persons []*sbom.Person) TxFunc { //nolint:go
SetPhone(person.GetPhone()).
SetURL(person.GetUrl())

setContactOwnerID(backend.ctx, newPerson)
setDocumentID(backend.ctx, newPerson)
setMetadataID(backend.ctx, newPerson)
setNodeID(backend.ctx, newPerson)
for _, fn := range opts {
fn(newPerson)
}

builders = append(builders, newPerson)
backend.ctx = context.WithValue(backend.ctx, contactOwnerIDKey{}, id)

if err := backend.savePersons(person.GetContacts())(tx); err != nil {
if err := backend.savePersons(person.GetContacts(), func(pc *ent.PersonCreate) {
pc.SetContactOwnerID(id)
setDocumentID(backend.ctx, pc)
})(tx); err != nil {
return err
}
}
Expand All @@ -474,19 +497,19 @@ func (backend *Backend) savePersons(persons []*sbom.Person) TxFunc { //nolint:go
}
}

func (backend *Backend) saveProperties(properties []*sbom.Property, nodeID uuid.UUID) TxFunc {
func (backend *Backend) saveProperties(properties []*sbom.Property, opts ...func(*ent.PropertyCreate)) TxFunc {
return func(tx *ent.Tx) error {
builders := []*ent.PropertyCreate{}

for _, prop := range properties {
newProp := tx.Property.Create().
SetProtoMessage(prop).
SetNodeID(nodeID).
SetName(prop.GetName()).
SetData(prop.GetData())

setDocumentID(backend.ctx, newProp)
setNodeID(backend.ctx, newProp)
for _, fn := range opts {
fn(newProp)
}

builders = append(builders, newProp)
}
Expand All @@ -503,16 +526,17 @@ func (backend *Backend) saveProperties(properties []*sbom.Property, nodeID uuid.
}
}

func (backend *Backend) savePurposes(purposes []sbom.Purpose, nodeID uuid.UUID) TxFunc {
func (backend *Backend) savePurposes(purposes []sbom.Purpose, opts ...func(*ent.PurposeCreate)) TxFunc {
return func(tx *ent.Tx) error {
builders := []*ent.PurposeCreate{}

for idx := range purposes {
newPurpose := tx.Purpose.Create().
SetNodeID(nodeID).
SetPrimaryPurpose(purpose.PrimaryPurpose(purposes[idx].String()))

setDocumentID(backend.ctx, newPurpose)
for _, fn := range opts {
fn(newPurpose)
}

builders = append(builders, newPurpose)
}
Expand All @@ -529,7 +553,7 @@ func (backend *Backend) savePurposes(purposes []sbom.Purpose, nodeID uuid.UUID)
}
}

func (backend *Backend) saveSourceData(sourceData *sbom.SourceData) TxFunc {
func (backend *Backend) saveSourceData(sourceData *sbom.SourceData, opts ...func(*ent.SourceDataCreate)) TxFunc {
return func(tx *ent.Tx) error {
newSourceData := tx.SourceData.Create().
SetProtoMessage(sourceData).
Expand All @@ -538,8 +562,9 @@ func (backend *Backend) saveSourceData(sourceData *sbom.SourceData) TxFunc {
SetSize(sourceData.GetSize()).
SetURI(sourceData.GetUri())

setDocumentID(backend.ctx, newSourceData)
setMetadataID(backend.ctx, newSourceData)
for _, fn := range opts {
fn(newSourceData)
}

if err := newSourceData.OnConflict().Ignore().Exec(backend.ctx); err != nil && !ent.IsConstraintError(err) {
return fmt.Errorf("saving source data: %w", err)
Expand All @@ -549,7 +574,7 @@ func (backend *Backend) saveSourceData(sourceData *sbom.SourceData) TxFunc {
}
}

func (backend *Backend) saveTools(tools []*sbom.Tool) TxFunc {
func (backend *Backend) saveTools(tools []*sbom.Tool, opts ...func(*ent.ToolCreate)) TxFunc {
return func(tx *ent.Tx) error {
builders := []*ent.ToolCreate{}

Expand All @@ -560,8 +585,9 @@ func (backend *Backend) saveTools(tools []*sbom.Tool) TxFunc {
SetVersion(tool.GetVersion()).
SetVendor(tool.GetVendor())

setDocumentID(backend.ctx, newTool)
setMetadataID(backend.ctx, newTool)
for _, fn := range opts {
fn(newTool)
}

builders = append(builders, newTool)
}
Expand All @@ -588,32 +614,8 @@ func GenerateUUID(msg proto.Message) (uuid.UUID, error) {
return uuid.NewHash(sha256.New(), uuid.Max, data, int(uuid.Max.Version())), nil
}

func addNodeListIDs[T interface{ AddNodeListIDs(...uuid.UUID) T }](ctx context.Context, builder T) {
if nodeListID, ok := ctx.Value(nodeListIDKey{}).(uuid.UUID); ok {
builder.AddNodeListIDs(nodeListID)
}
}

func setContactOwnerID[T interface{ SetContactOwnerID(uuid.UUID) T }](ctx context.Context, builder T) {
if contactOwnerID, ok := ctx.Value(contactOwnerIDKey{}).(uuid.UUID); ok {
builder.SetContactOwnerID(contactOwnerID)
}
}

func setDocumentID[T interface{ SetDocumentID(uuid.UUID) T }](ctx context.Context, builder T) {
if documentID, ok := ctx.Value(documentIDKey{}).(uuid.UUID); ok {
builder.SetDocumentID(documentID)
}
}

func setMetadataID[T interface{ SetMetadataID(uuid.UUID) T }](ctx context.Context, builder T) {
if metadataID, ok := ctx.Value(metadataIDKey{}).(uuid.UUID); ok {
builder.SetMetadataID(metadataID)
}
}

func setNodeID[T interface{ SetNodeID(uuid.UUID) T }](ctx context.Context, builder T) {
if nodeID, ok := ctx.Value(nodeIDKey{}).(uuid.UUID); ok {
builder.SetNodeID(nodeID)
}
}
Loading

0 comments on commit 32da4ad

Please sign in to comment.