Skip to content

Commit

Permalink
SUB-3589 - Risk factors
Browse files Browse the repository at this point in the history
  • Loading branch information
rinao12 committed Dec 28, 2023
1 parent 757c589 commit 3dc6690
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 0 deletions.
52 changes: 52 additions & 0 deletions armotypes/common.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package armotypes

import "strings"

// swagger:strfmt uuid4
// Example: 0f42fbe3-d81e-444d-8cc7-bc892c7623e9
type GUID string
Expand All @@ -13,3 +15,53 @@ const (
RiskFactorDataAccess RiskFactor = "Data access"
RiskFactorHostAccess RiskFactor = "Host access"
)

var RiskFactorMapping = map[string]RiskFactor{
"C-0256": RiskFactorInternetFacing,
"C-0046": RiskFactorPrivileged,
"C-0057": RiskFactorPrivileged,
"C-0255": RiskFactorSecretAccess,
"C-0257": RiskFactorDataAccess,
"C-0038": RiskFactorHostAccess,
"C-0041": RiskFactorHostAccess,
"C-0044": RiskFactorHostAccess,
"C-0048": RiskFactorHostAccess,
}

// GetRiskFactors returns a list of unique risk factors for given control IDs.
func GetRiskFactors(controlIDs []string) []RiskFactor {
riskFactorSet := make(map[RiskFactor]bool)
for _, id := range controlIDs {
if riskFactor, exists := RiskFactorMapping[id]; exists {
riskFactorSet[riskFactor] = true
}
}

var riskFactors []RiskFactor
for riskFactor := range riskFactorSet {
riskFactors = append(riskFactors, riskFactor)
}
return riskFactors
}

func GetControlIDsByRiskFactors(riskFactorsStr string) []string {
riskFactors := strings.Split(riskFactorsStr, ",")
controlIDSet := make(map[string]bool)

for _, rfStr := range riskFactors {
rf := RiskFactor(rfStr) // Assuming risk factor strings match the enum names
for controlID, mappedRF := range RiskFactorMapping {
if mappedRF == rf {
controlIDSet[controlID] = true
}
}
}

// Convert set to slice
var controlIDs []string
for controlID := range controlIDSet {
controlIDs = append(controlIDs, controlID)
}

return controlIDs
}
105 changes: 105 additions & 0 deletions armotypes/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package armotypes

import (
"github.com/stretchr/testify/assert"
"sort"
"testing"
)

func TestGetControlIDsByRiskFactors(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{
name: "Single Risk Factor",
input: "RiskFactorInternetFacing",
expected: []string{"C-0256"},
},
{
name: "Multiple Risk Factors",
input: "RiskFactorPrivileged,RiskFactorSecretAccess",
expected: []string{"C-0046", "C-0057", "C-0255"},
},
{
name: "No Risk Factors",
input: "",
expected: []string{},
},
{
name: "Invalid Risk Factor",
input: "RiskFactorNonExistent",
expected: []string{},
},
{
name: "Duplicate Risk Factors",
input: "RiskFactorHostAccess,RiskFactorHostAccess",
expected: []string{"C-0038", "C-0041", "C-0044", "C-0048"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetControlIDsByRiskFactors(tt.input)
sort.Strings(result)
sort.Strings(tt.expected)

assert.Equal(t, tt.expected, result)
})
}
}

func TestGetRiskFactors(t *testing.T) {
tests := []struct {
name string
input []string
expected []RiskFactor
}{
{
name: "Multiple Risk Factors",
input: []string{"C-0256", "C-0046", "C-0057", "C-0255"},
expected: []RiskFactor{RiskFactorInternetFacing, RiskFactorPrivileged, RiskFactorSecretAccess},
},
{
name: "Empty controls list",
input: []string{},
expected: nil,
},
{
name: "nil controls list",
input: nil,
expected: nil,
},
{
name: "Single Risk Factor",
input: []string{"C-0256"},
expected: []RiskFactor{RiskFactorInternetFacing},
},
{
name: "No Risk Factors",
input: []string{"C-0000"},
expected: nil,
},
{
name: "Duplicate Risk Factors",
input: []string{"C-0046", "C-0046"},
expected: []RiskFactor{RiskFactorPrivileged},
},
{
name: "Mixed Valid and Invalid IDs",
input: []string{"C-0046", "C-9999"},
expected: []RiskFactor{RiskFactorPrivileged},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetRiskFactors(tt.input)
sort.Slice(result, func(i, j int) bool { return result[i] < result[j] })
sort.Slice(tt.expected, func(i, j int) bool { return tt.expected[i] < tt.expected[j] })

assert.Equal(t, tt.expected, result, "GetRiskFactors(%v) = %v, want %v", tt.input, result, tt.expected)
})
}
}

0 comments on commit 3dc6690

Please sign in to comment.