diff --git a/ognl.go b/ognl.go index 2ca040f..960dee4 100644 --- a/ognl.go +++ b/ognl.go @@ -5,6 +5,7 @@ import ( "fmt" "reflect" "strconv" + "strings" "unsafe" ) @@ -26,6 +27,8 @@ var ErrUnableExpand = errors.New("unable to expand") var ErrInvalidValue = errors.New("invalid value") +var ErrInvalidSet = errors.New("invalid set") + // Type is Result type type Type int @@ -525,6 +528,21 @@ loop: return string(key), index } +func parseLastKeyIndex(selector string) int { + idx := len(selector) - 1 + for ; idx > 0; idx-- { + switch selector[idx] { + case ' ', '\t', '\n', '\r', '.', '#': + if idx-1 >= 0 && selector[idx-1] == '\\' { + continue + } else { + return idx + } + } + } + return idx +} + func GetMany(value interface{}, path ...string) []Result { results := make([]Result, 0, len(path)) for _, s := range path { @@ -749,9 +767,216 @@ func warpError(err error, object interface{}, path string) error { return fmt.Errorf("object:%v,path:%s,err: %w", object, path, err) } -func min(a, b int) int { - if a > b { - return b +func Set(obj interface{}, path string, value interface{}) error { + idx := parseLastKeyIndex(path) + parentPath := path[:idx] + offset := 1 + if idx == 0 && len(path) > 0 && (path[0] != '.' && path[0] != '#' && path[0] != ' ') { + offset = 0 + } + key := strings.ReplaceAll(path[idx+offset:], "\\", "") + + if key == "" { + return fmt.Errorf("path:%s target path is empty", parentPath) + } + result, err := GetE(obj, parentPath) + if err != nil { + return err + } + + if !result.Effective() { + return fmt.Errorf("path:%s, invalid parent obj", parentPath) + } + + if result.deployment { + list := result.raw.([]interface{}) + ln := len(list) + for i := 0; i < ln; i++ { + err = set(list[i], key, value) + if err != nil { + return err + } + } + return nil + } + return set(result.raw, key, value) +} + +func set(obj interface{}, key string, value interface{}) error { + + if IsNil(obj) { + return ErrInvalidValue + } + + v, err := strconv.Atoi(key) + digit := err == nil && v >= 0 + + t, f := reflect.TypeOf(obj), reflect.ValueOf(obj) + if digit { + return setInt(t, f, v, value) + } + return setString(t, f, key, value) +} + +func setString(t reflect.Type, v reflect.Value, key string, value interface{}) error { + if !v.IsValid() { + return ErrInvalidValue + } + + switch v.Kind() { + case reflect.Ptr, reflect.Interface: + if v.IsNil() { + return ErrInvalidValue + } + if !v.Elem().IsValid() { + return ErrInvalidValue + } + return setString(t.Elem(), v.Elem(), key, value) + + case reflect.Map: + if t.Key().Kind() != reflect.String { + return ErrMapKeyMustString + } + + newValue := reflect.ValueOf(value) + if t.Elem() == newValue.Type() { + v.SetMapIndex(reflect.ValueOf(key), newValue) + return nil + } else if newValue.Type().ConvertibleTo(t.Elem()) { + v.SetMapIndex(reflect.ValueOf(key), newValue.Convert(t.Elem())) + return nil + } else { + return fmt.Errorf("type mismatch in map assignment: want %s, got %s", t.Elem().String(), newValue.Type().String()) + } + + case reflect.Struct: + field := v.FieldByName(key) + if !field.IsValid() { + return ErrStructIndexOutOfBounds + } + if !field.CanSet() { + return ErrInvalidSet + } + + newValue := reflect.ValueOf(value) + if newValue.Type() == field.Type() { + field.Set(newValue) + return nil + } else if newValue.Type().ConvertibleTo(field.Type()) { + field.Set(newValue.Convert(field.Type())) + return nil + } else { + return fmt.Errorf("type mismatch in struct assignment: want %s, got %s", field.Type().String(), newValue.Type().String()) + } + + default: + return ErrInvalidStructure + } +} + +func setInt(t reflect.Type, v reflect.Value, key int, value interface{}) error { + if !v.IsValid() { + return ErrInvalidValue + } + + switch v.Kind() { + case reflect.Ptr, reflect.Interface: + if v.IsNil() { + return ErrInvalidValue + } + if !v.Elem().IsValid() { + return ErrInvalidValue + } + + return setInt(t.Elem(), v.Elem(), key, value) + case reflect.Map: + if t.Key().Kind() != reflect.Int { + return ErrMapKeyMustInt + } + + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } + + newValue := reflect.ValueOf(value) + if newValue.Type() == t.Elem() { + v.SetMapIndex(reflect.ValueOf(key), newValue) + return nil + } else if newValue.Type().ConvertibleTo(t.Elem()) { + v.SetMapIndex(reflect.ValueOf(key), newValue.Convert(t.Elem())) + return nil + } else { + return fmt.Errorf("type mismatch in map assignment: want %s, got %s", t.Elem().String(), newValue.Type().String()) + } + + case reflect.Slice, reflect.Array: + if key < 0 || key >= v.Len() { + return ErrIndexOutOfBounds + } + + field := v.Index(key) + if !field.IsValid() { + return ErrInvalidSet + } + if !field.CanSet() { + return ErrInvalidSet + } + + newValue := reflect.ValueOf(value) + if newValue.Type() == field.Type() { + field.Set(newValue) + return nil + } else if newValue.Type().ConvertibleTo(field.Type()) { + field.Set(newValue.Convert(field.Type())) + return nil + } else { + return fmt.Errorf("type mismatch in slice assignment: want %s, got %s", field.Type().String(), newValue.Type().String()) + } + + case reflect.Struct: + if key < 0 || key >= v.NumField() { + return ErrStructIndexOutOfBounds + } + + field := v.Field(key) + if !field.IsValid() { + return ErrInvalidSet + } + if !field.CanSet() { + return ErrInvalidSet + } + + newValue := reflect.ValueOf(value) + if newValue.Type() == field.Type() { + field.Set(newValue) + return nil + } else if newValue.Type().ConvertibleTo(field.Type()) { + field.Set(newValue.Convert(field.Type())) + return nil + } else { + return fmt.Errorf("type mismatch in struct assignment: want %s, got %s", field.Type().String(), newValue.Type().String()) + } + + default: + return ErrInvalidStructure + } +} + +func IsNil(value interface{}) bool { + if value == nil { + return true + } + v := reflect.ValueOf(value) + switch v.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Slice: + return v.IsNil() + case reflect.Ptr: + elem := v.Elem() + if !elem.IsValid() { + return true + } + return IsNil(elem.Interface()) + default: + return false } - return a } diff --git a/ognl_test.go b/ognl_test.go index a9977d1..daf4461 100644 --- a/ognl_test.go +++ b/ognl_test.go @@ -459,3 +459,228 @@ func TestDeep(t *testing.T) { t.Log(Get(t1, "##").Value()) // []interface{}{"first","first",t1,7,7} t.Log(Get(t1, "##").Values()) // []interface{}{"first","first",t1,7,7} } + +func Test_parseLastKeyIndex(t *testing.T) { + type args struct { + selector string + } + tests := []struct { + name string + args args + wantVal string + }{ + { + name: "", + args: args{ + selector: ".name", + }, + wantVal: "", + }, + { + name: "", + args: args{ + selector: "First", + }, + wantVal: "", + }, + { + name: "", + args: args{ + selector: "#", + }, + wantVal: "", + }, + { + name: "", + args: args{ + selector: "##", + }, + wantVal: "#", + }, + { + name: "", + args: args{ + selector: "###", + }, + wantVal: "##", + }, + { + name: "", + args: args{ + selector: "Middle.Middle", + }, + wantVal: "Middle", + }, + { + name: "", + args: args{ + selector: "Middle.Middle#", + }, + wantVal: "Middle.Middle", + }, + { + name: "", + args: args{ + selector: "Foo\\.Bar\\.Name", + }, + wantVal: "", + }, + { + name: "", + args: args{ + selector: "Foo\\.\\.\\.\\.\\.Bar", + }, + wantVal: "", + }, + { + name: "", + args: args{ + selector: "Foo\\.\\.\\.\\.\\.Bar.1", + }, + wantVal: "Foo\\.\\.\\.\\.\\.Bar", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + idx := parseLastKeyIndex(tt.args.selector) + assert.Equalf(t, tt.wantVal, tt.args.selector[:idx], "parseLastKeyIndex(%v)", tt.args.selector[:idx]) + }) + } +} + +func TestSet(t *testing.T) { + var ( + t2 = &Mock{ + Name: "t2", + Age: 2, + } + hash1 = map[string]interface{}{ + "string1": "string", + "int1": 1, + "t2": t2, + "Foo.Bar.Name": "bar name 001", + } + t3 = &Mock{ + Name: "t3", + Age: 3, + } + t4 = &Mock{ + Name: "t4", + Age: 4, + Hash1: map[string]interface{}{}, + } + hash2 = map[int]interface{}{ + 2: t2, + 3: t3, + 4: t4, + } + list = []*Mock{t2, t3, t4} + array = [3]*Mock{t2, t3, t4} + t1 = &Mock{ + Name: "t1", + lName: "lt1", + Age: 1, + lAge: 11, + Hash1: hash1, + lHash1: hash1, + Hash2: hash2, + lHash2: hash2, + List: list, + lList: list, + Array: array, + lArray: array, + } + ) + hash1["t1"] = t1 + + type args struct { + path string + value interface{} + } + tests := []struct { + name string + args args + wantErr assert.ErrorAssertionFunc + }{ + { + name: "", + args: args{ + path: "Name", + value: "t1change", + }, + wantErr: nil, + }, + { + name: "", + args: args{ + path: "Age", + value: uint(10), + }, + wantErr: nil, + }, + { + name: "", + args: args{ + path: "Hash1", + value: map[string]interface{}{ + "key": "value", + }, + }, + wantErr: nil, + }, + { + name: "", + args: args{ + path: "Array", + value: [3]*Mock{ + t1, + }, + }, + wantErr: nil, + }, + { + name: "", + args: args{ + path: "List", + value: []*Mock{ + t3, + }, + }, + wantErr: nil, + }, + { + name: "", + args: args{ + path: "List.0.Name", + value: "hhh", + }, + wantErr: nil, + }, + { + name: "", + args: args{ + path: "Array.0.Name", + value: "hhh", + }, + wantErr: nil, + }, + { + name: "", + args: args{ + path: "Hash1.Foo\\.Bar\\.Name", + value: "bar name 002", + }, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := Set(t1, tt.args.path, tt.args.value) + if tt.wantErr != nil { + tt.wantErr(t, err) + } else { + assert.NoError(t, err) + } + }) + } +}