Skip to content

Commit

Permalink
Add increment method
Browse files Browse the repository at this point in the history
  • Loading branch information
m1guelpf committed Dec 14, 2023
1 parent b03681d commit 050c700
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
24 changes: 24 additions & 0 deletions ensemble/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ pub trait Model: DeserializeOwned + Serialize + Sized + Send + Sync + Debug + De
Self::query().with(eager_load)
}

/// Load a relationship for the model.
fn load<T: Into<EagerLoad> + Send>(
&mut self,
relation: T,
Expand All @@ -173,6 +174,29 @@ pub trait Model: DeserializeOwned + Serialize + Sized + Send + Sync + Debug + De
}
}

fn increment(
&mut self,
column: &str,
amount: u64,
) -> impl Future<Output = Result<(), Error>> + Send {
async move {
let rows_affected = Self::query()
.r#where(
Self::PRIMARY_KEY,
"=",
value::for_db(self.primary_key()).unwrap(),
)
.increment(column, amount)
.await?;

if rows_affected != 1 {
return Err(Error::UniqueViolation);
}

Ok(())
}
}

/// Convert the model to a JSON value.
///
/// # Panics
Expand Down
25 changes: 25 additions & 0 deletions ensemble/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,31 @@ impl Builder {
Ok(rbs::from_value(result.last_insert_id)?)
}

/// Increment a column's value by a given amount. Returns the number of affected rows.
///
/// # Errors
///
/// Returns an error if the query fails, or if a connection to the database cannot be established.
pub async fn increment(self, column: &str, amount: u64) -> Result<u64, Error> {
let mut conn = connection::get().await?;
let (sql, mut bindings) = (
format!(
"UPDATE {} SET {column} = {column} + ? {}",
self.table,
self.to_sql(Type::Update)
),
self.get_bindings(),
);
bindings.insert(0, amount.into());

tracing::debug!(sql = sql.as_str(), bindings = ?bindings, "Executing UPDATE SQL query for increment");

conn.exec(&sql, bindings)
.await
.map_err(|e| Error::Database(e.to_string()))
.map(|r| r.rows_affected)
}

/// Update records in the database. Returns the number of affected rows.
///
/// # Errors
Expand Down

0 comments on commit 050c700

Please sign in to comment.