diff --git a/armotypes/common.go b/armotypes/common.go index d263e8e..a734b04 100644 --- a/armotypes/common.go +++ b/armotypes/common.go @@ -1,5 +1,7 @@ package armotypes +import "strings" + // swagger:strfmt uuid4 // Example: 0f42fbe3-d81e-444d-8cc7-bc892c7623e9 type GUID string @@ -13,3 +15,53 @@ const ( RiskFactorDataAccess RiskFactor = "Data access" RiskFactorHostAccess RiskFactor = "Host access" ) + +var RiskFactorMapping = map[string]RiskFactor{ + "C-0256": RiskFactorInternetFacing, + "C-0046": RiskFactorPrivileged, + "C-0057": RiskFactorPrivileged, + "C-0255": RiskFactorSecretAccess, + "C-0257": RiskFactorDataAccess, + "C-0038": RiskFactorHostAccess, + "C-0041": RiskFactorHostAccess, + "C-0044": RiskFactorHostAccess, + "C-0048": RiskFactorHostAccess, +} + +// GetRiskFactors returns a list of unique risk factors for given control IDs. +func GetRiskFactors(controlIDs []string) []RiskFactor { + riskFactorSet := make(map[RiskFactor]bool) + for _, id := range controlIDs { + if riskFactor, exists := RiskFactorMapping[id]; exists { + riskFactorSet[riskFactor] = true + } + } + + var riskFactors []RiskFactor + for riskFactor := range riskFactorSet { + riskFactors = append(riskFactors, riskFactor) + } + return riskFactors +} + +func GetControlIDsByRiskFactors(riskFactorsStr string) []string { + riskFactors := strings.Split(riskFactorsStr, ",") + controlIDSet := make(map[string]bool) + + for _, rfStr := range riskFactors { + rf := RiskFactor(rfStr) // Assuming risk factor strings match the enum names + for controlID, mappedRF := range RiskFactorMapping { + if mappedRF == rf { + controlIDSet[controlID] = true + } + } + } + + // Convert set to slice + var controlIDs []string + for controlID := range controlIDSet { + controlIDs = append(controlIDs, controlID) + } + + return controlIDs +} diff --git a/armotypes/common_test.go b/armotypes/common_test.go new file mode 100644 index 0000000..d1ac3d6 --- /dev/null +++ b/armotypes/common_test.go @@ -0,0 +1,105 @@ +package armotypes + +import ( + "github.com/stretchr/testify/assert" + "sort" + "testing" +) + +func TestGetControlIDsByRiskFactors(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "Single Risk Factor", + input: "RiskFactorInternetFacing", + expected: []string{"C-0256"}, + }, + { + name: "Multiple Risk Factors", + input: "RiskFactorPrivileged,RiskFactorSecretAccess", + expected: []string{"C-0046", "C-0057", "C-0255"}, + }, + { + name: "No Risk Factors", + input: "", + expected: []string{}, + }, + { + name: "Invalid Risk Factor", + input: "RiskFactorNonExistent", + expected: []string{}, + }, + { + name: "Duplicate Risk Factors", + input: "RiskFactorHostAccess,RiskFactorHostAccess", + expected: []string{"C-0038", "C-0041", "C-0044", "C-0048"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetControlIDsByRiskFactors(tt.input) + sort.Strings(result) + sort.Strings(tt.expected) + + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestGetRiskFactors(t *testing.T) { + tests := []struct { + name string + input []string + expected []RiskFactor + }{ + { + name: "Multiple Risk Factors", + input: []string{"C-0256", "C-0046", "C-0057", "C-0255"}, + expected: []RiskFactor{RiskFactorInternetFacing, RiskFactorPrivileged, RiskFactorSecretAccess}, + }, + { + name: "Empty controls list", + input: []string{}, + expected: nil, + }, + { + name: "nil controls list", + input: nil, + expected: nil, + }, + { + name: "Single Risk Factor", + input: []string{"C-0256"}, + expected: []RiskFactor{RiskFactorInternetFacing}, + }, + { + name: "No Risk Factors", + input: []string{"C-0000"}, + expected: nil, + }, + { + name: "Duplicate Risk Factors", + input: []string{"C-0046", "C-0046"}, + expected: []RiskFactor{RiskFactorPrivileged}, + }, + { + name: "Mixed Valid and Invalid IDs", + input: []string{"C-0046", "C-9999"}, + expected: []RiskFactor{RiskFactorPrivileged}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetRiskFactors(tt.input) + sort.Slice(result, func(i, j int) bool { return result[i] < result[j] }) + sort.Slice(tt.expected, func(i, j int) bool { return tt.expected[i] < tt.expected[j] }) + + assert.Equal(t, tt.expected, result, "GetRiskFactors(%v) = %v, want %v", tt.input, result, tt.expected) + }) + } +}