Skip to content

Commit

Permalink
[BugFix] add return type check in Java UDTF (#50615)
Browse files Browse the repository at this point in the history
Signed-off-by: stdpain <[email protected]>
(cherry picked from commit f147df2)
  • Loading branch information
stdpain authored and mergify[bot] committed Sep 3, 2024
1 parent fb9009b commit 89a5ae0
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 31 deletions.
40 changes: 27 additions & 13 deletions be/src/exprs/table_function/java_udtf_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "column/column_helper.h"
#include "column/nullable_column.h"
#include "column/vectorized_fwd.h"
#include "common/compiler_util.h"
#include "exprs/table_function/table_function.h"
#include "gutil/casts.h"
#include "jni.h"
Expand Down Expand Up @@ -137,21 +138,28 @@ std::pair<Columns, UInt32Column::Ptr> JavaUDTFFunction::process(RuntimeState* ru

std::vector<jvalue> call_stack;
std::vector<jobject> rets;
DeferOp defer = DeferOp([&]() {
// clean up arrays
for (auto& ret : rets) {
if (ret) {
env->DeleteLocalRef(ret);
}
}
});
size_t num_rows = cols[0]->size();
size_t num_cols = cols.size();
state->set_processed_rows(num_rows);

call_stack.reserve(num_cols);
rets.resize(num_rows);

// reserve 16 local refs
DeferOp defer = DeferOp([&]() {
// clean up arrays
env->PopLocalFrame(nullptr);
});
env->PushLocalFrame(num_cols * num_rows + 16);

for (int i = 0; i < num_rows; ++i) {
DeferOp defer = DeferOp([&]() {
for (int j = 0; j < num_cols; ++j) {
release_jvalue(stateUDTF->method_process()->method_desc[j + 1].is_box, call_stack[j]);
}
call_stack.clear();
});

for (int j = 0; j < num_cols; ++j) {
auto method_type = stateUDTF->method_process()->method_desc[j + 1];
jvalue val = cast_to_jvalue<true>(method_type.type, method_type.is_box, cols[j].get(), i);
Expand All @@ -160,11 +168,13 @@ std::pair<Columns, UInt32Column::Ptr> JavaUDTFFunction::process(RuntimeState* ru

rets[i] = env->CallObjectMethodA(stateUDTF->handle(), methodID, call_stack.data());

for (int j = 0; j < num_cols; ++j) {
release_jvalue(stateUDTF->method_process()->method_desc[j + 1].is_box, call_stack[j]);
if (auto jthr = helper.getEnv()->ExceptionOccurred(); jthr != nullptr) {
std::string err = fmt::format("execute UDF Function meet Exception:{}", helper.dumpExceptionString(jthr));
LOG(WARNING) << err;
helper.getEnv()->ExceptionClear();
state->set_status(Status::InternalError(err));
return std::make_pair(Columns{}, nullptr);
}

call_stack.clear();
}

// Build Return Type
Expand All @@ -185,8 +195,12 @@ std::pair<Columns, UInt32Column::Ptr> JavaUDTFFunction::process(RuntimeState* ru
for (int j = 0; j < len; ++j) {
jobject vi = env->GetObjectArrayElement((jobjectArray)rets[i], j);
LOCAL_REF_GUARD_ENV(env, vi);
auto st = check_type_matched(method_desc, vi);
if (UNLIKELY(!st.ok())) {
state->set_status(st);
return std::make_pair(Columns{}, nullptr);
}
append_jvalue(method_desc, col.get(), {.l = vi});
release_jvalue(method_desc.is_box, {.l = vi});
}
}

Expand Down
52 changes: 50 additions & 2 deletions be/src/udf/java/java_data_converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "column/type_traits.h"
#include "common/compiler_util.h"
#include "common/status.h"
#include "udf/java/java_udf.h"
#include "util/defer_op.h"

#define APPLY_FOR_NUMBERIC_TYPE(M) \
M(TYPE_BOOLEAN) \
Expand Down Expand Up @@ -161,7 +163,8 @@ void assign_jvalue(MethodTypeDescriptor method_type_desc, Column* col, int row_n
if (val.l == nullptr) {
col->append_nulls(1);
} else {
auto slice = helper.sliceVal((jstring)val.l);
std::string buffer;
auto slice = helper.sliceVal((jstring)val.l, &buffer);
col->append_datum(Datum(slice));
}
break;
Expand Down Expand Up @@ -209,7 +212,8 @@ void append_jvalue(MethodTypeDescriptor method_type_desc, Column* col, jvalue va
APPEND_BOX_TYPE(TYPE_DOUBLE, double)

case TYPE_VARCHAR: {
auto slice = helper.sliceVal((jstring)val.l);
std::string buffer;
auto slice = helper.sliceVal((jstring)val.l, &buffer);
col->append_datum(Datum(slice));
break;
}
Expand All @@ -220,6 +224,50 @@ void append_jvalue(MethodTypeDescriptor method_type_desc, Column* col, jvalue va
}
}

Status check_type_matched(MethodTypeDescriptor method_type_desc, jobject val) {
if (val == nullptr) {
return Status::OK();
}
auto& helper = JVMFunctionHelper::getInstance();
auto* env = helper.getEnv();

switch (method_type_desc.type) {
#define INSTANCE_OF_TYPE(NAME, TYPE) \
case NAME: { \
if (!env->IsInstanceOf(val, helper.TYPE##_class())) { \
auto clazz = env->GetObjectClass(val); \
LOCAL_REF_GUARD(clazz); \
return Status::InternalError(fmt::format("Type not matched, expect {}, but got {}", \
helper.to_string(helper.TYPE##_class()), \
helper.to_string(clazz))); \
} \
break; \
}
INSTANCE_OF_TYPE(TYPE_BOOLEAN, uint8_t)
INSTANCE_OF_TYPE(TYPE_TINYINT, int8_t)
INSTANCE_OF_TYPE(TYPE_SMALLINT, int16_t)
INSTANCE_OF_TYPE(TYPE_INT, int32_t)
INSTANCE_OF_TYPE(TYPE_BIGINT, int64_t)
INSTANCE_OF_TYPE(TYPE_FLOAT, float)
INSTANCE_OF_TYPE(TYPE_DOUBLE, double)
case TYPE_VARCHAR: {
std::string buffer;
if (!env->IsInstanceOf(val, helper.string_clazz())) {
auto clazz = env->GetObjectClass(val);
LOCAL_REF_GUARD(clazz);
return Status::InternalError(
fmt::format("Type not matched, expect string, but got {}", helper.to_string(clazz)));
}
break;
}
default:
DCHECK(false) << "unsupport UDF TYPE" << method_type_desc.type;
break;
}

return Status::OK();
}

Status ConvertDirectBufferVistor::do_visit(const NullableColumn& column) {
const auto& null_data = column.immutable_null_column_data();
_buffers.emplace_back((void*)null_data.data(), null_data.size());
Expand Down
1 change: 1 addition & 0 deletions be/src/udf/java/java_data_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,5 @@ template <bool handle_null>
jvalue cast_to_jvalue(LogicalType type, bool is_boxed, const Column* col, int row_num);
void release_jvalue(bool is_boxed, jvalue val);
void append_jvalue(MethodTypeDescriptor method_type_desc, Column* col, jvalue val);
Status check_type_matched(MethodTypeDescriptor method_type_desc, jobject val);
} // namespace starrocks
4 changes: 0 additions & 4 deletions be/src/udf/java/java_udf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,6 @@ Slice JVMFunctionHelper::sliceVal(jstring jstr, std::string* buffer) {
return {buffer->data(), buffer->length()};
}

Slice JVMFunctionHelper::sliceVal(jstring jstr) {
return {_env->GetStringUTFChars(jstr, nullptr)};
}

std::string JVMFunctionHelper::to_jni_class_name(const std::string& name) {
std::string jni_class_name;
auto inserter = std::inserter(jni_class_name, jni_class_name.end());
Expand Down
26 changes: 14 additions & 12 deletions be/src/udf/java/java_udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ extern "C" JNIEnv* getJNIEnv(void);
jmethodID _value_of_##TYPE; \
jmethodID _val_##TYPE;

#define DECLARE_NEW_BOX(TYPE, CLAZZ) \
jobject new##CLAZZ(TYPE value); \
TYPE val##TYPE(jobject obj);
#define DECLARE_NEW_BOX(PRIM_CLAZZ, TYPE, CLAZZ) \
jobject new##CLAZZ(TYPE value); \
TYPE val##TYPE(jobject obj); \
jclass TYPE##_class() { return _class_##PRIM_CLAZZ; }

namespace starrocks {
class DirectByteBuffer;
Expand All @@ -59,6 +60,7 @@ class JVMFunctionHelper {
// Arrays.toString()
std::string array_to_string(jobject object);
// Object::toString()
bool equals(jobject obj1, jobject obj2);
std::string to_string(jobject obj);
std::string to_cxx_string(jstring str);
std::string dumpExceptionString(jthrowable throwable);
Expand Down Expand Up @@ -117,19 +119,19 @@ class JVMFunctionHelper {
jobject list_get(jobject obj, int idx);
int list_size(jobject obj);

DECLARE_NEW_BOX(uint8_t, Boolean)
DECLARE_NEW_BOX(int8_t, Byte)
DECLARE_NEW_BOX(int16_t, Short)
DECLARE_NEW_BOX(int32_t, Integer)
DECLARE_NEW_BOX(int64_t, Long)
DECLARE_NEW_BOX(float, Float)
DECLARE_NEW_BOX(double, Double)
DECLARE_NEW_BOX(boolean, uint8_t, Boolean)
DECLARE_NEW_BOX(byte, int8_t, Byte)
DECLARE_NEW_BOX(short, int16_t, Short)
DECLARE_NEW_BOX(int, int32_t, Integer)
DECLARE_NEW_BOX(long, int64_t, Long)
DECLARE_NEW_BOX(float, float, Float)
DECLARE_NEW_BOX(double, double, Double)

jobject newString(const char* data, size_t size);

Slice sliceVal(jstring jstr);
size_t string_length(jstring jstr);

Slice sliceVal(jstring jstr, std::string* buffer);
jclass string_clazz() { return _string_class; }
// replace '.' as '/'
// eg: java.lang.Integer -> java/lang/Integer
static std::string to_jni_class_name(const std::string& name);
Expand Down
66 changes: 66 additions & 0 deletions test/sql/test_udf/R/test_jvm_udf
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,48 @@ type = "StarrocksJar"
file = "${udf_url}/starrocks-jdbc%2FSumbigint.jar";
-- result:
-- !result
CREATE TABLE FUNCTION udtfstring(string)
RETURNS string
symbol = "UDTFstring"
type = "StarrocksJar"
file = "${udf_url}/starrocks-jdbc%2FUDTFstring.jar";
-- result:
-- !result
CREATE TABLE FUNCTION udtfstring_wrong_match(string)
RETURNS int
symbol = "UDTFstring"
type = "StarrocksJar"
file = "${udf_url}/starrocks-jdbc%2FUDTFstring.jar";
-- result:
-- !result
CREATE TABLE FUNCTION udtfint(int)
RETURNS int
symbol = "UDTFint"
type = "StarrocksJar"
file = "${udf_url}/starrocks-jdbc%2FUDTFint.jar";
-- result:
-- !result
CREATE TABLE FUNCTION udtfbigint(bigint)
RETURNS bigint
symbol = "UDTFbigint"
type = "StarrocksJar"
file = "${udf_url}/starrocks-jdbc%2FUDTFbigint.jar";
-- result:
-- !result
CREATE TABLE FUNCTION udtffloat(float)
RETURNS float
symbol = "UDTFfloat"
type = "StarrocksJar"
file = "${udf_url}/starrocks-jdbc%2FUDTFfloat.jar";
-- result:
-- !result
CREATE TABLE FUNCTION udtfdouble(double)
RETURNS double
symbol = "UDTFdouble"
type = "StarrocksJar"
file = "${udf_url}/starrocks-jdbc%2FUDTFdouble.jar";
-- result:
-- !result
CREATE TABLE `t0` (
`c0` int(11) NULL COMMENT "",
`c1` varchar(20) NULL COMMENT "",
Expand All @@ -25,6 +67,30 @@ PROPERTIES (
insert into t0 SELECT generate_series, generate_series, generate_series, generate_series FROM TABLE(generate_series(1, 40960));
-- result:
-- !result
select count(udtfstring) from t0, udtfstring(c1);
-- result:
81920
-- !result
select count(udtfstring_wrong_match) from t0, udtfstring_wrong_match(c1);
-- result:
E: (1064, 'Type not matched, expect class java.lang.Integer, but got class java.lang.String')
-- !result
select count(udtfint) from t0, udtfint(c1);
-- result:
81920
-- !result
select count(udtfbigint) from t0, udtfbigint(c1);
-- result:
81920
-- !result
select count(udtffloat) from t0, udtffloat(c1);
-- result:
81920
-- !result
select count(udtfdouble) from t0, udtfdouble(c1);
-- result:
81920
-- !result
set streaming_preaggregation_mode="force_streaming";
-- result:
-- !result
Expand Down
45 changes: 45 additions & 0 deletions test/sql/test_udf/T/test_jvm_udf
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,43 @@ symbol = "Sumbigint"
type = "StarrocksJar"
file = "${udf_url}/starrocks-jdbc%2FSumbigint.jar";

CREATE TABLE FUNCTION udtfstring(string)
RETURNS string
symbol = "UDTFstring"
type = "StarrocksJar"
file = "${udf_url}/starrocks-jdbc%2FUDTFstring.jar";

CREATE TABLE FUNCTION udtfstring_wrong_match(string)
RETURNS int
symbol = "UDTFstring"
type = "StarrocksJar"
file = "${udf_url}/starrocks-jdbc%2FUDTFstring.jar";

CREATE TABLE FUNCTION udtfint(int)
RETURNS int
symbol = "UDTFint"
type = "StarrocksJar"
file = "${udf_url}/starrocks-jdbc%2FUDTFint.jar";

CREATE TABLE FUNCTION udtfbigint(bigint)
RETURNS bigint
symbol = "UDTFbigint"
type = "StarrocksJar"
file = "${udf_url}/starrocks-jdbc%2FUDTFbigint.jar";

CREATE TABLE FUNCTION udtffloat(float)
RETURNS float
symbol = "UDTFfloat"
type = "StarrocksJar"
file = "${udf_url}/starrocks-jdbc%2FUDTFfloat.jar";

CREATE TABLE FUNCTION udtfdouble(double)
RETURNS double
symbol = "UDTFdouble"
type = "StarrocksJar"
file = "${udf_url}/starrocks-jdbc%2FUDTFdouble.jar";


CREATE TABLE `t0` (
`c0` int(11) NULL COMMENT "",
`c1` varchar(20) NULL COMMENT "",
Expand All @@ -22,6 +59,14 @@ PROPERTIES (

insert into t0 SELECT generate_series, generate_series, generate_series, generate_series FROM TABLE(generate_series(1, 40960));

-- test udtf cases
select count(udtfstring) from t0, udtfstring(c1);
select count(udtfstring_wrong_match) from t0, udtfstring_wrong_match(c1);
select count(udtfint) from t0, udtfint(c1);
select count(udtfbigint) from t0, udtfbigint(c1);
select count(udtffloat) from t0, udtffloat(c1);
select count(udtfdouble) from t0, udtfdouble(c1);

-- test group by limit case:
set streaming_preaggregation_mode="force_streaming";
select sum(delta), count(*), count(delta) from (select (sum(c3) - sumbigint(c3)) as delta from t0 group by c0,c1 limit 10) tb;
Expand Down

0 comments on commit 89a5ae0

Please sign in to comment.