Skip to content

Commit

Permalink
Merge pull request #138 from TimeleapLabs/new-pos-schnorr
Browse files Browse the repository at this point in the history
Integrate the new PoS contract with Schnorr
  • Loading branch information
pouya-eghbali authored Jun 3, 2024
2 parents 6e0ecd3 + 1fa5ac6 commit 1bf6f91
Show file tree
Hide file tree
Showing 11 changed files with 1,533 additions and 2,803 deletions.
1,476 changes: 1,476 additions & 0 deletions internal/crypto/ethereum/contracts/ProofOfStake.go

Large diffs are not rendered by default.

2,644 changes: 0 additions & 2,644 deletions internal/crypto/ethereum/contracts/UnchainedStaking.go

This file was deleted.

6 changes: 3 additions & 3 deletions internal/crypto/ethereum/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
type RPC interface {
RefreshRPC(network string)
GetClient(network string) *ethclient.Client
GetNewStakingContract(network string, address string, refresh bool) (*contracts.UnchainedStaking, error)
GetNewStakingContract(network string, address string, refresh bool) (*contracts.ProofOfStake, error)
GetNewUniV3Contract(network string, address string, refresh bool) (*contracts.UniV3, error)
GetBlockNumber(ctx context.Context, network string) (uint64, error)
}
Expand Down Expand Up @@ -72,7 +72,7 @@ func (r *repository) RefreshRPC(network string) {
r.refreshRPCWithRetries(network, len(r.list))
}

func (r *repository) GetNewStakingContract(network string, address string, refresh bool) (*contracts.UnchainedStaking, error) {
func (r *repository) GetNewStakingContract(network string, address string, refresh bool) (*contracts.ProofOfStake, error) {
if refresh {
r.RefreshRPC(network)
}
Expand All @@ -82,7 +82,7 @@ func (r *repository) GetNewStakingContract(network string, address string, refre
return nil, consts.ErrClientNotFound
}

return contracts.NewUnchainedStaking(common.HexToAddress(address), client)
return contracts.NewProofOfStake(common.HexToAddress(address), client)
}

func (r *repository) GetNewUniV3Contract(network string, address string, refresh bool) (*contracts.UniV3, error) {
Expand Down
4 changes: 2 additions & 2 deletions internal/crypto/ethereum/rpc_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ func (m mockRPC) GetClient(_ string) *ethclient.Client {

func (m mockRPC) RefreshRPC(_ string) {}

func (m mockRPC) GetNewStakingContract(_ string, address string, _ bool) (*contracts.UnchainedStaking, error) {
return contracts.NewUnchainedStaking(
func (m mockRPC) GetNewStakingContract(_ string, address string, _ bool) (*contracts.ProofOfStake, error) {
return contracts.NewProofOfStake(
common.HexToAddress(address),
m.backend,
)
Expand Down
2 changes: 1 addition & 1 deletion internal/service/correctness/correctness.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (s *service) RecordSignature(
voted = *big.NewInt(0)
}

votingPower, err := s.pos.GetVotingPowerOfPublicKey(ctx, signer.PublicKey)
votingPower, err := s.pos.GetVotingPowerOfEvm(ctx, signer.EvmAddress)
if err != nil {
utils.Logger.
With("Address", address.Calculate(signer.PublicKey[:])).
Expand Down
2 changes: 1 addition & 1 deletion internal/service/evmlog/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (s *service) RecordSignature(
s.consensus.Add(key, make(map[bls12381.G1Affine]big.Int))
}

votingPower, err := s.pos.GetVotingPowerOfPublicKey(ctx, signer.PublicKey)
votingPower, err := s.pos.GetVotingPowerOfEvm(ctx, signer.EvmAddress)
if err != nil {
utils.Logger.
With("Address", address.Calculate(signer.PublicKey[:])).
Expand Down
75 changes: 1 addition & 74 deletions internal/service/pos/eip712.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,82 +4,9 @@ import (
"context"
"math/big"

"github.com/TimeleapLabs/unchained/internal/utils"

"github.com/TimeleapLabs/unchained/internal/crypto"
"github.com/TimeleapLabs/unchained/internal/crypto/ethereum/contracts"

"github.com/TimeleapLabs/unchained/internal/config"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
)

func (s *service) Slash(ctx context.Context, address [20]byte, to common.Address, amount *big.Int, nftIDs []*big.Int) error {
evmAddress, err := s.posContract.EvmAddressOf(nil, address)

if err != nil {
utils.Logger.
With("Error", err).
Error("Failed to get EVM address of the staker")
return err
}

transfer := contracts.UnchainedStakingEIP712Transfer{
From: evmAddress,
To: to,
Amount: amount,
NftIds: nftIDs,
}

signature, err := s.eip712Signer.SignTransferRequest(crypto.Identity.Eth, &transfer)

if err != nil {
utils.Logger.
With("Error", err).
Error("Failed to sign transfer request")
return err
}

tx, err := s.posContract.Transfer(
nil,
[]contracts.UnchainedStakingEIP712Transfer{transfer},
[]contracts.UnchainedStakingSignature{*signature},
)

if err != nil {
utils.Logger.
With("Error", err).
Error("Failed to transfer")
return err
}

receipt, err := bind.WaitMined(
ctx,
s.ethRPC.GetClient(config.App.ProofOfStake.Chain),
tx,
)

if err != nil {
utils.Logger.
With("Error", err).
Error("Failed to wait for transaction to be mined")
return err
}

if receipt.Status != types.ReceiptStatusSuccessful {
utils.Logger.
With("Error", err).
Error("Transaction failed")
return err
}

utils.Logger.
With("Address", evmAddress.Hex()).
With("To", to.Hex()).
With("Amount", amount.String()).
With("NftIds", nftIDs).
Info("Slashed")

func (s *service) Slash(_ context.Context, _ [20]byte, _ common.Address, _ *big.Int, _ []*big.Int) error {
return nil
}
68 changes: 4 additions & 64 deletions internal/service/pos/eip712/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"math/big"

"github.com/TimeleapLabs/unchained/internal/crypto/ethereum"
"github.com/TimeleapLabs/unchained/internal/crypto/ethereum/contracts"

"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/crypto"
Expand All @@ -16,15 +15,8 @@ type Signer struct {
domain apitypes.TypedDataDomain
}

func (s *Signer) bytesToUnchainedSignature(signature []byte) *contracts.UnchainedStakingSignature {
return &contracts.UnchainedStakingSignature{
V: signature[64],
R: [32]byte(signature[:32]),
S: [32]byte(signature[32:64]),
}
}

func (s *Signer) signEip712Message(evmSigner *ethereum.Signer, data *apitypes.TypedData) (*contracts.UnchainedStakingSignature, error) {
// TODO: Rewrite to use Schnorr signature scheme
func (s *Signer) SignEip712Message(evmSigner *ethereum.Signer, data *apitypes.TypedData) ([]byte, error) {
domainSeparator, err := data.HashStruct("EIP712Domain", data.Domain.Map())
if err != nil {
return nil, err
Expand All @@ -38,6 +30,7 @@ func (s *Signer) signEip712Message(evmSigner *ethereum.Signer, data *apitypes.Ty
message := []byte(fmt.Sprintf("\x19\x01%s%s", string(domainSeparator), string(typedDataHash)))
messageHash := crypto.Keccak256(message)

// This should be replaced with Schnorr signature scheme
signature, err := crypto.Sign(messageHash, evmSigner.PrivateKey)
if err != nil {
return nil, err
Expand All @@ -47,60 +40,7 @@ func (s *Signer) signEip712Message(evmSigner *ethereum.Signer, data *apitypes.Ty
signature[64] += 27
}

return s.bytesToUnchainedSignature(signature), nil
}

func (s *Signer) SignTransferRequest(evmSigner *ethereum.Signer, request *contracts.UnchainedStakingEIP712Transfer) (*contracts.UnchainedStakingSignature, error) {
data := &apitypes.TypedData{
Types: Types,
PrimaryType: "Transfer",
Domain: s.domain,
Message: map[string]interface{}{
"signer": evmSigner.Address,
"from": request.From,
"to": request.To,
"amount": request.Amount,
"nftIds": request.NftIds,
"nonces": request.Nonces,
},
}

return s.signEip712Message(evmSigner, data)
}

func (s *Signer) SignSetParamsRequest(evmSigner *ethereum.Signer, request *contracts.UnchainedStakingEIP712SetParams) (*contracts.UnchainedStakingSignature, error) {
data := &apitypes.TypedData{
Types: Types,
PrimaryType: "SetParams",
Domain: s.domain,
Message: map[string]interface{}{
"requester": evmSigner.Address,
"token": request.Token,
"nft": request.Nft,
"nftTracker": request.NftTracker,
"threshold": request.Threshold,
"expiration": request.Expiration,
"nonce": request.Nonce,
},
}

return s.signEip712Message(evmSigner, data)
}

func (s *Signer) SignSetNftPriceRequest(evmSigner *ethereum.Signer, request *contracts.UnchainedStakingEIP712SetNftPrice) (*contracts.UnchainedStakingSignature, error) {
data := &apitypes.TypedData{
Types: Types,
PrimaryType: "SetNftPrice",
Domain: s.domain,
Message: map[string]interface{}{
"requester": evmSigner.Address,
"nftId": request.NftId,
"price": request.Price,
"nonce": request.Nonce,
},
}

return s.signEip712Message(evmSigner, data)
return signature, nil
}

func New(chainID *big.Int, verifyingContract string) *Signer {
Expand Down
46 changes: 33 additions & 13 deletions internal/service/pos/pos.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,45 @@ import (
"github.com/TimeleapLabs/unchained/internal/service/pos/eip712"
"github.com/TimeleapLabs/unchained/internal/utils"
"github.com/TimeleapLabs/unchained/internal/utils/address"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
"github.com/puzpuzpuz/xsync/v3"
)

type Service interface {
GetTotalVotingPower() (*big.Int, error)
GetVotingPowerFromContract(address [20]byte, block *big.Int) (*big.Int, error)
GetVotingPower(address [20]byte, block *big.Int) (*big.Int, error)
GetVotingPowerOfEvm(ctx context.Context, evmAddress string) (*big.Int, error)
GetVotingPowerOfPublicKey(ctx context.Context, pkBytes [96]byte) (*big.Int, error)
GetSchnorrSigners(ctx context.Context) ([]common.Address, error)
}

type service struct {
ethRPC ethereum.RPC
posContract *contracts.UnchainedStaking
posContract *contracts.ProofOfStake
votingPowers *xsync.MapOf[[20]byte, *big.Int]
lastUpdated *xsync.MapOf[[20]byte, *big.Int]
base *big.Int
eip712Signer *eip712.Signer
}

func (s *service) GetTotalVotingPower() (*big.Int, error) {
return s.posContract.GetTotalVotingPower(nil)
return new(big.Int).Mul(big.NewInt(5e10), big.NewInt(1e18)), nil
// return s.posContract.GetTotalVotingPower(nil)
}

func (s *service) GetVotingPowerFromContract(address [20]byte, block *big.Int) (*big.Int, error) {
votingPower, err := s.posContract.GetVotingPower(nil, address)
stake, err := s.posContract.GetStake(nil, address)
// votingPower, err := s.posContract.GetVotingPower(nil, address)
if err != nil {
return votingPower, err
return stake.Amount, err
}

s.votingPowers.Store(address, votingPower)
s.votingPowers.Store(address, stake.Amount)
s.lastUpdated.Store(address, block)

return votingPower, nil
return stake.Amount, nil
}

func (s *service) minBase(power *big.Int) *big.Int {
Expand Down Expand Up @@ -74,6 +80,19 @@ func (s *service) GetVotingPower(address [20]byte, block *big.Int) (*big.Int, er
return s.base, nil
}

func (s *service) GetVotingPowerOfEvm(ctx context.Context, evmAddress string) (*big.Int, error) {
block, err := s.ethRPC.GetBlockNumber(ctx, config.App.ProofOfStake.Chain)
if err != nil {
return nil, err
}
address := common.HexToAddress(evmAddress)
return s.GetVotingPower(address, big.NewInt(int64(block)))
}

func (s *service) GetSchnorrSigners(ctx context.Context) ([]common.Address, error) {
return s.posContract.GetValidators(&bind.CallOpts{Context: ctx})
}

func (s *service) GetVotingPowerOfPublicKey(ctx context.Context, pkBytes [96]byte) (*big.Int, error) {
_, addrHex := address.CalculateHex(pkBytes[:])
block, err := s.ethRPC.GetBlockNumber(ctx, config.App.ProofOfStake.Chain)
Expand Down Expand Up @@ -136,15 +155,16 @@ func New(ethRPC ethereum.RPC) Service {
With("Network", utils.BigIntToFloat(total)).
Info("PoS")

chainID, err := s.posContract.GetChainId(nil)
if err != nil {
utils.Logger.
With("Error", err).
Error("Failed to get chain ID")
// chainID, err := s.posContract.GetChainId(nil)
// if err != nil {
// utils.Logger.
// With("Error", err).
// Error("Failed to get chain ID")

panic(err)
}
// panic(err)
// }

chainID := big.NewInt(421614)
s.eip712Signer = eip712.New(chainID, config.App.ProofOfStake.Address)

return s
Expand Down
11 changes: 11 additions & 0 deletions internal/service/pos/pos_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"math/big"

"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/mock"
)

Expand All @@ -26,7 +27,17 @@ func (m *MockService) GetVotingPower(address [20]byte, block *big.Int) (*big.Int
return big.NewInt(int64(args.Int(0))), args.Error(1)
}

func (m *MockService) GetVotingPowerOfEvm(ctx context.Context, evmAddress string) (*big.Int, error) {
args := m.Called(ctx, evmAddress)
return big.NewInt(int64(args.Int(0))), args.Error(1)
}

func (m *MockService) GetVotingPowerOfPublicKey(_ context.Context, pkBytes [96]byte) (*big.Int, error) {
args := m.Called(pkBytes)
return big.NewInt(int64(args.Int(0))), args.Error(1)
}

func (m *MockService) GetSchnorrSigners(_ context.Context) ([]common.Address, error) {
args := m.Called()
return args.Get(0).([]common.Address), args.Error(1)
}
2 changes: 1 addition & 1 deletion internal/service/uniswap/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (s *service) RecordSignature(
voted = *big.NewInt(0)
}

votingPower, err := s.pos.GetVotingPowerOfPublicKey(ctx, signer.PublicKey)
votingPower, err := s.pos.GetVotingPowerOfEvm(ctx, signer.EvmAddress)
if err != nil {
utils.Logger.
With("Address", address.Calculate(signer.PublicKey[:])).
Expand Down

0 comments on commit 1bf6f91

Please sign in to comment.