Skip to content

Commit

Permalink
fix: dangling open transaction for cdc-only
Browse files Browse the repository at this point in the history
feat: add nats connector (for cdc events)

For the first part, fix START_REPLICATION cannot run inside a
transaction block for CDC only, which is caused by
PostgresSource::new opening a transaction, but never closing it
for CDC only. This can be remedied by committing that transaction
and offering an explicit start_transaction function to be called
in cases of not CDC only.

For the 2nd part, introduce a NATS connector (WIP,
this should probably create or get a stream in Jetstream
for persistence and needs some configuration options)
so we can propagate out CDC events onto a message broker

Refs: supabase#80
  • Loading branch information
cawfeecoder committed Jan 1, 2025
1 parent 1f7c0f7 commit 57380ea
Show file tree
Hide file tree
Showing 10 changed files with 307 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ utoipa = { version = "4.2.3", default-features = false }
utoipa-swagger-ui = { version = "7.1.0", default-features = false }
uuid = { version = "1.10.0", default-features = false }
deltalake = {version="0.22.0",default-features = false}
sha256 = {version="1.5.0",default-features = false}


# [patch."https://github.com/imor/gcp-bigquery-client"]
Expand Down
4 changes: 4 additions & 0 deletions pg_replicate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ tokio-postgres = { workspace = true, features = [
] }
tracing = { workspace = true, default-features = true }
uuid = { workspace = true, features = ["v4"] }
async-nats = { version = "0.38.0", optional = true }
anyhow = { workspace = true }
sha256 = { workspace = true }

[dev-dependencies]
clap = { workspace = true, default-features = true, features = [
Expand All @@ -66,6 +69,7 @@ tracing-subscriber = { workspace = true, default-features = true, features = [
bigquery = ["dep:gcp-bigquery-client", "dep:prost"]
duckdb = ["dep:duckdb"]
stdout = []
nats = ["dep:async-nats"]
delta = ["dep:deltalake"]
# When enabled converts unknown types to bytes
unknown_types_to_bytes = []
Expand Down
2 changes: 2 additions & 0 deletions pg_replicate/src/clients/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ pub mod bigquery;
pub mod delta;
#[cfg(feature = "duckdb")]
pub mod duckdb;
#[cfg(feature = "nats")]
pub mod nats;
pub mod postgres;
128 changes: 128 additions & 0 deletions pg_replicate/src/clients/nats.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use std::{collections::HashMap, time::Duration};

use anyhow::anyhow;
use async_nats::{header::NATS_MESSAGE_ID, jetstream, ConnectError, ConnectOptions, HeaderMap};
use async_trait::async_trait;
use bytes::{Buf, BufMut, BytesMut};
use serde_json::json;
use sha256::digest;
use tokio_postgres::types::PgLsn;
use tracing::{error, info, warn};

use crate::{
conversions::{table_row::TableRow, Cell},
table::TableSchema,
};

#[async_trait]
pub trait MessageMapper {
fn map(
&self,
table_id: u32,
row: TableRow,
schema: &TableSchema,
) -> Result<serde_json::Value, serde_json::Error>;
}

pub struct NatsClient<M: MessageMapper + Send + Sync> {
conn: jetstream::Context,
message_mapper: M,
}

impl<M: MessageMapper + Send + Sync> NatsClient<M> {
pub async fn new(address: String, message_mapper: M) -> Result<NatsClient<M>, ConnectError> {
let client = async_nats::connect_with_options(
address,
ConnectOptions::new()
.no_echo()
.ping_interval(Duration::from_secs(5))
.connection_timeout(Duration::from_secs(5))
.event_callback(|e| async move {
match e {
async_nats::Event::Connected => info!("{e}"),
async_nats::Event::Disconnected => error!("{e}"),
async_nats::Event::ServerError(_) => error!("{e}"),
async_nats::Event::ClientError(_) => error!("{e}"),
_ => warn!("{e}"),
}
}),
)
.await?;
let jetstream = async_nats::jetstream::new(client);

return Ok(Self {
conn: jetstream,
message_mapper,
});
}

pub async fn bucket_exists(&self) -> bool {
let response = self.conn.get_key_value("postgres_cdc_lsn").await;
return response.is_ok();
}

pub async fn create_bucket(&self) -> Result<(), async_nats::Error> {
let _ = self
.conn
.create_key_value(jetstream::kv::Config {
bucket: "postgres_cdc_lsn".into(),
..Default::default()
})
.await?;
return Ok(());
}

pub async fn insert_last_lsn_row(&self) -> Result<(), async_nats::Error> {
let store = self.conn.get_key_value("postgres_cdc_lsn").await?;
let mut buf = BytesMut::with_capacity(8);
buf.put_u64(0);
store.put("last_lsn", buf.freeze()).await?;
Ok(())
}

pub async fn get_last_lsn(&self) -> Result<PgLsn, async_nats::Error> {
let store = self.conn.get_key_value("postgres_cdc_lsn").await?;
let response = store.get("last_lsn").await?;
if response.is_none() {
return Err(anyhow!("no data in the 'last_lsn' key/value").into());
}
let mut buf = BytesMut::with_capacity(8);
buf.put_slice(&response.unwrap());
let mut buf = buf.freeze();
let lsn = buf.get_u64();
Ok(lsn.into())
}

pub async fn set_last_lsn(&self, lsn: PgLsn) -> Result<(), async_nats::Error> {
let store = self.conn.get_key_value("postgres_cdc_lsn").await?;
let mut buf = BytesMut::with_capacity(8);
buf.put_u64(lsn.into());
store.put("last_lsn", buf.freeze()).await?;
Ok(())
}

pub async fn publish(
&self,
table_id: u32,
row: TableRow,
schema: &TableSchema,
) -> Result<(), async_nats::Error> {
let payload = self.message_mapper.map(table_id, row, schema)?;
let serialized: String = payload.to_string();

let mut headers: HeaderMap = HeaderMap::new();
let sha256 = digest(serialized.clone());

headers.insert(NATS_MESSAGE_ID, sha256.as_str());

let serialized: String = payload.to_string();

let topic = format!("postgres.table.{}", table_id);

self.conn
.publish_with_headers(topic, headers, serialized.into())
.await?;

Ok(())
}
}
5 changes: 5 additions & 0 deletions pg_replicate/src/pipeline/batching/data_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ impl<Src: Source, Snk: BatchSink> BatchDataPipeline<Src, Snk> {

async fn copy_table_schemas(&mut self) -> Result<(), PipelineError<Src::Error, Snk::Error>> {
let table_schemas = self.source.get_table_schemas();

let table_schemas = table_schemas.clone();

if !table_schemas.is_empty() {
Expand All @@ -54,6 +55,10 @@ impl<Src: Source, Snk: BatchSink> BatchDataPipeline<Src, Snk> {
copied_tables: &HashSet<TableId>,
) -> Result<(), PipelineError<Src::Error, Snk::Error>> {
let start = Instant::now();
self.source
.start_transaction()
.await
.map_err(PipelineError::Source)?;
let table_schemas = self.source.get_table_schemas();

let mut keys: Vec<u32> = table_schemas.keys().copied().collect();
Expand Down
2 changes: 2 additions & 0 deletions pg_replicate/src/pipeline/sinks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub mod bigquery;
pub mod delta;
#[cfg(feature = "duckdb")]
pub mod duckdb;
#[cfg(feature = "nats")]
pub mod nats;
#[cfg(feature = "stdout")]
pub mod stdout;

Expand Down
3 changes: 3 additions & 0 deletions pg_replicate/src/pipeline/sinks/nats/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub use sink::{NatsBatchSink, NatsSinkError};

mod sink;
151 changes: 151 additions & 0 deletions pg_replicate/src/pipeline/sinks/nats/sink.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use std::collections::{HashMap, HashSet};

use async_trait::async_trait;
use chrono::Utc;
use thiserror::Error;
use tokio_postgres::types::PgLsn;
use tracing::info;

use crate::{
clients::nats::{MessageMapper, NatsClient},
conversions::{cdc_event::CdcEvent, table_row::TableRow, Cell},
pipeline::{
sinks::{BatchSink, SinkError},
PipelineResumptionState,
},
table::{TableId, TableSchema},
};

#[derive(Debug, Error)]
pub enum NatsSinkError {
#[error("incorrect commit lsn: {0}(expected: {0})")]
IncorrectCommitLsn(PgLsn, PgLsn),

#[error("commit message without begin message")]
CommitWithoutBegin,

#[error("nats error: {0}")]
Nats(#[from] async_nats::Error),

#[error("missing table schemas")]
MissingTableSchemas,
}

pub struct NatsBatchSink<M: MessageMapper + Send + Sync> {
client: NatsClient<M>,
committed_lsn: Option<PgLsn>,
final_lsn: Option<PgLsn>,
table_schemas: HashMap<TableId, TableSchema>,
}

impl<M: MessageMapper + Send + Sync> NatsBatchSink<M> {
pub async fn new(
address: &str,
message_mapper: M,
) -> Result<NatsBatchSink<M>, async_nats::ConnectError> {
let client = NatsClient::new(address.to_string(), message_mapper).await?;
Ok(NatsBatchSink {
client,
committed_lsn: None,
final_lsn: None,
table_schemas: HashMap::new(),
})
}
}

impl SinkError for NatsSinkError {}

#[async_trait]
impl<M: MessageMapper + Send + Sync> BatchSink for NatsBatchSink<M> {
type Error = NatsSinkError;
async fn get_resumption_state(&mut self) -> Result<PipelineResumptionState, Self::Error> {
if !self.client.bucket_exists().await {
self.client.create_bucket().await?;
self.client.insert_last_lsn_row().await?;
} else {
info!("bucket already exists")
}

let last_lsn = self.client.get_last_lsn().await?;
self.committed_lsn = Some(last_lsn);

Ok(PipelineResumptionState {
copied_tables: HashSet::new(),
last_lsn,
})
}

async fn write_table_schemas(
&mut self,
table_schemas: HashMap<TableId, TableSchema>,
) -> Result<(), Self::Error> {
self.table_schemas = table_schemas;
Ok(())
}

async fn write_table_rows(
&mut self,
rows: Vec<TableRow>,
table_id: TableId,
) -> Result<(), Self::Error> {
Ok(())
}

async fn write_cdc_events(&mut self, events: Vec<CdcEvent>) -> Result<PgLsn, Self::Error> {
let mut rows_batch: HashMap<TableId, Vec<TableRow>> = HashMap::new();
let mut new_last_lsn = PgLsn::from(0);

for event in events {
match event {
CdcEvent::Begin(begin_body) => {
let final_lsn_u64 = begin_body.final_lsn();
self.final_lsn = Some(final_lsn_u64.into());
}
CdcEvent::Commit(commit_body) => {
let commit_lsn: PgLsn = commit_body.commit_lsn().into();
if let Some(final_lsn) = self.final_lsn {
if commit_lsn == final_lsn {
new_last_lsn = commit_lsn;
} else {
Err(NatsSinkError::IncorrectCommitLsn(commit_lsn, final_lsn))?
}
} else {
Err(NatsSinkError::CommitWithoutBegin)?
}
}
CdcEvent::Insert(insert) => {
let (table_id, table_row) = insert;
let schema = self
.table_schemas
.get(&table_id)
.ok_or(NatsSinkError::MissingTableSchemas)?;

self.client.publish(table_id, table_row, schema).await?;
}
CdcEvent::Update(_) => {}
CdcEvent::Delete(_) => {}
CdcEvent::Relation(_) => {}
CdcEvent::KeepAliveRequested { reply: _ } => {}
CdcEvent::Type(_) => {}
};
}

if new_last_lsn != PgLsn::from(0) {
self.client.set_last_lsn(new_last_lsn).await?;
self.committed_lsn = Some(new_last_lsn);
}

let committed_lsn = self.committed_lsn.expect("committed lsn is none");
Ok(committed_lsn)
}

async fn table_copied(&mut self, table_id: TableId) -> Result<(), Self::Error> {
info!("table {table_id} copied");
Ok(())
}

async fn truncate_table(&mut self, table_id: TableId) -> Result<(), Self::Error> {
info!("table {table_id} truncated");
Ok(())
}
}
2 changes: 2 additions & 0 deletions pg_replicate/src/pipeline/sources/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ pub trait Source {
column_schemas: &[ColumnSchema],
) -> Result<TableCopyStream, Self::Error>;

async fn start_transaction(&self) -> Result<(), Self::Error>;

async fn commit_transaction(&self) -> Result<(), Self::Error>;

async fn get_cdc_stream(&self, start_lsn: PgLsn) -> Result<CdcStream, Self::Error>;
Expand Down
9 changes: 9 additions & 0 deletions pg_replicate/src/pipeline/sources/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ impl PostgresSource {
let (table_names, publication) =
Self::get_table_names_and_publication(&replication_client, table_names_from).await?;
let table_schemas = replication_client.get_table_schemas(&table_names).await?;
replication_client.commit_txn().await?;
Ok(PostgresSource {
replication_client,
table_schemas,
Expand Down Expand Up @@ -135,6 +136,14 @@ impl Source for PostgresSource {
})
}

async fn start_transaction(&self) -> Result<(), Self::Error> {
self.replication_client
.begin_readonly_transaction()
.await
.map_err(PostgresSourceError::ReplicationClient)?;
Ok(())
}

async fn commit_transaction(&self) -> Result<(), Self::Error> {
self.replication_client
.commit_txn()
Expand Down

0 comments on commit 57380ea

Please sign in to comment.