Skip to content

Commit

Permalink
refactor(db): some changes, add InMemoryDb
Browse files Browse the repository at this point in the history
  • Loading branch information
mempirate committed Jan 9, 2025
1 parent c85c435 commit 76bcf19
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 48 deletions.
61 changes: 61 additions & 0 deletions src/db/memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use std::sync::{Arc, RwLock};

use alloy::primitives::Address;
use tracing::info;

use super::{BlsPublicKey, DbResult, Operator, Registration, RegistryDb};

#[derive(Debug, Clone)]
pub(crate) struct InMemoryDb {
validator_registrations: Arc<RwLock<Vec<Registration>>>,
operator_registrations: Arc<RwLock<Vec<Operator>>>,
}

#[async_trait::async_trait]
impl RegistryDb for InMemoryDb {
async fn register_validators(&self, registration: Registration) -> DbResult<()> {
info!(
keys_count = registration.validator_pubkeys.len(),
sig_count = registration.signatures.len(),
digest = ?registration.digest(),
"NoOpDb: register_validators"
);

let mut registrations = self.validator_registrations.write().unwrap();
registrations.push(registration);

Ok(())
}

async fn register_operator(&self, operator: Operator) -> DbResult<()> {
info!(signer = %operator.signer, "NoOpDb: register_operator");

let mut operators = self.operator_registrations.write().unwrap();
operators.push(operator);

Ok(())
}

async fn get_operator(&self, signer: Address) -> DbResult<Option<Operator>> {
let operators = self.operator_registrations.read().unwrap();
let operator = operators.iter().find(|op| op.signer == signer);

match operator {
Some(op) => Ok(Some(op.clone())),
None => Ok(None),
}
}

async fn get_validator_registration(
&self,
pubkey: BlsPublicKey,
) -> DbResult<Option<Registration>> {
let registrations = self.validator_registrations.read().unwrap();
let registration = registrations.iter().find(|reg| reg.validator_pubkeys.contains(&pubkey));

match registration {
Some(reg) => Ok(Some(reg.clone())),
None => Ok(None),
}
}
}
11 changes: 7 additions & 4 deletions src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use crate::primitives::{
mod types;

/// No-op database implementation.
mod noop;
pub(crate) use noop::NoOpDb;
mod memory;
pub(crate) use memory::InMemoryDb;

/// SQL database backend implementation.
mod sql;
Expand Down Expand Up @@ -50,8 +50,11 @@ pub(crate) trait RegistryDb: Clone {
async fn register_operator(&self, operator: Operator) -> DbResult<()>;

/// Get an operator from the database.
async fn get_operator(&self, signer: Address) -> DbResult<Operator>;
async fn get_operator(&self, signer: Address) -> DbResult<Option<Operator>>;

/// Get a validator registration from the database.
async fn get_validator_registration(&self, pubkey: BlsPublicKey) -> DbResult<Registration>;
async fn get_validator_registration(
&self,
pubkey: BlsPublicKey,
) -> DbResult<Option<Registration>>;
}
36 changes: 0 additions & 36 deletions src/db/noop.rs

This file was deleted.

27 changes: 19 additions & 8 deletions src/db/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,33 +97,44 @@ impl RegistryDb for SQLDb<Postgres> {
Ok(())
}

async fn get_operator(&self, signer: Address) -> DbResult<Operator> {
let row: OperatorRow = sqlx::query_as(
async fn get_operator(&self, signer: Address) -> DbResult<Option<Operator>> {
let row: Option<OperatorRow> = sqlx::query_as(
"
SELECT signer, rpc, protocol, source, collateral_tokens, collateral_amounts, last_update
FROM operators
WHERE signer = $1
",
)
.bind(signer.to_vec())
.fetch_one(&self.conn)
.fetch_optional(&self.conn)
.await?;

row.try_into()
let Some(row) = row else {
return Ok(None);
};

Ok(Some(row.try_into()?))
}

async fn get_validator_registration(&self, pubkey: BlsPublicKey) -> DbResult<Registration> {
let row: ValidatorRegistrationRow = sqlx::query_as(
async fn get_validator_registration(
&self,
pubkey: BlsPublicKey,
) -> DbResult<Option<Registration>> {
let row: Option<ValidatorRegistrationRow> = sqlx::query_as(
"
SELECT pubkey, signature, expiry, gas_limit, operator, priority, source, last_update
FROM validator_registrations
WHERE pubkey = $1
",
)
.bind(pubkey.serialize().to_vec())
.fetch_one(&self.conn)
.fetch_optional(&self.conn)
.await?;

row.try_into()
let Some(row) = row else {
return Ok(None);
};

Ok(Some(row.try_into()?))
}
}

0 comments on commit 76bcf19

Please sign in to comment.