Skip to content

PoC: support custom keyring equality in ocr3 #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 26 additions & 24 deletions offchainreporting2plus/internal/config/ocr3config/shared_config.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package ocr3config

import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
Expand Down Expand Up @@ -48,7 +47,7 @@ func SharedConfigFromContractConfig[RI any](
skipResourceExhaustionChecks bool,
change types.ContractConfig,
offchainKeyring types.OffchainKeyring,
onchainKeyring ocr3types.OnchainKeyring[RI],
onchainKeyring ocr3types.ComparableOnchainKeyring[RI],
peerID string,
transmitAccount types.Account,
) (SharedConfig, commontypes.OracleID, error) {
Expand All @@ -59,32 +58,35 @@ func SharedConfigFromContractConfig[RI any](

oracleID := commontypes.OracleID(math.MaxUint8)
{
onchainPublicKey := onchainKeyring.PublicKey()
var onchainPublicKey types.OnchainPublicKey //:= onchainKeyring.PublicKey()
offchainPublicKey := offchainKeyring.OffchainPublicKey()
var found bool
for i, identity := range publicConfig.OracleIdentities {
if bytes.Equal(identity.OnchainPublicKey, onchainPublicKey) {
if identity.OffchainPublicKey != offchainPublicKey {
return SharedConfig{}, 0, errors.Errorf(
"OnchainPublicKey %x in publicConfig matches "+
"mine, but OffchainPublicKey does not: %v (config) vs %v (mine)",
onchainPublicKey, identity.OffchainPublicKey, offchainPublicKey)
}
if identity.PeerID != peerID {
return SharedConfig{}, 0, errors.Errorf(
"OnchainPublicKey %x in publicConfig matches "+
"mine, but PeerID does not: %v (config) vs %v (mine)",
onchainPublicKey, identity.PeerID, peerID)
}
if identity.TransmitAccount != transmitAccount {
return SharedConfig{}, 0, errors.Errorf(
"OnchainPublicKey %x in publicConfig matches "+
"mine, but TransmitAccount does not: %v (config) vs %v (mine)",
onchainPublicKey, identity.TransmitAccount, transmitAccount)
}
oracleID = commontypes.OracleID(i)
found = true
//if bytes.Equal(identity.OnchainPublicKey, onchainPublicKey) {
if !onchainKeyring.Equal(identity.OnchainPublicKey) {
continue
}
onchainPublicKey = identity.OnchainPublicKey
if identity.OffchainPublicKey != offchainPublicKey {
return SharedConfig{}, 0, errors.Errorf(
"OnchainPublicKey %x in publicConfig matches "+
"mine, but OffchainPublicKey does not: %v (config) vs %v (mine)",
onchainPublicKey, identity.OffchainPublicKey, offchainPublicKey)
}
if identity.PeerID != peerID {
return SharedConfig{}, 0, errors.Errorf(
"OnchainPublicKey %x in publicConfig matches "+
"mine, but PeerID does not: %v (config) vs %v (mine)",
onchainPublicKey, identity.PeerID, peerID)
}
if identity.TransmitAccount != transmitAccount {
return SharedConfig{}, 0, errors.Errorf(
"OnchainPublicKey %x in publicConfig matches "+
"mine, but TransmitAccount does not: %v (config) vs %v (mine)",
onchainPublicKey, identity.TransmitAccount, transmitAccount)
}
oracleID = commontypes.OracleID(i)
found = true
}

if !found {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package managed

import (
"bytes"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -75,7 +76,7 @@ func RunManagedMercuryOracle(
skipResourceExhaustionChecks,
contractConfig,
offchainKeyring,
ocr3OnchainKeyring,
&shimComparableKeyRing{ocr3OnchainKeyring},
netEndpointFactory.PeerID(),
fromAccount,
)
Expand Down Expand Up @@ -251,3 +252,11 @@ func validateMercuryPluginLimits(limits ocr3types.MercuryPluginLimits) error {
}
return err
}

type shimComparableKeyRing struct {
ocr3types.OnchainKeyring[mercuryshim.MercuryReportInfo]
}

func (s *shimComparableKeyRing) Equal(other types.OnchainPublicKey) bool {
return bytes.Equal(other, s.PublicKey())
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func RunManagedOCR3Oracle[RI any](
netEndpointFactory types.BinaryNetworkEndpointFactory,
offchainConfigDigester types.OffchainConfigDigester,
offchainKeyring types.OffchainKeyring,
onchainKeyring ocr3types.OnchainKeyring[RI],
onchainKeyring ocr3types.ComparableOnchainKeyring[RI],
reportingPluginFactory ocr3types.ReportingPluginFactory[RI],
) {
subs := subprocesses.Subprocesses{}
Expand Down
4 changes: 2 additions & 2 deletions offchainreporting2plus/internal/ocr3/protocol/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func RunOracle[RI any](
metricsRegisterer prometheus.Registerer,
netEndpoint NetworkEndpoint[RI],
offchainKeyring types.OffchainKeyring,
onchainKeyring ocr3types.OnchainKeyring[RI],
onchainKeyring ocr3types.SignVerifier[RI], //ocr3types.OnchainKeyring[RI],
reportingPlugin ocr3types.ReportingPlugin[RI],
telemetrySender TelemetrySender,
) {
Expand Down Expand Up @@ -66,7 +66,7 @@ type oracleState[RI any] struct {
metricsRegisterer prometheus.Registerer
netEndpoint NetworkEndpoint[RI]
offchainKeyring types.OffchainKeyring
onchainKeyring ocr3types.OnchainKeyring[RI]
onchainKeyring ocr3types.SignVerifier[RI] //ocr3types.OnchainKeyring[RI]
reportingPlugin ocr3types.ReportingPlugin[RI]
telemetrySender TelemetrySender

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ func RunReportAttestation[RI any](
contractTransmitter ocr3types.ContractTransmitter[RI],
logger loghelper.LoggerWithContext,
netSender NetworkSender[RI],
onchainKeyring ocr3types.OnchainKeyring[RI],
onchainSignVerifier ocr3types.SignVerifier[RI],
reportingPlugin ocr3types.ReportingPlugin[RI],
) {
sched := scheduler.NewScheduler[EventMissingOutcome[RI]]()
defer sched.Close()

newReportAttestationState(ctx, chNetToReportAttestation,
chOutcomeGenerationToReportAttestation, chReportAttestationToTransmission,
config, contractTransmitter, logger, netSender, onchainKeyring, reportingPlugin, sched).run()
config, contractTransmitter, logger, netSender, onchainSignVerifier, reportingPlugin, sched).run()
}

const expiryMinRounds int = 10
Expand All @@ -59,7 +59,7 @@ type reportAttestationState[RI any] struct {
contractTransmitter ocr3types.ContractTransmitter[RI]
logger loghelper.LoggerWithContext
netSender NetworkSender[RI]
onchainKeyring ocr3types.OnchainKeyring[RI]
onchainSignVerifier ocr3types.SignVerifier[RI]
reportingPlugin ocr3types.ReportingPlugin[RI]

scheduler *scheduler.Scheduler[EventMissingOutcome[RI]]
Expand Down Expand Up @@ -436,7 +436,7 @@ func (repatt *reportAttestationState[RI]) verifySignatures(publicKey types.Oncha
return
}

if !repatt.onchainKeyring.Verify(publicKey, repatt.config.ConfigDigest, seqNr, reportsPlus[i].ReportWithInfo, signatures[i]) {
if !repatt.onchainSignVerifier.Verify(publicKey, repatt.config.ConfigDigest, seqNr, reportsPlus[i].ReportWithInfo, signatures[i]) {
mutex.Lock()
allValid = false
mutex.Unlock()
Expand Down Expand Up @@ -526,7 +526,7 @@ func (repatt *reportAttestationState[RI]) eventComputedReports(ev EventComputedR

var sigs [][]byte
for i, reportPlus := range ev.ReportsPlus {
sig, err := repatt.onchainKeyring.Sign(repatt.config.ConfigDigest, ev.SeqNr, reportPlus.ReportWithInfo)
sig, err := repatt.onchainSignVerifier.Sign(repatt.config.ConfigDigest, ev.SeqNr, reportPlus.ReportWithInfo)
if err != nil {
repatt.logger.Error("error while signing report", commontypes.LogFields{
"seqNr": ev.SeqNr,
Expand Down Expand Up @@ -627,7 +627,7 @@ func newReportAttestationState[RI any](
contractTransmitter ocr3types.ContractTransmitter[RI],
logger loghelper.LoggerWithContext,
netSender NetworkSender[RI],
onchainKeyring ocr3types.OnchainKeyring[RI],
onchainSignVerifier ocr3types.SignVerifier[RI],
reportingPlugin ocr3types.ReportingPlugin[RI],
sched *scheduler.Scheduler[EventMissingOutcome[RI]],
) *reportAttestationState[RI] {
Expand All @@ -642,7 +642,7 @@ func newReportAttestationState[RI any](
contractTransmitter,
logger.MakeUpdated(commontypes.LogFields{"proto": "repatt"}),
netSender,
onchainKeyring,
onchainSignVerifier,
reportingPlugin,

sched,
Expand Down
29 changes: 29 additions & 0 deletions offchainreporting2plus/ocr3types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,32 @@ type OnchainKeyring[RI any] interface {
// Maximum length of a signature
MaxSignatureLength() int
}

type ComparableOnchainKeyring[RI any] interface {
SignVerifier[RI]
// Equal returns true if the public keys are equal
// Implementations that wrap a single public key are encouraged to use
// [OnchainPublicKey].Equal directly.
// Implementations that wrap multiple public keys should implement
// return true if any of their wrapped public keys are equal to the argument
Equal(types.OnchainPublicKey) bool

// Maximum length of a signature
MaxSignatureLength() int
}

// SignVerifier provides cryptographic signatures that need to be verifiable
// on the targeted blockchain. The underlying cryptographic primitives may be
// different on each chain; for example, on Ethereum one would use ECDSA over
// secp256k1 and Keccak256, whereas on Solana one would use Ed25519 and SHA256.
type SignVerifier[RI any] interface {
// Sign returns a signature over Report.
Sign(types.ConfigDigest, uint64, ReportWithInfo[RI]) (signature []byte, err error)

// Verify verifies a signature over ReportContext and Report allegedly
// created from OnchainPublicKey.
//
// Implementations of this function must gracefully handle malformed or
// adversarially crafted inputs.
Verify(_ types.OnchainPublicKey, _ types.ConfigDigest, seqNr uint64, _ ReportWithInfo[RI], signature []byte) bool
}
41 changes: 40 additions & 1 deletion offchainreporting2plus/oracle.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package offchainreporting2plus

import (
"bytes"
"context"
"fmt"
"sync"
Expand All @@ -18,6 +19,7 @@ type OracleArgs interface {
oracleArgsMarker()
localConfig() types.LocalConfig
runManaged(ctx context.Context)
validate() error
}

// OCR2OracleArgs contains the configuration and services a caller must provide, in
Expand Down Expand Up @@ -69,6 +71,8 @@ type OCR2OracleArgs struct {
ReportingPluginFactory types.ReportingPluginFactory
}

func (OCR2OracleArgs) validate() error { return nil } // No validation needed for OCR2

func (OCR2OracleArgs) oracleArgsMarker() {}

func (args OCR2OracleArgs) localConfig() types.LocalConfig { return args.LocalConfig }
Expand Down Expand Up @@ -144,6 +148,8 @@ type MercuryOracleArgs struct {
MercuryPluginFactory ocr3types.MercuryPluginFactory
}

func (MercuryOracleArgs) validate() error { return nil } // No validation needed for Mercury

func (MercuryOracleArgs) oracleArgsMarker() {}

func (args MercuryOracleArgs) localConfig() types.LocalConfig { return args.LocalConfig }
Expand Down Expand Up @@ -210,8 +216,16 @@ type OCR3OracleArgs[RI any] struct {

// OnchainKeyring is used to sign reports that can be validated
// offchain and by the target contract.
//
// it is an error if both OnchainKeyring and ComparableOnchainKeyring are set
OnchainKeyring ocr3types.OnchainKeyring[RI]

// ComparableOnchainKeyring is used to verify that the onchain keyring
// It enable custom equality check when comparing onchain keys fetch from the contract
//
// it is an error if both OnchainKeyring and ComparableOnchainKeyring are set
ComparableOnchainKeyring ocr3types.ComparableOnchainKeyring[RI]

// PluginFactory creates Plugins that determine the "application logic" used
// in a protocol instance.
ReportingPluginFactory ocr3types.ReportingPluginFactory[RI]
Expand All @@ -221,6 +235,20 @@ func (OCR3OracleArgs[RI]) oracleArgsMarker() {}

func (args OCR3OracleArgs[RI]) localConfig() types.LocalConfig { return args.LocalConfig }

func (args OCR3OracleArgs[RI]) validate() error {
if args.OnchainKeyring != nil && args.ComparableOnchainKeyring != nil {
return fmt.Errorf("cannot set both OnchainKeyring and ComparableOnchainKeyring")
}
return nil
}

func (args OCR3OracleArgs[RI]) onchainKeyRing() ocr3types.ComparableOnchainKeyring[RI] {
if args.ComparableOnchainKeyring != nil {
return args.ComparableOnchainKeyring
}
return &shimComparableKeyRing[RI]{args.OnchainKeyring}
}

func (args OCR3OracleArgs[RI]) runManaged(ctx context.Context) {
logger := loghelper.MakeRootLoggerWithContext(args.Logger)

Expand All @@ -238,11 +266,19 @@ func (args OCR3OracleArgs[RI]) runManaged(ctx context.Context) {
args.BinaryNetworkEndpointFactory,
args.OffchainConfigDigester,
args.OffchainKeyring,
args.OnchainKeyring,
args.onchainKeyRing(),
args.ReportingPluginFactory,
)
}

type shimComparableKeyRing[RI any] struct {
ocr3types.OnchainKeyring[RI]
}

func (k *shimComparableKeyRing[RI]) Equal(onchainPublicKey types.OnchainPublicKey) bool {
return bytes.Equal(k.PublicKey(), onchainPublicKey)
}

type oracleState int

const (
Expand Down Expand Up @@ -276,6 +312,9 @@ func NewOracle(args OracleArgs) (Oracle, error) {
if err := SanityCheckLocalConfig(args.localConfig()); err != nil {
return nil, fmt.Errorf("bad local config while creating new oracle: %w", err)
}
if err := args.validate(); err != nil {
return nil, fmt.Errorf("bad oracle args while creating new oracle: %w", err)
}
return &oracle{
sync.Mutex{},
oracleStateUnstarted,
Expand Down
4 changes: 4 additions & 0 deletions offchainreporting2plus/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,10 @@ type OffchainPublicKey [ed25519.PublicKeySize]byte
// oracle to the on-chain smart contract.
type OnchainPublicKey []byte

func (opk OnchainPublicKey) Equal(other OnchainPublicKey) bool {
return bytes.Equal(opk, other)
}

// ConfigEncryptionPublicKey is the public key used to receive an encrypted
// version of the secret shared amongst all oracles on a common contract.
type ConfigEncryptionPublicKey [curve25519.PointSize]byte // X25519
Expand Down