From 85a7d023082cd77cbb8b697623efbb23c0e7d08a Mon Sep 17 00:00:00 2001 From: stdpain Date: Tue, 3 Sep 2024 11:43:41 +0800 Subject: [PATCH 1/3] [BugFix] add return type check in UDTF Signed-off-by: stdpain --- .../table_function/java_udtf_function.cpp | 6 ++ be/src/udf/java/java_data_converter.cpp | 44 ++++++++++++- be/src/udf/java/java_data_converter.h | 1 + be/src/udf/java/java_udf.cpp | 19 ++++-- be/src/udf/java/java_udf.h | 27 ++++---- test/sql/test_udf/R/test_jvm_udf | 66 +++++++++++++++++++ test/sql/test_udf/T/test_jvm_udf | 45 +++++++++++++ 7 files changed, 190 insertions(+), 18 deletions(-) diff --git a/be/src/exprs/table_function/java_udtf_function.cpp b/be/src/exprs/table_function/java_udtf_function.cpp index 9884ab79c1a13..56d3f73232229 100644 --- a/be/src/exprs/table_function/java_udtf_function.cpp +++ b/be/src/exprs/table_function/java_udtf_function.cpp @@ -21,6 +21,7 @@ #include "column/column_helper.h" #include "column/nullable_column.h" #include "column/vectorized_fwd.h" +#include "common/compiler_util.h" #include "exprs/table_function/table_function.h" #include "gutil/casts.h" #include "jni.h" @@ -185,6 +186,11 @@ std::pair JavaUDTFFunction::process(RuntimeState* ru for (int j = 0; j < len; ++j) { jobject vi = env->GetObjectArrayElement((jobjectArray)rets[i], j); LOCAL_REF_GUARD_ENV(env, vi); + auto st = check_type_matched(method_desc, vi); + if (UNLIKELY(!st.ok())) { + state->set_status(st); + return std::make_pair(Columns{}, nullptr); + } append_jvalue(method_desc, col.get(), {.l = vi}); release_jvalue(method_desc.is_box, {.l = vi}); } diff --git a/be/src/udf/java/java_data_converter.cpp b/be/src/udf/java/java_data_converter.cpp index b4030e60ed649..95e92ecade71a 100644 --- a/be/src/udf/java/java_data_converter.cpp +++ b/be/src/udf/java/java_data_converter.cpp @@ -21,6 +21,7 @@ #include "column/type_traits.h" #include "common/compiler_util.h" #include "common/status.h" +#include "util/defer_op.h" #define APPLY_FOR_NUMBERIC_TYPE(M) \ M(TYPE_BOOLEAN) \ @@ -161,7 +162,8 @@ void assign_jvalue(MethodTypeDescriptor method_type_desc, Column* col, int row_n if (val.l == nullptr) { col->append_nulls(1); } else { - auto slice = helper.sliceVal((jstring)val.l); + std::string buffer; + auto slice = helper.sliceVal((jstring)val.l, &buffer); col->append_datum(Datum(slice)); } break; @@ -209,7 +211,8 @@ void append_jvalue(MethodTypeDescriptor method_type_desc, Column* col, jvalue va APPEND_BOX_TYPE(TYPE_DOUBLE, double) case TYPE_VARCHAR: { - auto slice = helper.sliceVal((jstring)val.l); + std::string buffer; + auto slice = helper.sliceVal((jstring)val.l, &buffer); col->append_datum(Datum(slice)); break; } @@ -220,6 +223,43 @@ void append_jvalue(MethodTypeDescriptor method_type_desc, Column* col, jvalue va } } +Status check_type_matched(MethodTypeDescriptor method_type_desc, jobject val) { + if (val == nullptr) { + return Status::OK(); + } + auto& helper = JVMFunctionHelper::getInstance(); + auto* env = helper.getEnv(); + + switch (method_type_desc.type) { +#define INSTANCE_OF_TYPE(NAME, TYPE) \ + case NAME: { \ + if (!env->IsInstanceOf(val, helper.TYPE##_class())) { \ + return Status::InternalError("Type not matched"); \ + } \ + break; \ + } + INSTANCE_OF_TYPE(TYPE_BOOLEAN, uint8_t) + INSTANCE_OF_TYPE(TYPE_TINYINT, int8_t) + INSTANCE_OF_TYPE(TYPE_SMALLINT, int16_t) + INSTANCE_OF_TYPE(TYPE_INT, int32_t) + INSTANCE_OF_TYPE(TYPE_BIGINT, int64_t) + INSTANCE_OF_TYPE(TYPE_FLOAT, float) + INSTANCE_OF_TYPE(TYPE_DOUBLE, double) + case TYPE_VARCHAR: { + std::string buffer; + if (!env->IsInstanceOf(val, helper.string_clazz())) { + return Status::InternalError("Type not matched"); + } + break; + } + default: + DCHECK(false) << "unsupport UDF TYPE" << method_type_desc.type; + break; + } + + return Status::OK(); +} + Status ConvertDirectBufferVistor::do_visit(const NullableColumn& column) { const auto& null_data = column.immutable_null_column_data(); _buffers.emplace_back((void*)null_data.data(), null_data.size()); diff --git a/be/src/udf/java/java_data_converter.h b/be/src/udf/java/java_data_converter.h index a54818f096e75..e94f45d318952 100644 --- a/be/src/udf/java/java_data_converter.h +++ b/be/src/udf/java/java_data_converter.h @@ -63,4 +63,5 @@ template jvalue cast_to_jvalue(LogicalType type, bool is_boxed, const Column* col, int row_num); void release_jvalue(bool is_boxed, jvalue val); void append_jvalue(MethodTypeDescriptor method_type_desc, Column* col, jvalue val); +Status check_type_matched(MethodTypeDescriptor method_type_desc, jobject val); } // namespace starrocks diff --git a/be/src/udf/java/java_udf.cpp b/be/src/udf/java/java_udf.cpp index 77c9270b0dbaa..8df8a5d0c2914 100644 --- a/be/src/udf/java/java_udf.cpp +++ b/be/src/udf/java/java_udf.cpp @@ -102,6 +102,8 @@ void JVMFunctionHelper::_init() { CHECK(_list_class); CHECK(_exception_util_class); + _equals = _env->GetMethodID(_object_class, "equals", "(Ljava/lang/Object;)Z"); + ADD_NUMBERIC_CLASS(boolean, Boolean, Z); ADD_NUMBERIC_CLASS(byte, Byte, B); ADD_NUMBERIC_CLASS(short, Short, S); @@ -220,6 +222,19 @@ std::string JVMFunctionHelper::array_to_string(jobject object) { return value; } +bool JVMFunctionHelper::equals(jobject obj1, jobject obj2) { + if (obj1 == obj2) { + return true; + } + if (obj1 == nullptr) { + return false; + } + _env->ExceptionClear(); + bool res = _env->CallBooleanMethod(obj1, _equals, obj2); + CHECK_FUNCTION_EXCEPTION(_env, "equals") + return res; +} + std::string JVMFunctionHelper::to_string(jobject obj) { _env->ExceptionClear(); std::string value; @@ -417,10 +432,6 @@ Slice JVMFunctionHelper::sliceVal(jstring jstr, std::string* buffer) { return {buffer->data(), buffer->length()}; } -Slice JVMFunctionHelper::sliceVal(jstring jstr) { - return {_env->GetStringUTFChars(jstr, nullptr)}; -} - std::string JVMFunctionHelper::to_jni_class_name(const std::string& name) { std::string jni_class_name; auto inserter = std::inserter(jni_class_name, jni_class_name.end()); diff --git a/be/src/udf/java/java_udf.h b/be/src/udf/java/java_udf.h index cda2ebd47bf69..749f6be13706b 100644 --- a/be/src/udf/java/java_udf.h +++ b/be/src/udf/java/java_udf.h @@ -37,9 +37,10 @@ extern "C" JNIEnv* getJNIEnv(void); jmethodID _value_of_##TYPE; \ jmethodID _val_##TYPE; -#define DECLARE_NEW_BOX(TYPE, CLAZZ) \ - jobject new##CLAZZ(TYPE value); \ - TYPE val##TYPE(jobject obj); +#define DECLARE_NEW_BOX(PRIM_CLAZZ, TYPE, CLAZZ) \ + jobject new##CLAZZ(TYPE value); \ + TYPE val##TYPE(jobject obj); \ + jclass TYPE##_class() { return _class_##PRIM_CLAZZ; } namespace starrocks { class DirectByteBuffer; @@ -59,6 +60,7 @@ class JVMFunctionHelper { // Arrays.toString() std::string array_to_string(jobject object); // Object::toString() + bool equals(jobject obj1, jobject obj2); std::string to_string(jobject obj); std::string to_cxx_string(jstring str); std::string dumpExceptionString(jthrowable throwable); @@ -117,19 +119,19 @@ class JVMFunctionHelper { jobject list_get(jobject obj, int idx); int list_size(jobject obj); - DECLARE_NEW_BOX(uint8_t, Boolean) - DECLARE_NEW_BOX(int8_t, Byte) - DECLARE_NEW_BOX(int16_t, Short) - DECLARE_NEW_BOX(int32_t, Integer) - DECLARE_NEW_BOX(int64_t, Long) - DECLARE_NEW_BOX(float, Float) - DECLARE_NEW_BOX(double, Double) + DECLARE_NEW_BOX(boolean, uint8_t, Boolean) + DECLARE_NEW_BOX(byte, int8_t, Byte) + DECLARE_NEW_BOX(short, int16_t, Short) + DECLARE_NEW_BOX(int, int32_t, Integer) + DECLARE_NEW_BOX(long, int64_t, Long) + DECLARE_NEW_BOX(float, float, Float) + DECLARE_NEW_BOX(double, double, Double) jobject newString(const char* data, size_t size); - - Slice sliceVal(jstring jstr); size_t string_length(jstring jstr); + Slice sliceVal(jstring jstr, std::string* buffer); + jclass string_clazz() { return _string_class; } // replace '.' as '/' // eg: java.lang.Integer -> java/lang/Integer static std::string to_jni_class_name(const std::string& name); @@ -174,6 +176,7 @@ class JVMFunctionHelper { jobject _utf8_charsets; jclass _udf_helper_class; + jmethodID _equals; jmethodID _create_boxed_array; jmethodID _batch_update; jmethodID _batch_update_if_not_null; diff --git a/test/sql/test_udf/R/test_jvm_udf b/test/sql/test_udf/R/test_jvm_udf index e800f6223e5fc..51b233b9a7be8 100644 --- a/test/sql/test_udf/R/test_jvm_udf +++ b/test/sql/test_udf/R/test_jvm_udf @@ -9,6 +9,48 @@ type = "StarrocksJar" file = "${udf_url}/starrocks-jdbc%2FSumbigint.jar"; -- result: -- !result +CREATE TABLE FUNCTION udtfstring(string) +RETURNS string +symbol = "UDTFstring" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFstring.jar"; +-- result: +-- !result +CREATE TABLE FUNCTION udtfstring_wrong_match(string) +RETURNS int +symbol = "UDTFstring" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFstring.jar"; +-- result: +-- !result +CREATE TABLE FUNCTION udtfint(int) +RETURNS int +symbol = "UDTFint" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFint.jar"; +-- result: +-- !result +CREATE TABLE FUNCTION udtfbigint(bigint) +RETURNS bigint +symbol = "UDTFbigint" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFbigint.jar"; +-- result: +-- !result +CREATE TABLE FUNCTION udtffloat(float) +RETURNS float +symbol = "UDTFfloat" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFfloat.jar"; +-- result: +-- !result +CREATE TABLE FUNCTION udtfdouble(double) +RETURNS double +symbol = "UDTFdouble" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFdouble.jar"; +-- result: +-- !result CREATE TABLE `t0` ( `c0` int(11) NULL COMMENT "", `c1` varchar(20) NULL COMMENT "", @@ -25,6 +67,30 @@ PROPERTIES ( insert into t0 SELECT generate_series, generate_series, generate_series, generate_series FROM TABLE(generate_series(1, 40960)); -- result: -- !result +select count(udtfstring) from t0, udtfstring(c1); +-- result: +81920 +-- !result +select count(udtfstring_wrong_match) from t0, udtfstring_wrong_match(c1); +-- result: +E: (1064, 'Type not matched') +-- !result +select count(udtfint) from t0, udtfint(c1); +-- result: +81920 +-- !result +select count(udtfbigint) from t0, udtfbigint(c1); +-- result: +81920 +-- !result +select count(udtffloat) from t0, udtffloat(c1); +-- result: +81920 +-- !result +select count(udtfdouble) from t0, udtfdouble(c1); +-- result: +81920 +-- !result set streaming_preaggregation_mode="force_streaming"; -- result: -- !result diff --git a/test/sql/test_udf/T/test_jvm_udf b/test/sql/test_udf/T/test_jvm_udf index 4ffb61fcd769e..5fde82d1a0dd0 100644 --- a/test/sql/test_udf/T/test_jvm_udf +++ b/test/sql/test_udf/T/test_jvm_udf @@ -8,6 +8,43 @@ symbol = "Sumbigint" type = "StarrocksJar" file = "${udf_url}/starrocks-jdbc%2FSumbigint.jar"; +CREATE TABLE FUNCTION udtfstring(string) +RETURNS string +symbol = "UDTFstring" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFstring.jar"; + +CREATE TABLE FUNCTION udtfstring_wrong_match(string) +RETURNS int +symbol = "UDTFstring" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFstring.jar"; + +CREATE TABLE FUNCTION udtfint(int) +RETURNS int +symbol = "UDTFint" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFint.jar"; + +CREATE TABLE FUNCTION udtfbigint(bigint) +RETURNS bigint +symbol = "UDTFbigint" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFbigint.jar"; + +CREATE TABLE FUNCTION udtffloat(float) +RETURNS float +symbol = "UDTFfloat" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFfloat.jar"; + +CREATE TABLE FUNCTION udtfdouble(double) +RETURNS double +symbol = "UDTFdouble" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFdouble.jar"; + + CREATE TABLE `t0` ( `c0` int(11) NULL COMMENT "", `c1` varchar(20) NULL COMMENT "", @@ -22,6 +59,14 @@ PROPERTIES ( insert into t0 SELECT generate_series, generate_series, generate_series, generate_series FROM TABLE(generate_series(1, 40960)); +-- test udtf cases +select count(udtfstring) from t0, udtfstring(c1); +select count(udtfstring_wrong_match) from t0, udtfstring_wrong_match(c1); +select count(udtfint) from t0, udtfint(c1); +select count(udtfbigint) from t0, udtfbigint(c1); +select count(udtffloat) from t0, udtffloat(c1); +select count(udtfdouble) from t0, udtfdouble(c1); + -- test group by limit case: set streaming_preaggregation_mode="force_streaming"; select sum(delta), count(*), count(delta) from (select (sum(c3) - sumbigint(c3)) as delta from t0 group by c0,c1 limit 10) tb; From 3c8e1443c5965123401629a58f684fad1d3888cd Mon Sep 17 00:00:00 2001 From: stdpain Date: Tue, 3 Sep 2024 14:08:52 +0800 Subject: [PATCH 2/3] remove unused code equals Signed-off-by: stdpain --- be/src/udf/java/java_udf.cpp | 15 --------------- be/src/udf/java/java_udf.h | 1 - 2 files changed, 16 deletions(-) diff --git a/be/src/udf/java/java_udf.cpp b/be/src/udf/java/java_udf.cpp index 8df8a5d0c2914..be918425ebd12 100644 --- a/be/src/udf/java/java_udf.cpp +++ b/be/src/udf/java/java_udf.cpp @@ -102,8 +102,6 @@ void JVMFunctionHelper::_init() { CHECK(_list_class); CHECK(_exception_util_class); - _equals = _env->GetMethodID(_object_class, "equals", "(Ljava/lang/Object;)Z"); - ADD_NUMBERIC_CLASS(boolean, Boolean, Z); ADD_NUMBERIC_CLASS(byte, Byte, B); ADD_NUMBERIC_CLASS(short, Short, S); @@ -222,19 +220,6 @@ std::string JVMFunctionHelper::array_to_string(jobject object) { return value; } -bool JVMFunctionHelper::equals(jobject obj1, jobject obj2) { - if (obj1 == obj2) { - return true; - } - if (obj1 == nullptr) { - return false; - } - _env->ExceptionClear(); - bool res = _env->CallBooleanMethod(obj1, _equals, obj2); - CHECK_FUNCTION_EXCEPTION(_env, "equals") - return res; -} - std::string JVMFunctionHelper::to_string(jobject obj) { _env->ExceptionClear(); std::string value; diff --git a/be/src/udf/java/java_udf.h b/be/src/udf/java/java_udf.h index 749f6be13706b..078b3ad423b31 100644 --- a/be/src/udf/java/java_udf.h +++ b/be/src/udf/java/java_udf.h @@ -176,7 +176,6 @@ class JVMFunctionHelper { jobject _utf8_charsets; jclass _udf_helper_class; - jmethodID _equals; jmethodID _create_boxed_array; jmethodID _batch_update; jmethodID _batch_update_if_not_null; From 381bdf359956a9b620ec1645d8b345ee107ccc57 Mon Sep 17 00:00:00 2001 From: stdpain Date: Tue, 3 Sep 2024 17:05:38 +0800 Subject: [PATCH 3/3] fix comments Signed-off-by: stdpain --- .../table_function/java_udtf_function.cpp | 34 ++++++++++++------- be/src/udf/java/java_data_converter.cpp | 22 ++++++++---- test/sql/test_udf/R/test_jvm_udf | 2 +- 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/be/src/exprs/table_function/java_udtf_function.cpp b/be/src/exprs/table_function/java_udtf_function.cpp index 56d3f73232229..33389ac60cde0 100644 --- a/be/src/exprs/table_function/java_udtf_function.cpp +++ b/be/src/exprs/table_function/java_udtf_function.cpp @@ -138,21 +138,28 @@ std::pair JavaUDTFFunction::process(RuntimeState* ru std::vector call_stack; std::vector rets; - DeferOp defer = DeferOp([&]() { - // clean up arrays - for (auto& ret : rets) { - if (ret) { - env->DeleteLocalRef(ret); - } - } - }); size_t num_rows = cols[0]->size(); size_t num_cols = cols.size(); state->set_processed_rows(num_rows); call_stack.reserve(num_cols); rets.resize(num_rows); + + // reserve 16 local refs + DeferOp defer = DeferOp([&]() { + // clean up arrays + env->PopLocalFrame(nullptr); + }); + env->PushLocalFrame(num_cols * num_rows + 16); + for (int i = 0; i < num_rows; ++i) { + DeferOp defer = DeferOp([&]() { + for (int j = 0; j < num_cols; ++j) { + release_jvalue(stateUDTF->method_process()->method_desc[j + 1].is_box, call_stack[j]); + } + call_stack.clear(); + }); + for (int j = 0; j < num_cols; ++j) { auto method_type = stateUDTF->method_process()->method_desc[j + 1]; jvalue val = cast_to_jvalue(method_type.type, method_type.is_box, cols[j].get(), i); @@ -161,11 +168,13 @@ std::pair JavaUDTFFunction::process(RuntimeState* ru rets[i] = env->CallObjectMethodA(stateUDTF->handle(), methodID, call_stack.data()); - for (int j = 0; j < num_cols; ++j) { - release_jvalue(stateUDTF->method_process()->method_desc[j + 1].is_box, call_stack[j]); + if (auto jthr = helper.getEnv()->ExceptionOccurred(); jthr != nullptr) { + std::string err = fmt::format("execute UDF Function meet Exception:{}", helper.dumpExceptionString(jthr)); + LOG(WARNING) << err; + helper.getEnv()->ExceptionClear(); + state->set_status(Status::InternalError(err)); + return std::make_pair(Columns{}, nullptr); } - - call_stack.clear(); } // Build Return Type @@ -192,7 +201,6 @@ std::pair JavaUDTFFunction::process(RuntimeState* ru return std::make_pair(Columns{}, nullptr); } append_jvalue(method_desc, col.get(), {.l = vi}); - release_jvalue(method_desc.is_box, {.l = vi}); } } diff --git a/be/src/udf/java/java_data_converter.cpp b/be/src/udf/java/java_data_converter.cpp index 95e92ecade71a..d6a4ef4fb41ea 100644 --- a/be/src/udf/java/java_data_converter.cpp +++ b/be/src/udf/java/java_data_converter.cpp @@ -21,6 +21,7 @@ #include "column/type_traits.h" #include "common/compiler_util.h" #include "common/status.h" +#include "udf/java/java_udf.h" #include "util/defer_op.h" #define APPLY_FOR_NUMBERIC_TYPE(M) \ @@ -231,12 +232,16 @@ Status check_type_matched(MethodTypeDescriptor method_type_desc, jobject val) { auto* env = helper.getEnv(); switch (method_type_desc.type) { -#define INSTANCE_OF_TYPE(NAME, TYPE) \ - case NAME: { \ - if (!env->IsInstanceOf(val, helper.TYPE##_class())) { \ - return Status::InternalError("Type not matched"); \ - } \ - break; \ +#define INSTANCE_OF_TYPE(NAME, TYPE) \ + case NAME: { \ + if (!env->IsInstanceOf(val, helper.TYPE##_class())) { \ + auto clazz = env->GetObjectClass(val); \ + LOCAL_REF_GUARD(clazz); \ + return Status::InternalError(fmt::format("Type not matched, expect {}, but got {}", \ + helper.to_string(helper.TYPE##_class()), \ + helper.to_string(clazz))); \ + } \ + break; \ } INSTANCE_OF_TYPE(TYPE_BOOLEAN, uint8_t) INSTANCE_OF_TYPE(TYPE_TINYINT, int8_t) @@ -248,7 +253,10 @@ Status check_type_matched(MethodTypeDescriptor method_type_desc, jobject val) { case TYPE_VARCHAR: { std::string buffer; if (!env->IsInstanceOf(val, helper.string_clazz())) { - return Status::InternalError("Type not matched"); + auto clazz = env->GetObjectClass(val); + LOCAL_REF_GUARD(clazz); + return Status::InternalError( + fmt::format("Type not matched, expect string, but got {}", helper.to_string(clazz))); } break; } diff --git a/test/sql/test_udf/R/test_jvm_udf b/test/sql/test_udf/R/test_jvm_udf index 51b233b9a7be8..f05c5e1785b62 100644 --- a/test/sql/test_udf/R/test_jvm_udf +++ b/test/sql/test_udf/R/test_jvm_udf @@ -73,7 +73,7 @@ select count(udtfstring) from t0, udtfstring(c1); -- !result select count(udtfstring_wrong_match) from t0, udtfstring_wrong_match(c1); -- result: -E: (1064, 'Type not matched') +E: (1064, 'Type not matched, expect class java.lang.Integer, but got class java.lang.String') -- !result select count(udtfint) from t0, udtfint(c1); -- result: