From 78820395456ecb56bfae31ca5811d4a689a852ca Mon Sep 17 00:00:00 2001 From: Miguel Piedrafita Date: Fri, 15 Dec 2023 00:25:15 +0100 Subject: [PATCH] fix where_or queries --- ensemble/docs/relationships.md | 6 ++-- ensemble/src/connection.rs | 6 ++-- ensemble/src/query.rs | 56 ++++++++++++++++++++-------------- 3 files changed, 39 insertions(+), 29 deletions(-) diff --git a/ensemble/docs/relationships.md b/ensemble/docs/relationships.md index ffa753b..d367dec 100644 --- a/ensemble/docs/relationships.md +++ b/ensemble/docs/relationships.md @@ -2,9 +2,9 @@ Database tables are often related to one another. For example, a blog post may have many comments or an order could be related to the user who placed it. Ensemble makes managing and working with these relationships easy, with native support for the three most common: -- [One To One](#one-to-one) -- [One To Many](#one-to-many) -- [Many To Many](#many-to-many-relationships) +- [One To One](#one-to-one) +- [One To Many](#one-to-many) +- [Many To Many](#many-to-many-relationships) ## Defining Relationships diff --git a/ensemble/src/connection.rs b/ensemble/src/connection.rs index c9e86a4..85b6a2f 100644 --- a/ensemble/src/connection.rs +++ b/ensemble/src/connection.rs @@ -80,13 +80,11 @@ pub async fn get() -> Result { } } -#[cfg(any(feature = "mysql", feature = "postgres"))] pub enum Database { MySQL, PostgreSQL, } -#[cfg(any(feature = "mysql", feature = "postgres"))] impl Database { pub const fn is_mysql(&self) -> bool { matches!(self, Self::MySQL) @@ -97,8 +95,10 @@ impl Database { } } -#[cfg(any(feature = "mysql", feature = "postgres"))] pub const fn which_db() -> Database { + #[cfg(all(not(feature = "mysql"), not(feature = "postgres")))] + panic!("Either the `mysql` or `postgres` feature must be enabled to use `ensemble`."); + #[cfg(all(feature = "mysql", feature = "postgres"))] panic!("Both the `mysql` and `postgres` features are enabled. Please enable only one of them."); diff --git a/ensemble/src/query.rs b/ensemble/src/query.rs index 631722d..4bfd0b4 100644 --- a/ensemble/src/query.rs +++ b/ensemble/src/query.rs @@ -6,7 +6,10 @@ use std::{ fmt::Display, }; -use crate::{connection, value, Error, Model}; +use crate::{ + connection::{self, Database}, + value, Error, Model, +}; /// The Query Builder. #[derive(Debug)] @@ -92,7 +95,7 @@ impl Builder { self.r#where.push(WhereClause::Simple(Where { boolean: Boolean::And, operator: operator.into(), - column: column.to_string(), + column: Columns::escape(column), value: Some(value::for_db(value).unwrap()), })); @@ -141,7 +144,7 @@ impl Builder { operator: op.into(), boolean: Boolean::Or, value: Some(value.into()), - column: column.to_string(), + column: Columns::escape(column), })); self @@ -153,8 +156,8 @@ impl Builder { self.r#where.push(WhereClause::Simple(Where { value: None, boolean: Boolean::And, - column: column.to_string(), operator: Operator::NotNull, + column: Columns::escape(column), })); self @@ -169,7 +172,7 @@ impl Builder { self.r#where.push(WhereClause::Simple(Where { boolean: Boolean::And, operator: Operator::In, - column: column.to_string(), + column: Columns::escape(column), value: Some(Value::Array(values.into_iter().map(Into::into).collect())), })); @@ -182,8 +185,8 @@ impl Builder { self.r#where.push(WhereClause::Simple(Where { value: None, boolean: Boolean::And, - column: column.to_string(), operator: Operator::IsNull, + column: Columns::escape(column), })); self @@ -200,10 +203,10 @@ impl Builder { ) -> Self { self.join.push(Join { operator: op.into(), - first: first.to_string(), - column: column.to_string(), r#type: JoinType::Inner, + first: first.to_string(), second: second.to_string(), + column: Columns::escape(column), }); self @@ -213,8 +216,8 @@ impl Builder { #[must_use] pub fn order_by>(mut self, column: &str, direction: Dir) -> Self { self.order.push(Order { - column: column.to_string(), direction: direction.into(), + column: Columns::escape(column), }); self @@ -254,7 +257,7 @@ impl Builder { sql.push_str(" WHERE "); for (i, where_clause) in self.r#where.iter().enumerate() { - sql.push_str(&where_clause.to_sql(i != self.r#where.len() - 1)); + sql.push_str(&where_clause.to_sql(i != 0)); } } @@ -430,8 +433,10 @@ impl Builder { let mut conn = connection::get().await?; let (sql, mut bindings) = ( format!( - "UPDATE {} SET {column} = {column} + ? {}", + "UPDATE {} SET {} = {} + ? {}", self.table, + Columns::escape(column), + Columns::escape(column), self.to_sql(Type::Update) ), self.get_bindings(), @@ -560,13 +565,22 @@ impl From> for EagerLoad { pub struct Columns(Vec<(String, Value)>); +impl Columns { + fn escape(column: &str) -> String { + match connection::which_db() { + Database::MySQL => format!("`{column}`"), + Database::PostgreSQL => format!("\"{column}\""), + } + } +} + #[allow(clippy::fallible_impl_from)] impl From for Columns { fn from(value: Value) -> Self { match value { Value::Map(map) => Self( map.into_iter() - .map(|(column, value)| (column.into_string().unwrap(), value)) + .map(|(column, value)| (Self::escape(&column.into_string().unwrap()), value)) .collect(), ), _ => panic!("The provided value is not a map."), @@ -579,7 +593,7 @@ impl From> for Columns { Self( values .iter() - .map(|(column, value)| ((*column).to_string(), value::for_db(value).unwrap())) + .map(|(column, value)| (Self::escape(column), value::for_db(value).unwrap())) .collect(), ) } @@ -589,7 +603,7 @@ impl From<&[(&str, T)]> for Columns { Self( values .iter() - .map(|(column, value)| ((*column).to_string(), value::for_db(value).unwrap())) + .map(|(column, value)| (Self::escape(column), value::for_db(value).unwrap())) .collect(), ) } @@ -683,17 +697,13 @@ impl WhereClause { let mut sql = String::new(); for (i, where_clause) in where_clauses.iter().enumerate() { - sql.push_str(&format!("({})", where_clause.to_sql(false))); - - if i != where_clauses.len() - 1 { - sql.push_str(" AND "); - } + sql.push_str(&where_clause.to_sql(i != 0)); } if add_boolean { - format!("{boolean} {sql}") + format!(" {boolean} ({sql})") } else { - sql + format!("({sql})") } }, } @@ -741,7 +751,7 @@ impl Where { ); if add_boolean { - format!("{sql} {} ", self.boolean) + format!(" {} {sql} ", self.boolean) } else { sql } @@ -839,7 +849,7 @@ impl From<&str> for Operator { } } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] enum Boolean { And, Or,