Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fast import: restore to neondb (not postgres) database #10251

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 120 additions & 81 deletions compute_tools/src/bin/fast_import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ struct Args {
#[clap(long)]
working_directory: Utf8PathBuf,
#[clap(long, env = "NEON_IMPORTER_S3_PREFIX")]
s3_prefix: s3_uri::S3Uri,
s3_prefix: Option<s3_uri::S3Uri>,
#[clap(long)]
source_connection_string: Option<String>,
#[clap(short, long)]
interactive: bool,
#[clap(long)]
pg_bin_dir: Utf8PathBuf,
#[clap(long)]
Expand Down Expand Up @@ -77,30 +81,67 @@ pub(crate) async fn main() -> anyhow::Result<()> {

info!("starting");

let Args {
working_directory,
s3_prefix,
pg_bin_dir,
pg_lib_dir,
} = Args::parse();

let aws_config = aws_config::load_defaults(BehaviorVersion::v2024_03_28()).await;

let spec: Spec = {
let spec_key = s3_prefix.append("/spec.json");
let s3_client = aws_sdk_s3::Client::new(&aws_config);
let object = s3_client
.get_object()
.bucket(&spec_key.bucket)
.key(spec_key.key)
.send()
.await
.context("get spec from s3")?
.body
.collect()
.await
.context("download spec body")?;
serde_json::from_slice(&object.into_bytes()).context("parse spec as json")?
let args = Args::parse();

// Validate arguments
if args.s3_prefix.is_none() && args.source_connection_string.is_none() {
anyhow::bail!("either s3_prefix or source_connection_string must be specified");
}

let working_directory = args.working_directory;
let pg_bin_dir = args.pg_bin_dir;
let pg_lib_dir = args.pg_lib_dir;

// Initialize AWS clients only if s3_prefix is specified
let (aws_config, kms_client) = if args.s3_prefix.is_some() {
let config = aws_config::load_defaults(BehaviorVersion::v2024_03_28()).await;
let kms = aws_sdk_kms::Client::new(&config);
(Some(config), Some(kms))
} else {
(None, None)
};

// Get source connection string either from S3 spec or direct argument
let source_connection_string = if let Some(s3_prefix) = &args.s3_prefix {
let spec: Spec = {
let spec_key = s3_prefix.append("/spec.json");
let s3_client = aws_sdk_s3::Client::new(aws_config.as_ref().unwrap());
let object = s3_client
.get_object()
.bucket(&spec_key.bucket)
.key(spec_key.key)
.send()
.await
.context("get spec from s3")?
.body
.collect()
.await
.context("download spec body")?;
serde_json::from_slice(&object.into_bytes()).context("parse spec as json")?
};

match spec.encryption_secret {
EncryptionSecret::KMS { key_id } => {
let mut output = kms_client
.unwrap()
.decrypt()
.key_id(key_id)
.ciphertext_blob(aws_sdk_s3::primitives::Blob::new(
spec.source_connstring_ciphertext_base64,
))
.send()
.await
.context("decrypt source connection string")?;
let plaintext = output
.plaintext
.take()
.context("get plaintext source connection string")?;
String::from_utf8(plaintext.into_inner())
.context("parse source connection string as utf8")?
}
}
} else {
args.source_connection_string.unwrap()
};

match tokio::fs::create_dir(&working_directory).await {
Expand All @@ -123,15 +164,6 @@ pub(crate) async fn main() -> anyhow::Result<()> {
.await
.context("create pgdata directory")?;

//
// Setup clients
//
let aws_config = aws_config::load_defaults(BehaviorVersion::v2024_03_28()).await;
let kms_client = aws_sdk_kms::Client::new(&aws_config);

//
// Initialize pgdata
//
let pgbin = pg_bin_dir.join("postgres");
let pg_version = match get_pg_version(pgbin.as_ref()) {
PostgresMajorVersion::V14 => 14,
Expand Down Expand Up @@ -170,7 +202,13 @@ pub(crate) async fn main() -> anyhow::Result<()> {
.args(["-c", &format!("max_parallel_workers={nproc}")])
.args(["-c", &format!("max_parallel_workers_per_gather={nproc}")])
.args(["-c", &format!("max_worker_processes={nproc}")])
.args(["-c", "effective_io_concurrency=100"])
.args([
"-c",
&format!(
"effective_io_concurrency={}",
if cfg!(target_os = "macos") { 0 } else { 100 }
),
])
.env_clear()
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
Expand All @@ -185,44 +223,36 @@ pub(crate) async fn main() -> anyhow::Result<()> {
)
.instrument(info_span!("postgres")),
);

// Create neondb database in the running postgres
let restore_pg_connstring =
format!("host=localhost port=5432 user={superuser} dbname=postgres");
loop {
let res = tokio_postgres::connect(&restore_pg_connstring, tokio_postgres::NoTls).await;
if res.is_ok() {
info!("postgres is ready, could connect to it");
break;
}
}

//
// Decrypt connection string
//
let source_connection_string = {
match spec.encryption_secret {
EncryptionSecret::KMS { key_id } => {
let mut output = kms_client
.decrypt()
.key_id(key_id)
.ciphertext_blob(aws_sdk_s3::primitives::Blob::new(
spec.source_connstring_ciphertext_base64,
))
.send()
.await
.context("decrypt source connection string")?;
let plaintext = output
.plaintext
.take()
.context("get plaintext source connection string")?;
String::from_utf8(plaintext.into_inner())
.context("parse source connection string as utf8")?
match tokio_postgres::connect(&restore_pg_connstring, tokio_postgres::NoTls).await {
Ok((client, connection)) => {
// Spawn the connection handling task to maintain the connection
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});

match client.simple_query("CREATE DATABASE neondb;").await {
Ok(_) => {
info!("created neondb database");
break;
}
Err(e) => {
info!("failed to create database: {}", e);
break;
}
}
}
Err(_) => continue,
}
};
}

//
// Start the work
//
let restore_pg_connstring = restore_pg_connstring.replace("dbname=postgres", "dbname=neondb");

let dumpdir = working_directory.join("dumpdir");

Expand Down Expand Up @@ -310,6 +340,12 @@ pub(crate) async fn main() -> anyhow::Result<()> {
}
}

// If interactive mode, wait for Ctrl+C
if args.interactive {
info!("Running in interactive mode. Press Ctrl+C to shut down.");
tokio::signal::ctrl_c().await.context("wait for ctrl-c")?;
}

info!("shutdown postgres");
{
nix::sys::signal::kill(
Expand All @@ -325,21 +361,24 @@ pub(crate) async fn main() -> anyhow::Result<()> {
.context("wait for postgres to shut down")?;
}

info!("upload pgdata");
aws_s3_sync::sync(Utf8Path::new(&pgdata_dir), &s3_prefix.append("/pgdata/"))
.await
.context("sync dump directory to destination")?;

info!("write status");
{
let status_dir = working_directory.join("status");
std::fs::create_dir(&status_dir).context("create status directory")?;
let status_file = status_dir.join("pgdata");
std::fs::write(&status_file, serde_json::json!({"done": true}).to_string())
.context("write status file")?;
aws_s3_sync::sync(&status_dir, &s3_prefix.append("/status/"))
// Only sync if s3_prefix was specified
if let Some(s3_prefix) = args.s3_prefix {
info!("upload pgdata");
aws_s3_sync::sync(Utf8Path::new(&pgdata_dir), &s3_prefix.append("/pgdata/"))
.await
.context("sync status directory to destination")?;
.context("sync dump directory to destination")?;

info!("write status");
{
let status_dir = working_directory.join("status");
std::fs::create_dir(&status_dir).context("create status directory")?;
let status_file = status_dir.join("pgdata");
std::fs::write(&status_file, serde_json::json!({"done": true}).to_string())
.context("write status file")?;
aws_s3_sync::sync(&status_dir, &s3_prefix.append("/status/"))
.await
.context("sync status directory to destination")?;
}
}

Ok(())
Expand Down
Loading