Skip to content

Commit

Permalink
Refactor MigrationRunner::run_migrations() to call a helper
Browse files Browse the repository at this point in the history
This will make it easier to add cluster migrations, such as that for
CVE-2024-4317.

Link: https://www.postgresql.org/support/security/CVE-2024-4317/
Signed-off-by: Tristan Partin <[email protected]>
  • Loading branch information
tristan957 committed Jan 6, 2025
1 parent fda52a0 commit f6daa2e
Showing 1 changed file with 43 additions and 49 deletions.
92 changes: 43 additions & 49 deletions compute_tools/src/migration.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use anyhow::{Context, Result};
use fail::fail_point;
use postgres::Client;
use postgres::{Client, Transaction};
use tracing::info;

/// Runs a series of migrations on a target database
Expand All @@ -20,11 +20,9 @@ impl<'m> MigrationRunner<'m> {

/// Get the current value neon_migration.migration_id
fn get_migration_id(&mut self) -> Result<i64> {
let query = "SELECT id FROM neon_migration.migration_id";
let row = self
.client
.query_one(query, &[])
.context("run_migrations get migration_id")?;
.query_one("SELECT id FROM neon_migration.migration_id", &[])?;

Ok(row.get::<&str, i64>("id"))
}
Expand All @@ -34,7 +32,7 @@ impl<'m> MigrationRunner<'m> {
/// This function has a fail point called compute-migration, which can be
/// used if you would like to fail the application of a series of migrations
/// at some point.
fn update_migration_id(&mut self, migration_id: i64) -> Result<()> {
fn update_migration_id(txn: &mut Transaction, migration_id: i64) -> Result<()> {
// We use this fail point in order to check that failing in the
// middle of applying a series of migrations fails in an expected
// manner
Expand All @@ -55,12 +53,11 @@ impl<'m> MigrationRunner<'m> {
}
}

self.client
.query(
"UPDATE neon_migration.migration_id SET id = $1",
&[&migration_id],
)
.context("run_migrations update id")?;
txn.query(
"UPDATE neon_migration.migration_id SET id = $1",
&[&migration_id],
)
.with_context(|| format!("update neon_migration.migration_id to {migration_id}"))?;

Ok(())
}
Expand All @@ -81,53 +78,50 @@ impl<'m> MigrationRunner<'m> {
Ok(())
}

/// Run the configrured set of migrations
pub fn run_migrations(mut self) -> Result<()> {
self.prepare_database()?;
/// Run an individual migration
fn run_migration(txn: &mut Transaction, migration_id: i64, migration: &str) -> Result<()> {
if migration.starts_with("-- SKIP") {
info!("Skipping migration id={}", migration_id);

let mut current_migration = self.get_migration_id()? as usize;
while current_migration < self.migrations.len() {
macro_rules! migration_id {
($cm:expr) => {
($cm + 1) as i64
};
}
// Even though we are skipping the migration, updating the
// migration ID should help keep logic easy to understand when
// trying to understand the state of a cluster.
Self::update_migration_id(txn, migration_id)?;
} else {
info!("Running migration id={}:\n{}\n", migration_id, migration);

let migration = self.migrations[current_migration];
txn.simple_query(migration)
.with_context(|| format!("apply migration {migration_id}"))?;

if migration.starts_with("-- SKIP") {
info!("Skipping migration id={}", migration_id!(current_migration));
Self::update_migration_id(txn, migration_id)?;
}

// Even though we are skipping the migration, updating the
// migration ID should help keep logic easy to understand when
// trying to understand the state of a cluster.
self.update_migration_id(migration_id!(current_migration))?;
} else {
info!(
"Running migration id={}:\n{}\n",
migration_id!(current_migration),
migration
);
Ok(())
}

/// Run the configured set of migrations
pub fn run_migrations(mut self) -> Result<()> {
self.prepare_database()
.context("prepare database to handle migrations")?;

self.client
.simple_query("BEGIN")
.context("begin migration")?;
let mut current_migration = self.get_migration_id()? as usize;
while current_migration < self.migrations.len() {
// The index lags the migration ID by 1, so the current migration
// ID is also the next index
let migration_id = (current_migration + 1) as i64;

self.client.simple_query(migration).with_context(|| {
format!(
"run_migrations migration id={}",
migration_id!(current_migration)
)
})?;
let mut txn = self
.client
.transaction()
.with_context(|| format!("begin transaction for migration {migration_id}"))?;

self.update_migration_id(migration_id!(current_migration))?;
Self::run_migration(&mut txn, migration_id, self.migrations[current_migration])
.with_context(|| format!("running migration {migration_id}"))?;

self.client
.simple_query("COMMIT")
.context("commit migration")?;
txn.commit()
.with_context(|| format!("commit transaction for migration {migration_id}"))?;

info!("Finished migration id={}", migration_id!(current_migration));
}
info!("Finished migration id={}", migration_id);

current_migration += 1;
}
Expand Down

0 comments on commit f6daa2e

Please sign in to comment.