From c55d6ad13f328a10e3dc83d7cdcaca7468d109b4 Mon Sep 17 00:00:00 2001 From: Awbrey Hughlett Date: Thu, 20 Feb 2025 10:06:58 -0600 Subject: [PATCH] property extractor path traversal --- pkg/codec/config.go | 5 +- pkg/codec/modifier_base.go | 13 ++ pkg/codec/property_extractor.go | 177 ++++++++++++++---- pkg/codec/property_extractor_test.go | 17 ++ .../chain_components_interface_tests.go | 8 + 5 files changed, 178 insertions(+), 42 deletions(-) diff --git a/pkg/codec/config.go b/pkg/codec/config.go index 8632f4018..9b006cf59 100644 --- a/pkg/codec/config.go +++ b/pkg/codec/config.go @@ -337,11 +337,12 @@ func (e *EpochToTimeModifierConfig) MarshalJSON() ([]byte, error) { } type PropertyExtractorConfig struct { - FieldName string + FieldName string + EnablePathTraverse bool } func (c *PropertyExtractorConfig) ToModifier(_ ...mapstructure.DecodeHookFunc) (Modifier, error) { - return NewPropertyExtractor(upperFirstCharacter(c.FieldName)), nil + return NewPathTraversePropertyExtractor(upperFirstCharacter(c.FieldName), c.EnablePathTraverse), nil } func (c *PropertyExtractorConfig) MarshalJSON() ([]byte, error) { diff --git a/pkg/codec/modifier_base.go b/pkg/codec/modifier_base.go index d1d939f03..52a91bfa4 100644 --- a/pkg/codec/modifier_base.go +++ b/pkg/codec/modifier_base.go @@ -422,6 +422,19 @@ func valueForPath(from reflect.Value, itemType string) (any, error) { } return valueForPath(field, tail) + case reflect.Map: + head, tail := ItemTyper(itemType).Next() + + field := from.MapIndex(reflect.ValueOf(head)) + if !field.IsValid() { + return nil, fmt.Errorf("%w: field not found for path %s and itemType %s", types.ErrInvalidType, from, itemType) + } + + if tail == "" { + return field.Interface(), nil + } + + return valueForPath(reflect.ValueOf(field.Interface()), tail) default: return nil, fmt.Errorf("%w: cannot extract a field from kind %s", types.ErrInvalidType, from.Kind()) } diff --git a/pkg/codec/property_extractor.go b/pkg/codec/property_extractor.go index a5e7c38c9..0a6d9c9ef 100644 --- a/pkg/codec/property_extractor.go +++ b/pkg/codec/property_extractor.go @@ -14,19 +14,25 @@ import ( // This modifier is lossy, as TransformToOffchain will discard unwanted struct properties and // return a single element. Calling TransformToOnchain will result in unset properties. func NewPropertyExtractor(fieldName string) Modifier { - m := &propertyExtractor{ - onToOffChainType: map[reflect.Type]reflect.Type{}, - offToOnChainType: map[reflect.Type]reflect.Type{}, - fieldName: fieldName, - } + return NewPathTraversePropertyExtractor(fieldName, false) +} - return m +func NewPathTraversePropertyExtractor(fieldName string, enablePathTraverse bool) Modifier { + return &propertyExtractor{ + onToOffChainType: map[reflect.Type]reflect.Type{}, + offToOnChainType: map[reflect.Type]reflect.Type{}, + fieldName: fieldName, + enablePathTraverse: enablePathTraverse, + } } type propertyExtractor struct { - onToOffChainType map[reflect.Type]reflect.Type - offToOnChainType map[reflect.Type]reflect.Type - fieldName string + fieldName string + enablePathTraverse bool + onToOffChainType map[reflect.Type]reflect.Type + offToOnChainType map[reflect.Type]reflect.Type + onChainStructType reflect.Type + offChainStructType reflect.Type } func (e *propertyExtractor) RetypeToOffChain(onChainType reflect.Type, itemType string) (reflect.Type, error) { @@ -34,57 +40,142 @@ func (e *propertyExtractor) RetypeToOffChain(onChainType reflect.Type, itemType return nil, fmt.Errorf("%w: field name required for extraction", types.ErrInvalidConfig) } - if cached, ok := e.onToOffChainType[onChainType]; ok { + // path traverse allows an item type of Struct.FieldA.NestedField to isolate modifiers + // associated with the nested field `NestedField`. + if !e.enablePathTraverse { + itemType = "" + } + + // if itemType is empty, store the type mappings + // if itemType is not empty, assume a sub-field property is expected to be extracted + onChainStructType := onChainType + if itemType != "" { + onChainStructType = e.onChainStructType + } + + if cached, ok := e.onToOffChainType[onChainStructType]; ok { return cached, nil } + var ( + offChainType reflect.Type + err error + ) + switch onChainType.Kind() { case reflect.Pointer: - elm, err := e.RetypeToOffChain(onChainType.Elem(), "") - if err != nil { + var elm reflect.Type + + if elm, err = e.RetypeToOffChain(onChainStructType.Elem(), ""); err != nil { return nil, err } - ptr := reflect.PointerTo(elm) - e.onToOffChainType[onChainType] = ptr - e.offToOnChainType[ptr] = onChainType - - return ptr, nil + offChainType = reflect.PointerTo(elm) case reflect.Slice: - elm, err := e.RetypeToOffChain(onChainType.Elem(), "") - if err != nil { + var elm reflect.Type + + if elm, err = e.RetypeToOffChain(onChainStructType.Elem(), ""); err != nil { return nil, err } - sliceType := reflect.SliceOf(elm) - e.onToOffChainType[onChainType] = sliceType - e.offToOnChainType[sliceType] = onChainType - - return sliceType, nil + offChainType = reflect.SliceOf(elm) case reflect.Array: - elm, err := e.RetypeToOffChain(onChainType.Elem(), "") - if err != nil { + var elm reflect.Type + + if elm, err = e.RetypeToOffChain(onChainStructType.Elem(), ""); err != nil { return nil, err } - arrayType := reflect.ArrayOf(onChainType.Len(), elm) - e.onToOffChainType[onChainType] = arrayType - e.offToOnChainType[arrayType] = onChainType - - return arrayType, nil + offChainType = reflect.ArrayOf(onChainStructType.Len(), elm) case reflect.Struct: - return e.getPropTypeFromStruct(onChainType) + if offChainType, err = e.getPropTypeFromStruct(onChainStructType); err != nil { + return nil, err + } default: + // if the types don't match, it means we are attempting to traverse the main struct + if onChainType != e.onChainStructType { + return onChainType, nil + } + return nil, fmt.Errorf("%w: cannot retype the kind %v", types.ErrInvalidType, onChainType.Kind()) } + + e.onToOffChainType[onChainStructType] = offChainType + e.offToOnChainType[offChainType] = onChainStructType + + if e.onChainStructType == nil { + e.onChainStructType = onChainType + e.offChainStructType = offChainType + } + + return typeForPath(offChainType, itemType) } -func (e *propertyExtractor) TransformToOnChain(offChainValue any, _ string) (any, error) { - return extractOrExpandWithMaps(offChainValue, e.offToOnChainType, e.fieldName, expandWithMapsHelper) +func (e *propertyExtractor) TransformToOnChain(offChainValue any, itemType string) (any, error) { + offChainValue, itemType, err := e.selectType(offChainValue, e.offChainStructType, itemType) + if err != nil { + return nil, err + } + + modified, err := extractOrExpandWithMaps(offChainValue, e.offToOnChainType, e.fieldName, expandWithMapsHelper) + if err != nil { + return nil, err + } + + if itemType != "" { + // add the field name because the offChainType was nested into a new struct + itemType = fmt.Sprintf("%s.%s", e.fieldName, itemType) + + return valueForPath(reflect.ValueOf(modified), itemType) + } + + return modified, nil +} + +func (e *propertyExtractor) TransformToOffChain(onChainValue any, itemType string) (any, error) { + onChainValue, itemType, err := e.selectType(onChainValue, e.onChainStructType, itemType) + if err != nil { + return nil, err + } + + modified, err := extractOrExpandWithMaps(onChainValue, e.onToOffChainType, e.fieldName, extractWithMapsHelper) + if err != nil { + return nil, err + } + + if itemType != "" { + // remove the head from the itemType because a field was extracted + _, tail := ItemTyper(itemType).Next() + + return valueForPath(reflect.ValueOf(modified), tail) + } + + return modified, nil } -func (e *propertyExtractor) TransformToOffChain(onChainValue any, _ string) (any, error) { - return extractOrExpandWithMaps(onChainValue, e.onToOffChainType, e.fieldName, extractWithMapsHelper) +func (e *propertyExtractor) selectType(inputValue any, savedType reflect.Type, itemType string) (any, string, error) { + // set itemType to an ignore value if path traversal is not enabled + if !e.enablePathTraverse { + return inputValue, "", nil + } + + // the offChainValue might be a subfield value; get the true offChainStruct type already stored and set the value + baseStructValue := inputValue + + // path traversal is expected, but offChainValue is the value of a field, not the actual struct + // create a new struct from the stored offChainStruct with the provided value applied and all other fields set to + // their zero value. + if itemType != "" { + into := reflect.New(savedType) + + if err := applyValueForPath(into, reflect.ValueOf(inputValue), itemType); err != nil { + return nil, itemType, err + } + + baseStructValue = reflect.Indirect(into).Interface() + } + + return baseStructValue, itemType, nil } func (e *propertyExtractor) getPropTypeFromStruct(onChainType reflect.Type) (reflect.Type, error) { @@ -110,9 +201,6 @@ func (e *propertyExtractor) getPropTypeFromStruct(onChainType reflect.Type) (ref return nil, fmt.Errorf("%w: field not found in on-chain type %s", types.ErrInvalidType, e.fieldName) } - e.onToOffChainType[onChainType] = field.Type - e.offToOnChainType[field.Type] = onChainType - return field.Type, nil } @@ -186,9 +274,18 @@ func extractWithMapsHelper(rItem reflect.Value, toType reflect.Type, field strin case reflect.Pointer: elm := rItem.Elem() if elm.Kind() == reflect.Struct { - tmp, err := extractElement(rItem.Interface(), field) + var ( + tmp reflect.Value + err error + ) + + if tmp, err = extractElement(rItem.Interface(), field); err != nil { + return rItem, err + } + result := reflect.New(toType.Elem()) err = mapstructure.Decode(tmp.Interface(), result.Interface()) + return result, err } diff --git a/pkg/codec/property_extractor_test.go b/pkg/codec/property_extractor_test.go index 6f58a5b81..8a1381425 100644 --- a/pkg/codec/property_extractor_test.go +++ b/pkg/codec/property_extractor_test.go @@ -28,6 +28,7 @@ func TestPropertyExtractor(t *testing.T) { extractor := codec.NewPropertyExtractor("A") invalidExtractor := codec.NewPropertyExtractor("A.B") nestedExtractor := codec.NewPropertyExtractor("B.B") + pathTraverseExt := codec.NewPathTraversePropertyExtractor("B", true) t.Run("RetypeToOffChain sets the type for offchain to the onchain property", func(t *testing.T) { offChainType, err := extractor.RetypeToOffChain(reflect.TypeOf(nestedTestStruct{}), "") @@ -246,4 +247,20 @@ func TestPropertyExtractor(t *testing.T) { assert.Equal(t, expectedLossy, lossyOnChain) }) + + t.Run("TransformToOnChain and TransformToOffChain works for path traversal", func(t *testing.T) { + _, err := pathTraverseExt.RetypeToOffChain(reflect.PointerTo(onChainType), "") + require.NoError(t, err) + + offChainValue, err := pathTraverseExt.TransformToOffChain(int64(42), "B.B") + require.NoError(t, err) + + expectedVal := int64(42) + require.Equal(t, expectedVal, offChainValue) + + lossyOnChain, err := pathTraverseExt.TransformToOnChain(int64(42), "B") + require.NoError(t, err) + + assert.Equal(t, int64(42), lossyOnChain) + }) } diff --git a/pkg/types/interfacetests/chain_components_interface_tests.go b/pkg/types/interfacetests/chain_components_interface_tests.go index c65f9b6ff..85e31ecbb 100644 --- a/pkg/types/interfacetests/chain_components_interface_tests.go +++ b/pkg/types/interfacetests/chain_components_interface_tests.go @@ -3,6 +3,7 @@ package interfacetests import ( "errors" "fmt" + "log" "reflect" "time" @@ -701,6 +702,13 @@ func runContractReaderGetLatestValueInterfaceTests[T TestingT[T]](t T, tester Ch result := &TestStruct{} require.Eventually(t, func() bool { err := cr.GetLatestValue(ctx, bound.ReadIdentifier(EventName), primitives.Unconfirmed, nil, &result) + + if err == nil { + log.Println("test struct", ts.BigField, result.BigField) + } else { + log.Println("error", err) + } + return err == nil && reflect.DeepEqual(result, &ts) }, tester.MaxWaitTimeForEvents(), time.Millisecond*10) },