From aeed806d55e542e47b07218d47f638d660586cf3 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Sat, 21 Dec 2024 21:50:58 +0900 Subject: [PATCH] Fix cast process for decoding of anchor value (#602) * add test case * fix cast process for decoding of anchor value --- decode.go | 68 ++++++++++++++++++++++++++++++++------------------ decode_test.go | 40 ++++++++++++++++++++--------- 2 files changed, 72 insertions(+), 36 deletions(-) diff --git a/decode.go b/decode.go index ec60fc6a..7df6cb1a 100644 --- a/decode.go +++ b/decode.go @@ -997,7 +997,11 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No if err := d.decodeValue(ctx, v, src); err != nil { return err } - dst.Set(d.castToAssignableValue(v, dst.Type())) + castedValue, err := d.castToAssignableValue(v, dst.Type(), src) + if err != nil { + return err + } + dst.Set(castedValue) case reflect.Interface: if dst.Type() == astNodeType { dst.Set(reflect.ValueOf(src)) @@ -1121,23 +1125,26 @@ func (d *Decoder) createDecodableValue(typ reflect.Type) reflect.Value { return reflect.New(typ).Elem() } -func (d *Decoder) castToAssignableValue(value reflect.Value, target reflect.Type) reflect.Value { +func (d *Decoder) castToAssignableValue(value reflect.Value, target reflect.Type, src ast.Node) (reflect.Value, error) { if target.Kind() != reflect.Ptr { - return value - } - maxTryCount := 5 - tryCount := 0 - for { - if tryCount > maxTryCount { - return value + if !value.Type().AssignableTo(target) { + return reflect.Value{}, errors.ErrTypeMismatch(target, value.Type(), src.GetToken()) } + return value, nil + } + + const maxAddrCount = 5 + + for i := 0; i < maxAddrCount; i++ { if value.Type().AssignableTo(target) { break } value = value.Addr() - tryCount++ } - return value + if !value.Type().AssignableTo(target) { + return reflect.Value{}, errors.ErrTypeMismatch(target, value.Type(), src.GetToken()) + } + return value, nil } func (d *Decoder) createDecodedNewValue( @@ -1145,9 +1152,16 @@ func (d *Decoder) createDecodedNewValue( ) (reflect.Value, error) { if node.Type() == ast.AliasType { aliasName := node.(*ast.AliasNode).Value.GetToken().Value - newValue := d.anchorValueMap[aliasName] - if newValue.IsValid() { - return newValue, nil + value := d.anchorValueMap[aliasName] + if value.IsValid() { + v, err := d.castToAssignableValue(value, typ, node) + if err == nil { + return v, nil + } + } + anchor, exists := d.anchorNodeMap[aliasName] + if exists { + node = anchor } } var newValue reflect.Value @@ -1164,10 +1178,10 @@ func (d *Decoder) createDecodedNewValue( } if node.Type() != ast.NullType { if err := d.decodeValue(ctx, newValue, node); err != nil { - return newValue, err + return reflect.Value{}, err } } - return newValue, nil + return d.castToAssignableValue(newValue, typ, node) } func (d *Decoder) keyToNodeMap(node ast.Node, ignoreMergeKey bool, getKeyOrValueNode func(*ast.MapNodeIter) ast.Node) (map[string]ast.Node, error) { @@ -1238,6 +1252,9 @@ func (d *Decoder) keyToValueNodeMap(node ast.Node, ignoreMergeKey bool) (map[str } func (d *Decoder) setDefaultValueIfConflicted(v reflect.Value, fieldMap StructFieldMap) error { + for v.Type().Kind() == reflect.Ptr { + v = v.Elem() + } typ := v.Type() if typ.Kind() != reflect.Struct { return nil @@ -1413,7 +1430,11 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N if aliasName != "" { newFieldValue := d.anchorValueMap[aliasName] if newFieldValue.IsValid() { - fieldValue.Set(d.castToAssignableValue(newFieldValue, fieldValue.Type())) + value, err := d.castToAssignableValue(newFieldValue, fieldValue.Type(), d.anchorNodeMap[aliasName]) + if err != nil { + return err + } + fieldValue.Set(value) } } continue @@ -1459,7 +1480,7 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N continue } _ = d.setDefaultValueIfConflicted(newFieldValue, structFieldMap) - fieldValue.Set(d.castToAssignableValue(newFieldValue, fieldValue.Type())) + fieldValue.Set(newFieldValue) continue } v, exists := keyToNodeMap[structField.RenderName] @@ -1488,7 +1509,7 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N } continue } - fieldValue.Set(d.castToAssignableValue(newFieldValue, fieldValue.Type())) + fieldValue.Set(newFieldValue) } if foundErr != nil { return foundErr @@ -1566,9 +1587,8 @@ func (d *Decoder) decodeArray(ctx context.Context, dst reflect.Value, src ast.No foundErr = err } continue - } else { - arrayValue.Index(idx).Set(d.castToAssignableValue(dstValue, elemType)) } + arrayValue.Index(idx).Set(dstValue) } idx++ } @@ -1613,7 +1633,7 @@ func (d *Decoder) decodeSlice(ctx context.Context, dst reflect.Value, src ast.No } continue } - sliceValue = reflect.Append(sliceValue, d.castToAssignableValue(dstValue, elemType)) + sliceValue = reflect.Append(sliceValue, dstValue) } dst.Set(sliceValue) if foundErr != nil { @@ -1796,7 +1816,7 @@ func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node } if !k.IsValid() { // expect nil key - mapValue.SetMapIndex(d.createDecodableValue(keyType), d.castToAssignableValue(dstValue, valueType)) + mapValue.SetMapIndex(d.createDecodableValue(keyType), dstValue) continue } if keyType.Kind() != k.Kind() { @@ -1805,7 +1825,7 @@ func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node key.GetToken(), ) } - mapValue.SetMapIndex(k, d.castToAssignableValue(dstValue, valueType)) + mapValue.SetMapIndex(k, dstValue) } dst.Set(mapValue) if foundErr != nil { diff --git a/decode_test.go b/decode_test.go index 4f0fc11e..d90587f9 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1321,12 +1321,6 @@ func TestDecoder_TypeConversionError(t *testing.T) { if !strings.Contains(err.Error(), msg) { t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } - if len(v) == 0 || len(v["v"]) == 0 { - t.Fatal("failed to decode value") - } - if v["v"][0] != 1 { - t.Fatal("failed to decode value") - } }) t.Run("string to int", func(t *testing.T) { var v map[string][]int @@ -1338,12 +1332,6 @@ func TestDecoder_TypeConversionError(t *testing.T) { if !strings.Contains(err.Error(), msg) { t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } - if len(v) == 0 || len(v["v"]) == 0 { - t.Fatal("failed to decode value") - } - if v["v"][0] != 1 { - t.Fatal("failed to decode value") - } }) }) t.Run("overflow error", func(t *testing.T) { @@ -2739,6 +2727,34 @@ func (u *unmarshalList) UnmarshalYAML(b []byte) error { return nil } +func TestDecoder_DecodeWithAnchorAnyValue(t *testing.T) { + type Config struct { + Env []string `json:"env"` + } + + type Schema struct { + Def map[string]any `json:"def"` + Config Config `json:"config"` + } + + data := ` +def: + myenv: &my_env + - VAR1=1 + - VAR2=2 +config: + env: *my_env +` + + var cfg Schema + if err := yaml.NewDecoder(strings.NewReader(data)).Decode(&cfg); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(cfg.Config.Env, []string{"VAR1=1", "VAR2=2"}) { + t.Fatalf("failed to decode value. actual = %+v", cfg) + } +} + func TestDecoder_UnmarshalBytesWithSeparatedList(t *testing.T) { yml := ` a: