diff --git a/crypto/bls/aggregation.go b/crypto/bls/aggregation.go index 346e704c..d8f96d70 100644 --- a/crypto/bls/aggregation.go +++ b/crypto/bls/aggregation.go @@ -134,7 +134,7 @@ func (a *StdSignatureAggregator) AggregateSignatures( aggSigs[ind] = &Signature{sig.Deserialize(sig.Serialize())} aggPubKeys[ind] = op.PubkeyG2.Deserialize(op.PubkeyG2.Serialize()) } else { - aggSigs[ind].Add(sig.G1Point) + aggSigs[ind].Add(sig) aggPubKeys[ind].Add(op.PubkeyG2) } @@ -203,7 +203,7 @@ func (a *StdSignatureAggregator) AggregateSignatures( // Aggregate the aggregated signatures. We reuse the first aggregated signature as the accumulator for i := 1; i < len(aggSigs); i++ { - aggSigs[0].Add(aggSigs[i].G1Point) + aggSigs[0].Add(aggSigs[i]) } // Aggregate the aggregated public keys. We reuse the first aggregated public key as the accumulator diff --git a/crypto/bls/attestation.go b/crypto/bls/attestation.go index d0ee11ef..e86b9426 100644 --- a/crypto/bls/attestation.go +++ b/crypto/bls/attestation.go @@ -47,14 +47,20 @@ func NewG1Point(x, y *big.Int) *G1Point { } } +func NewZeroG1Point() *G1Point { + return NewG1Point(big.NewInt(0), big.NewInt(0)) +} + // Add another G1 point to this one -func (p *G1Point) Add(p2 *G1Point) { +func (p *G1Point) Add(p2 *G1Point) *G1Point { p.G1Affine.Add(p.G1Affine, p2.G1Affine) + return p } // Sub another G1 point from this one -func (p *G1Point) Sub(p2 *G1Point) { +func (p *G1Point) Sub(p2 *G1Point) *G1Point { p.G1Affine.Sub(p.G1Affine, p2.G1Affine) + return p } // VerifyEquivalence verifies G1Point is equivalent the G2Point @@ -90,14 +96,20 @@ func NewG2Point(X, Y [2]*big.Int) *G2Point { } } +func NewZeroG2Point() *G2Point { + return NewG2Point([2]*big.Int{big.NewInt(0), big.NewInt(0)}, [2]*big.Int{big.NewInt(0), big.NewInt(0)}) +} + // Add another G2 point to this one -func (p *G2Point) Add(p2 *G2Point) { +func (p *G2Point) Add(p2 *G2Point) *G2Point { p.G2Affine.Add(p.G2Affine, p2.G2Affine) + return p } // Sub another G2 point from this one -func (p *G2Point) Sub(p2 *G2Point) { +func (p *G2Point) Sub(p2 *G2Point) *G2Point { p.G2Affine.Sub(p.G2Affine, p2.G2Affine) + return p } func (p *G2Point) Serialize() []byte { @@ -112,6 +124,15 @@ type Signature struct { *G1Point `json:"g1_point"` } +func NewZeroSignature() *Signature { + return &Signature{NewZeroG1Point()} +} + +func (s *Signature) Add(otherS *Signature) *Signature { + s.G1Point.Add(otherS.G1Point) + return s +} + // Verify a message against a public key func (s *Signature) Verify(pubkey *G2Point, message [32]byte) (bool, error) { ok, err := bn254utils.VerifySig(s.G1Affine, pubkey.G2Affine, message) diff --git a/services/avsregistry/avsregistry.go b/services/avsregistry/avsregistry.go index 9484c1c1..cbf67fb3 100644 --- a/services/avsregistry/avsregistry.go +++ b/services/avsregistry/avsregistry.go @@ -3,7 +3,7 @@ package avsregistry import ( "context" - "github.com/Layr-Labs/eigensdk-go/crypto/bls" + blsoperatorstateretrievar "github.com/Layr-Labs/eigensdk-go/contracts/bindings/BLSOperatorStateRetriever" "github.com/Layr-Labs/eigensdk-go/types" ) @@ -12,6 +12,11 @@ import ( type AvsRegistryService interface { // GetOperatorsAvsState returns the state of an avs wrt to a list of quorums at a certain block. // The state includes the operatorId, pubkey, and staking amount in each quorum. - GetOperatorsAvsStateAtBlock(ctx context.Context, quorumNumbers []types.QuorumNum, blockNumber types.BlockNumber) (map[types.OperatorId]types.OperatorAvsState, map[types.QuorumNum]*bls.G1Point, error) - GetOperatorPubkeys(ctx context.Context, operatorId types.OperatorId) (types.OperatorPubkeys, error) + GetOperatorsAvsStateAtBlock(ctx context.Context, quorumNumbers []types.QuorumNum, blockNumber types.BlockNum) (map[types.OperatorId]types.OperatorAvsState, error) + // GetQuorumsAvsStateAtBlock returns the aggregated data for a list of quorums at a certain block. + // The aggregated data includes the aggregated pubkey and total stake in each quorum. + // This information is derivable from the Operators Avs State (returned from GetOperatorsAvsStateAtBlock), but this function is provided for convenience. + GetQuorumsAvsStateAtBlock(ctx context.Context, quorumNumbers []types.QuorumNum, blockNumber types.BlockNum) (map[types.QuorumNum]types.QuorumAvsState, error) + // GetCheckSignaturesIndices returns the registry indices of the nonsigner operators specified by nonSignerOperatorIds who were registered at referenceBlockNumber. + GetCheckSignaturesIndices(ctx context.Context, referenceBlockNumber types.BlockNum, quorumNumbers []types.QuorumNum, nonSignerOperatorIds []types.OperatorId) (blsoperatorstateretrievar.BLSOperatorStateRetrieverCheckSignaturesIndices, error) } diff --git a/services/avsregistry/avsregistry_chaincaller.go b/services/avsregistry/avsregistry_chaincaller.go index 298d96fd..94374e4a 100644 --- a/services/avsregistry/avsregistry_chaincaller.go +++ b/services/avsregistry/avsregistry_chaincaller.go @@ -12,9 +12,11 @@ import ( "github.com/Layr-Labs/eigensdk-go/types" ) +// AvsRegistryServiceChainCaller is a wrapper around AvsRegistryReader that transforms the data into +// nicer golang types that are easier to work with type AvsRegistryServiceChainCaller struct { + avsregistry.AvsRegistryReader elReader elcontracts.ELReader - avsRegistryReader avsregistry.AvsRegistryReader pubkeyCompendiumService pcservice.PubkeyCompendiumService logger logging.Logger } @@ -24,19 +26,19 @@ var _ AvsRegistryService = (*AvsRegistryServiceChainCaller)(nil) func NewAvsRegistryServiceChainCaller(avsRegistryReader avsregistry.AvsRegistryReader, elReader elcontracts.ELReader, pubkeyCompendiumService pcservice.PubkeyCompendiumService, logger logging.Logger) *AvsRegistryServiceChainCaller { return &AvsRegistryServiceChainCaller{ elReader: elReader, - avsRegistryReader: avsRegistryReader, + AvsRegistryReader: avsRegistryReader, pubkeyCompendiumService: pubkeyCompendiumService, logger: logger, } } -func (ar *AvsRegistryServiceChainCaller) GetOperatorsAvsStateAtBlock(ctx context.Context, quorumNumbers []types.QuorumNum, blockNumber types.BlockNumber) (map[types.OperatorId]types.OperatorAvsState, map[types.QuorumNum]*bls.G1Point, error) { +func (ar *AvsRegistryServiceChainCaller) GetOperatorsAvsStateAtBlock(ctx context.Context, quorumNumbers []types.QuorumNum, blockNumber types.BlockNum) (map[types.OperatorId]types.OperatorAvsState, error) { operatorsAvsState := make(map[types.OperatorId]types.OperatorAvsState) // Get operator state for each quorum by querying BLSOperatorStateRetriever (this call is why this service implementation is called ChainCaller) - operatorsStakesInQuorums, err := ar.avsRegistryReader.GetOperatorsStakeInQuorumsAtBlock(ctx, quorumNumbers, blockNumber) + operatorsStakesInQuorums, err := ar.AvsRegistryReader.GetOperatorsStakeInQuorumsAtBlock(ctx, quorumNumbers, blockNumber) if err != nil { ar.logger.Error("Failed to get operator state", "err", err, "service", "AvsRegistryServiceChainCaller") - return nil, nil, err + return nil, err } numquorums := len(quorumNumbers) if len(operatorsStakesInQuorums) != numquorums { @@ -45,10 +47,10 @@ func (ar *AvsRegistryServiceChainCaller) GetOperatorsAvsStateAtBlock(ctx context for quorumIdx, quorumNum := range quorumNumbers { for _, operator := range operatorsStakesInQuorums[quorumIdx] { - pubkeys, err := ar.GetOperatorPubkeys(ctx, operator.OperatorId) + pubkeys, err := ar.getOperatorPubkeys(ctx, operator.OperatorId) if err != nil { ar.logger.Error("Failed find pubkeys for operator while building operatorsAvsState", "err", err, "service", "AvsRegistryServiceChainCaller") - return nil, nil, err + return nil, err } if operatorAvsState, ok := operatorsAvsState[operator.OperatorId]; ok { operatorAvsState.StakePerQuorum[quorumNum] = operator.Stake @@ -66,28 +68,44 @@ func (ar *AvsRegistryServiceChainCaller) GetOperatorsAvsStateAtBlock(ctx context } } - aggG1PubkeyPerQuorum := make(map[types.QuorumNum]*bls.G1Point) + return operatorsAvsState, nil +} + +func (ar *AvsRegistryServiceChainCaller) GetQuorumsAvsStateAtBlock(ctx context.Context, quorumNumbers []types.QuorumNum, blockNumber types.BlockNum) (map[types.QuorumNum]types.QuorumAvsState, error) { + operatorsAvsState, err := ar.GetOperatorsAvsStateAtBlock(ctx, quorumNumbers, blockNumber) + if err != nil { + ar.logger.Error("Failed to get operator state", "err", err, "service", "AvsRegistryServiceChainCaller") + return nil, err + } + quorumsAvsState := make(map[types.QuorumNum]types.QuorumAvsState) for _, quorumNum := range quorumNumbers { - aggG1Pubkey := bls.NewG1Point(big.NewInt(0), big.NewInt(0)) + aggPubkeyG1 := bls.NewG1Point(big.NewInt(0), big.NewInt(0)) + totalStake := big.NewInt(0) for _, operator := range operatorsAvsState { // only include operators that have a stake in this quorum - if operator.StakePerQuorum != nil { - aggG1Pubkey.Add(operator.Pubkeys.G1Pubkey) + if stake, ok := operator.StakePerQuorum[quorumNum]; ok { + aggPubkeyG1.Add(operator.Pubkeys.G1Pubkey) + totalStake.Add(totalStake, stake) } } - aggG1PubkeyPerQuorum[quorumNum] = aggG1Pubkey + quorumsAvsState[quorumNum] = types.QuorumAvsState{ + QuorumNumber: quorumNum, + AggPubkeyG1: aggPubkeyG1, + TotalStake: totalStake, + BlockNumber: blockNumber, + } } - return operatorsAvsState, aggG1PubkeyPerQuorum, nil + return quorumsAvsState, nil } -func (ar *AvsRegistryServiceChainCaller) GetOperatorPubkeys(ctx context.Context, operatorId types.OperatorId) (types.OperatorPubkeys, error) { - // TODO(samlaf): This is a temporary hack until we implement GetOperatorAddr on the BLSpubkeyregistry contract (shortly) - // we need operatorId -> operatorAddr so that we can query the pubkeyCompendiumService - // this inverse mapping (only operatorAddr->operatorId is stored in registryCoordinator) is not stored, - // but we know that the current implementation uses the hash of the G1 pubkey as the operatorId, - // and the pubkeycompendium contract stores the mapping G1pubkeyHash -> operatorAddr - // When the above PR is merged, we should change this to instead call GetOperatorAddressFromOperatorId on the avsRegistryReader - // and not hardcode the definition of the operatorId here +// getOperatorPubkeys is a temporary hack until we implement GetOperatorAddr on the BLSpubkeyregistry contract +// TODO(samlaf): we need operatorId -> operatorAddr so that we can query the pubkeyCompendiumService +// this inverse mapping (only operatorAddr->operatorId is stored in registryCoordinator) is not stored, +// but we know that the current implementation uses the hash of the G1 pubkey as the operatorId, +// and the pubkeycompendium contract stores the mapping G1pubkeyHash -> operatorAddr +// When the above PR is merged, we should change this to instead call GetOperatorAddressFromOperatorId on the avsRegistryReader +// and not hardcode the definition of the operatorId here +func (ar *AvsRegistryServiceChainCaller) getOperatorPubkeys(ctx context.Context, operatorId types.OperatorId) (types.OperatorPubkeys, error) { operatorAddr, err := ar.elReader.GetOperatorAddressFromPubkeyHash(ctx, operatorId) if err != nil { ar.logger.Error("Failed to get operator address from pubkey hash", "err", err, "service", "AvsRegistryServiceChainCaller") diff --git a/services/avsregistry/avsregistry_chaincaller_test.go b/services/avsregistry/avsregistry_chaincaller_test.go index 2a3b8854..f58b92af 100644 --- a/services/avsregistry/avsregistry_chaincaller_test.go +++ b/services/avsregistry/avsregistry_chaincaller_test.go @@ -22,7 +22,7 @@ type testOperator struct { pubkeys types.OperatorPubkeys } -func TestAvsRegistryServiceChainCaller_GetOperatorPubkeys(t *testing.T) { +func TestAvsRegistryServiceChainCaller_getOperatorPubkeys(t *testing.T) { logger := logging.NewNoopLogger() testOperator := testOperator{ operatorAddr: common.HexToAddress("0x1"), @@ -33,6 +33,7 @@ func TestAvsRegistryServiceChainCaller_GetOperatorPubkeys(t *testing.T) { }, } + // TODO(samlaf): add error test cases var tests = []struct { name string mocksInitializationFunc func(*chainiomocks.MockAvsRegistryReader, *chainiomocks.MockELReader, *servicemocks.MockPubkeyCompendiumService) @@ -67,7 +68,7 @@ func TestAvsRegistryServiceChainCaller_GetOperatorPubkeys(t *testing.T) { service := NewAvsRegistryServiceChainCaller(mockAvsRegistryReader, mockElReader, mockPubkeyCompendium, logger) // Call the GetOperatorPubkeys method with the test operator address - gotOperatorPubkeys, gotErr := service.GetOperatorPubkeys(context.Background(), tt.queryOperatorId) + gotOperatorPubkeys, gotErr := service.getOperatorPubkeys(context.Background(), tt.queryOperatorId) if tt.wantErr != gotErr { t.Fatalf("GetOperatorPubkeys returned wrong error. Got: %v, want: %v.", gotErr, tt.wantErr) } @@ -93,15 +94,14 @@ func TestAvsRegistryServiceChainCaller_GetOperatorsAvsState(t *testing.T) { name string mocksInitializationFunc func(*chainiomocks.MockAvsRegistryReader, *chainiomocks.MockELReader, *servicemocks.MockPubkeyCompendiumService) queryQuorumNumbers []types.QuorumNum - queryBlockNum types.BlockNumber + queryBlockNum types.BlockNum wantErr error wantOperatorsAvsStateDict map[types.OperatorId]types.OperatorAvsState - wantAggG1PubkeyPerQuorum map[types.QuorumNum]*bls.G1Point }{ { name: "should return operatorsAvsState", mocksInitializationFunc: func(mockAvsRegistryReader *chainiomocks.MockAvsRegistryReader, mockElReader *chainiomocks.MockELReader, mockPubkeyCompendiumService *servicemocks.MockPubkeyCompendiumService) { - mockAvsRegistryReader.EXPECT().GetOperatorsStakeInQuorumsAtBlock(context.Background(), []types.QuorumNum{1}, types.BlockNumber(1)).Return([][]blsoperatorstateretrievar.BLSOperatorStateRetrieverOperator{ + mockAvsRegistryReader.EXPECT().GetOperatorsStakeInQuorumsAtBlock(context.Background(), []types.QuorumNum{1}, types.BlockNum(1)).Return([][]blsoperatorstateretrievar.BLSOperatorStateRetrieverOperator{ { { OperatorId: testOperator.operatorId, @@ -123,9 +123,6 @@ func TestAvsRegistryServiceChainCaller_GetOperatorsAvsState(t *testing.T) { BlockNumber: 1, }, }, - wantAggG1PubkeyPerQuorum: map[types.QuorumNum]*bls.G1Point{ - 1: bls.NewG1Point(big.NewInt(1), big.NewInt(1)), - }, }, } @@ -144,15 +141,85 @@ func TestAvsRegistryServiceChainCaller_GetOperatorsAvsState(t *testing.T) { service := NewAvsRegistryServiceChainCaller(mockAvsRegistryReader, mockElReader, mockPubkeyCompendium, logger) // Call the GetOperatorPubkeys method with the test operator address - gotOperatorsAvsStateDict, aggG1PubkeyPerQuorum, gotErr := service.GetOperatorsAvsStateAtBlock(context.Background(), tt.queryQuorumNumbers, tt.queryBlockNum) + gotOperatorsAvsStateDict, gotErr := service.GetOperatorsAvsStateAtBlock(context.Background(), tt.queryQuorumNumbers, tt.queryBlockNum) if tt.wantErr != gotErr { t.Fatalf("GetOperatorsAvsState returned wrong error. Got: %v, want: %v.", gotErr, tt.wantErr) } if tt.wantErr == nil && !reflect.DeepEqual(tt.wantOperatorsAvsStateDict, gotOperatorsAvsStateDict) { t.Fatalf("GetOperatorsAvsState returned wrong operatorsAvsStateDict. Got: %v, want: %v.", gotOperatorsAvsStateDict, tt.wantOperatorsAvsStateDict) } - if tt.wantErr == nil && !reflect.DeepEqual(tt.wantAggG1PubkeyPerQuorum, aggG1PubkeyPerQuorum) { - t.Fatalf("GetOperatorsAvsState returned wrong aggG1PubkeyPerQuorum. Got: %v, want: %v.", aggG1PubkeyPerQuorum, tt.wantAggG1PubkeyPerQuorum) + }) + } +} + +func TestAvsRegistryServiceChainCaller_GetQuorumsAvsState(t *testing.T) { + logger := logging.NewNoopLogger() + testOperator := testOperator{ + operatorAddr: common.HexToAddress("0x1"), + operatorId: types.OperatorId{1}, + pubkeys: types.OperatorPubkeys{ + G1Pubkey: bls.NewG1Point(big.NewInt(1), big.NewInt(1)), + G2Pubkey: bls.NewG2Point([2]*big.Int{big.NewInt(1), big.NewInt(1)}, [2]*big.Int{big.NewInt(1), big.NewInt(1)}), + }, + } + + var tests = []struct { + name string + mocksInitializationFunc func(*chainiomocks.MockAvsRegistryReader, *chainiomocks.MockELReader, *servicemocks.MockPubkeyCompendiumService) + queryQuorumNumbers []types.QuorumNum + queryBlockNum types.BlockNum + wantErr error + wantQuorumsAvsStateDict map[types.QuorumNum]types.QuorumAvsState + }{ + { + name: "should return operatorsAvsState", + mocksInitializationFunc: func(mockAvsRegistryReader *chainiomocks.MockAvsRegistryReader, mockElReader *chainiomocks.MockELReader, mockPubkeyCompendiumService *servicemocks.MockPubkeyCompendiumService) { + mockAvsRegistryReader.EXPECT().GetOperatorsStakeInQuorumsAtBlock(context.Background(), []types.QuorumNum{1}, types.BlockNum(1)).Return([][]blsoperatorstateretrievar.BLSOperatorStateRetrieverOperator{ + { + { + OperatorId: testOperator.operatorId, + Stake: big.NewInt(123), + }, + }, + }, nil) + mockElReader.EXPECT().GetOperatorAddressFromPubkeyHash(context.Background(), testOperator.operatorId).Return(testOperator.operatorAddr, nil) + mockPubkeyCompendiumService.EXPECT().GetOperatorPubkeys(context.Background(), testOperator.operatorAddr).Return(testOperator.pubkeys, true) + }, + queryQuorumNumbers: []types.QuorumNum{1}, + queryBlockNum: 1, + wantErr: nil, + wantQuorumsAvsStateDict: map[types.QuorumNum]types.QuorumAvsState{ + 1: types.QuorumAvsState{ + QuorumNumber: types.QuorumNum(1), + TotalStake: big.NewInt(123), + AggPubkeyG1: bls.NewG1Point(big.NewInt(1), big.NewInt(1)), + BlockNumber: 1, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create mocks + mockCtrl := gomock.NewController(t) + mockAvsRegistryReader := chainiomocks.NewMockAvsRegistryReader(mockCtrl) + mockElReader := chainiomocks.NewMockELReader(mockCtrl) + mockPubkeyCompendium := servicemocks.NewMockPubkeyCompendiumService(mockCtrl) + + if tt.mocksInitializationFunc != nil { + tt.mocksInitializationFunc(mockAvsRegistryReader, mockElReader, mockPubkeyCompendium) + } + // Create a new instance of the avsregistry service + service := NewAvsRegistryServiceChainCaller(mockAvsRegistryReader, mockElReader, mockPubkeyCompendium, logger) + + // Call the GetOperatorPubkeys method with the test operator address + aggG1PubkeyPerQuorum, gotErr := service.GetQuorumsAvsStateAtBlock(context.Background(), tt.queryQuorumNumbers, tt.queryBlockNum) + if tt.wantErr != gotErr { + t.Fatalf("GetOperatorsAvsState returned wrong error. Got: %v, want: %v.", gotErr, tt.wantErr) + } + if tt.wantErr == nil && !reflect.DeepEqual(tt.wantQuorumsAvsStateDict, aggG1PubkeyPerQuorum) { + t.Fatalf("GetOperatorsAvsState returned wrong aggG1PubkeyPerQuorum. Got: %v, want: %v.", aggG1PubkeyPerQuorum, tt.wantQuorumsAvsStateDict) } }) } diff --git a/services/avsregistry/avsregistry_fake.go b/services/avsregistry/avsregistry_fake.go new file mode 100644 index 00000000..25c22a8a --- /dev/null +++ b/services/avsregistry/avsregistry_fake.go @@ -0,0 +1,72 @@ +package avsregistry + +import ( + "context" + "errors" + "math/big" + + blsoperatorstateretrievar "github.com/Layr-Labs/eigensdk-go/contracts/bindings/BLSOperatorStateRetriever" + "github.com/Layr-Labs/eigensdk-go/crypto/bls" + "github.com/Layr-Labs/eigensdk-go/types" +) + +type FakeAvsRegistryService struct { + operators map[types.BlockNum]map[types.OperatorId]types.OperatorAvsState +} + +func NewFakeAvsRegistryService(blockNum types.BlockNum, operators []types.TestOperator) *FakeAvsRegistryService { + fakeAvsRegistryService := &FakeAvsRegistryService{ + operators: map[types.BlockNum]map[types.OperatorId]types.OperatorAvsState{ + blockNum: {}, + }, + } + for _, operator := range operators { + fakeAvsRegistryService.operators[blockNum][operator.OperatorId] = types.OperatorAvsState{ + OperatorId: operator.OperatorId, + Pubkeys: types.OperatorPubkeys{G1Pubkey: operator.BlsKeypair.GetPubKeyG1(), G2Pubkey: operator.BlsKeypair.GetPubKeyG2()}, + StakePerQuorum: operator.StakePerQuorum, + BlockNumber: blockNum, + } + } + return fakeAvsRegistryService +} + +var _ AvsRegistryService = (*FakeAvsRegistryService)(nil) + +func (f *FakeAvsRegistryService) GetOperatorsAvsStateAtBlock(ctx context.Context, quorumNumbers []types.QuorumNum, blockNumber types.BlockNum) (map[types.OperatorId]types.OperatorAvsState, error) { + operatorsAvsState, ok := f.operators[blockNumber] + if !ok { + return nil, errors.New("block number not found") + } + return operatorsAvsState, nil +} + +func (f *FakeAvsRegistryService) GetQuorumsAvsStateAtBlock(ctx context.Context, quorumNumbers []types.QuorumNum, blockNumber types.BlockNum) (map[types.QuorumNum]types.QuorumAvsState, error) { + operatorsAvsState, ok := f.operators[blockNumber] + if !ok { + return nil, errors.New("block number not found") + } + quorumsAvsState := make(map[types.QuorumNum]types.QuorumAvsState) + for _, quorumNum := range quorumNumbers { + aggPubkeyG1 := bls.NewG1Point(big.NewInt(0), big.NewInt(0)) + totalStake := big.NewInt(0) + for _, operator := range operatorsAvsState { + // only include operators that have a stake in this quorum + if stake, ok := operator.StakePerQuorum[quorumNum]; ok { + aggPubkeyG1.Add(operator.Pubkeys.G1Pubkey) + totalStake.Add(totalStake, stake) + } + } + quorumsAvsState[quorumNum] = types.QuorumAvsState{ + QuorumNumber: quorumNum, + AggPubkeyG1: aggPubkeyG1, + TotalStake: totalStake, + BlockNumber: blockNumber, + } + } + return quorumsAvsState, nil +} + +func (f *FakeAvsRegistryService) GetCheckSignaturesIndices(ctx context.Context, referenceBlockNumber types.BlockNum, quorumNumbers []types.QuorumNum, nonSignerOperatorIds []types.OperatorId) (blsoperatorstateretrievar.BLSOperatorStateRetrieverCheckSignaturesIndices, error) { + return blsoperatorstateretrievar.BLSOperatorStateRetrieverCheckSignaturesIndices{}, nil +} diff --git a/services/bls_aggregation/blsagg.go b/services/bls_aggregation/blsagg.go new file mode 100644 index 00000000..a529d002 --- /dev/null +++ b/services/bls_aggregation/blsagg.go @@ -0,0 +1,387 @@ +package blsagg + +import ( + "context" + "errors" + "fmt" + "math/big" + "sync" + "time" + + "github.com/Layr-Labs/eigensdk-go/crypto/bls" + "github.com/Layr-Labs/eigensdk-go/logging" + "github.com/Layr-Labs/eigensdk-go/services/avsregistry" + "github.com/Layr-Labs/eigensdk-go/types" +) + +var ( + TaskAlreadyInitializedErrorFn = func(taskIndex types.TaskIndex) error { + return fmt.Errorf("task %d already initialized", taskIndex) + } + TaskExpiredError = fmt.Errorf("task expired") + TaskNotFoundErrorFn = func(taskIndex types.TaskIndex) error { + return fmt.Errorf("task %d not initialized or already completed", taskIndex) + } + OperatorNotPartOfTaskQuorumErrorFn = func(operatorId types.OperatorId, taskIndex types.TaskIndex) error { + return fmt.Errorf("operator %x not part of task %d's quorum", operatorId, taskIndex) + } + SignatureVerificationError = func(err error) error { + return fmt.Errorf("Failed to verify signature: %w", err) + } + IncorrectSignatureError = errors.New("Signature verification failed. Incorrect Signature.") +) + +// BlsAggregationServiceResponse is the response from the bls aggregation service +// it's half of the data needed to build the NonSignerStakesAndSignature struct +type BlsAggregationServiceResponse struct { + Err error + TaskIndex types.TaskIndex + TaskResponseDigest types.TaskResponseDigest + NonSignersPubkeysG1 []*bls.G1Point + QuorumApksG1 []*bls.G1Point + SignersApkG2 *bls.G2Point + SignersAggSigG1 *bls.Signature + NonSignerQuorumBitmapIndices []uint32 + QuorumApkIndices []uint32 + TotalStakeIndices []uint32 + NonSignerStakeIndices [][]uint32 +} + +// aggregatedOperators is meant to be used as a value in a map +// map[taskResponseDigest]aggregatedOperators +type aggregatedOperators struct { + // aggregate g2 pubkey of all operatos who signed on this taskResponseDigest + signersApkG2 *bls.G2Point + // aggregate signature of all operators who signed on this taskResponseDigest + signersAggSigG1 *bls.Signature + // aggregate stake of all operators who signed on this header for each quorum + signersTotalStakePerQuorum map[types.QuorumNum]*big.Int + // set of OperatorId of operators who signed on this header + signersOperatorIdsSet map[types.OperatorId]bool +} + +// BlsAggregationService is the interface provided to avs aggregator code for doing bls aggregation +// Currently its only implementation is the BlsAggregatorService, so see the comment there for more details +type BlsAggregationService interface { + // InitializeNewTask should be called whenever a new task is created. ProcessNewSignature will return an error + // if the task it is trying to process has not been initialized yet. + // quorumNumbers and quorumThresholdPercentages set the requirements for this task to be considered complete, which happens + // when a particular TaskResponseDigest (received via the a.taskChans[taskIndex]) has been signed by signers whose stake + // in each of the listed quorums adds up to at least quorumThresholdPercentages[i] of the total stake in that quorum + InitializeNewTask( + taskIndex types.TaskIndex, + taskCreatedBlock uint32, + quorumNumbers []types.QuorumNum, + quorumThresholdPercentages []types.QuorumThresholdPercentage, + timeToExpiry time.Duration, + ) error + + // ProcessNewSignature processes a new signature over a taskResponseDigest for a particular taskIndex by a particular operator + // It verifies that the signature is correct and returns an error if it is not, and then aggregates the signature and stake of + // the operator with all other signatures for the same taskIndex and taskResponseDigest pair. + // Note: This function currently only verifies signatures over the taskResponseDigest directly, so avs code needs to verify that the digest + // passed to ProcessNewSignature is indeed the digest of a valid taskResponse (that is, BlsAggregationService does not verify semantic integrity of the taskResponses) + ProcessNewSignature( + ctx context.Context, + taskIndex types.TaskIndex, + taskResponseDigest types.TaskResponseDigest, + blsSignature *bls.Signature, + operatorId bls.OperatorId, + ) error + + // GetResponseChannel returns the single channel that meant to be used as the response channel + // Any task that is completed (see the completion criterion in the comment above InitializeNewTask) + // will be sent on this channel along with all the necessary information to call BLSSignatureChecker onchain + GetResponseChannel() <-chan BlsAggregationServiceResponse +} + +// BlsAggregatorService is a service that performs BLS signature aggregation for an AVS' tasks +// Assumptions: +// 1. BlsAggregatorService only verifies digest signatures, so avs code needs to verify that the digest +// passed to ProcessNewSignature is indeed the digest of a valid taskResponse +// (see the comment above checkSignature for more details) +// 2. BlsAggregatorService is VERY generic and makes very few assumptions about the tasks structure or +// the time at which operators will send their signatures. It is mostly suitable for offchain computation +// oracle (a la truebit) type of AVS, where tasks are sent onchain by users sporadically, and where +// new tasks can start even before the previous ones have finished aggregation. +// AVSs like eigenDA that have a much more controlled task submission schedule and where new tasks are +// only submitted after the previous one's response has been aggregated and responded onchain, could have +// a much simpler AggregationService without all the complicated parallel goroutines. +type BlsAggregatorService struct { + // aggregatedResponsesC is the channel which all goroutines share to send their responses back to the + // main thread after they are done aggregating (either they reached the threshold, or timeout expired) + aggregatedResponsesC chan BlsAggregationServiceResponse + // signedTaskRespsCs are the channels to send the signed task responses to the goroutines processing them + // each new task is assigned a new goroutine and a new channel + signedTaskRespsCs map[types.TaskIndex]chan types.SignedTaskResponseDigest + // we add chans to taskChans from the main thread (InitializeNewTask) when we create new tasks, + // we read them in ProcessNewSignature from the main thread when we receive new signed tasks, + // and remove them from its respective goroutine when the task is completed or reached timeout + // we thus need a mutex to protect taskChans + taskChansMutex sync.RWMutex + avsRegistryService avsregistry.AvsRegistryService + logger logging.Logger +} + +var _ BlsAggregationService = (*BlsAggregatorService)(nil) + +func NewBlsAggregatorService(avsRegistryService avsregistry.AvsRegistryService, logger logging.Logger) *BlsAggregatorService { + return &BlsAggregatorService{ + aggregatedResponsesC: make(chan BlsAggregationServiceResponse), + signedTaskRespsCs: make(map[types.TaskIndex]chan types.SignedTaskResponseDigest), + taskChansMutex: sync.RWMutex{}, + avsRegistryService: avsRegistryService, + logger: logger, + } +} + +func (a *BlsAggregatorService) GetResponseChannel() <-chan BlsAggregationServiceResponse { + return a.aggregatedResponsesC +} + +// InitializeNewTask creates a new task goroutine meant to process new signed task responses for that task +// (that are sent via ProcessNewSignature) and adds a channel to a.taskChans to send the signed task responses to it +// quorumNumbers and quorumThresholdPercentages set the requirements for this task to be considered complete, which happens +// when a particular TaskResponseDigest (received via the a.taskChans[taskIndex]) has been signed by signers whose stake +// in each of the listed quorums adds up to at least quorumThresholdPercentages[i] of the total stake in that quorum +func (a *BlsAggregatorService) InitializeNewTask( + taskIndex types.TaskIndex, + taskCreatedBlock uint32, + quorumNumbers []types.QuorumNum, + quorumThresholdPercentages []types.QuorumThresholdPercentage, + timeToExpiry time.Duration, +) error { + if _, taskExists := a.signedTaskRespsCs[taskIndex]; taskExists { + return TaskAlreadyInitializedErrorFn(taskIndex) + } + signedTaskRespsC := make(chan types.SignedTaskResponseDigest) + a.taskChansMutex.Lock() + a.signedTaskRespsCs[taskIndex] = signedTaskRespsC + a.taskChansMutex.Unlock() + go a.singleTaskAggregatorGoroutineFunc(taskIndex, taskCreatedBlock, quorumNumbers, quorumThresholdPercentages, timeToExpiry, signedTaskRespsC) + return nil +} + +func (a *BlsAggregatorService) ProcessNewSignature( + ctx context.Context, + taskIndex types.TaskIndex, + taskResponseDigest types.TaskResponseDigest, + blsSignature *bls.Signature, + operatorId bls.OperatorId, +) error { + a.taskChansMutex.Lock() + taskC, taskInitialized := a.signedTaskRespsCs[taskIndex] + a.taskChansMutex.Unlock() + if !taskInitialized { + return TaskNotFoundErrorFn(taskIndex) + } + signatureVerificationErrorC := make(chan error) + // send the task to the goroutine processing this task + // and return the error (if any) returned by the signature verification routine + select { + // we need to send this as part of select because if the goroutine is processing another SignedTaskResponseDigest + // and cannot receive this one, we want the context to be able to cancel the request + case taskC <- types.SignedTaskResponseDigest{ + TaskResponseDigest: taskResponseDigest, + BlsSignature: blsSignature, + OperatorId: operatorId, + SignatureVerificationErrorC: signatureVerificationErrorC, + }: + // note that we need to wait synchronously here for this response because we want to + // send back an informative error message to the operator who sent his signature to the aggregator + return <-signatureVerificationErrorC + case <-ctx.Done(): + return ctx.Err() + } +} + +func (a *BlsAggregatorService) singleTaskAggregatorGoroutineFunc( + taskIndex types.TaskIndex, + taskCreatedBlock uint32, + quorumNumbers []types.QuorumNum, + quorumThresholdPercentages []types.QuorumThresholdPercentage, + timeToExpiry time.Duration, + signedTaskRespsC <-chan types.SignedTaskResponseDigest, +) { + defer a.closeTaskGoroutine(taskIndex) + + quorumThresholdPercentagesMap := make(map[types.QuorumNum]types.QuorumThresholdPercentage) + for i, quorumNumber := range quorumNumbers { + quorumThresholdPercentagesMap[quorumNumber] = quorumThresholdPercentages[i] + } + operatorsAvsStateDict, err := a.avsRegistryService.GetOperatorsAvsStateAtBlock(context.Background(), quorumNumbers, taskCreatedBlock) + if err != nil { + // TODO: how should we handle such an error? + a.logger.Fatal("Aggregator failed to get operators state from avs registry", "err", err) + } + quorumsAvsStakeDict, err := a.avsRegistryService.GetQuorumsAvsStateAtBlock(context.Background(), quorumNumbers, taskCreatedBlock) + if err != nil { + a.logger.Fatal("Aggregator failed to get quorums state from avs registry", "err", err) + } + totalStakePerQuorum := make(map[types.QuorumNum]*big.Int) + for quorumNum, quorumAvsState := range quorumsAvsStakeDict { + totalStakePerQuorum[quorumNum] = quorumAvsState.TotalStake + } + quorumApksG1 := []*bls.G1Point{} + for _, quorumNumber := range quorumNumbers { + quorumApksG1 = append(quorumApksG1, quorumsAvsStakeDict[quorumNumber].AggPubkeyG1) + } + + // TODO(samlaf): instead of taking a TTE, we should take a block as input + // and monitor the chain and only close the task goroutine when that block is reached + taskExpiredTimer := time.NewTimer(timeToExpiry) + + aggregatedOperatorsDict := map[types.TaskResponseDigest]aggregatedOperators{} + for { + select { + case signedTaskResponseDigest := <-signedTaskRespsC: + a.logger.Debug("Task goroutine received new signed task response digest", "taskIndex", taskIndex, "signedTaskResponseDigest", signedTaskResponseDigest) + signedTaskResponseDigest.SignatureVerificationErrorC <- a.verifySignature(taskIndex, signedTaskResponseDigest, operatorsAvsStateDict) + // after verifying signature we aggregate its sig and pubkey, and update the signed stake amount + digestAggregatedOperators, ok := aggregatedOperatorsDict[signedTaskResponseDigest.TaskResponseDigest] + if !ok { + // first operator to sign on this digest + digestAggregatedOperators = aggregatedOperators{ + // we've already verified that the operator is part of the task's quorum, so we don't need checks here + signersApkG2: bls.NewZeroG2Point().Add(operatorsAvsStateDict[signedTaskResponseDigest.OperatorId].Pubkeys.G2Pubkey), + signersAggSigG1: signedTaskResponseDigest.BlsSignature, + signersOperatorIdsSet: map[types.OperatorId]bool{signedTaskResponseDigest.OperatorId: true}, + signersTotalStakePerQuorum: operatorsAvsStateDict[signedTaskResponseDigest.OperatorId].StakePerQuorum, + } + } else { + digestAggregatedOperators.signersAggSigG1.Add(signedTaskResponseDigest.BlsSignature) + digestAggregatedOperators.signersApkG2.Add(operatorsAvsStateDict[signedTaskResponseDigest.OperatorId].Pubkeys.G2Pubkey) + digestAggregatedOperators.signersOperatorIdsSet[signedTaskResponseDigest.OperatorId] = true + for quorumNum, stake := range operatorsAvsStateDict[signedTaskResponseDigest.OperatorId].StakePerQuorum { + if _, ok := digestAggregatedOperators.signersTotalStakePerQuorum[quorumNum]; !ok { + // if we haven't seen this quorum before, initialize its signed stake to 0 + // possible if previous operators who sent us signatures were not part of this quorum + digestAggregatedOperators.signersTotalStakePerQuorum[quorumNum] = big.NewInt(0) + } + digestAggregatedOperators.signersTotalStakePerQuorum[quorumNum].Add(digestAggregatedOperators.signersTotalStakePerQuorum[quorumNum], stake) + } + } + // update the aggregatedOperatorsDict. Note that we need to assign the whole struct value at once, + // because of https://github.com/golang/go/issues/3117 + aggregatedOperatorsDict[signedTaskResponseDigest.TaskResponseDigest] = digestAggregatedOperators + + if checkIfStakeThresholdsMet(digestAggregatedOperators.signersTotalStakePerQuorum, totalStakePerQuorum, quorumThresholdPercentagesMap) { + nonSignersOperatorIds := []types.OperatorId{} + for operatorId := range operatorsAvsStateDict { + if _, operatorSigned := digestAggregatedOperators.signersOperatorIdsSet[operatorId]; !operatorSigned { + nonSignersOperatorIds = append(nonSignersOperatorIds, operatorId) + } + } + indices, err := a.avsRegistryService.GetCheckSignaturesIndices(context.Background(), taskCreatedBlock, quorumNumbers, nonSignersOperatorIds) + if err != nil { + a.logger.Error("Failed to get check signatures indices", "err", err) + a.aggregatedResponsesC <- BlsAggregationServiceResponse{ + Err: err, + } + return + } + blsAggregationServiceResponse := BlsAggregationServiceResponse{ + Err: nil, + TaskIndex: taskIndex, + TaskResponseDigest: signedTaskResponseDigest.TaskResponseDigest, + NonSignersPubkeysG1: getG1PubkeysOfNonSigners(digestAggregatedOperators.signersOperatorIdsSet, operatorsAvsStateDict), + QuorumApksG1: quorumApksG1, + SignersApkG2: digestAggregatedOperators.signersApkG2, + SignersAggSigG1: digestAggregatedOperators.signersAggSigG1, + NonSignerQuorumBitmapIndices: indices.NonSignerQuorumBitmapIndices, + QuorumApkIndices: indices.QuorumApkIndices, + TotalStakeIndices: indices.TotalStakeIndices, + NonSignerStakeIndices: indices.NonSignerStakeIndices, + } + a.aggregatedResponsesC <- blsAggregationServiceResponse + return + } + case <-taskExpiredTimer.C: + a.aggregatedResponsesC <- BlsAggregationServiceResponse{ + Err: TaskExpiredError, + } + return + } + } + +} + +// closeTaskGoroutine is run when the goroutine processing taskIndex's task responses ends (for whatever reason) +// it deletes the response channel for taskIndex from a.taskChans +// so that the main thread knows that this task goroutine is no longer running +// and doesn't try to send new signatures to it +func (a *BlsAggregatorService) closeTaskGoroutine(taskIndex types.TaskIndex) { + a.taskChansMutex.Lock() + delete(a.signedTaskRespsCs, taskIndex) + a.taskChansMutex.Unlock() +} + +// verifySignature verifies that a signature is valid against the operator pubkey stored in the +// operatorsAvsStateDict for that particular task +// TODO(samlaf): right now we are only checking that the *digest* is signed correctly!! +// we could be sent a signature of any kind of garbage and we would happily aggregate it +// this forces the avs code to verify that the digest is indeed the digest of a valid taskResponse +// we could take taskResponse as an interface{} and have avs code pass us a taskResponseHashFunction +// that we could use to hash and verify the taskResponse itself +func (a *BlsAggregatorService) verifySignature( + taskIndex types.TaskIndex, + signedTaskResponseDigest types.SignedTaskResponseDigest, + operatorsAvsStateDict map[types.OperatorId]types.OperatorAvsState, +) error { + _, ok := operatorsAvsStateDict[signedTaskResponseDigest.OperatorId] + if !ok { + a.logger.Warnf("Operator %#v not found. Skipping message.", signedTaskResponseDigest.OperatorId) + return OperatorNotPartOfTaskQuorumErrorFn(signedTaskResponseDigest.OperatorId, taskIndex) + } + + // 0. verify that the msg actually came from the correct operator + operatorG2Pubkey := operatorsAvsStateDict[signedTaskResponseDigest.OperatorId].Pubkeys.G2Pubkey + if operatorG2Pubkey == nil { + a.logger.Fatal("Operator G2 pubkey not found") + } + a.logger.Debug("Verifying signed task response digest signature", + "operatorG2Pubkey", operatorG2Pubkey, + "taskResponseDigest", signedTaskResponseDigest.TaskResponseDigest, + "blsSignature", signedTaskResponseDigest.BlsSignature, + ) + signatureVerified, err := signedTaskResponseDigest.BlsSignature.Verify(operatorG2Pubkey, signedTaskResponseDigest.TaskResponseDigest) + if err != nil { + a.logger.Error(SignatureVerificationError(err).Error()) + return SignatureVerificationError(err) + } + if !signatureVerified { + a.logger.Error(IncorrectSignatureError.Error()) + return IncorrectSignatureError + } + return nil +} + +// checkIfStakeThresholdsMet checks at least quorumThresholdPercentage of stake +// has signed for each quorum. +func checkIfStakeThresholdsMet( + signedStakePerQuorum map[types.QuorumNum]*big.Int, + totalStakePerQuorum map[types.QuorumNum]*big.Int, + quorumThresholdPercentagesMap map[types.QuorumNum]types.QuorumThresholdPercentage, +) bool { + for quorumNum, quorumThresholdPercentage := range quorumThresholdPercentagesMap { + // we check that signedStake >= totalStake * quorumThresholdPercentage / 100 + // to be exact (and do like the contracts), we actually check that + // signedStake * 100 >= totalStake * quorumThresholdPercentage + signedStake := big.NewInt(0).Mul(signedStakePerQuorum[quorumNum], big.NewInt(100)) + thresholdStake := big.NewInt(0).Mul(totalStakePerQuorum[quorumNum], big.NewInt(int64(quorumThresholdPercentage))) + if signedStake.Cmp(thresholdStake) < 0 { + return false + } + } + return true +} + +func getG1PubkeysOfNonSigners(signersOperatorIdsSet map[types.OperatorId]bool, operatorAvsStateDict map[[32]byte]types.OperatorAvsState) []*bls.G1Point { + nonSignersG1Pubkeys := []*bls.G1Point{} + for operatorId, operator := range operatorAvsStateDict { + if _, operatorSigned := signersOperatorIdsSet[operatorId]; !operatorSigned { + nonSignersG1Pubkeys = append(nonSignersG1Pubkeys, operator.Pubkeys.G1Pubkey) + } + } + return nonSignersG1Pubkeys +} diff --git a/services/bls_aggregation/blsagg_test.go b/services/bls_aggregation/blsagg_test.go new file mode 100644 index 00000000..358e5fd3 --- /dev/null +++ b/services/bls_aggregation/blsagg_test.go @@ -0,0 +1,447 @@ +package blsagg + +import ( + "context" + "math/big" + "testing" + "time" + + "github.com/Layr-Labs/eigensdk-go/crypto/bls" + "github.com/Layr-Labs/eigensdk-go/logging" + "github.com/Layr-Labs/eigensdk-go/services/avsregistry" + "github.com/Layr-Labs/eigensdk-go/types" + "github.com/stretchr/testify/require" +) + +// TestBlsAgg is a suite of test that tests the main aggregation logic of the aggregation service +// it don't check any of the indices fields because those are just provided as a convenience to the caller +// and aren't related to the main logic which we actually need to test +// they are gotten from a call to the chain at the end of the aggregation so we should test that elsewhere +func TestBlsAgg(t *testing.T) { + + // we hardcode this for now, until we implement this feature properly + // 1 second seems to be enough for tests to pass. Currently takes 5s to run all tests + tasksTimeToExpiry := 1 * time.Second + + t.Run("1 quorum 1 operator 1 correct signature", func(t *testing.T) { + testOperator1 := types.TestOperator{ + OperatorId: types.OperatorId{1}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x1"), + } + blockNum := uint32(1) + taskIndex := types.TaskIndex(0) + quorumNumbers := []types.QuorumNum{0} + quorumThresholdPercentages := []types.QuorumThresholdPercentage{100} + taskResponseDigest := types.TaskResponseDigest{123} + blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) + + fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) + noopLogger := logging.NewNoopLogger() + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + + err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + require.Nil(t, err) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSig, testOperator1.OperatorId) + require.Nil(t, err) + wantAggregationServiceResponse := BlsAggregationServiceResponse{ + Err: nil, + TaskIndex: taskIndex, + TaskResponseDigest: taskResponseDigest, + NonSignersPubkeysG1: []*bls.G1Point{}, + QuorumApksG1: []*bls.G1Point{testOperator1.BlsKeypair.GetPubKeyG1()}, + SignersApkG2: testOperator1.BlsKeypair.GetPubKeyG2(), + SignersAggSigG1: testOperator1.BlsKeypair.SignMessage(taskResponseDigest), + } + gotAggregationServiceResponse := <-blsAggServ.aggregatedResponsesC + require.Equal(t, wantAggregationServiceResponse, gotAggregationServiceResponse) + }) + + t.Run("1 quorum 3 operator 3 correct signatures", func(t *testing.T) { + testOperator1 := types.TestOperator{ + OperatorId: types.OperatorId{1}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x1"), + } + testOperator2 := types.TestOperator{ + OperatorId: types.OperatorId{2}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x2"), + } + testOperator3 := types.TestOperator{ + OperatorId: types.OperatorId{3}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(300), 1: big.NewInt(100)}, + BlsKeypair: newBlsKeyPairPanics("0x3"), + } + taskIndex := types.TaskIndex(0) + quorumNumbers := []types.QuorumNum{0} + quorumThresholdPercentages := []types.QuorumThresholdPercentage{100} + taskResponseDigest := types.TaskResponseDigest{123} + blockNum := uint32(1) + + fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3}) + noopLogger := logging.NewNoopLogger() + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + + err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + require.Nil(t, err) + blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + require.Nil(t, err) + blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp2, testOperator2.OperatorId) + require.Nil(t, err) + blsSigOp3 := testOperator3.BlsKeypair.SignMessage(taskResponseDigest) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp3, testOperator3.OperatorId) + require.Nil(t, err) + + wantAggregationServiceResponse := BlsAggregationServiceResponse{ + Err: nil, + TaskIndex: taskIndex, + TaskResponseDigest: taskResponseDigest, + NonSignersPubkeysG1: []*bls.G1Point{}, + QuorumApksG1: []*bls.G1Point{testOperator1.BlsKeypair.GetPubKeyG1(). + Add(testOperator2.BlsKeypair.GetPubKeyG1()). + Add(testOperator3.BlsKeypair.GetPubKeyG1()), + }, + SignersApkG2: testOperator1.BlsKeypair.GetPubKeyG2(). + Add(testOperator2.BlsKeypair.GetPubKeyG2()). + Add(testOperator3.BlsKeypair.GetPubKeyG2()), + SignersAggSigG1: testOperator1.BlsKeypair.SignMessage(taskResponseDigest). + Add(testOperator2.BlsKeypair.SignMessage(taskResponseDigest)). + Add(testOperator3.BlsKeypair.SignMessage(taskResponseDigest)), + } + gotAggregationServiceResponse := <-blsAggServ.aggregatedResponsesC + require.Equal(t, wantAggregationServiceResponse, gotAggregationServiceResponse) + }) + + t.Run("2 quorums 2 operators 2 correct signatures", func(t *testing.T) { + testOperator1 := types.TestOperator{ + OperatorId: types.OperatorId{1}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x1"), + } + testOperator2 := types.TestOperator{ + OperatorId: types.OperatorId{2}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x2"), + } + taskIndex := types.TaskIndex(0) + quorumNumbers := []types.QuorumNum{0, 1} + quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} + taskResponseDigest := types.TaskResponseDigest{123} + blockNum := uint32(1) + + fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) + noopLogger := logging.NewNoopLogger() + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + + err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + require.Nil(t, err) + blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + require.Nil(t, err) + blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp2, testOperator2.OperatorId) + require.Nil(t, err) + + wantAggregationServiceResponse := BlsAggregationServiceResponse{ + Err: nil, + TaskIndex: taskIndex, + TaskResponseDigest: taskResponseDigest, + NonSignersPubkeysG1: []*bls.G1Point{}, + QuorumApksG1: []*bls.G1Point{ + bls.NewZeroG1Point().Add(testOperator1.BlsKeypair.GetPubKeyG1()).Add(testOperator2.BlsKeypair.GetPubKeyG1()), + bls.NewZeroG1Point().Add(testOperator1.BlsKeypair.GetPubKeyG1()).Add(testOperator2.BlsKeypair.GetPubKeyG1()), + }, + SignersApkG2: testOperator1.BlsKeypair.GetPubKeyG2().Add(testOperator2.BlsKeypair.GetPubKeyG2()), + SignersAggSigG1: testOperator1.BlsKeypair.SignMessage(taskResponseDigest).Add(testOperator2.BlsKeypair.SignMessage(taskResponseDigest)), + } + gotAggregationServiceResponse := <-blsAggServ.aggregatedResponsesC + require.EqualValues(t, wantAggregationServiceResponse, gotAggregationServiceResponse) + }) + + t.Run("2 concurrent tasks 2 quorums 2 operators 2 correct signatures", func(t *testing.T) { + testOperator1 := types.TestOperator{ + OperatorId: types.OperatorId{1}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x1"), + } + testOperator2 := types.TestOperator{ + OperatorId: types.OperatorId{2}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x2"), + } + quorumNumbers := []types.QuorumNum{0, 1} + quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} + blockNum := uint32(1) + + fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) + noopLogger := logging.NewNoopLogger() + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + + // initialize 2 concurrent tasks + task1Index := types.TaskIndex(1) + task1ResponseDigest := types.TaskResponseDigest{123} + err := blsAggServ.InitializeNewTask(task1Index, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + require.Nil(t, err) + task2Index := types.TaskIndex(2) + task2ResponseDigest := types.TaskResponseDigest{230} + err = blsAggServ.InitializeNewTask(task2Index, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + require.Nil(t, err) + + // Don't change the order of these, as the checks below assume task1 is completed first + blsSigTask1Op1 := testOperator1.BlsKeypair.SignMessage(task1ResponseDigest) + err = blsAggServ.ProcessNewSignature(context.Background(), task1Index, task1ResponseDigest, blsSigTask1Op1, testOperator1.OperatorId) + require.Nil(t, err) + blsSigTask2Op1 := testOperator1.BlsKeypair.SignMessage(task2ResponseDigest) + err = blsAggServ.ProcessNewSignature(context.Background(), task2Index, task2ResponseDigest, blsSigTask2Op1, testOperator1.OperatorId) + require.Nil(t, err) + blsSigTask1Op2 := testOperator2.BlsKeypair.SignMessage(task1ResponseDigest) + err = blsAggServ.ProcessNewSignature(context.Background(), task1Index, task1ResponseDigest, blsSigTask1Op2, testOperator2.OperatorId) + require.Nil(t, err) + blsSigTask2Op2 := testOperator2.BlsKeypair.SignMessage(task2ResponseDigest) + err = blsAggServ.ProcessNewSignature(context.Background(), task2Index, task2ResponseDigest, blsSigTask2Op2, testOperator2.OperatorId) + require.Nil(t, err) + + wantAggregationServiceResponseTask1 := BlsAggregationServiceResponse{ + Err: nil, + TaskIndex: task1Index, + TaskResponseDigest: task1ResponseDigest, + NonSignersPubkeysG1: []*bls.G1Point{}, + QuorumApksG1: []*bls.G1Point{ + bls.NewZeroG1Point().Add(testOperator1.BlsKeypair.GetPubKeyG1()).Add(testOperator2.BlsKeypair.GetPubKeyG1()), + bls.NewZeroG1Point().Add(testOperator1.BlsKeypair.GetPubKeyG1()).Add(testOperator2.BlsKeypair.GetPubKeyG1()), + }, + SignersApkG2: bls.NewZeroG2Point().Add(testOperator1.BlsKeypair.GetPubKeyG2().Add(testOperator2.BlsKeypair.GetPubKeyG2())), + SignersAggSigG1: testOperator1.BlsKeypair.SignMessage(task1ResponseDigest).Add(testOperator2.BlsKeypair.SignMessage(task1ResponseDigest)), + } + gotAggregationServiceResponseTask1 := <-blsAggServ.aggregatedResponsesC + require.EqualValues(t, wantAggregationServiceResponseTask1, gotAggregationServiceResponseTask1) + + wantAggregationServiceResponseTask2 := BlsAggregationServiceResponse{ + Err: nil, + TaskIndex: task2Index, + TaskResponseDigest: task2ResponseDigest, + NonSignersPubkeysG1: []*bls.G1Point{}, + QuorumApksG1: []*bls.G1Point{ + bls.NewZeroG1Point().Add(testOperator1.BlsKeypair.GetPubKeyG1()).Add(testOperator2.BlsKeypair.GetPubKeyG1()), + bls.NewZeroG1Point().Add(testOperator1.BlsKeypair.GetPubKeyG1()).Add(testOperator2.BlsKeypair.GetPubKeyG1()), + }, + SignersApkG2: testOperator1.BlsKeypair.GetPubKeyG2().Add(testOperator2.BlsKeypair.GetPubKeyG2()), + SignersAggSigG1: testOperator1.BlsKeypair.SignMessage(task2ResponseDigest).Add(testOperator2.BlsKeypair.SignMessage(task2ResponseDigest)), + } + gotAggregationServiceResponseTask2 := <-blsAggServ.aggregatedResponsesC + require.EqualValues(t, wantAggregationServiceResponseTask2, gotAggregationServiceResponseTask2) + }) + + t.Run("1 quorum 1 operator 0 signatures - task expired", func(t *testing.T) { + testOperator1 := types.TestOperator{ + OperatorId: types.OperatorId{1}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x1"), + } + taskIndex := types.TaskIndex(0) + quorumNumbers := []types.QuorumNum{0} + quorumThresholdPercentages := []types.QuorumThresholdPercentage{100} + blockNum := uint32(1) + + fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) + noopLogger := logging.NewNoopLogger() + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + + err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + require.Nil(t, err) + wantAggregationServiceResponse := BlsAggregationServiceResponse{ + Err: TaskExpiredError, + } + gotAggregationServiceResponse := <-blsAggServ.aggregatedResponsesC + require.Equal(t, wantAggregationServiceResponse, gotAggregationServiceResponse) + }) + + t.Run("1 quorum 2 operator 1 correct signature quorumThreshold 50% - verified", func(t *testing.T) { + testOperator1 := types.TestOperator{ + OperatorId: types.OperatorId{1}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x1"), + } + testOperator2 := types.TestOperator{ + OperatorId: types.OperatorId{2}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x2"), + } + taskIndex := types.TaskIndex(0) + quorumNumbers := []types.QuorumNum{0} + quorumThresholdPercentages := []types.QuorumThresholdPercentage{50} + taskResponseDigest := types.TaskResponseDigest{123} + blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) + blockNum := uint32(1) + + fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) + noopLogger := logging.NewNoopLogger() + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + + err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + require.Nil(t, err) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSig, testOperator1.OperatorId) + require.Nil(t, err) + wantAggregationServiceResponse := BlsAggregationServiceResponse{ + Err: nil, + TaskIndex: taskIndex, + TaskResponseDigest: taskResponseDigest, + NonSignersPubkeysG1: []*bls.G1Point{testOperator2.BlsKeypair.GetPubKeyG1()}, + QuorumApksG1: []*bls.G1Point{testOperator1.BlsKeypair.GetPubKeyG1().Add(testOperator2.BlsKeypair.GetPubKeyG1())}, + SignersApkG2: testOperator1.BlsKeypair.GetPubKeyG2(), + SignersAggSigG1: testOperator1.BlsKeypair.SignMessage(taskResponseDigest), + } + gotAggregationServiceResponse := <-blsAggServ.aggregatedResponsesC + require.Equal(t, wantAggregationServiceResponse, gotAggregationServiceResponse) + }) + + t.Run("1 quorum 2 operator 1 correct signature quorumThreshold 60% - task expired", func(t *testing.T) { + testOperator1 := types.TestOperator{ + OperatorId: types.OperatorId{1}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x1"), + } + testOperator2 := types.TestOperator{ + OperatorId: types.OperatorId{2}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x2"), + } + blockNum := uint32(1) + taskIndex := types.TaskIndex(0) + quorumNumbers := []types.QuorumNum{0} + quorumThresholdPercentages := []types.QuorumThresholdPercentage{60} + taskResponseDigest := types.TaskResponseDigest{123} + blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) + + fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) + noopLogger := logging.NewNoopLogger() + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + + err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + require.Nil(t, err) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSig, testOperator1.OperatorId) + require.Nil(t, err) + wantAggregationServiceResponse := BlsAggregationServiceResponse{ + Err: TaskExpiredError, + } + gotAggregationServiceResponse := <-blsAggServ.aggregatedResponsesC + require.Equal(t, wantAggregationServiceResponse, gotAggregationServiceResponse) + }) + + t.Run("send signature of task that isn't initialized - task not found error", func(t *testing.T) { + testOperator1 := types.TestOperator{ + OperatorId: types.OperatorId{1}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x1"), + } + blockNum := uint32(1) + taskIndex := types.TaskIndex(0) + taskResponseDigest := types.TaskResponseDigest{123} + blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) + + fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) + noopLogger := logging.NewNoopLogger() + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + + err := blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSig, testOperator1.OperatorId) + require.Equal(t, TaskNotFoundErrorFn(taskIndex), err) + }) + + // this is an edge case as typically we would send new tasks and listen for task responses in a for select loop + // but this test makes sure the context deadline exceeded can get us out of a deadlock + t.Run("send new signedTaskDigest before listen on responseChan - context timeout cancels the request to prevent deadlock", func(t *testing.T) { + testOperator1 := types.TestOperator{ + OperatorId: types.OperatorId{1}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x1"), + } + testOperator2 := types.TestOperator{ + OperatorId: types.OperatorId{2}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x2"), + } + blockNum := uint32(1) + taskIndex := types.TaskIndex(0) + quorumNumbers := []types.QuorumNum{0} + quorumThresholdPercentages := []types.QuorumThresholdPercentage{100} + + fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) + noopLogger := logging.NewNoopLogger() + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + + err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + require.Nil(t, err) + taskResponseDigest1 := types.TaskResponseDigest{1} + blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest1) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest1, blsSigOp1, testOperator1.OperatorId) + require.Nil(t, err) + + taskResponseDigest2 := types.TaskResponseDigest{2} + blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest2) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + err = blsAggServ.ProcessNewSignature(ctx, taskIndex, taskResponseDigest2, blsSigOp2, testOperator2.OperatorId) + // this should timeout because the task goroutine is blocked on the response channel (since we only listen for it below) + require.Equal(t, context.DeadlineExceeded, err) + + wantAggregationServiceResponse := BlsAggregationServiceResponse{ + Err: nil, + TaskIndex: taskIndex, + TaskResponseDigest: taskResponseDigest1, + NonSignersPubkeysG1: []*bls.G1Point{}, + QuorumApksG1: []*bls.G1Point{testOperator1.BlsKeypair.GetPubKeyG1()}, + SignersApkG2: testOperator1.BlsKeypair.GetPubKeyG2(), + SignersAggSigG1: testOperator1.BlsKeypair.SignMessage(taskResponseDigest1), + } + gotAggregationServiceResponse := <-blsAggServ.aggregatedResponsesC + require.Equal(t, wantAggregationServiceResponse, gotAggregationServiceResponse) + }) + + t.Run("1 quorum 2 operator 2 signatures on 2 different msgs - task expired", func(t *testing.T) { + testOperator1 := types.TestOperator{ + OperatorId: types.OperatorId{1}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x1"), + } + testOperator2 := types.TestOperator{ + OperatorId: types.OperatorId{2}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x2"), + } + blockNum := uint32(1) + taskIndex := types.TaskIndex(0) + quorumNumbers := []types.QuorumNum{0} + quorumThresholdPercentages := []types.QuorumThresholdPercentage{100} + + fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) + noopLogger := logging.NewNoopLogger() + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + + err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + require.Nil(t, err) + taskResponseDigest1 := types.TaskResponseDigest{1} + blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest1) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest1, blsSigOp1, testOperator1.OperatorId) + require.Nil(t, err) + taskResponseDigest2 := types.TaskResponseDigest{2} + blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest2) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest2, blsSigOp2, testOperator2.OperatorId) + require.Nil(t, err) + wantAggregationServiceResponse := BlsAggregationServiceResponse{ + Err: TaskExpiredError, + } + gotAggregationServiceResponse := <-blsAggServ.aggregatedResponsesC + require.Equal(t, wantAggregationServiceResponse, gotAggregationServiceResponse) + }) +} + +func newBlsKeyPairPanics(hexKey string) *bls.KeyPair { + keypair, err := bls.NewKeyPairFromString(hexKey) + if err != nil { + panic(err) + } + return keypair +} diff --git a/services/gen.go b/services/gen.go index 99cc5299..88ddba28 100644 --- a/services/gen.go +++ b/services/gen.go @@ -2,3 +2,12 @@ package services //go:generate mockgen -destination=./mocks/pubkeycompendium.go -package=mocks github.com/Layr-Labs/eigensdk-go/services/pubkeycompendium PubkeyCompendiumService //go:generate mockgen -destination=./mocks/avsregistry.go -package=mocks github.com/Layr-Labs/eigensdk-go/services/avsregistry AvsRegistryService + +// We generate it in ./mocks/blsagg/ instead of ./mocks like the others because otherwise we get a circular dependency +// avsregistry -> mocks -> avsregistry +// because avsregistry_chaincaller_test -> for pubkeycompendium mock +// and blsaggregation mock -> avsregistry interface +// TODO: are there better ways to organize these dependencies? Maybe by using ben johnson +// and having teh avs registry interface be in the /avsregistry dir but the avsregistry_chaincaller +// and its test in a subdir? +//go:generate mockgen -destination=./mocks/blsagg/blsaggregation.go -package=mocks github.com/Layr-Labs/eigensdk-go/services/bls_aggregation BlsAggregationService diff --git a/services/mocks/avsregistry.go b/services/mocks/avsregistry.go index 0fa3842d..7888b5d5 100644 --- a/services/mocks/avsregistry.go +++ b/services/mocks/avsregistry.go @@ -8,7 +8,7 @@ import ( context "context" reflect "reflect" - bls "github.com/Layr-Labs/eigensdk-go/crypto/bls" + contractBLSOperatorStateRetriever "github.com/Layr-Labs/eigensdk-go/contracts/bindings/BLSOperatorStateRetriever" types "github.com/Layr-Labs/eigensdk-go/types" gomock "go.uber.org/mock/gomock" ) @@ -36,29 +36,28 @@ func (m *MockAvsRegistryService) EXPECT() *MockAvsRegistryServiceMockRecorder { return m.recorder } -// GetOperatorPubkeys mocks base method. -func (m *MockAvsRegistryService) GetOperatorPubkeys(arg0 context.Context, arg1 [32]byte) (types.OperatorPubkeys, error) { +// GetCheckSignaturesIndices mocks base method. +func (m *MockAvsRegistryService) GetCheckSignaturesIndices(arg0 context.Context, arg1 uint32, arg2 []byte, arg3 [][32]byte) (contractBLSOperatorStateRetriever.BLSOperatorStateRetrieverCheckSignaturesIndices, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetOperatorPubkeys", arg0, arg1) - ret0, _ := ret[0].(types.OperatorPubkeys) + ret := m.ctrl.Call(m, "GetCheckSignaturesIndices", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(contractBLSOperatorStateRetriever.BLSOperatorStateRetrieverCheckSignaturesIndices) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetOperatorPubkeys indicates an expected call of GetOperatorPubkeys. -func (mr *MockAvsRegistryServiceMockRecorder) GetOperatorPubkeys(arg0, arg1 interface{}) *gomock.Call { +// GetCheckSignaturesIndices indicates an expected call of GetCheckSignaturesIndices. +func (mr *MockAvsRegistryServiceMockRecorder) GetCheckSignaturesIndices(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOperatorPubkeys", reflect.TypeOf((*MockAvsRegistryService)(nil).GetOperatorPubkeys), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCheckSignaturesIndices", reflect.TypeOf((*MockAvsRegistryService)(nil).GetCheckSignaturesIndices), arg0, arg1, arg2, arg3) } // GetOperatorsAvsStateAtBlock mocks base method. -func (m *MockAvsRegistryService) GetOperatorsAvsStateAtBlock(arg0 context.Context, arg1 []byte, arg2 uint32) (map[[32]byte]types.OperatorAvsState, map[byte]*bls.G1Point, error) { +func (m *MockAvsRegistryService) GetOperatorsAvsStateAtBlock(arg0 context.Context, arg1 []byte, arg2 uint32) (map[[32]byte]types.OperatorAvsState, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetOperatorsAvsStateAtBlock", arg0, arg1, arg2) ret0, _ := ret[0].(map[[32]byte]types.OperatorAvsState) - ret1, _ := ret[1].(map[byte]*bls.G1Point) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetOperatorsAvsStateAtBlock indicates an expected call of GetOperatorsAvsStateAtBlock. @@ -66,3 +65,18 @@ func (mr *MockAvsRegistryServiceMockRecorder) GetOperatorsAvsStateAtBlock(arg0, mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOperatorsAvsStateAtBlock", reflect.TypeOf((*MockAvsRegistryService)(nil).GetOperatorsAvsStateAtBlock), arg0, arg1, arg2) } + +// GetQuorumsAvsStateAtBlock mocks base method. +func (m *MockAvsRegistryService) GetQuorumsAvsStateAtBlock(arg0 context.Context, arg1 []byte, arg2 uint32) (map[byte]types.QuorumAvsState, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetQuorumsAvsStateAtBlock", arg0, arg1, arg2) + ret0, _ := ret[0].(map[byte]types.QuorumAvsState) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetQuorumsAvsStateAtBlock indicates an expected call of GetQuorumsAvsStateAtBlock. +func (mr *MockAvsRegistryServiceMockRecorder) GetQuorumsAvsStateAtBlock(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetQuorumsAvsStateAtBlock", reflect.TypeOf((*MockAvsRegistryService)(nil).GetQuorumsAvsStateAtBlock), arg0, arg1, arg2) +} diff --git a/services/mocks/blsagg/blsaggregation.go b/services/mocks/blsagg/blsaggregation.go new file mode 100644 index 00000000..4df7ab87 --- /dev/null +++ b/services/mocks/blsagg/blsaggregation.go @@ -0,0 +1,81 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/Layr-Labs/eigensdk-go/services/bls_aggregation (interfaces: BlsAggregationService) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + time "time" + + bls "github.com/Layr-Labs/eigensdk-go/crypto/bls" + blsagg "github.com/Layr-Labs/eigensdk-go/services/bls_aggregation" + types "github.com/Layr-Labs/eigensdk-go/types" + gomock "go.uber.org/mock/gomock" +) + +// MockBlsAggregationService is a mock of BlsAggregationService interface. +type MockBlsAggregationService struct { + ctrl *gomock.Controller + recorder *MockBlsAggregationServiceMockRecorder +} + +// MockBlsAggregationServiceMockRecorder is the mock recorder for MockBlsAggregationService. +type MockBlsAggregationServiceMockRecorder struct { + mock *MockBlsAggregationService +} + +// NewMockBlsAggregationService creates a new mock instance. +func NewMockBlsAggregationService(ctrl *gomock.Controller) *MockBlsAggregationService { + mock := &MockBlsAggregationService{ctrl: ctrl} + mock.recorder = &MockBlsAggregationServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBlsAggregationService) EXPECT() *MockBlsAggregationServiceMockRecorder { + return m.recorder +} + +// GetResponseChannel mocks base method. +func (m *MockBlsAggregationService) GetResponseChannel() <-chan blsagg.BlsAggregationServiceResponse { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetResponseChannel") + ret0, _ := ret[0].(<-chan blsagg.BlsAggregationServiceResponse) + return ret0 +} + +// GetResponseChannel indicates an expected call of GetResponseChannel. +func (mr *MockBlsAggregationServiceMockRecorder) GetResponseChannel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponseChannel", reflect.TypeOf((*MockBlsAggregationService)(nil).GetResponseChannel)) +} + +// InitializeNewTask mocks base method. +func (m *MockBlsAggregationService) InitializeNewTask(arg0, arg1 uint32, arg2 []byte, arg3 []uint32, arg4 time.Duration) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InitializeNewTask", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// InitializeNewTask indicates an expected call of InitializeNewTask. +func (mr *MockBlsAggregationServiceMockRecorder) InitializeNewTask(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitializeNewTask", reflect.TypeOf((*MockBlsAggregationService)(nil).InitializeNewTask), arg0, arg1, arg2, arg3, arg4) +} + +// ProcessNewSignature mocks base method. +func (m *MockBlsAggregationService) ProcessNewSignature(arg0 context.Context, arg1 uint32, arg2 types.TaskResponseDigest, arg3 *bls.Signature, arg4 [32]byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ProcessNewSignature", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// ProcessNewSignature indicates an expected call of ProcessNewSignature. +func (mr *MockBlsAggregationServiceMockRecorder) ProcessNewSignature(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProcessNewSignature", reflect.TypeOf((*MockBlsAggregationService)(nil).ProcessNewSignature), arg0, arg1, arg2, arg3, arg4) +} diff --git a/types/avs.go b/types/avs.go new file mode 100644 index 00000000..92f02fb3 --- /dev/null +++ b/types/avs.go @@ -0,0 +1,13 @@ +package types + +import "github.com/Layr-Labs/eigensdk-go/crypto/bls" + +type TaskIndex = uint32 +type TaskResponseDigest [32]byte + +type SignedTaskResponseDigest struct { + TaskResponseDigest TaskResponseDigest + BlsSignature *bls.Signature + OperatorId bls.OperatorId + SignatureVerificationErrorC chan error +} diff --git a/types/operator.go b/types/operator.go index 232f21e7..173ebdcc 100644 --- a/types/operator.go +++ b/types/operator.go @@ -96,7 +96,8 @@ type StakeAmount = *big.Int // It is currently the hash of the operator's G1 pubkey (in the bls pubkey registry) type OperatorId = [32]byte type QuorumNum = uint8 -type BlockNumber = uint32 +type QuorumThresholdPercentage = uint32 +type BlockNum = uint32 // AvsOperator represents the operator state in AVS registries type OperatorAvsState struct { @@ -104,7 +105,7 @@ type OperatorAvsState struct { Pubkeys OperatorPubkeys // Stake of the operator for each quorum StakePerQuorum map[QuorumNum]StakeAmount - BlockNumber BlockNumber + BlockNumber BlockNum } var ( @@ -121,3 +122,10 @@ func BitmapToQuorumIds(bitmap *big.Int) []QuorumNum { } return quorumIds } + +type QuorumAvsState struct { + QuorumNumber QuorumNum + TotalStake StakeAmount + AggPubkeyG1 *bls.G1Point + BlockNumber BlockNum +} diff --git a/types/test.go b/types/test.go new file mode 100644 index 00000000..01a754b9 --- /dev/null +++ b/types/test.go @@ -0,0 +1,9 @@ +package types + +import "github.com/Layr-Labs/eigensdk-go/crypto/bls" + +type TestOperator struct { + OperatorId OperatorId + StakePerQuorum map[QuorumNum]StakeAmount + BlsKeypair *bls.KeyPair +}