diff --git a/offchainreporting2plus/internal/config/ocr3config/shared_config.go b/offchainreporting2plus/internal/config/ocr3config/shared_config.go index 283424f..6e4353f 100644 --- a/offchainreporting2plus/internal/config/ocr3config/shared_config.go +++ b/offchainreporting2plus/internal/config/ocr3config/shared_config.go @@ -1,7 +1,6 @@ package ocr3config import ( - "bytes" "crypto/hmac" "crypto/rand" "crypto/sha256" @@ -48,7 +47,7 @@ func SharedConfigFromContractConfig[RI any]( skipResourceExhaustionChecks bool, change types.ContractConfig, offchainKeyring types.OffchainKeyring, - onchainKeyring ocr3types.OnchainKeyring[RI], + onchainKeyring ocr3types.ComparableOnchainKeyring[RI], peerID string, transmitAccount types.Account, ) (SharedConfig, commontypes.OracleID, error) { @@ -59,32 +58,35 @@ func SharedConfigFromContractConfig[RI any]( oracleID := commontypes.OracleID(math.MaxUint8) { - onchainPublicKey := onchainKeyring.PublicKey() + var onchainPublicKey types.OnchainPublicKey //:= onchainKeyring.PublicKey() offchainPublicKey := offchainKeyring.OffchainPublicKey() var found bool for i, identity := range publicConfig.OracleIdentities { - if bytes.Equal(identity.OnchainPublicKey, onchainPublicKey) { - if identity.OffchainPublicKey != offchainPublicKey { - return SharedConfig{}, 0, errors.Errorf( - "OnchainPublicKey %x in publicConfig matches "+ - "mine, but OffchainPublicKey does not: %v (config) vs %v (mine)", - onchainPublicKey, identity.OffchainPublicKey, offchainPublicKey) - } - if identity.PeerID != peerID { - return SharedConfig{}, 0, errors.Errorf( - "OnchainPublicKey %x in publicConfig matches "+ - "mine, but PeerID does not: %v (config) vs %v (mine)", - onchainPublicKey, identity.PeerID, peerID) - } - if identity.TransmitAccount != transmitAccount { - return SharedConfig{}, 0, errors.Errorf( - "OnchainPublicKey %x in publicConfig matches "+ - "mine, but TransmitAccount does not: %v (config) vs %v (mine)", - onchainPublicKey, identity.TransmitAccount, transmitAccount) - } - oracleID = commontypes.OracleID(i) - found = true + //if bytes.Equal(identity.OnchainPublicKey, onchainPublicKey) { + if !onchainKeyring.Equal(identity.OnchainPublicKey) { + continue } + onchainPublicKey = identity.OnchainPublicKey + if identity.OffchainPublicKey != offchainPublicKey { + return SharedConfig{}, 0, errors.Errorf( + "OnchainPublicKey %x in publicConfig matches "+ + "mine, but OffchainPublicKey does not: %v (config) vs %v (mine)", + onchainPublicKey, identity.OffchainPublicKey, offchainPublicKey) + } + if identity.PeerID != peerID { + return SharedConfig{}, 0, errors.Errorf( + "OnchainPublicKey %x in publicConfig matches "+ + "mine, but PeerID does not: %v (config) vs %v (mine)", + onchainPublicKey, identity.PeerID, peerID) + } + if identity.TransmitAccount != transmitAccount { + return SharedConfig{}, 0, errors.Errorf( + "OnchainPublicKey %x in publicConfig matches "+ + "mine, but TransmitAccount does not: %v (config) vs %v (mine)", + onchainPublicKey, identity.TransmitAccount, transmitAccount) + } + oracleID = commontypes.OracleID(i) + found = true } if !found { diff --git a/offchainreporting2plus/internal/managed/managed_mercury_oracle.go b/offchainreporting2plus/internal/managed/managed_mercury_oracle.go index 6949705..8a67c52 100644 --- a/offchainreporting2plus/internal/managed/managed_mercury_oracle.go +++ b/offchainreporting2plus/internal/managed/managed_mercury_oracle.go @@ -1,6 +1,7 @@ package managed import ( + "bytes" "context" "errors" "fmt" @@ -75,7 +76,7 @@ func RunManagedMercuryOracle( skipResourceExhaustionChecks, contractConfig, offchainKeyring, - ocr3OnchainKeyring, + &shimComparableKeyRing{ocr3OnchainKeyring}, netEndpointFactory.PeerID(), fromAccount, ) @@ -251,3 +252,11 @@ func validateMercuryPluginLimits(limits ocr3types.MercuryPluginLimits) error { } return err } + +type shimComparableKeyRing struct { + ocr3types.OnchainKeyring[mercuryshim.MercuryReportInfo] +} + +func (s *shimComparableKeyRing) Equal(other types.OnchainPublicKey) bool { + return bytes.Equal(other, s.PublicKey()) +} diff --git a/offchainreporting2plus/internal/managed/managed_ocr3_oracle.go b/offchainreporting2plus/internal/managed/managed_ocr3_oracle.go index 366c18c..9a0e7ee 100644 --- a/offchainreporting2plus/internal/managed/managed_ocr3_oracle.go +++ b/offchainreporting2plus/internal/managed/managed_ocr3_oracle.go @@ -38,7 +38,7 @@ func RunManagedOCR3Oracle[RI any]( netEndpointFactory types.BinaryNetworkEndpointFactory, offchainConfigDigester types.OffchainConfigDigester, offchainKeyring types.OffchainKeyring, - onchainKeyring ocr3types.OnchainKeyring[RI], + onchainKeyring ocr3types.ComparableOnchainKeyring[RI], reportingPluginFactory ocr3types.ReportingPluginFactory[RI], ) { subs := subprocesses.Subprocesses{} diff --git a/offchainreporting2plus/internal/ocr3/protocol/oracle.go b/offchainreporting2plus/internal/ocr3/protocol/oracle.go index 5a19e80..b437681 100644 --- a/offchainreporting2plus/internal/ocr3/protocol/oracle.go +++ b/offchainreporting2plus/internal/ocr3/protocol/oracle.go @@ -31,7 +31,7 @@ func RunOracle[RI any]( metricsRegisterer prometheus.Registerer, netEndpoint NetworkEndpoint[RI], offchainKeyring types.OffchainKeyring, - onchainKeyring ocr3types.OnchainKeyring[RI], + onchainKeyring ocr3types.SignVerifier[RI], //ocr3types.OnchainKeyring[RI], reportingPlugin ocr3types.ReportingPlugin[RI], telemetrySender TelemetrySender, ) { @@ -66,7 +66,7 @@ type oracleState[RI any] struct { metricsRegisterer prometheus.Registerer netEndpoint NetworkEndpoint[RI] offchainKeyring types.OffchainKeyring - onchainKeyring ocr3types.OnchainKeyring[RI] + onchainKeyring ocr3types.SignVerifier[RI] //ocr3types.OnchainKeyring[RI] reportingPlugin ocr3types.ReportingPlugin[RI] telemetrySender TelemetrySender diff --git a/offchainreporting2plus/internal/ocr3/protocol/report_attestation.go b/offchainreporting2plus/internal/ocr3/protocol/report_attestation.go index f0a760c..e75b7e7 100644 --- a/offchainreporting2plus/internal/ocr3/protocol/report_attestation.go +++ b/offchainreporting2plus/internal/ocr3/protocol/report_attestation.go @@ -29,7 +29,7 @@ func RunReportAttestation[RI any]( contractTransmitter ocr3types.ContractTransmitter[RI], logger loghelper.LoggerWithContext, netSender NetworkSender[RI], - onchainKeyring ocr3types.OnchainKeyring[RI], + onchainSignVerifier ocr3types.SignVerifier[RI], reportingPlugin ocr3types.ReportingPlugin[RI], ) { sched := scheduler.NewScheduler[EventMissingOutcome[RI]]() @@ -37,7 +37,7 @@ func RunReportAttestation[RI any]( newReportAttestationState(ctx, chNetToReportAttestation, chOutcomeGenerationToReportAttestation, chReportAttestationToTransmission, - config, contractTransmitter, logger, netSender, onchainKeyring, reportingPlugin, sched).run() + config, contractTransmitter, logger, netSender, onchainSignVerifier, reportingPlugin, sched).run() } const expiryMinRounds int = 10 @@ -59,7 +59,7 @@ type reportAttestationState[RI any] struct { contractTransmitter ocr3types.ContractTransmitter[RI] logger loghelper.LoggerWithContext netSender NetworkSender[RI] - onchainKeyring ocr3types.OnchainKeyring[RI] + onchainSignVerifier ocr3types.SignVerifier[RI] reportingPlugin ocr3types.ReportingPlugin[RI] scheduler *scheduler.Scheduler[EventMissingOutcome[RI]] @@ -436,7 +436,7 @@ func (repatt *reportAttestationState[RI]) verifySignatures(publicKey types.Oncha return } - if !repatt.onchainKeyring.Verify(publicKey, repatt.config.ConfigDigest, seqNr, reportsPlus[i].ReportWithInfo, signatures[i]) { + if !repatt.onchainSignVerifier.Verify(publicKey, repatt.config.ConfigDigest, seqNr, reportsPlus[i].ReportWithInfo, signatures[i]) { mutex.Lock() allValid = false mutex.Unlock() @@ -526,7 +526,7 @@ func (repatt *reportAttestationState[RI]) eventComputedReports(ev EventComputedR var sigs [][]byte for i, reportPlus := range ev.ReportsPlus { - sig, err := repatt.onchainKeyring.Sign(repatt.config.ConfigDigest, ev.SeqNr, reportPlus.ReportWithInfo) + sig, err := repatt.onchainSignVerifier.Sign(repatt.config.ConfigDigest, ev.SeqNr, reportPlus.ReportWithInfo) if err != nil { repatt.logger.Error("error while signing report", commontypes.LogFields{ "seqNr": ev.SeqNr, @@ -627,7 +627,7 @@ func newReportAttestationState[RI any]( contractTransmitter ocr3types.ContractTransmitter[RI], logger loghelper.LoggerWithContext, netSender NetworkSender[RI], - onchainKeyring ocr3types.OnchainKeyring[RI], + onchainSignVerifier ocr3types.SignVerifier[RI], reportingPlugin ocr3types.ReportingPlugin[RI], sched *scheduler.Scheduler[EventMissingOutcome[RI]], ) *reportAttestationState[RI] { @@ -642,7 +642,7 @@ func newReportAttestationState[RI any]( contractTransmitter, logger.MakeUpdated(commontypes.LogFields{"proto": "repatt"}), netSender, - onchainKeyring, + onchainSignVerifier, reportingPlugin, sched, diff --git a/offchainreporting2plus/ocr3types/types.go b/offchainreporting2plus/ocr3types/types.go index b9de200..8ac2e12 100644 --- a/offchainreporting2plus/ocr3types/types.go +++ b/offchainreporting2plus/ocr3types/types.go @@ -53,3 +53,32 @@ type OnchainKeyring[RI any] interface { // Maximum length of a signature MaxSignatureLength() int } + +type ComparableOnchainKeyring[RI any] interface { + SignVerifier[RI] + // Equal returns true if the public keys are equal + // Implementations that wrap a single public key are encouraged to use + // [OnchainPublicKey].Equal directly. + // Implementations that wrap multiple public keys should implement + // return true if any of their wrapped public keys are equal to the argument + Equal(types.OnchainPublicKey) bool + + // Maximum length of a signature + MaxSignatureLength() int +} + +// SignVerifier provides cryptographic signatures that need to be verifiable +// on the targeted blockchain. The underlying cryptographic primitives may be +// different on each chain; for example, on Ethereum one would use ECDSA over +// secp256k1 and Keccak256, whereas on Solana one would use Ed25519 and SHA256. +type SignVerifier[RI any] interface { + // Sign returns a signature over Report. + Sign(types.ConfigDigest, uint64, ReportWithInfo[RI]) (signature []byte, err error) + + // Verify verifies a signature over ReportContext and Report allegedly + // created from OnchainPublicKey. + // + // Implementations of this function must gracefully handle malformed or + // adversarially crafted inputs. + Verify(_ types.OnchainPublicKey, _ types.ConfigDigest, seqNr uint64, _ ReportWithInfo[RI], signature []byte) bool +} diff --git a/offchainreporting2plus/oracle.go b/offchainreporting2plus/oracle.go index 92f6a33..b9dfa0d 100644 --- a/offchainreporting2plus/oracle.go +++ b/offchainreporting2plus/oracle.go @@ -1,6 +1,7 @@ package offchainreporting2plus import ( + "bytes" "context" "fmt" "sync" @@ -18,6 +19,7 @@ type OracleArgs interface { oracleArgsMarker() localConfig() types.LocalConfig runManaged(ctx context.Context) + validate() error } // OCR2OracleArgs contains the configuration and services a caller must provide, in @@ -69,6 +71,8 @@ type OCR2OracleArgs struct { ReportingPluginFactory types.ReportingPluginFactory } +func (OCR2OracleArgs) validate() error { return nil } // No validation needed for OCR2 + func (OCR2OracleArgs) oracleArgsMarker() {} func (args OCR2OracleArgs) localConfig() types.LocalConfig { return args.LocalConfig } @@ -144,6 +148,8 @@ type MercuryOracleArgs struct { MercuryPluginFactory ocr3types.MercuryPluginFactory } +func (MercuryOracleArgs) validate() error { return nil } // No validation needed for Mercury + func (MercuryOracleArgs) oracleArgsMarker() {} func (args MercuryOracleArgs) localConfig() types.LocalConfig { return args.LocalConfig } @@ -210,8 +216,16 @@ type OCR3OracleArgs[RI any] struct { // OnchainKeyring is used to sign reports that can be validated // offchain and by the target contract. + // + // it is an error if both OnchainKeyring and ComparableOnchainKeyring are set OnchainKeyring ocr3types.OnchainKeyring[RI] + // ComparableOnchainKeyring is used to verify that the onchain keyring + // It enable custom equality check when comparing onchain keys fetch from the contract + // + // it is an error if both OnchainKeyring and ComparableOnchainKeyring are set + ComparableOnchainKeyring ocr3types.ComparableOnchainKeyring[RI] + // PluginFactory creates Plugins that determine the "application logic" used // in a protocol instance. ReportingPluginFactory ocr3types.ReportingPluginFactory[RI] @@ -221,6 +235,20 @@ func (OCR3OracleArgs[RI]) oracleArgsMarker() {} func (args OCR3OracleArgs[RI]) localConfig() types.LocalConfig { return args.LocalConfig } +func (args OCR3OracleArgs[RI]) validate() error { + if args.OnchainKeyring != nil && args.ComparableOnchainKeyring != nil { + return fmt.Errorf("cannot set both OnchainKeyring and ComparableOnchainKeyring") + } + return nil +} + +func (args OCR3OracleArgs[RI]) onchainKeyRing() ocr3types.ComparableOnchainKeyring[RI] { + if args.ComparableOnchainKeyring != nil { + return args.ComparableOnchainKeyring + } + return &shimComparableKeyRing[RI]{args.OnchainKeyring} +} + func (args OCR3OracleArgs[RI]) runManaged(ctx context.Context) { logger := loghelper.MakeRootLoggerWithContext(args.Logger) @@ -238,11 +266,19 @@ func (args OCR3OracleArgs[RI]) runManaged(ctx context.Context) { args.BinaryNetworkEndpointFactory, args.OffchainConfigDigester, args.OffchainKeyring, - args.OnchainKeyring, + args.onchainKeyRing(), args.ReportingPluginFactory, ) } +type shimComparableKeyRing[RI any] struct { + ocr3types.OnchainKeyring[RI] +} + +func (k *shimComparableKeyRing[RI]) Equal(onchainPublicKey types.OnchainPublicKey) bool { + return bytes.Equal(k.PublicKey(), onchainPublicKey) +} + type oracleState int const ( @@ -276,6 +312,9 @@ func NewOracle(args OracleArgs) (Oracle, error) { if err := SanityCheckLocalConfig(args.localConfig()); err != nil { return nil, fmt.Errorf("bad local config while creating new oracle: %w", err) } + if err := args.validate(); err != nil { + return nil, fmt.Errorf("bad oracle args while creating new oracle: %w", err) + } return &oracle{ sync.Mutex{}, oracleStateUnstarted, diff --git a/offchainreporting2plus/types/types.go b/offchainreporting2plus/types/types.go index 96b061e..28033f5 100644 --- a/offchainreporting2plus/types/types.go +++ b/offchainreporting2plus/types/types.go @@ -378,6 +378,10 @@ type OffchainPublicKey [ed25519.PublicKeySize]byte // oracle to the on-chain smart contract. type OnchainPublicKey []byte +func (opk OnchainPublicKey) Equal(other OnchainPublicKey) bool { + return bytes.Equal(opk, other) +} + // ConfigEncryptionPublicKey is the public key used to receive an encrypted // version of the secret shared amongst all oracles on a common contract. type ConfigEncryptionPublicKey [curve25519.PointSize]byte // X25519