diff --git a/README.md b/README.md index 67be08a1..cba0af90 100644 --- a/README.md +++ b/README.md @@ -304,6 +304,25 @@ sqlite::sqlite_exception has a get_code() member function to get the SQLITE3 err catch(sqlite::exceptions::constraint e) { } */ ``` +Custom SQL functions +---- + +To extend SQLite with custom functions, you just implement them in C++: + +```c++ + database db(":memory:"); + db.define("tgamma", [](double i) {return std::tgamma(i);}); + db << "CREATE TABLE numbers (number INTEGER);"; + + for(auto i=0; i!=10; ++i) + db << "INSERT INTO numbers VALUES (?);" << i; + + db << "SELECT number, tgamma(number+1) FROM numbers;" >> [](double number, double factorial) { + cout << number << "! = " << factorial << '\n'; + }; +``` + + NDK support ---- Just Make sure you are using the full path of your database file : diff --git a/hdr/sqlite_modern_cpp.h b/hdr/sqlite_modern_cpp.h index 4b3682b7..ba07d507 100644 --- a/hdr/sqlite_modern_cpp.h +++ b/hdr/sqlite_modern_cpp.h @@ -318,6 +318,73 @@ namespace sqlite { } }; + namespace sql_function_binder { + template< + typename ContextType, + std::size_t Count, + typename Functions + > + inline void step( + sqlite3_context* db, + int count, + sqlite3_value** vals + ); + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ); + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ); + + template< + typename ContextType, + typename Functions + > + inline void final(sqlite3_context* db); + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ); + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ); + } + struct sqlite_config { }; @@ -374,6 +441,37 @@ namespace sqlite { sqlite3_int64 last_insert_rowid() const { return sqlite3_last_insert_rowid(_db.get()); } + + template + void define(const std::string &name, Function&& func) { + typedef utility::function_traits traits; + + auto funcPtr = new auto(std::forward(func)); + if(int result = sqlite3_create_function_v2( + _db.get(), name.c_str(), traits::arity, SQLITE_UTF8, funcPtr, + sql_function_binder::scalar::type>, + nullptr, nullptr, [](void* ptr){ + delete static_cast(ptr); + })) + exceptions::throw_sqlite_error(result); + } + + template + void define(const std::string &name, StepFunction&& step, FinalFunction&& final) { + typedef utility::function_traits traits; + using ContextType = typename std::remove_reference>::type; + + auto funcPtr = new auto(std::make_pair(std::forward(step), std::forward(final))); + if(int result = sqlite3_create_function_v2( + _db.get(), name.c_str(), traits::arity - 1, SQLITE_UTF8, funcPtr, nullptr, + sql_function_binder::step::type>, + sql_function_binder::final::type>, + [](void* ptr){ + delete static_cast(ptr); + })) + exceptions::throw_sqlite_error(result); + } + }; template @@ -404,7 +502,7 @@ namespace sqlite { Function&& function, Values&&... values ) { - nth_argument_type value{}; + typename std::remove_cv>::type>::type value{}; get_col_from_db(db, sizeof...(Values), value); run(db, function, std::forward(values)..., std::move(value)); @@ -432,6 +530,9 @@ namespace sqlite { } ++db._inx; return db; + } + inline void store_result_in_db(sqlite3_context* db, const int& val) { + sqlite3_result_int(db, val); } inline void get_col_from_db(database_binder& db, int inx, int& val) { if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { @@ -440,6 +541,13 @@ namespace sqlite { val = sqlite3_column_int(db._stmt.get(), inx); } } + inline void get_val_from_db(sqlite3_value *value, int& val) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + val = 0; + } else { + val = sqlite3_value_int(value); + } + } // sqlite_int64 inline database_binder& operator <<(database_binder& db, const sqlite_int64& val) { @@ -450,6 +558,9 @@ namespace sqlite { ++db._inx; return db; + } + inline void store_result_in_db(sqlite3_context* db, const sqlite_int64& val) { + sqlite3_result_int64(db, val); } inline void get_col_from_db(database_binder& db, int inx, sqlite3_int64& i) { if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { @@ -458,6 +569,13 @@ namespace sqlite { i = sqlite3_column_int64(db._stmt.get(), inx); } } + inline void get_val_from_db(sqlite3_value *value, sqlite3_int64& i) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + i = 0; + } else { + i = sqlite3_value_int64(value); + } + } // float inline database_binder& operator <<(database_binder& db, const float& val) { @@ -468,6 +586,9 @@ namespace sqlite { ++db._inx; return db; + } + inline void store_result_in_db(sqlite3_context* db, const float& val) { + sqlite3_result_double(db, val); } inline void get_col_from_db(database_binder& db, int inx, float& f) { if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { @@ -476,6 +597,13 @@ namespace sqlite { f = float(sqlite3_column_double(db._stmt.get(), inx)); } } + inline void get_val_from_db(sqlite3_value *value, float& f) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + f = 0; + } else { + f = float(sqlite3_value_double(value)); + } + } // double inline database_binder& operator <<(database_binder& db, const double& val) { @@ -486,6 +614,9 @@ namespace sqlite { ++db._inx; return db; + } + inline void store_result_in_db(sqlite3_context* db, const double& val) { + sqlite3_result_double(db, val); } inline void get_col_from_db(database_binder& db, int inx, double& d) { if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { @@ -494,6 +625,13 @@ namespace sqlite { d = sqlite3_column_double(db._stmt.get(), inx); } } + inline void get_val_from_db(sqlite3_value *value, double& d) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + d = 0; + } else { + d = sqlite3_value_double(value); + } + } // vector template inline database_binder& operator<<(database_binder& db, const std::vector& vec) { @@ -506,6 +644,11 @@ namespace sqlite { ++db._inx; return db; } + template inline void store_result_in_db(sqlite3_context* db, const std::vector& vec) { + void const* buf = reinterpret_cast(vec.data()); + int bytes = vec.size() * sizeof(T); + sqlite3_result_blob(db, buf, bytes, SQLITE_TRANSIENT); + } template inline void get_col_from_db(database_binder& db, int inx, std::vector& vec) { if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { vec.clear(); @@ -515,6 +658,15 @@ namespace sqlite { vec = std::vector(buf, buf + bytes/sizeof(T)); } } + template inline void get_val_from_db(sqlite3_value *value, std::vector& vec) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + vec.clear(); + } else { + int bytes = sqlite3_value_bytes(value); + T const* buf = reinterpret_cast(sqlite3_value_blob(value)); + vec = std::vector(buf, buf + bytes/sizeof(T)); + } + } /* for nullptr support */ inline database_binder& operator <<(database_binder& db, std::nullptr_t) { @@ -524,13 +676,16 @@ namespace sqlite { } ++db._inx; return db; + } + inline void store_result_in_db(sqlite3_context* db, std::nullptr_t) { + sqlite3_result_null(db); } /* for nullptr support */ template inline database_binder& operator <<(database_binder& db, const std::unique_ptr& val) { if(val) db << *val; else - db << nullptr; + db << nullptr; return db; } @@ -544,6 +699,15 @@ namespace sqlite { _ptr_.reset(underling_ptr); } } + template inline void get_val_from_db(sqlite3_value *value, std::unique_ptr& _ptr_) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + _ptr_ = nullptr; + } else { + auto underling_ptr = new T(); + get_val_from_db(value, *underling_ptr); + _ptr_.reset(underling_ptr); + } + } // std::string inline void get_col_from_db(database_binder& db, int inx, std::string & s) { @@ -554,6 +718,14 @@ namespace sqlite { s = std::string(reinterpret_cast(sqlite3_column_text(db._stmt.get(), inx))); } } + inline void get_val_from_db(sqlite3_value *value, std::string & s) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + s = std::string(); + } else { + sqlite3_value_bytes(value); + s = std::string(reinterpret_cast(sqlite3_value_text(value))); + } + } // Convert char* to string to trigger op<<(..., const std::string ) template inline database_binder& operator <<(database_binder& db, const char(&STR)[N]) { return db << std::string(STR); } @@ -567,6 +739,9 @@ namespace sqlite { ++db._inx; return db; + } + inline void store_result_in_db(sqlite3_context* db, const std::string& val) { + sqlite3_result_text(db, val.data(), -1, SQLITE_TRANSIENT); } // std::u16string inline void get_col_from_db(database_binder& db, int inx, std::u16string & w) { @@ -577,6 +752,14 @@ namespace sqlite { w = std::u16string(reinterpret_cast(sqlite3_column_text16(db._stmt.get(), inx))); } } + inline void get_val_from_db(sqlite3_value *value, std::u16string & w) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + w = std::u16string(); + } else { + sqlite3_value_bytes16(value); + w = std::u16string(reinterpret_cast(sqlite3_value_text16(value))); + } + } inline database_binder& operator <<(database_binder& db, const std::u16string& txt) { @@ -587,6 +770,9 @@ namespace sqlite { ++db._inx; return db; + } + inline void store_result_in_db(sqlite3_context* db, const std::u16string& val) { + sqlite3_result_text16(db, val.data(), -1, SQLITE_TRANSIENT); } // std::optional support for NULL values #ifdef _MODERN_SQLITE_STD_OPTIONAL_SUPPORT @@ -602,13 +788,28 @@ namespace sqlite { ++db._inx; return db; } + template inline void store_result_in_db(sqlite3_context* db, const std::optional& val) { + if(val) { + store_result_in_db(db, *val); + } + sqlite3_result_null(db); + } template inline void get_col_from_db(database_binder& db, int inx, std::optional& o) { if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { o.reset(); } else { OptionalT v; - get_col_from_db(db, inx, v); + get_col_from_db(value, v); + o = std::move(v); + } + } + template inline void get_val_from_db(sqlite3_value *value, std::optional& o) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + o.reset(); + } else { + OptionalT v; + get_val_from_db(value, v); o = std::move(v); } } @@ -628,6 +829,12 @@ namespace sqlite { ++db._inx; return db; } + template inline void store_result_in_db(sqlite3_context* db, const boost::optional& val) { + if(val) { + store_result_in_db(db, *val); + } + sqlite3_result_null(db); + } template inline void get_col_from_db(database_binder& db, int inx, boost::optional& o) { if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { @@ -638,6 +845,15 @@ namespace sqlite { o = std::move(v); } } + template inline void get_val_from_db(sqlite3_value *value, boost::optional& o) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + o.reset(); + } else { + BoostOptionalT v; + get_val_from_db(value, v); + o = std::move(v); + } + } #endif // Some ppl are lazy so we have a operator for proper prep. statemant handling. @@ -646,4 +862,130 @@ namespace sqlite { // Convert the rValue binder to a reference and call first op<<, its needed for the call that creates the binder (be carefull of recursion here!) template database_binder& operator << (database_binder&& db, const T& val) { return db << val; } + namespace sql_function_binder { + template + struct AggregateCtxt { + T obj; + bool constructed = true; + }; + + template< + typename ContextType, + std::size_t Count, + typename Functions + > + inline void step( + sqlite3_context* db, + int count, + sqlite3_value** vals + ) { + auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); + if(!ctxt) return; + if(!ctxt->constructed) new(ctxt) AggregateCtxt(); + step(db, count, vals, ctxt->obj); + } + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ) { + typename std::remove_cv< + typename std::remove_reference< + typename utility::function_traits< + typename Functions::first_type + >::template argument + >::type + >::type value{}; + get_val_from_db(vals[sizeof...(Values) - 1], value); + + step(db, count, vals, std::forward(values)..., std::move(value)); + } + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ) { + static_cast(sqlite3_user_data(db))->first(std::forward(values)...); + }; + + template< + typename ContextType, + typename Functions + > + inline void final(sqlite3_context* db) { + try { + auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); + if(!ctxt) return; + if(!ctxt->constructed) new(ctxt) AggregateCtxt(); + store_result_in_db(db, + static_cast(sqlite3_user_data(db))->second(ctxt->obj)); + } catch(sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + } + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ) { + typename std::remove_cv< + typename std::remove_reference< + typename utility::function_traits::template argument + >::type + >::type value{}; + get_val_from_db(vals[sizeof...(Values)], value); + + scalar(db, count, vals, std::forward(values)..., std::move(value)); + } + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ) { + try { + store_result_in_db(db, + (*static_cast(sqlite3_user_data(db)))(std::forward(values)...)); + } catch(sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + } + } } diff --git a/tests/functions.cc b/tests/functions.cc new file mode 100644 index 00000000..f1fb073b --- /dev/null +++ b/tests/functions.cc @@ -0,0 +1,65 @@ +#include +#include +#include +#include +using namespace sqlite; +using namespace std; + +int main() +{ + try + { + database db(":memory:"); + + db.define("my_new_concat", [](std::string i, std::string j) {return i+j;}); + db.define("my_new_concat", [](std::string i, std::string j, std::string k) {return i+j+k;}); + db.define("add_integers", [](int i, int j) {return i+j;}); + std::string test1, test3; + int test2 = 0; + db << "select my_new_concat('Hello ','world!')" >> test1; + db << "select add_integers(1,1)" >> test2; + db << "select my_new_concat('a','b','c')" >> test3; + + if(test1 != "Hello world!" || test2 != 2 || test3 != "abc") { + cout << "Wrong result\n"; + exit(EXIT_FAILURE); + } + + db.define("my_count", [](int &i, int) {++i;}, [](int &i) {return i;}); + db.define("my_concat_aggregate", [](std::string &stored, std::string current) {stored += current;}, [](std::string &stored) {return stored;}); + db << "create table countable(i, s)"; + db << "insert into countable values(1, 'a')"; + db << "insert into countable values(2, 'b')"; + db << "insert into countable values(3, 'c')"; + db << "select my_count(i) from countable" >> test2; + db << "select my_concat_aggregate(s) from countable order by i" >> test3; + + if(test2 != 3 || test3 != "abc") { + cout << "Wrong result\n"; + exit(EXIT_FAILURE); + } + + db.define("tgamma", [](double i) {return std::tgamma(i);}); + db << "CREATE TABLE numbers (number INTEGER);"; + + for(auto i=0; i!=10; ++i) + db << "INSERT INTO numbers VALUES (?);" << i; + + db << "SELECT number, tgamma(number+1) FROM numbers;" >> [](double number, double factorial) { + cout << number << "! = " << factorial << '\n'; + }; + } + catch(sqlite_exception e) + { + cout << "Unexpected error " << e.what() << endl; + exit(EXIT_FAILURE); + } + catch(...) + { + cout << "Unknown error\n"; + exit(EXIT_FAILURE); + } + + cout << "OK\n"; + exit(EXIT_SUCCESS); +}