Skip to content

Commit 7810f7d

Browse files
agentsimmehcode
andauthored
Sqlite Collation Support (#446)
* Sqlite Collation Support Adds a method create_collation to SqliteConnection. Adds a unit test confirming the collation works as expected. * Fix formatting * Address feedback Co-authored-by: Ryan Leckey <[email protected]>
1 parent aaa8b25 commit 7810f7d

File tree

3 files changed

+121
-2
lines changed

3 files changed

+121
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
use std::cmp::Ordering;
2+
use std::ffi::CString;
3+
use std::os::raw::{c_char, c_int, c_void};
4+
use std::slice;
5+
use std::str::from_utf8_unchecked;
6+
7+
use libsqlite3_sys::{sqlite3_create_collation_v2, SQLITE_OK, SQLITE_UTF8};
8+
9+
use crate::error::Error;
10+
use crate::sqlite::connection::handle::ConnectionHandle;
11+
use crate::sqlite::SqliteError;
12+
13+
unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
14+
drop(Box::from_raw(p as *mut T));
15+
}
16+
17+
pub(crate) fn create_collation<F>(
18+
handle: &ConnectionHandle,
19+
name: &str,
20+
compare: F,
21+
) -> Result<(), Error>
22+
where
23+
F: Fn(&str, &str) -> Ordering + Send + Sync + 'static,
24+
{
25+
unsafe extern "C" fn call_boxed_closure<C>(
26+
arg1: *mut c_void,
27+
arg2: c_int,
28+
arg3: *const c_void,
29+
arg4: c_int,
30+
arg5: *const c_void,
31+
) -> c_int
32+
where
33+
C: Fn(&str, &str) -> Ordering,
34+
{
35+
let boxed_f: *mut C = arg1 as *mut C;
36+
debug_assert!(!boxed_f.is_null());
37+
let s1 = {
38+
let c_slice = slice::from_raw_parts(arg3 as *const u8, arg2 as usize);
39+
from_utf8_unchecked(c_slice)
40+
};
41+
let s2 = {
42+
let c_slice = slice::from_raw_parts(arg5 as *const u8, arg4 as usize);
43+
from_utf8_unchecked(c_slice)
44+
};
45+
let t = (*boxed_f)(s1, s2);
46+
47+
match t {
48+
Ordering::Less => -1,
49+
Ordering::Equal => 0,
50+
Ordering::Greater => 1,
51+
}
52+
}
53+
54+
let boxed_f: *mut F = Box::into_raw(Box::new(compare));
55+
let c_name =
56+
CString::new(name).map_err(|_| err_protocol!("invalid collation name: {}", name))?;
57+
let flags = SQLITE_UTF8;
58+
let r = unsafe {
59+
sqlite3_create_collation_v2(
60+
handle.as_ptr(),
61+
c_name.as_ptr(),
62+
flags,
63+
boxed_f as *mut c_void,
64+
Some(call_boxed_closure::<F>),
65+
Some(free_boxed_value::<F>),
66+
)
67+
};
68+
69+
if r == SQLITE_OK {
70+
Ok(())
71+
} else {
72+
Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr()))))
73+
}
74+
}

sqlx-core/src/sqlite/connection/mod.rs

+10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::cmp::Ordering;
12
use std::fmt::{self, Debug, Formatter};
23
use std::sync::Arc;
34

@@ -15,6 +16,7 @@ use crate::sqlite::connection::establish::establish;
1516
use crate::sqlite::statement::{SqliteStatement, StatementWorker};
1617
use crate::sqlite::{Sqlite, SqliteConnectOptions};
1718

19+
mod collation;
1820
mod describe;
1921
mod establish;
2022
mod executor;
@@ -43,6 +45,14 @@ impl SqliteConnection {
4345
pub fn as_raw_handle(&mut self) -> *mut sqlite3 {
4446
self.handle.as_ptr()
4547
}
48+
49+
pub fn create_collation(
50+
&mut self,
51+
name: &str,
52+
compare: impl Fn(&str, &str) -> Ordering + Send + Sync + 'static,
53+
) -> Result<(), Error> {
54+
collation::create_collation(&self.handle, name, compare)
55+
}
4656
}
4757

4858
impl Debug for SqliteConnection {

tests/sqlite/sqlite.rs

+37-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use futures::TryStreamExt;
2-
use sqlx::sqlite::{Sqlite, SqliteConnection, SqlitePool, SqliteRow};
3-
use sqlx::{query, Connect, Connection, Executor, Row};
2+
use sqlx::{
3+
query, sqlite::Sqlite, sqlite::SqliteRow, Connect, Connection, Executor, Row, SqliteConnection,
4+
SqlitePool,
5+
};
46
use sqlx_test::new;
57

68
#[sqlx_macros::test]
@@ -303,6 +305,39 @@ SELECT id, text FROM _sqlx_test;
303305
Ok(())
304306
}
305307

308+
#[sqlx_macros::test]
309+
async fn it_supports_collations() -> anyhow::Result<()> {
310+
let mut conn = new::<Sqlite>().await?;
311+
312+
conn.create_collation("test_collation", |l, r| l.cmp(r).reverse())?;
313+
314+
let _ = conn
315+
.execute(
316+
r#"
317+
CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL COLLATE test_collation)
318+
"#,
319+
)
320+
.await?;
321+
322+
sqlx::query("INSERT INTO users (name) VALUES (?)")
323+
.bind("a")
324+
.execute(&mut conn)
325+
.await?;
326+
sqlx::query("INSERT INTO users (name) VALUES (?)")
327+
.bind("b")
328+
.execute(&mut conn)
329+
.await?;
330+
331+
let row: SqliteRow = conn
332+
.fetch_one("SELECT name FROM users ORDER BY name ASC")
333+
.await?;
334+
let name: &str = row.try_get(0)?;
335+
336+
assert_eq!(name, "b");
337+
338+
Ok(())
339+
}
340+
306341
#[sqlx_macros::test]
307342
async fn it_caches_statements() -> anyhow::Result<()> {
308343
let mut conn = new::<Sqlite>().await?;

0 commit comments

Comments
 (0)