diff --git a/internal/core/arg_file_content.go b/internal/core/arg_file_content.go index 21a2a248b8..76385001de 100644 --- a/internal/core/arg_file_content.go +++ b/internal/core/arg_file_content.go @@ -19,7 +19,7 @@ func loadArgsFileContent(cmd *Command, cmdArgs interface{}) error { } fieldName := strcase.ToPublicGoName(argSpec.Name) - fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, ".")) + fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, ".")) if err != nil { continue } diff --git a/internal/core/reflect.go b/internal/core/reflect.go index 2180476fbc..0f5646b4b3 100644 --- a/internal/core/reflect.go +++ b/internal/core/reflect.go @@ -1,6 +1,7 @@ package core import ( + "errors" "fmt" "reflect" "sort" @@ -34,26 +35,33 @@ func newObjectWithForcedJSONTags(t reflect.Type) interface{} { return reflect.New(reflect.StructOf(structFieldsCopy)).Interface() } -// getValuesForFieldByName recursively search for fields in a cmdArgs' value and returns its values if they exist. +// GetValuesForFieldByName recursively search for fields in a cmdArgs' value and returns its values if they exist. // The search is based on the name of the field. -func getValuesForFieldByName(value reflect.Value, parts []string) (values []reflect.Value, err error) { +func GetValuesForFieldByName(value reflect.Value, parts []string) (values []reflect.Value, err error) { if len(parts) == 0 { return []reflect.Value{value}, nil } - switch value.Kind() { case reflect.Ptr: - return getValuesForFieldByName(value.Elem(), parts) + return GetValuesForFieldByName(value.Elem(), parts) case reflect.Slice: values := []reflect.Value(nil) + errs := []error(nil) + for i := 0; i < value.Len(); i++ { - newValues, err := getValuesForFieldByName(value.Index(i), parts[1:]) + newValues, err := GetValuesForFieldByName(value.Index(i), parts[1:]) if err != nil { - return nil, err + errs = append(errs, err) + } else { + values = append(values, newValues...) } - values = append(values, newValues...) } + + if len(values) == 0 && len(errs) != 0 { + return nil, errors.Join(errs...) + } + return values, nil case reflect.Map: @@ -70,7 +78,7 @@ func getValuesForFieldByName(value reflect.Value, parts []string) (values []refl for _, mapKey := range mapKeys { mapValue := value.MapIndex(mapKey) - newValues, err := getValuesForFieldByName(mapValue, parts[1:]) + newValues, err := GetValuesForFieldByName(mapValue, parts[1:]) if err != nil { return nil, err } @@ -93,12 +101,12 @@ func getValuesForFieldByName(value reflect.Value, parts []string) (values []refl fieldName := strcase.ToPublicGoName(parts[0]) if fieldIndex, exist := fieldIndexByName[fieldName]; exist { - return getValuesForFieldByName(value.Field(fieldIndex), parts[1:]) + return GetValuesForFieldByName(value.Field(fieldIndex), parts[1:]) } // If it does not exist we try to find it in nested anonymous field for _, fieldIndex := range anonymousFieldIndexes { - newValues, err := getValuesForFieldByName(value.Field(fieldIndex), parts) + newValues, err := GetValuesForFieldByName(value.Field(fieldIndex), parts) if err == nil { return newValues, nil } @@ -106,6 +114,5 @@ func getValuesForFieldByName(value reflect.Value, parts []string) (values []refl return nil, fmt.Errorf("field %v does not exist for %v", fieldName, value.Type().Name()) } - return nil, fmt.Errorf("case is not handled") } diff --git a/internal/core/reflect_test.go b/internal/core/reflect_test.go new file mode 100644 index 0000000000..288d622f12 --- /dev/null +++ b/internal/core/reflect_test.go @@ -0,0 +1,181 @@ +package core_test + +import ( + "net" + "reflect" + "strings" + "testing" + + "github.com/alecthomas/assert" + "github.com/scaleway/scaleway-cli/v2/internal/core" + "github.com/scaleway/scaleway-sdk-go/scw" +) + +type RequestEmbedding struct { + EmbeddingField1 string + EmbeddingField2 int +} + +type CreateRequest struct { + *RequestEmbedding + CreateField1 string + CreateField2 int +} + +type ExtendedRequest struct { + *CreateRequest + ExtendedField1 string + ExtendedField2 int +} + +type ArrowRequest struct { + PrivateNetwork *PrivateNetwork +} + +type SpecialRequest struct { + *RequestEmbedding + TabRequest []*ArrowRequest +} + +type EndpointSpecPrivateNetwork struct { + PrivateNetworkID string + ServiceIP *scw.IPNet +} + +type PrivateNetwork struct { + *EndpointSpecPrivateNetwork + OtherValue string +} + +func Test_getValuesForFieldByName(t *testing.T) { + type TestCase struct { + cmdArgs interface{} + fieldName string + expectedError string + expectedValues []reflect.Value + } + + expectedServiceIP := &scw.IPNet{ + IPNet: net.IPNet{ + IP: net.ParseIP("192.0.2.1"), + Mask: net.CIDRMask(24, 32), + }, + } + + tests := []struct { + name string + testCase TestCase + testFunc func(*testing.T, TestCase) + }{ + { + name: "Simple test", + testCase: TestCase{ + cmdArgs: &ExtendedRequest{ + CreateRequest: &CreateRequest{ + RequestEmbedding: &RequestEmbedding{ + EmbeddingField1: "value1", + EmbeddingField2: 2, + }, + CreateField1: "value3", + CreateField2: 4, + }, + ExtendedField1: "value5", + ExtendedField2: 6, + }, + fieldName: "EmbeddingField1", + expectedError: "", + expectedValues: []reflect.Value{reflect.ValueOf("value1")}, + }, + testFunc: func(t *testing.T, tc TestCase) { + values, err := core.GetValuesForFieldByName(reflect.ValueOf(tc.cmdArgs), strings.Split(tc.fieldName, ".")) + if err != nil { + assert.Equal(t, tc.expectedError, err.Error()) + } else { + if tc.expectedValues != nil && !reflect.DeepEqual(tc.expectedValues[0].Interface(), values[0].Interface()) { + t.Errorf("Expected %v, got %v", tc.expectedValues[0].Interface(), values[0].Interface()) + } + } + }, + }, + { + name: "Error test", + testCase: TestCase{ + cmdArgs: &ExtendedRequest{ + CreateRequest: &CreateRequest{ + RequestEmbedding: &RequestEmbedding{ + EmbeddingField1: "value1", + EmbeddingField2: 2, + }, + CreateField1: "value3", + CreateField2: 4, + }, + ExtendedField1: "value5", + ExtendedField2: 6, + }, + fieldName: "NotExist", + expectedError: "field NotExist does not exist for ExtendedRequest", + expectedValues: []reflect.Value{reflect.ValueOf("value1")}, + }, + testFunc: func(t *testing.T, tc TestCase) { + values, err := core.GetValuesForFieldByName(reflect.ValueOf(tc.cmdArgs), strings.Split(tc.fieldName, ".")) + if err != nil { + assert.Equal(t, tc.expectedError, err.Error()) + } else { + if tc.expectedValues != nil && !reflect.DeepEqual(tc.expectedValues[0].Interface(), values[0].Interface()) { + t.Errorf("Expected %v, got %v", tc.expectedValues[0].Interface(), values[0].Interface()) + } + } + }, + }, + { + + name: "Special test", + testCase: TestCase{ + cmdArgs: &SpecialRequest{ + RequestEmbedding: &RequestEmbedding{ + EmbeddingField1: "value1", + EmbeddingField2: 2, + }, + TabRequest: []*ArrowRequest{ + { + PrivateNetwork: &PrivateNetwork{ + EndpointSpecPrivateNetwork: &EndpointSpecPrivateNetwork{ + ServiceIP: &scw.IPNet{ + IPNet: net.IPNet{ + IP: net.ParseIP("192.0.2.1"), + Mask: net.CIDRMask(24, 32), + }, + }, + }, + }, + }, + { + PrivateNetwork: &PrivateNetwork{ + OtherValue: "hello", + }, + }, + }, + }, + fieldName: "tabRequest.{index}.privateNetwork.serviceIP", + expectedError: "", + expectedValues: []reflect.Value{reflect.ValueOf(expectedServiceIP)}, + }, + testFunc: func(t *testing.T, tc TestCase) { + values, err := core.GetValuesForFieldByName(reflect.ValueOf(tc.cmdArgs), strings.Split(tc.fieldName, ".")) + if err != nil { + assert.Equal(t, nil, err.Error()) + } else { + if tc.expectedValues != nil && !reflect.DeepEqual(tc.expectedValues[0].Interface(), values[0].Interface()) { + t.Errorf("Expected %v, got %v", tc.expectedValues[0].Interface(), values[0].Interface()) + } + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.testFunc(t, tt.testCase) + }) + } +} diff --git a/internal/core/validate.go b/internal/core/validate.go index b6c56a78ca..5b98a95284 100644 --- a/internal/core/validate.go +++ b/internal/core/validate.go @@ -45,7 +45,7 @@ func DefaultCommandValidateFunc() CommandValidateFunc { func validateArgValues(cmd *Command, cmdArgs interface{}) error { for _, argSpec := range cmd.ArgSpecs { fieldName := strcase.ToPublicGoName(argSpec.Name) - fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, ".")) + fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, ".")) if err != nil { logger.Infof("could not validate arg value for '%v': invalid fieldName: %v: %v", argSpec.Name, fieldName, err.Error()) continue @@ -75,7 +75,7 @@ func validateRequiredArgs(cmd *Command, cmdArgs interface{}, rawArgs args.RawArg } fieldName := strcase.ToPublicGoName(arg.Name) - fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, ".")) + fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, ".")) if err != nil { validationErr := fmt.Errorf("could not validate arg value for '%v': invalid field name '%v': %v", arg.Name, fieldName, err.Error()) if !arg.Required { @@ -117,7 +117,7 @@ func validateDeprecated(ctx context.Context, cmd *Command, cmdArgs interface{}, deprecatedArgs := cmd.ArgSpecs.GetDeprecated(true) for _, arg := range deprecatedArgs { fieldName := strcase.ToPublicGoName(arg.Name) - fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, ".")) + fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, ".")) if err != nil { validationErr := fmt.Errorf("could not validate arg value for '%v': invalid field name '%v': %v", arg.Name, fieldName, err.Error()) if !arg.Required {