Skip to content

Commit

Permalink
Nonevm 1065/events for standard and batch reads (#1044)
Browse files Browse the repository at this point in the history
* functional connection of events and batch/single latest value calls

* fix read filter
  • Loading branch information
EasterTheBunny authored Feb 18, 2025
1 parent e726352 commit adc225c
Show file tree
Hide file tree
Showing 8 changed files with 435 additions and 25 deletions.
73 changes: 56 additions & 17 deletions pkg/solana/chainreader/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,43 @@ func doMultiRead(ctx context.Context, client MultipleAccountGetter, bdRegistry *
}

func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bdRegistry *bindingsRegistry, batch []call) ([]batchResultWithErr, error) {
// Create the list of public keys to fetch
keys := make([]solana.PublicKey, len(batch))
results := make([]batchResultWithErr, len(batch))

// create the list of public keys to fetch
keys := []solana.PublicKey{}

// map batch call index to key index (some calls are event reads and will be handled by a different binding)
dataMap := make(map[int]int)

for idx, batchCall := range batch {
rBinding, err := bdRegistry.GetReader(batchCall.Namespace, batchCall.ReadName)
if err != nil {
return nil, fmt.Errorf("%w: read binding not found for contract: %q read: %q: %w", types.ErrInvalidConfig, batchCall.Namespace, batchCall.ReadName, err)
}

keys[idx], err = rBinding.GetAddress(ctx, batchCall.Params)
key, err := rBinding.GetAddress(ctx, batchCall.Params)
if err != nil {
return nil, fmt.Errorf("failed to get address for contract: %q read: %q: %w", batchCall.Namespace, batchCall.ReadName, err)
}

eBinding, ok := rBinding.(eventBinding)
if ok {
results[idx] = batchResultWithErr{
address: key.String(),
namespace: batchCall.Namespace,
readName: batchCall.ReadName,
returnVal: batchCall.ReturnVal,
}

results[idx].err = eBinding.GetLatestValue(ctx, batchCall.Params, results[idx].returnVal)

continue
}

// map the idx to the key idx
dataMap[idx] = len(keys)

keys = append(keys, key)
}

// Fetch the account data
Expand All @@ -90,18 +115,21 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bdRegi
return nil, err
}

results := make([]batchResultWithErr, len(batch))

// decode batch call results
for idx, batchCall := range batch {
dataIdx, ok := dataMap[idx]
if !ok {
return nil, fmt.Errorf("%w: unexpected data index state", types.ErrInternal)
}

results[idx] = batchResultWithErr{
address: keys[idx].String(),
address: keys[dataIdx].String(),
namespace: batchCall.Namespace,
readName: batchCall.ReadName,
returnVal: batchCall.ReturnVal,
}

if data[idx] == nil || len(data[idx]) == 0 {
if data[dataIdx] == nil || len(data[dataIdx]) == 0 {
results[idx].err = ErrMissingAccountData

continue
Expand All @@ -114,30 +142,35 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bdRegi
continue
}

results[idx].err = errors.Join(
decodeReturnVal(ctx, rBinding, data[idx], results[idx].returnVal),
results[idx].err)
results[idx].err = asValueDotValue(
ctx,
rBinding,
results[dataIdx].returnVal,
wrapDecodeValuer(rBinding, data[dataIdx]),
)
}

return results, nil
}

// decodeReturnVal checks if returnVal is a *values.Value vs. a normal struct pointer, and decodes accordingly.
func decodeReturnVal(ctx context.Context, binding readBinding, raw []byte, returnVal any) error {
// If we are not dealing with a `*values.Value`, just decode directly.
// asValueDotValue checks if returnVal is a *values.Value vs. a normal struct pointer, and decodes accordingly.
func asValueDotValue(
ctx context.Context,
binding readBinding,
returnVal any,
op func(context.Context, any) error,
) error {
ptrToValue, isValue := returnVal.(*values.Value)
if !isValue {
return binding.Decode(ctx, raw, returnVal)
return op(ctx, returnVal)
}

// Otherwise, we need to create an intermediate type, decode into it,
// wrap it, and set it back into *values.Value
contractType, err := binding.CreateType(false)
if err != nil {
return err
}

if err = binding.Decode(ctx, raw, contractType); err != nil {
if err = op(ctx, contractType); err != nil {
return err
}

Expand All @@ -150,3 +183,9 @@ func decodeReturnVal(ctx context.Context, binding readBinding, raw []byte, retur

return nil
}

func wrapDecodeValuer(binding readBinding, data []byte) func(context.Context, any) error {
return func(ctx context.Context, returnVal any) error {
return binding.Decode(ctx, data, returnVal)
}
}
4 changes: 4 additions & 0 deletions pkg/solana/chainreader/bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ type readBinding interface {
Unregister(context.Context) error
CreateType(bool) (any, error)
Decode(context.Context, []byte, any) error
}

type eventBinding interface {
GetLatestValue(_ context.Context, params, returnVal any) error
QueryKey(context.Context, query.KeyFilter, query.LimitAndSort, any) ([]types.Sequence, error)
}

Expand Down
25 changes: 18 additions & 7 deletions pkg/solana/chainreader/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,17 +290,22 @@ func (s *ContractReaderService) QueryKey(ctx context.Context, contract types.Bou
return nil, err
}

eBinding, ok := binding.(eventBinding)
if !ok {
return nil, fmt.Errorf("%w: invalid binding type for %s", types.ErrInvalidType, contract.Name)
}

_, isValuePtr := sequenceDataType.(*values.Value)
if !isValuePtr {
return binding.QueryKey(ctx, filter, limitAndSort, sequenceDataType)
return eBinding.QueryKey(ctx, filter, limitAndSort, sequenceDataType)
}

dataTypeFromReadIdentifier, err := s.CreateContractType(contract.ReadIdentifier(filter.Key), false)
if err != nil {
return nil, err
}

sequence, err := binding.QueryKey(ctx, filter, limitAndSort, dataTypeFromReadIdentifier)
sequence, err := eBinding.QueryKey(ctx, filter, limitAndSort, dataTypeFromReadIdentifier)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -463,7 +468,7 @@ func (s *ContractReaderService) addAccountRead(namespace string, genericName str
isPDA = true
}

if err := s.addAccountReadToCodec(s.parsed, namespace, genericName, idl, inputIDLDef, outputIDLDef, readDefinition); err != nil {
if err := s.addReadToCodec(s.parsed, namespace, genericName, idl, inputIDLDef, outputIDLDef, readDefinition); err != nil {
return err
}

Expand All @@ -473,7 +478,7 @@ func (s *ContractReaderService) addAccountRead(namespace string, genericName str
return nil
}

func (s *ContractReaderService) addAccountReadToCodec(parsed *codec.ParsedTypes, namespace string, genericName string, idl codec.IDL, inputIDLDef interface{}, outputIDLDef codec.IdlTypeDef, readDefinition config.ReadDefinition) error {
func (s *ContractReaderService) addReadToCodec(parsed *codec.ParsedTypes, namespace string, genericName string, idl codec.IDL, inputIDLDef interface{}, outputIDLDef interface{}, readDefinition config.ReadDefinition) error {
if err := s.addCodecDef(parsed, true, namespace, genericName, idl, inputIDLDef, readDefinition.InputModifications); err != nil {
return err
}
Expand Down Expand Up @@ -510,7 +515,7 @@ func (s *ContractReaderService) addMultiAccountReadToCodec(namespace string, rea
// multi read defs don't have a generic name as they are accessed from the parent read which does have a generic name.
// generic name is used everywhere, so add a prefix to avoid potential collision with generic names of other reads.
genericName := "multiread-" + mr.ChainSpecificName
if err = s.addAccountReadToCodec(s.parsed, namespace, genericName, idl, inputIDLDef, accountIDLDef, mr); err != nil {
if err = s.addReadToCodec(s.parsed, namespace, genericName, idl, inputIDLDef, accountIDLDef, mr); err != nil {
return nil, fmt.Errorf("failed to add read to multi read %q: %w", mr.ChainSpecificName, err)
}

Expand Down Expand Up @@ -554,7 +559,7 @@ func (s *ContractReaderService) addAddressResponseHardCoderModifier(namespace st

readDef := rb.GetReadDefinition()
readDef.OutputModifications = append(readDef.OutputModifications, hardCoder)
if err = s.addAccountReadToCodec(parsed, namespace, rb.GetGenericName(), idl, inputIDlType, outputIDLType, readDef); err != nil {
if err = s.addReadToCodec(parsed, namespace, rb.GetGenericName(), idl, inputIDlType, outputIDLType, readDef); err != nil {
return fmt.Errorf("failed to set codec with address response hardcoder for read: %q: %w", rb.GetGenericName(), err)
}

Expand Down Expand Up @@ -604,6 +609,12 @@ func (s *ContractReaderService) addEventRead(
applyIndexedFieldTuple(subkeys, conf.IndexedField2, 2)
applyIndexedFieldTuple(subkeys, conf.IndexedField3, 3)

eventDef := codec.EventIDLTypes{Event: eventIdl, Types: idl.Types}

if err := s.addReadToCodec(s.parsed, namespace, genericName, idl, eventIdl, eventIdl, readDefinition); err != nil {
return err
}

reader := newEventReadBinding(
namespace,
genericName,
Expand All @@ -614,7 +625,7 @@ func (s *ContractReaderService) addEventRead(
)

s.shouldStartLP = true
reader.SetFilter(toLPFilter(readDefinition.ChainSpecificName, pf, subkeys.subKeys[:], codec.EventIDLTypes{Event: eventIdl, Types: idl.Types}))
reader.SetFilter(toLPFilter(readDefinition.ChainSpecificName, pf, subkeys.subKeys[:], eventDef))

s.bdRegistry.AddReader(namespace, genericName, reader)

Expand Down
140 changes: 140 additions & 0 deletions pkg/solana/chainreader/event_read_binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ func (b *eventReadBinding) Unregister(ctx context.Context) error {
return b.filter.Unregister(ctx, b.reader)
}

// GetAddress for events returns a static address. Since solana contracts emit events, and not accounts, PDAs are not
// valid for events.
func (b *eventReadBinding) GetAddress(_ context.Context, _ any) (solana.PublicKey, error) {
b.mu.RLock()
defer b.mu.RUnlock()
Expand Down Expand Up @@ -188,6 +190,52 @@ func (b *eventReadBinding) Decode(ctx context.Context, bts []byte, outVal any) e
return b.codec.Decode(ctx, bts, outVal, itemType)
}

func (b *eventReadBinding) GetLatestValue(ctx context.Context, params, returnVal any) error {
itemType := codec.WrapItemType(true, b.namespace, b.genericName)

pubKey, err := b.GetAddress(ctx, nil)
if err != nil {
return err
}

offChain, err := b.normalizeParams(params, itemType)
if err != nil {
return err
}

subkeyFilters, err := b.extractFilterSubkeys(offChain)
if err != nil {
return err
}

allFilters := []query.Expression{
logpoller.NewAddressFilter(pubKey),
logpoller.NewEventSigFilter(b.eventSig[:]),
}

if len(subkeyFilters) > 0 {
allFilters = append(allFilters, query.And(subkeyFilters...))
}

limiter := query.NewLimitAndSort(query.CountLimit(1), query.NewSortBySequence(query.Desc))

filter, err := logpoller.Where(allFilters...)
if err != nil {
return err
}

logs, err := b.reader.FilteredLogs(ctx, filter, limiter, b.namespace+"-"+pubKey.String()+"-"+b.genericName)
if err != nil {
return err
}

if len(logs) == 0 {
return fmt.Errorf("%w: no events found", types.ErrNotFound)
}

return asValueDotValue(ctx, b, returnVal, b.wrapDecoderForValuer(&logs[0]))
}

func (b *eventReadBinding) QueryKey(
ctx context.Context,
filter query.KeyFilter,
Expand Down Expand Up @@ -228,6 +276,59 @@ func (b *eventReadBinding) QueryKey(
return sequences, nil
}

func (b *eventReadBinding) normalizeParams(value any, itemType string) (any, error) {
offChain, err := b.codec.CreateType(itemType, true)
if err != nil {
return nil, fmt.Errorf("%w: failed to create type: %w", types.ErrInvalidType, err)
}

// params can be a singular primitive value, a map of values, or a struct
// in the case that the input params are presented as a map of values, apply the values to the off-chain type
// with solana hooks
if err = codec.MapstructureDecode(value, offChain); err != nil {
return nil, fmt.Errorf("%w: failed to decode offChain value: %s", types.ErrInternal, err.Error())
}

return offChain, nil
}

func (b *eventReadBinding) extractFilterSubkeys(offChainParams any) ([]query.Expression, error) {
var expressions []query.Expression

for offChainKey, idx := range b.indexedSubKeys.lookup {
itemType := codec.WrapItemType(true, b.namespace, b.genericName+"."+offChainKey)

fieldVal, err := valueForPath(reflect.ValueOf(offChainParams), offChainKey)
if err != nil {
return nil, fmt.Errorf("%w: no value for path %s", types.ErrInternal, b.genericName+"."+offChainKey)
}

onChainValue, err := b.modifier.TransformToOnChain(fieldVal, itemType)
if err != nil {
return nil, fmt.Errorf("%w: failed to apply on-chain transformation for key %s", types.ErrInternal, offChainKey)
}

valOf := reflect.ValueOf(onChainValue)

// check that onChainValue is not zero value for type
if valOf.IsZero() {
continue
}

expression, err := logpoller.NewEventBySubKeyFilter(
idx,
[]primitives.ValueComparator{{Value: reflect.Indirect(valOf).Interface(), Operator: primitives.Eq}},
)
if err != nil {
return nil, err
}

expressions = append(expressions, expression)
}

return expressions, nil
}

func (b *eventReadBinding) remapPrimitive(expression query.Expression) (query.Expression, error) {
var (
comp query.Expression
Expand Down Expand Up @@ -344,6 +445,12 @@ func (b *eventReadBinding) registered() bool {
return b.registerCalled
}

func (b *eventReadBinding) wrapDecoderForValuer(log *logpoller.Log) func(context.Context, any) error {
return func(ctx context.Context, returnVal any) error {
return b.decodeLog(ctx, log, returnVal)
}
}

type remapHelper struct {
primitive func(query.Expression) (query.Expression, error)
}
Expand Down Expand Up @@ -406,3 +513,36 @@ func (k *indexedSubkeys) indexForKey(key string) (uint64, bool) {

return idx, ok
}

func valueForPath(from reflect.Value, itemType string) (any, error) {
if itemType == "" {
return from.Interface(), nil
}

switch from.Kind() {
case reflect.Pointer:
elem, err := valueForPath(from.Elem(), itemType)
if err != nil {
return nil, err
}

return elem, nil
case reflect.Array, reflect.Slice:
return nil, fmt.Errorf("%w: cannot extract a field from an array or slice", types.ErrInvalidType)
case reflect.Struct:
head, tail := commoncodec.ItemTyper(itemType).Next()

field := from.FieldByName(head)
if !field.IsValid() {
return nil, fmt.Errorf("%w: field not found for path %s and itemType %s", types.ErrInvalidType, from, itemType)
}

if tail == "" {
return field.Interface(), nil
}

return valueForPath(field, tail)
default:
return nil, fmt.Errorf("%w: cannot extract a field from kind %s", types.ErrInvalidType, from.Kind())
}
}
Loading

0 comments on commit adc225c

Please sign in to comment.