diff --git a/be/src/exprs/table_function/java_udtf_function.cpp b/be/src/exprs/table_function/java_udtf_function.cpp index 9884ab79c1a13..33389ac60cde0 100644 --- a/be/src/exprs/table_function/java_udtf_function.cpp +++ b/be/src/exprs/table_function/java_udtf_function.cpp @@ -21,6 +21,7 @@ #include "column/column_helper.h" #include "column/nullable_column.h" #include "column/vectorized_fwd.h" +#include "common/compiler_util.h" #include "exprs/table_function/table_function.h" #include "gutil/casts.h" #include "jni.h" @@ -137,21 +138,28 @@ std::pair JavaUDTFFunction::process(RuntimeState* ru std::vector call_stack; std::vector rets; - DeferOp defer = DeferOp([&]() { - // clean up arrays - for (auto& ret : rets) { - if (ret) { - env->DeleteLocalRef(ret); - } - } - }); size_t num_rows = cols[0]->size(); size_t num_cols = cols.size(); state->set_processed_rows(num_rows); call_stack.reserve(num_cols); rets.resize(num_rows); + + // reserve 16 local refs + DeferOp defer = DeferOp([&]() { + // clean up arrays + env->PopLocalFrame(nullptr); + }); + env->PushLocalFrame(num_cols * num_rows + 16); + for (int i = 0; i < num_rows; ++i) { + DeferOp defer = DeferOp([&]() { + for (int j = 0; j < num_cols; ++j) { + release_jvalue(stateUDTF->method_process()->method_desc[j + 1].is_box, call_stack[j]); + } + call_stack.clear(); + }); + for (int j = 0; j < num_cols; ++j) { auto method_type = stateUDTF->method_process()->method_desc[j + 1]; jvalue val = cast_to_jvalue(method_type.type, method_type.is_box, cols[j].get(), i); @@ -160,11 +168,13 @@ std::pair JavaUDTFFunction::process(RuntimeState* ru rets[i] = env->CallObjectMethodA(stateUDTF->handle(), methodID, call_stack.data()); - for (int j = 0; j < num_cols; ++j) { - release_jvalue(stateUDTF->method_process()->method_desc[j + 1].is_box, call_stack[j]); + if (auto jthr = helper.getEnv()->ExceptionOccurred(); jthr != nullptr) { + std::string err = fmt::format("execute UDF Function meet Exception:{}", helper.dumpExceptionString(jthr)); + LOG(WARNING) << err; + helper.getEnv()->ExceptionClear(); + state->set_status(Status::InternalError(err)); + return std::make_pair(Columns{}, nullptr); } - - call_stack.clear(); } // Build Return Type @@ -185,8 +195,12 @@ std::pair JavaUDTFFunction::process(RuntimeState* ru for (int j = 0; j < len; ++j) { jobject vi = env->GetObjectArrayElement((jobjectArray)rets[i], j); LOCAL_REF_GUARD_ENV(env, vi); + auto st = check_type_matched(method_desc, vi); + if (UNLIKELY(!st.ok())) { + state->set_status(st); + return std::make_pair(Columns{}, nullptr); + } append_jvalue(method_desc, col.get(), {.l = vi}); - release_jvalue(method_desc.is_box, {.l = vi}); } } diff --git a/be/src/udf/java/java_data_converter.cpp b/be/src/udf/java/java_data_converter.cpp index cc28c5abfae26..7ff8539996347 100644 --- a/be/src/udf/java/java_data_converter.cpp +++ b/be/src/udf/java/java_data_converter.cpp @@ -21,6 +21,8 @@ #include "column/type_traits.h" #include "common/compiler_util.h" #include "common/status.h" +#include "udf/java/java_udf.h" +#include "util/defer_op.h" #define APPLY_FOR_NUMBERIC_TYPE(M) \ M(TYPE_BOOLEAN) \ @@ -161,7 +163,8 @@ void assign_jvalue(MethodTypeDescriptor method_type_desc, Column* col, int row_n if (val.l == nullptr) { col->append_nulls(1); } else { - auto slice = helper.sliceVal((jstring)val.l); + std::string buffer; + auto slice = helper.sliceVal((jstring)val.l, &buffer); col->append_datum(Datum(slice)); } break; @@ -209,7 +212,8 @@ void append_jvalue(MethodTypeDescriptor method_type_desc, Column* col, jvalue va APPEND_BOX_TYPE(TYPE_DOUBLE, double) case TYPE_VARCHAR: { - auto slice = helper.sliceVal((jstring)val.l); + std::string buffer; + auto slice = helper.sliceVal((jstring)val.l, &buffer); col->append_datum(Datum(slice)); break; } @@ -220,6 +224,50 @@ void append_jvalue(MethodTypeDescriptor method_type_desc, Column* col, jvalue va } } +Status check_type_matched(MethodTypeDescriptor method_type_desc, jobject val) { + if (val == nullptr) { + return Status::OK(); + } + auto& helper = JVMFunctionHelper::getInstance(); + auto* env = helper.getEnv(); + + switch (method_type_desc.type) { +#define INSTANCE_OF_TYPE(NAME, TYPE) \ + case NAME: { \ + if (!env->IsInstanceOf(val, helper.TYPE##_class())) { \ + auto clazz = env->GetObjectClass(val); \ + LOCAL_REF_GUARD(clazz); \ + return Status::InternalError(fmt::format("Type not matched, expect {}, but got {}", \ + helper.to_string(helper.TYPE##_class()), \ + helper.to_string(clazz))); \ + } \ + break; \ + } + INSTANCE_OF_TYPE(TYPE_BOOLEAN, uint8_t) + INSTANCE_OF_TYPE(TYPE_TINYINT, int8_t) + INSTANCE_OF_TYPE(TYPE_SMALLINT, int16_t) + INSTANCE_OF_TYPE(TYPE_INT, int32_t) + INSTANCE_OF_TYPE(TYPE_BIGINT, int64_t) + INSTANCE_OF_TYPE(TYPE_FLOAT, float) + INSTANCE_OF_TYPE(TYPE_DOUBLE, double) + case TYPE_VARCHAR: { + std::string buffer; + if (!env->IsInstanceOf(val, helper.string_clazz())) { + auto clazz = env->GetObjectClass(val); + LOCAL_REF_GUARD(clazz); + return Status::InternalError( + fmt::format("Type not matched, expect string, but got {}", helper.to_string(clazz))); + } + break; + } + default: + DCHECK(false) << "unsupport UDF TYPE" << method_type_desc.type; + break; + } + + return Status::OK(); +} + Status ConvertDirectBufferVistor::do_visit(const NullableColumn& column) { const auto& null_data = column.immutable_null_column_data(); _buffers.emplace_back((void*)null_data.data(), null_data.size()); diff --git a/be/src/udf/java/java_data_converter.h b/be/src/udf/java/java_data_converter.h index 1ced5c4376868..1aefce2a5a23b 100644 --- a/be/src/udf/java/java_data_converter.h +++ b/be/src/udf/java/java_data_converter.h @@ -64,4 +64,5 @@ template jvalue cast_to_jvalue(LogicalType type, bool is_boxed, const Column* col, int row_num); void release_jvalue(bool is_boxed, jvalue val); void append_jvalue(MethodTypeDescriptor method_type_desc, Column* col, jvalue val); +Status check_type_matched(MethodTypeDescriptor method_type_desc, jobject val); } // namespace starrocks diff --git a/be/src/udf/java/java_udf.cpp b/be/src/udf/java/java_udf.cpp index 07331f2802bcc..07bf9c9ec1402 100644 --- a/be/src/udf/java/java_udf.cpp +++ b/be/src/udf/java/java_udf.cpp @@ -417,10 +417,6 @@ Slice JVMFunctionHelper::sliceVal(jstring jstr, std::string* buffer) { return {buffer->data(), buffer->length()}; } -Slice JVMFunctionHelper::sliceVal(jstring jstr) { - return {_env->GetStringUTFChars(jstr, nullptr)}; -} - std::string JVMFunctionHelper::to_jni_class_name(const std::string& name) { std::string jni_class_name; auto inserter = std::inserter(jni_class_name, jni_class_name.end()); diff --git a/be/src/udf/java/java_udf.h b/be/src/udf/java/java_udf.h index 1fcc69682be9f..a4f2e25108368 100644 --- a/be/src/udf/java/java_udf.h +++ b/be/src/udf/java/java_udf.h @@ -36,9 +36,10 @@ extern "C" JNIEnv* getJNIEnv(void); jmethodID _value_of_##TYPE; \ jmethodID _val_##TYPE; -#define DECLARE_NEW_BOX(TYPE, CLAZZ) \ - jobject new##CLAZZ(TYPE value); \ - TYPE val##TYPE(jobject obj); +#define DECLARE_NEW_BOX(PRIM_CLAZZ, TYPE, CLAZZ) \ + jobject new##CLAZZ(TYPE value); \ + TYPE val##TYPE(jobject obj); \ + jclass TYPE##_class() { return _class_##PRIM_CLAZZ; } namespace starrocks { class DirectByteBuffer; @@ -58,6 +59,7 @@ class JVMFunctionHelper { // Arrays.toString() std::string array_to_string(jobject object); // Object::toString() + bool equals(jobject obj1, jobject obj2); std::string to_string(jobject obj); std::string to_cxx_string(jstring str); std::string dumpExceptionString(jthrowable throwable); @@ -116,19 +118,19 @@ class JVMFunctionHelper { jobject list_get(jobject obj, int idx); int list_size(jobject obj); - DECLARE_NEW_BOX(uint8_t, Boolean) - DECLARE_NEW_BOX(int8_t, Byte) - DECLARE_NEW_BOX(int16_t, Short) - DECLARE_NEW_BOX(int32_t, Integer) - DECLARE_NEW_BOX(int64_t, Long) - DECLARE_NEW_BOX(float, Float) - DECLARE_NEW_BOX(double, Double) + DECLARE_NEW_BOX(boolean, uint8_t, Boolean) + DECLARE_NEW_BOX(byte, int8_t, Byte) + DECLARE_NEW_BOX(short, int16_t, Short) + DECLARE_NEW_BOX(int, int32_t, Integer) + DECLARE_NEW_BOX(long, int64_t, Long) + DECLARE_NEW_BOX(float, float, Float) + DECLARE_NEW_BOX(double, double, Double) jobject newString(const char* data, size_t size); - - Slice sliceVal(jstring jstr); size_t string_length(jstring jstr); + Slice sliceVal(jstring jstr, std::string* buffer); + jclass string_clazz() { return _string_class; } // replace '.' as '/' // eg: java.lang.Integer -> java/lang/Integer static std::string to_jni_class_name(const std::string& name); diff --git a/test/sql/test_udf/R/test_jvm_udf b/test/sql/test_udf/R/test_jvm_udf index 75d8b184bb535..8785cbcd01b84 100644 --- a/test/sql/test_udf/R/test_jvm_udf +++ b/test/sql/test_udf/R/test_jvm_udf @@ -8,6 +8,48 @@ type = "StarrocksJar" file = "${udf_url}/starrocks-jdbc%2FSumbigint.jar"; -- result: -- !result +CREATE TABLE FUNCTION udtfstring(string) +RETURNS string +symbol = "UDTFstring" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFstring.jar"; +-- result: +-- !result +CREATE TABLE FUNCTION udtfstring_wrong_match(string) +RETURNS int +symbol = "UDTFstring" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFstring.jar"; +-- result: +-- !result +CREATE TABLE FUNCTION udtfint(int) +RETURNS int +symbol = "UDTFint" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFint.jar"; +-- result: +-- !result +CREATE TABLE FUNCTION udtfbigint(bigint) +RETURNS bigint +symbol = "UDTFbigint" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFbigint.jar"; +-- result: +-- !result +CREATE TABLE FUNCTION udtffloat(float) +RETURNS float +symbol = "UDTFfloat" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFfloat.jar"; +-- result: +-- !result +CREATE TABLE FUNCTION udtfdouble(double) +RETURNS double +symbol = "UDTFdouble" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFdouble.jar"; +-- result: +-- !result CREATE TABLE `t0` ( `c0` int(11) NULL COMMENT "", `c1` varchar(20) NULL COMMENT "", @@ -24,6 +66,30 @@ PROPERTIES ( insert into t0 SELECT generate_series, generate_series, generate_series, generate_series FROM TABLE(generate_series(1, 40960)); -- result: -- !result +select count(udtfstring) from t0, udtfstring(c1); +-- result: +81920 +-- !result +select count(udtfstring_wrong_match) from t0, udtfstring_wrong_match(c1); +-- result: +E: (1064, 'Type not matched, expect class java.lang.Integer, but got class java.lang.String') +-- !result +select count(udtfint) from t0, udtfint(c1); +-- result: +81920 +-- !result +select count(udtfbigint) from t0, udtfbigint(c1); +-- result: +81920 +-- !result +select count(udtffloat) from t0, udtffloat(c1); +-- result: +81920 +-- !result +select count(udtfdouble) from t0, udtfdouble(c1); +-- result: +81920 +-- !result set streaming_preaggregation_mode="force_streaming"; -- result: -- !result diff --git a/test/sql/test_udf/T/test_jvm_udf b/test/sql/test_udf/T/test_jvm_udf index 5243c5750ea16..d122aed5ba60a 100644 --- a/test/sql/test_udf/T/test_jvm_udf +++ b/test/sql/test_udf/T/test_jvm_udf @@ -6,6 +6,43 @@ symbol = "Sumbigint" type = "StarrocksJar" file = "${udf_url}/starrocks-jdbc%2FSumbigint.jar"; +CREATE TABLE FUNCTION udtfstring(string) +RETURNS string +symbol = "UDTFstring" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFstring.jar"; + +CREATE TABLE FUNCTION udtfstring_wrong_match(string) +RETURNS int +symbol = "UDTFstring" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFstring.jar"; + +CREATE TABLE FUNCTION udtfint(int) +RETURNS int +symbol = "UDTFint" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFint.jar"; + +CREATE TABLE FUNCTION udtfbigint(bigint) +RETURNS bigint +symbol = "UDTFbigint" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFbigint.jar"; + +CREATE TABLE FUNCTION udtffloat(float) +RETURNS float +symbol = "UDTFfloat" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFfloat.jar"; + +CREATE TABLE FUNCTION udtfdouble(double) +RETURNS double +symbol = "UDTFdouble" +type = "StarrocksJar" +file = "${udf_url}/starrocks-jdbc%2FUDTFdouble.jar"; + + CREATE TABLE `t0` ( `c0` int(11) NULL COMMENT "", `c1` varchar(20) NULL COMMENT "", @@ -20,6 +57,14 @@ PROPERTIES ( insert into t0 SELECT generate_series, generate_series, generate_series, generate_series FROM TABLE(generate_series(1, 40960)); +-- test udtf cases +select count(udtfstring) from t0, udtfstring(c1); +select count(udtfstring_wrong_match) from t0, udtfstring_wrong_match(c1); +select count(udtfint) from t0, udtfint(c1); +select count(udtfbigint) from t0, udtfbigint(c1); +select count(udtffloat) from t0, udtffloat(c1); +select count(udtfdouble) from t0, udtfdouble(c1); + -- test group by limit case: set streaming_preaggregation_mode="force_streaming"; select sum(delta), count(*), count(delta) from (select (sum(c3) - sumbigint(c3)) as delta from t0 group by c0,c1 limit 10) tb;