Skip to content

Commit

Permalink
Expose sqlite's register_collation_function()
Browse files Browse the repository at this point in the history
  • Loading branch information
z33ky committed Sep 4, 2020
1 parent 44d0096 commit 79a45f1
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 0 deletions.
7 changes: 7 additions & 0 deletions diesel/src/sqlite/connection/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ where
Ok(())
}

pub fn register_collation<F>(conn: &RawConnection, fn_name: &str, f: F) -> QueryResult<()>
where
F: FnMut(&str, &str) -> std::cmp::Ordering + Send + 'static,
{
conn.register_collation_function(fn_name, f)
}

pub(crate) fn build_sql_function_args<ArgsSqlType, Args>(
args: &[*mut ffi::sqlite3_value],
) -> Result<Args, Error>
Expand Down
104 changes: 104 additions & 0 deletions diesel/src/sqlite/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,39 @@ impl SqliteConnection {
functions::register_aggregate::<_, _, _, _, A>(&self.raw_connection, fn_name)
}

/// Register a collation function.
///
/// `f` must always return the same answer given the same inputs.
///
/// This method will return an error if registering the function fails.
///
/// The collation can be specified when creating a table:
/// `CREATE TABLE my_table ( str TEXT COLLATE MY_COLLATION )`,
/// where `MY_COLLATION` corresponds to `fn_name`.
///
/// # Example
///
/// ```rust
/// # include!("../../doctest_setup.rs");
/// #
/// # fn main() {
/// # run_test().unwrap();
/// # }
/// #
/// # fn run_test() -> QueryResult<()> {
/// # let conn = SqliteConnection::establish(":memory:").unwrap();
/// conn.register_collation("RUSTNOCASE", |rhs, lhs| {
/// rhs.to_lowercase().cmp(&lhs.to_lowercase())
/// }).unwrap();
/// # }
/// ```
pub fn register_collation<F>(&self, fn_name: &str, f: F) -> QueryResult<()>
where
F: FnMut(&str, &str) -> std::cmp::Ordering + Send + 'static,
{
functions::register_collation(&self.raw_connection, fn_name, f)
}

fn register_diesel_sql_functions(&self) -> QueryResult<()> {
use crate::sql_types::{Integer, Text};

Expand Down Expand Up @@ -514,4 +547,75 @@ mod tests {
.unwrap();
assert_eq!(Some(3), result);
}

table! {
my_collation_example {
id -> Integer,
value -> Text,
}
}

#[test]
fn register_collation_function() {
use self::my_collation_example::dsl::*;

let connection = SqliteConnection::establish(":memory:").unwrap();

connection
.register_collation("RUSTNOCASE", |rhs, lhs| {
rhs.to_lowercase().cmp(&lhs.to_lowercase())
})
.unwrap();

connection
.execute(
"CREATE TABLE my_collation_example (id integer primary key autoincrement, value text collate RUSTNOCASE)",
)
.unwrap();
connection
.execute("INSERT INTO my_collation_example (value) VALUES ('foo'), ('FOo'), ('f00')")
.unwrap();

let result = my_collation_example
.filter(value.eq("foo"))
.select(value)
.load::<String>(&connection);
assert_eq!(
Ok(&["foo".to_owned(), "FOo".to_owned()][..]),
result.as_ref().map(|vec| vec.as_ref())
);

let result = my_collation_example
.filter(value.eq("FOO"))
.select(value)
.load::<String>(&connection);
assert_eq!(
Ok(&["foo".to_owned(), "FOo".to_owned()][..]),
result.as_ref().map(|vec| vec.as_ref())
);

let result = my_collation_example
.filter(value.eq("f00"))
.select(value)
.load::<String>(&connection);
assert_eq!(
Ok(&["f00".to_owned()][..]),
result.as_ref().map(|vec| vec.as_ref())
);

let result = my_collation_example
.filter(value.eq("F00"))
.select(value)
.load::<String>(&connection);
assert_eq!(
Ok(&["f00".to_owned()][..]),
result.as_ref().map(|vec| vec.as_ref())
);

let result = my_collation_example
.filter(value.eq("oof"))
.select(value)
.load::<String>(&connection);
assert_eq!(Ok(&[][..]), result.as_ref().map(|vec| vec.as_ref()));
}
}
63 changes: 63 additions & 0 deletions diesel/src/sqlite/connection/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,27 @@ impl RawConnection {
Self::process_sql_function_result(result)
}

pub fn register_collation_function<F>(&self, fn_name: &str, f: F) -> QueryResult<()>
where
F: FnMut(&str, &str) -> std::cmp::Ordering + Send + 'static,
{
let fn_name = Self::get_fn_name(fn_name)?;
let callback_fn = Box::into_raw(Box::new(f));

let result = unsafe {
ffi::sqlite3_create_collation_v2(
self.internal_connection.as_ptr(),
fn_name.as_ptr(),
ffi::SQLITE_UTF8,
callback_fn as *mut _,
Some(run_collation_function::<F>),
Some(destroy_boxed_collation_fn::<F>),
)
};

Self::process_sql_function_result(result)
}

fn get_fn_name(fn_name: &str) -> Result<CString, NulError> {
Ok(CString::new(fn_name)?)
}
Expand Down Expand Up @@ -379,6 +400,40 @@ unsafe fn null_aggregate_context_error(ctx: *mut ffi::sqlite3_context) {
);
}

#[allow(warnings)]
extern "C" fn run_collation_function<F>(
user_ptr: *mut libc::c_void,
len_rhs: libc::c_int,
rhs_ptr: *const libc::c_void,
len_lhs: libc::c_int,
lhs_ptr: *const libc::c_void,
) -> libc::c_int
where
F: FnMut(&str, &str) -> std::cmp::Ordering + Send + 'static,
{
unsafe {
let user_ptr = user_ptr as *mut F;
let f = match user_ptr.as_mut() {
Some(f) => f,
None => {
//FIXME
return 0;
}
};

//FIXME: check args
let rhs =
str::from_utf8_unchecked(slice::from_raw_parts(rhs_ptr as *const u8, len_rhs as _));
let lhs =
str::from_utf8_unchecked(slice::from_raw_parts(lhs_ptr as *const u8, len_lhs as _));
match f(rhs, lhs) {
std::cmp::Ordering::Greater => 1,
std::cmp::Ordering::Equal => 0,
std::cmp::Ordering::Less => -1,
}
}
}

extern "C" fn destroy_boxed_fn<F>(data: *mut libc::c_void)
where
F: FnMut(&RawConnection, &[*mut ffi::sqlite3_value]) -> QueryResult<SerializedValue>
Expand All @@ -388,3 +443,11 @@ where
let ptr = data as *mut F;
unsafe { Box::from_raw(ptr) };
}

extern "C" fn destroy_boxed_collation_fn<F>(data: *mut libc::c_void)
where
F: FnMut(&str, &str) -> std::cmp::Ordering + Send + 'static,
{
let ptr = data as *mut F;
unsafe { Box::from_raw(ptr) };
}

0 comments on commit 79a45f1

Please sign in to comment.