From e0d05b398683c63609c7801bc2580814b7193626 Mon Sep 17 00:00:00 2001 From: ilija Date: Tue, 14 Jan 2025 20:26:34 +0100 Subject: [PATCH] Make events discriminator testable in codec interface tests --- pkg/solana/codec/codec_entry.go | 4 ++-- pkg/solana/codec/codec_test.go | 12 ++++++----- pkg/solana/codec/testutils/types.go | 33 +++++++++++++++++++++-------- 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/pkg/solana/codec/codec_entry.go b/pkg/solana/codec/codec_entry.go index 44a6b1d3b..be28d50e8 100644 --- a/pkg/solana/codec/codec_entry.go +++ b/pkg/solana/codec/codec_entry.go @@ -46,7 +46,7 @@ func NewAccountEntry(offchainName string, idlTypes AccountIDLTypes, includeDiscr var discriminator *Discriminator if includeDiscriminator { - discriminator = NewDiscriminator(offchainName, true) + discriminator = NewDiscriminator(idlTypes.Account.Name, true) } return newEntry( @@ -92,7 +92,7 @@ func NewEventArgsEntry(offChainName string, idlTypes EventIDLTypes, includeDiscr var discriminator *Discriminator if includeDiscriminator { - discriminator = NewDiscriminator(offChainName, false) + discriminator = NewDiscriminator(idlTypes.Event.Name, false) } return newEntry( diff --git a/pkg/solana/codec/codec_test.go b/pkg/solana/codec/codec_test.go index 50b0366ac..0b421695f 100644 --- a/pkg/solana/codec/codec_test.go +++ b/pkg/solana/codec/codec_test.go @@ -76,17 +76,19 @@ func (it *codecInterfaceTester) GetAccountString(i int) string { } func (it *codecInterfaceTester) EncodeFields(t *testing.T, request *EncodeRequest) []byte { - if request.TestOn == TestItemType || request.TestOn == testutils.TestEventItem { - return encodeFieldsOnItem(t, request) + if request.TestOn == TestItemType { + return encodeFieldsOnItem(t, request, true) + } else if request.TestOn == testutils.TestEventItem { + return encodeFieldsOnItem(t, request, false) } return encodeFieldsOnSliceOrArray(t, request) } -func encodeFieldsOnItem(t *testing.T, request *EncodeRequest) ocr2types.Report { +func encodeFieldsOnItem(t *testing.T, request *EncodeRequest, isAccount bool) ocr2types.Report { buf := new(bytes.Buffer) - // The underlying TestItemAsAccount adds a discriminator by default while being Borsh encoded. - if err := testutils.EncodeRequestToTestItemAsAccount(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil { + // The underlying TestItem adds a discriminator by default while being Borsh encoded. + if err := testutils.EncodeRequestToTestItemAsAccount(request.TestStructs[0], isAccount).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil { require.NoError(t, err) } return buf.Bytes() diff --git a/pkg/solana/codec/testutils/types.go b/pkg/solana/codec/testutils/types.go index 3c52adb0f..41d09f7f3 100644 --- a/pkg/solana/codec/testutils/types.go +++ b/pkg/solana/codec/testutils/types.go @@ -158,7 +158,8 @@ var CodecDefs = map[string]CodecDef{ }, } -type TestItemAsAccount struct { +type TestItem struct { + IsAccount bool Field int32 OracleID uint8 OracleIDs [32]uint8 @@ -170,14 +171,20 @@ type TestItemAsAccount struct { NestedStaticStruct NestedStatic } -var TestItemDiscriminator = [8]byte{148, 105, 105, 155, 26, 167, 212, 149} +var TestItemAsAccountDiscriminator = [8]byte{148, 105, 105, 155, 26, 167, 212, 149} -func (obj TestItemAsAccount) MarshalWithEncoder(encoder *agbinary.Encoder) (err error) { - // Write account discriminator: - err = encoder.WriteBytes(TestItemDiscriminator[:], false) +var TestItemAsEventDiscriminator = [8]byte{119, 183, 160, 247, 84, 104, 222, 251} + +func (obj TestItem) MarshalWithEncoder(encoder *agbinary.Encoder) (err error) { + if obj.IsAccount { + err = encoder.WriteBytes(TestItemAsAccountDiscriminator[:], false) + } else { + err = encoder.WriteBytes(TestItemAsEventDiscriminator[:], false) + } if err != nil { return err } + // Serialize `Field` param: err = encoder.Encode(obj.Field) if err != nil { @@ -226,19 +233,26 @@ func (obj TestItemAsAccount) MarshalWithEncoder(encoder *agbinary.Encoder) (err return nil } -func (obj *TestItemAsAccount) UnmarshalWithDecoder(decoder *agbinary.Decoder) error { +func (obj *TestItem) UnmarshalWithDecoder(decoder *agbinary.Decoder) error { // Read and check account discriminator: { discriminator, err := decoder.ReadTypeID() if err != nil { return err } - if !discriminator.Equal(TestItemDiscriminator[:]) { + if obj.IsAccount && !discriminator.Equal(TestItemAsAccountDiscriminator[:]) { return fmt.Errorf( "wrong discriminator: wanted %s, got %s", "[148 105 105 155 26 167 212 149]", fmt.Sprint(discriminator[:])) } + + if !obj.IsAccount && !discriminator.Equal(TestItemAsEventDiscriminator[:]) { + return fmt.Errorf( + "wrong discriminator: wanted %s, got %s", + "[119, 183, 160, 247, 84, 104, 222, 251]", + fmt.Sprint(discriminator[:])) + } } // Deserialize `Field`: err := decoder.Decode(&obj.Field) @@ -563,8 +577,9 @@ func (obj *NestedStatic) UnmarshalWithDecoder(decoder *agbinary.Decoder) (err er return nil } -func EncodeRequestToTestItemAsAccount(testStruct interfacetests.TestStruct) TestItemAsAccount { - return TestItemAsAccount{ +func EncodeRequestToTestItemAsAccount(testStruct interfacetests.TestStruct, isAccount bool) TestItem { + return TestItem{ + IsAccount: isAccount, Field: *testStruct.Field, OracleID: uint8(testStruct.OracleID), OracleIDs: getOracleIDs(testStruct),