Skip to content

Commit

Permalink
property extractor path traversal
Browse files Browse the repository at this point in the history
  • Loading branch information
EasterTheBunny committed Feb 20, 2025
1 parent 1c57750 commit 5773256
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 42 deletions.
5 changes: 3 additions & 2 deletions pkg/codec/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,12 @@ func (e *EpochToTimeModifierConfig) MarshalJSON() ([]byte, error) {
}

type PropertyExtractorConfig struct {
FieldName string
FieldName string
EnablePathTraverse bool
}

func (c *PropertyExtractorConfig) ToModifier(_ ...mapstructure.DecodeHookFunc) (Modifier, error) {
return NewPropertyExtractor(upperFirstCharacter(c.FieldName)), nil
return NewPathTraversePropertyExtractor(upperFirstCharacter(c.FieldName), c.EnablePathTraverse), nil
}

func (c *PropertyExtractorConfig) MarshalJSON() ([]byte, error) {
Expand Down
13 changes: 13 additions & 0 deletions pkg/codec/modifier_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,19 @@ func valueForPath(from reflect.Value, itemType string) (any, error) {
}

return valueForPath(field, tail)
case reflect.Map:
head, tail := ItemTyper(itemType).Next()

field := from.MapIndex(reflect.ValueOf(head))
if !field.IsValid() {
return nil, fmt.Errorf("%w: field not found for path %s and itemType %s", types.ErrInvalidType, from, itemType)
}

if tail == "" {
return field.Interface(), nil
}

return valueForPath(reflect.ValueOf(field.Interface()), tail)
default:
return nil, fmt.Errorf("%w: cannot extract a field from kind %s", types.ErrInvalidType, from.Kind())
}
Expand Down
177 changes: 137 additions & 40 deletions pkg/codec/property_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,77 +14,168 @@ import (
// This modifier is lossy, as TransformToOffchain will discard unwanted struct properties and
// return a single element. Calling TransformToOnchain will result in unset properties.
func NewPropertyExtractor(fieldName string) Modifier {
m := &propertyExtractor{
onToOffChainType: map[reflect.Type]reflect.Type{},
offToOnChainType: map[reflect.Type]reflect.Type{},
fieldName: fieldName,
}
return NewPathTraversePropertyExtractor(fieldName, false)
}

return m
func NewPathTraversePropertyExtractor(fieldName string, enablePathTraverse bool) Modifier {
return &propertyExtractor{
onToOffChainType: map[reflect.Type]reflect.Type{},
offToOnChainType: map[reflect.Type]reflect.Type{},
fieldName: fieldName,
enablePathTraverse: enablePathTraverse,
}
}

type propertyExtractor struct {
onToOffChainType map[reflect.Type]reflect.Type
offToOnChainType map[reflect.Type]reflect.Type
fieldName string
fieldName string
enablePathTraverse bool
onToOffChainType map[reflect.Type]reflect.Type
offToOnChainType map[reflect.Type]reflect.Type
onChainStructType reflect.Type
offChainStructType reflect.Type
}

func (e *propertyExtractor) RetypeToOffChain(onChainType reflect.Type, itemType string) (reflect.Type, error) {
if e.fieldName == "" {
return nil, fmt.Errorf("%w: field name required for extraction", types.ErrInvalidConfig)
}

if cached, ok := e.onToOffChainType[onChainType]; ok {
// path traverse allows an item type of Struct.FieldA.NestedField to isolate modifiers
// associated with the nested field `NestedField`.
if !e.enablePathTraverse {
itemType = ""
}

// if itemType is empty, store the type mappings
// if itemType is not empty, assume a sub-field property is expected to be extracted
onChainStructType := onChainType
if itemType != "" {
onChainStructType = e.onChainStructType
}

if cached, ok := e.onToOffChainType[onChainStructType]; ok {
return cached, nil
}

var (
offChainType reflect.Type
err error
)

switch onChainType.Kind() {
case reflect.Pointer:
elm, err := e.RetypeToOffChain(onChainType.Elem(), "")
if err != nil {
var elm reflect.Type

if elm, err = e.RetypeToOffChain(onChainStructType.Elem(), ""); err != nil {
return nil, err
}

ptr := reflect.PointerTo(elm)
e.onToOffChainType[onChainType] = ptr
e.offToOnChainType[ptr] = onChainType

return ptr, nil
offChainType = reflect.PointerTo(elm)
case reflect.Slice:
elm, err := e.RetypeToOffChain(onChainType.Elem(), "")
if err != nil {
var elm reflect.Type

if elm, err = e.RetypeToOffChain(onChainStructType.Elem(), ""); err != nil {
return nil, err
}

sliceType := reflect.SliceOf(elm)
e.onToOffChainType[onChainType] = sliceType
e.offToOnChainType[sliceType] = onChainType

return sliceType, nil
offChainType = reflect.SliceOf(elm)
case reflect.Array:
elm, err := e.RetypeToOffChain(onChainType.Elem(), "")
if err != nil {
var elm reflect.Type

if elm, err = e.RetypeToOffChain(onChainStructType.Elem(), ""); err != nil {
return nil, err
}

arrayType := reflect.ArrayOf(onChainType.Len(), elm)
e.onToOffChainType[onChainType] = arrayType
e.offToOnChainType[arrayType] = onChainType

return arrayType, nil
offChainType = reflect.ArrayOf(onChainStructType.Len(), elm)
case reflect.Struct:
return e.getPropTypeFromStruct(onChainType)
if offChainType, err = e.getPropTypeFromStruct(onChainStructType); err != nil {
return nil, err
}
default:
// if the types don't match, it means we are attempting to traverse the main struct
if onChainType != e.onChainStructType {
return onChainType, nil
}

return nil, fmt.Errorf("%w: cannot retype the kind %v", types.ErrInvalidType, onChainType.Kind())
}

e.onToOffChainType[onChainStructType] = offChainType
e.offToOnChainType[offChainType] = onChainStructType

if e.onChainStructType == nil {
e.onChainStructType = onChainType
e.offChainStructType = offChainType
}

return typeForPath(offChainType, itemType)
}

func (e *propertyExtractor) TransformToOnChain(offChainValue any, _ string) (any, error) {
return extractOrExpandWithMaps(offChainValue, e.offToOnChainType, e.fieldName, expandWithMapsHelper)
func (e *propertyExtractor) TransformToOnChain(offChainValue any, itemType string) (any, error) {
offChainValue, itemType, err := e.selectType(offChainValue, e.offChainStructType, itemType)
if err != nil {
return nil, err
}

modified, err := extractOrExpandWithMaps(offChainValue, e.offToOnChainType, e.fieldName, expandWithMapsHelper)
if err != nil {
return nil, err
}

if itemType != "" {
// add the field name because the offChainType was nested into a new struct
itemType = fmt.Sprintf("%s.%s", e.fieldName, itemType)

return valueForPath(reflect.ValueOf(modified), itemType)
}

return modified, nil
}

func (e *propertyExtractor) TransformToOffChain(onChainValue any, itemType string) (any, error) {
onChainValue, itemType, err := e.selectType(onChainValue, e.onChainStructType, itemType)
if err != nil {
return nil, err
}

modified, err := extractOrExpandWithMaps(onChainValue, e.onToOffChainType, e.fieldName, extractWithMapsHelper)
if err != nil {
return nil, err
}

if itemType != "" {
// remove the head from the itemType because a field was extracted
_, tail := ItemTyper(itemType).Next()

return valueForPath(reflect.ValueOf(modified), tail)
}

return modified, nil
}

func (e *propertyExtractor) TransformToOffChain(onChainValue any, _ string) (any, error) {
return extractOrExpandWithMaps(onChainValue, e.onToOffChainType, e.fieldName, extractWithMapsHelper)
func (e *propertyExtractor) selectType(inputValue any, savedType reflect.Type, itemType string) (any, string, error) {
// set itemType to an ignore value if path traversal is not enabled
if !e.enablePathTraverse {
return inputValue, "", nil
}

// the offChainValue might be a subfield value; get the true offChainStruct type already stored and set the value
baseStructValue := inputValue

// path traversal is expected, but offChainValue is the value of a field, not the actual struct
// create a new struct from the stored offChainStruct with the provided value applied and all other fields set to
// their zero value.
if itemType != "" {
into := reflect.New(savedType)

if err := applyValueForPath(into, reflect.ValueOf(inputValue), itemType); err != nil {
return nil, itemType, err
}

baseStructValue = reflect.Indirect(into).Interface()
}

return baseStructValue, itemType, nil
}

func (e *propertyExtractor) getPropTypeFromStruct(onChainType reflect.Type) (reflect.Type, error) {
Expand All @@ -110,9 +201,6 @@ func (e *propertyExtractor) getPropTypeFromStruct(onChainType reflect.Type) (ref
return nil, fmt.Errorf("%w: field not found in on-chain type %s", types.ErrInvalidType, e.fieldName)
}

e.onToOffChainType[onChainType] = field.Type
e.offToOnChainType[field.Type] = onChainType

return field.Type, nil
}

Expand Down Expand Up @@ -186,9 +274,18 @@ func extractWithMapsHelper(rItem reflect.Value, toType reflect.Type, field strin
case reflect.Pointer:
elm := rItem.Elem()
if elm.Kind() == reflect.Struct {
tmp, err := extractElement(rItem.Interface(), field)
var (
tmp reflect.Value
err error
)

if tmp, err = extractElement(rItem.Interface(), field); err != nil {
return rItem, err
}

result := reflect.New(toType.Elem())
err = mapstructure.Decode(tmp.Interface(), result.Interface())

return result, err
}

Expand Down
17 changes: 17 additions & 0 deletions pkg/codec/property_extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ func TestPropertyExtractor(t *testing.T) {
extractor := codec.NewPropertyExtractor("A")
invalidExtractor := codec.NewPropertyExtractor("A.B")
nestedExtractor := codec.NewPropertyExtractor("B.B")
pathTraverseExt := codec.NewPathTraversePropertyExtractor("B", true)

t.Run("RetypeToOffChain sets the type for offchain to the onchain property", func(t *testing.T) {
offChainType, err := extractor.RetypeToOffChain(reflect.TypeOf(nestedTestStruct{}), "")
Expand Down Expand Up @@ -246,4 +247,20 @@ func TestPropertyExtractor(t *testing.T) {

assert.Equal(t, expectedLossy, lossyOnChain)
})

t.Run("TransformToOnChain and TransformToOffChain works for path traversal", func(t *testing.T) {
_, err := pathTraverseExt.RetypeToOffChain(reflect.PointerTo(onChainType), "")
require.NoError(t, err)

offChainValue, err := pathTraverseExt.TransformToOffChain(int64(42), "B.B")
require.NoError(t, err)

expectedVal := int64(42)
require.Equal(t, expectedVal, offChainValue)

lossyOnChain, err := pathTraverseExt.TransformToOnChain(int64(42), "B")
require.NoError(t, err)

assert.Equal(t, int64(42), lossyOnChain)
})
}
8 changes: 8 additions & 0 deletions pkg/types/interfacetests/chain_components_interface_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package interfacetests
import (
"errors"
"fmt"
"log"
"reflect"
"time"

Expand Down Expand Up @@ -701,6 +702,13 @@ func runContractReaderGetLatestValueInterfaceTests[T TestingT[T]](t T, tester Ch
result := &TestStruct{}
require.Eventually(t, func() bool {
err := cr.GetLatestValue(ctx, bound.ReadIdentifier(EventName), primitives.Unconfirmed, nil, &result)

if err == nil {
log.Println("test struct", ts.BigField, result.BigField)
} else {
log.Println("error", err)
}

return err == nil && reflect.DeepEqual(result, &ts)
}, tester.MaxWaitTimeForEvents(), time.Millisecond*10)
},
Expand Down

0 comments on commit 5773256

Please sign in to comment.