Skip to content

Commit

Permalink
Fix get token prices read handling (#1075)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilija42 authored Feb 13, 2025
1 parent 376726a commit a95bf7c
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 62 deletions.
46 changes: 27 additions & 19 deletions integration-tests/relayinterface/chain_components_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -799,19 +799,19 @@ func (it *SolanaChainComponentsInterfaceTester[T]) buildContractReaderConfig(t T
ResponseAddressHardCoder: &commoncodec.HardCodeModifierConfig{
// placeholder values, whatever is put as value gets replaced with a solana pub key anyway
OffChainValues: map[string]any{
"SharedAddress": solana.PublicKey{},
"AddressToShare": solana.PublicKey{},
},
},
OutputModifications: commoncodec.ModifiersConfig{
&commoncodec.HardCodeModifierConfig{
OffChainValues: map[string]any{"U": "", "V": false},
"SharedAddress": "",
"AddressToShare": "",
},
},
}

multiReadDef := readWithAddressHardCodedIntoResponseDef
multiReadDef.ResponseAddressHardCoder = nil
multiReadDef.OutputModifications = commoncodec.ModifiersConfig{
&commoncodec.HardCodeModifierConfig{
OffChainValues: map[string]any{"U": "", "V": false},
},
}
multiReadDef.MultiReader = &config.MultiReader{
Reads: []config.ReadDefinition{{
ChainSpecificName: "MultiRead2",
Expand All @@ -820,14 +820,31 @@ func (it *SolanaChainComponentsInterfaceTester[T]) buildContractReaderConfig(t T
}},
}

idl := mustUnmarshalIDL(t, string(it.Helper.GetPrimaryIDL(t)))
idl.Accounts = append(idl.Accounts, codec.IdlTypeDef{
Name: "USDPerToken",
Type: codec.IdlTypeDefTy{
Kind: codec.IdlTypeDefTyKindStruct,
Fields: &codec.IdlTypeDefStruct{
{
Name: "tokenPrices",
Type: codec.IdlType{
AsIdlTypeVec: &codec.IdlTypeVec{Vec: codec.IdlType{AsIdlTypeDefined: &codec.IdlTypeDefined{Defined: "TimestampedPackedU224"}}},
},
},
},
},
})

return config.ContractReader{
Namespaces: map[string]config.ChainContractReader{
AnyContractName: {
IDL: mustUnmarshalIDL(t, string(it.Helper.GetPrimaryIDL(t))),
IDL: idl,
Reads: map[string]config.ReadDefinition{
ReadWithAddressHardCodedIntoResponse: readWithAddressHardCodedIntoResponseDef,
GetTokenPrices: {
ChainSpecificName: "BillingTokenConfigWrapper",
ChainSpecificName: "USDPerToken",
ReadType: config.Account,
PDADefinition: codec.PDATypeDef{
Prefix: []byte("fee_billing_token_config"),
Seeds: []codec.PDASeed{
Expand All @@ -842,17 +859,8 @@ func (it *SolanaChainComponentsInterfaceTester[T]) buildContractReaderConfig(t T
},
},
OutputModifications: commoncodec.ModifiersConfig{
&commoncodec.DropModifierConfig{
Fields: []string{"Config"},
},
&commoncodec.HardCodeModifierConfig{
OffChainValues: map[string]any{
"Response": make([]TimestampedUnixBig, 1000),
},
},
&commoncodec.PropertyExtractorConfig{FieldName: "Response"},
&commoncodec.PropertyExtractorConfig{FieldName: "TokenPrices"},
},
ReadType: config.Account,
},
MultiRead: multiReadDef,
MultiReadWithParamsReuse: {
Expand Down
93 changes: 54 additions & 39 deletions pkg/solana/chainreader/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,10 @@ func (s *ContractReaderService) GetLatestValue(ctx context.Context, readIdentifi

// TODO this is a temporary edge case - NONEVM-1320
if values.reads[0].readName == GetTokenPrices {
return s.handleGetTokenPricesGetLatestValue(ctx, params, values, returnVal)
if err := s.handleGetTokenPricesGetLatestValue(ctx, params, values, returnVal); err != nil {
return fmt.Errorf("failed to read contract: %q, account: %q err: %w", values.contract, values.reads[0].readName, err)
}
return nil
}

batch := []call{
Expand Down Expand Up @@ -635,72 +638,84 @@ func (s *ContractReaderService) handleGetTokenPricesGetLatestValue(
params any,
values readValues,
returnVal any,
) error {
) (err error) {
// shouldn't happen, but just to be sure
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic recovered: %v", r)
}
}()

pdaAddresses, err := s.getPDAsForGetTokenPrices(params, values)
if err != nil {
return err
}

data, err := s.client.GetMultipleAccountData(ctx, pdaAddresses...)
if err != nil {
return fmt.Errorf(
"for contract %q read %q: failed to get multiple account data: %w",
values.contract, values.reads[0].readName, err,
)
return err
}

// -------------- Fill out the returnVal slice with data --------------
// can't typecast returnVal so we have to use reflection here

// Ensure `returnVal` is a pointer to a slice we can populate.
returnSliceVal := reflect.ValueOf(returnVal)
if returnSliceVal.Kind() == reflect.Ptr {
if returnSliceVal.Kind() != reflect.Ptr {
return fmt.Errorf("expected <**[]*struct { Value *big.Int; Timestamp *int64 } Value>, got %q", returnSliceVal.String())
}
returnSliceVal = returnSliceVal.Elem()

returnSliceValType := returnSliceVal.Type()
if returnSliceValType.Kind() != reflect.Ptr {
return fmt.Errorf("expected <*[]*struct { Value *big.Int; Timestamp *int64 } Value>, got %q", returnSliceValType.String())
}

sliceType := returnSliceValType.Elem()
if sliceType.Kind() != reflect.Slice {
return fmt.Errorf("expected []*struct { Value *big.Int; Timestamp *int64 }, got %q", sliceType.String())
}

if returnSliceVal.IsNil() {
// init a slice
sliceVal := reflect.MakeSlice(sliceType, 0, 0)

// create a pointer to that slice to match what slicePtr
slicePtr := reflect.New(sliceType)
slicePtr.Elem().Set(sliceVal)

returnSliceVal.Set(slicePtr)
returnSliceVal = returnSliceVal.Elem()
if returnSliceVal.Kind() == reflect.Ptr {
returnSliceVal = returnSliceVal.Elem()
}
}
if returnSliceVal.Kind() != reflect.Slice {
return fmt.Errorf(
"for contract %q read %q: expected `returnVal` to be a slice, got %s",
values.contract, values.reads[0].readName, returnSliceVal.Kind(),
)

pointerType := sliceType.Elem()
if pointerType.Kind() != reflect.Ptr {
return fmt.Errorf("expected *struct { Value *big.Int; Timestamp *int64 }, got %q", pointerType.String())
}

underlyingStruct := pointerType.Elem()
if underlyingStruct.Kind() != reflect.Struct {
return fmt.Errorf("expected struct { Value *big.Int; Timestamp *int64 }, got %q", underlyingStruct.String())
}

elemType := returnSliceVal.Type().Elem()
for _, d := range data {
var wrapper fee_quoter.BillingTokenConfigWrapper
if err = wrapper.UnmarshalWithDecoder(bin.NewBorshDecoder(d)); err != nil {
return fmt.Errorf(
"for contract %q read %q: failed to unmarshal account data: %w",
values.contract, values.reads[0].readName, err,
)
return err
}

newElem := reflect.New(elemType).Elem()

newElemPtr := reflect.New(underlyingStruct)
newElem := newElemPtr.Elem()
valueField := newElem.FieldByName("Value")
if !valueField.IsValid() {
return fmt.Errorf(
"for contract %q read %q: struct type missing `Value` field",
values.contract, values.reads[0].readName,
)
return fmt.Errorf("field `Value` missing from %q", newElem.String())
}
valueField.Set(reflect.ValueOf(big.NewInt(0).SetBytes(wrapper.Config.UsdPerToken.Value[:])))

valueField.Set(reflect.ValueOf(big.NewInt(0).SetBytes(wrapper.Config.UsdPerToken.Value[:])))
timestampField := newElem.FieldByName("Timestamp")
if !timestampField.IsValid() {
return fmt.Errorf(
"for contract %q read %q: struct type missing `Timestamp` field",
values.contract, values.reads[0].readName,
)
return fmt.Errorf("field `Timestamp` missing from %q", newElem.String())
}

// nolint:gosec
// G115: integer overflow conversion int64 -&gt; uint32
timestampField.Set(reflect.ValueOf(uint32(wrapper.Config.UsdPerToken.Timestamp)))

returnSliceVal.Set(reflect.Append(returnSliceVal, newElem))
timestampField.Set(reflect.ValueOf(&wrapper.Config.UsdPerToken.Timestamp))
returnSliceVal.Set(reflect.Append(returnSliceVal, newElemPtr))
}

return nil
Expand Down
8 changes: 4 additions & 4 deletions pkg/solana/codec/anchoridl.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func (env *IdlType) UnmarshalJSON(data []byte) error {
if err := utilz.TranscodeJSON(temp, &target); err != nil {
return err
}
env.asIdlTypeDefined = &target
env.AsIdlTypeDefined = &target
}
if got, ok := v["array"]; ok {
if _, ok := got.([]interface{}); !ok {
Expand Down Expand Up @@ -303,7 +303,7 @@ type IdlType struct {
AsString IdlTypeAsString
AsIdlTypeVec *IdlTypeVec
asIdlTypeOption *IdlTypeOption
asIdlTypeDefined *IdlTypeDefined
AsIdlTypeDefined *IdlTypeDefined
AsIdlTypeArray *IdlTypeArray
}

Expand All @@ -323,7 +323,7 @@ func (env *IdlType) IsIdlTypeOption() bool {
return env.asIdlTypeOption != nil
}
func (env *IdlType) IsIdlTypeDefined() bool {
return env.asIdlTypeDefined != nil
return env.AsIdlTypeDefined != nil
}
func (env *IdlType) IsArray() bool {
return env.AsIdlTypeArray != nil
Expand All @@ -340,7 +340,7 @@ func (env *IdlType) GetIdlTypeOption() *IdlTypeOption {
return env.asIdlTypeOption
}
func (env *IdlType) GetIdlTypeDefined() *IdlTypeDefined {
return env.asIdlTypeDefined
return env.AsIdlTypeDefined
}
func (env *IdlType) GetArray() *IdlTypeArray {
return env.AsIdlTypeArray
Expand Down
11 changes: 11 additions & 0 deletions pkg/solana/codec/solana.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,17 @@ func asDefined(parentTypeName string, definedName *IdlTypeDefined, refs *codecRe
}

func asArray(parentTypeName string, idlArray *IdlTypeArray, refs *codecRefs) (commonencodings.TypeCodec, error) {
if idlArray == nil {
return nil, fmt.Errorf("%w: field type cannot be nil", commontypes.ErrInvalidConfig)
}

// better to implement bytes to big int codec modifiers, but this works fine
if idlArray.Num == 28 && idlArray.Thing.AsString == IdlTypeU8 {
// nolint:gosec
// G115: integer overflow conversion int -&gt; uint
return binary.BigEndian().BigInt(uint(idlArray.Num), false)
}

codec, err := processFieldType(parentTypeName, idlArray.Thing, refs)
if err != nil {
return nil, err
Expand Down

0 comments on commit a95bf7c

Please sign in to comment.