Skip to content

Commit

Permalink
[branch-2.1](functions) fix be crash for function random_bytes and ma…
Browse files Browse the repository at this point in the history
…rk_first/last_n (apache#36003)

pick apache#35884
  • Loading branch information
zclllyybb authored Jun 7, 2024
1 parent c794ea1 commit f751ca4
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 16 deletions.
31 changes: 15 additions & 16 deletions be/src/vec/functions/function_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -792,28 +792,28 @@ class FunctionMaskPartial : public IFunction {

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
DCHECK_GE(arguments.size(), 1);
DCHECK_LE(arguments.size(), 2);

int n = -1;
int n = -1; // means unassigned

auto res = ColumnString::create();
auto col = block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
const auto& source_column = assert_cast<const ColumnString&>(*col);

if (arguments.size() == 2) {
const auto& col = *block.get_by_position(arguments[1]).column;
// the 2nd arg is const. checked in fe.
if (col.get_int(0) < 0) [[unlikely]] {
return Status::InvalidArgument(
"function {} only accept non-negative input for 2nd argument but got {}",
name, col.get_int(0));
}
n = col.get_int(0);
} else if (arguments.size() > 2) {
return Status::InvalidArgument(
fmt::format("too many arguments for function {}", get_name()));
}

if (n == -1) {
if (n == -1) { // no 2nd arg, just mask all
FunctionMask::vector_mask(source_column, *res, FunctionMask::DEFAULT_UPPER_MASK,
FunctionMask::DEFAULT_LOWER_MASK,
FunctionMask::DEFAULT_NUMBER_MASK);
} else if (n >= 0) {
} else { // n >= 0
vector(source_column, n, *res);
}

Expand Down Expand Up @@ -2901,19 +2901,18 @@ class FunctionRandomBytes : public IFunction {

ColumnPtr argument_column =
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
const auto* length_col = check_and_get_column<ColumnInt32>(argument_column.get());

if (!length_col) {
return Status::InternalError("Not supported input argument type");
}
const auto* length_col = assert_cast<const ColumnInt32*>(argument_column.get());

std::vector<uint8_t> random_bytes;
std::random_device rd;
std::mt19937 gen(rd());

for (size_t i = 0; i < input_rows_count; ++i) {
UInt64 length = length_col->get64(i);
random_bytes.resize(length);
if (length_col->get_element(i) < 0) [[unlikely]] {
return Status::InvalidArgument("argument {} of function {} at row {} was invalid.",
length_col->get_element(i), name, i);
}
random_bytes.resize(length_col->get_element(i));

std::uniform_int_distribution<uint8_t> distribution(0, 255);
for (auto& byte : random_bytes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
Expand Down Expand Up @@ -65,6 +66,13 @@ public MaskFirstN withChildren(List<Expression> children) {
return new MaskFirstN(children.get(0), children.get(1));
}

@Override
public void checkLegalityAfterRewrite() {
if (arity() == 2 && !child(1).isLiteral()) {
throw new AnalysisException("mask_first_n must accept literal for 2nd argument");
}
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
Expand Down Expand Up @@ -65,6 +66,13 @@ public MaskLastN withChildren(List<Expression> children) {
return new MaskLastN(children.get(0), children.get(1));
}

@Override
public void checkLegalityAfterRewrite() {
if (arity() == 2 && !child(1).isLiteral()) {
throw new AnalysisException("mask_last_n must accept literal for 2nd argument");
}
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
Expand Down
17 changes: 17 additions & 0 deletions regression-test/suites/correctness_p0/test_mask_function.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,21 @@ suite("test_mask_function") {
qt_select_digital_masking """
select digital_masking(13812345678);
"""

test {
sql """ select mask_last_n("12345", -100); """
exception "function mask_last_n only accept non-negative input for 2nd argument but got -100"
}
test {
sql """ select mask_first_n("12345", -100); """
exception "function mask_first_n only accept non-negative input for 2nd argument but got -100"
}
test {
sql """ select mask_last_n("12345", id) from table_mask_test; """
exception "mask_last_n must accept literal for 2nd argument"
}
test {
sql """ select mask_first_n("12345", id) from table_mask_test; """
exception "mask_first_n must accept literal for 2nd argument"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,8 @@ suite("nereids_scalar_fn_R") {
qt_sql_rtrim_String_String_notnull "select rtrim(kstr, '1') from fn_test_not_nullable order by kstr"
sql "SELECT random_bytes(7);"
qt_sql_random_bytes "SELECT random_bytes(null);"
test {
sql " select random_bytes(-1); "
exception "argument -1 of function random_bytes at row 0 was invalid"
}
}

0 comments on commit f751ca4

Please sign in to comment.