Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Change the intermediate type of collect_list to VARBINARY #12121

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 238 additions & 60 deletions velox/functions/sparksql/aggregates/CollectListAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,88 +18,258 @@

#include "velox/exec/SimpleAggregateAdapter.h"
#include "velox/functions/lib/aggregates/ValueList.h"
#include "velox/row/UnsafeRowDeserializers.h"
#include "velox/row/UnsafeRowFast.h"

using namespace facebook::velox::aggregate;
using namespace facebook::velox::exec;

namespace facebook::velox::functions::aggregate::sparksql {
namespace {
class CollectListAggregate {
struct ArrayAccumulator {
ValueList elements;
};

class CollectListAggregate : public exec::Aggregate {
public:
using InputType = Row<Generic<T1>>;
CollectListAggregate(TypePtr resultType, std::optional<TypePtr> elementType)
: Aggregate(resultType), elementType_(elementType) {}

using IntermediateType = Array<Generic<T1>>;
int32_t accumulatorFixedWidthSize() const override {
return sizeof(ArrayAccumulator);
}

using OutputType = Array<Generic<T1>>;
bool isFixedSize() const override {
return false;
}

/// In Spark, when all inputs are null, the output is an empty array instead
/// of null. Therefore, in the writeIntermediateResult and writeFinalResult,
/// we still need to output the empty element_ when the group is null. This
/// behavior can only be achieved when the default-null behavior is disabled.
static constexpr bool default_null_behavior_ = false;
void extractValues(char** groups, int32_t numGroups, VectorPtr* result)
override {
auto vector = (*result)->as<ArrayVector>();
VELOX_CHECK(vector);
vector->resize(numGroups);

static bool toIntermediate(
exec::out_type<Array<Generic<T1>>>& out,
exec::optional_arg_type<Generic<T1>> in) {
if (in.has_value()) {
out.add_item().copy_from(in.value());
return true;
auto elements = vector->elements();
elements->resize(countElements(groups, numGroups));

uint64_t* rawNulls = getRawNulls(vector);
vector_size_t offset = 0;
for (int32_t i = 0; i < numGroups; ++i) {
auto& values = value<ArrayAccumulator>(groups[i])->elements;
auto arraySize = values.size();
// If the group's accumulator is null, the corresponding result is an
// empty array rather than null.
clearNull(rawNulls, i);
if (arraySize) {
ValueListReader reader(values);
for (auto index = 0; index < arraySize; ++index) {
reader.next(*elements, offset + index);
}
}
vector->setOffsetAndSize(i, offset, arraySize);
offset += arraySize;
}
return false;
}

struct AccumulatorType {
ValueList elements_;
void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result)
override {
VELOX_CHECK(result);
auto vector = BaseVector::create(
ROW({"data", "type"}, {ARRAY(elementType_.value()), VARCHAR()}),
numGroups,
allocator_->pool());
auto rowVector = vector->asChecked<RowVector>();
auto arrayVector = rowVector->childAt(0)->asChecked<ArrayVector>();
auto flatVector =
rowVector->childAt(1)->asChecked<FlatVector<StringView>>();
auto elements = arrayVector->elements();
elements->resize(countElements(groups, numGroups));

explicit AccumulatorType(HashStringAllocator* /*allocator*/)
: elements_{} {}
auto typeStr = folly::toJson(elementType_.value()->serialize());
vector_size_t offset = 0;
for (int32_t i = 0; i < numGroups; ++i) {
flatVector->set(i, StringView(typeStr));

static constexpr bool is_fixed_size_ = false;

bool addInput(
HashStringAllocator* allocator,
exec::optional_arg_type<Generic<T1>> data) {
if (data.has_value()) {
elements_.appendValue(data, allocator);
return true;
// To align with Spark's intermediate data, if the group's accumulator is
// null, the corresponding result is an empty array.
auto& values = value<ArrayAccumulator>(groups[i])->elements;
auto arraySize = values.size();
arrayVector->setNull(i, false);
if (arraySize) {
ValueListReader reader(values);
for (auto index = 0; index < arraySize; ++index) {
reader.next(*elements, offset + index);
}
}
return false;
arrayVector->setOffsetAndSize(i, offset, arraySize);
offset += arraySize;
}

auto flatResult = (*result)->asUnchecked<FlatVector<StringView>>();
flatResult->resize(numGroups);
RowVectorPtr rowVectorPtr = std::dynamic_pointer_cast<RowVector>(vector);
auto serializer = std::make_unique<row::UnsafeRowFast>(rowVectorPtr);
size_t totalSize = 0;
for (vector_size_t i = 0; i < numGroups; ++i) {
int32_t rowSize = serializer->rowSize(i);
totalSize += rowSize;
}

char* rawBuffer = flatResult->getRawStringBufferWithSpace(totalSize);
// RawBuffer must be set to all zeros.
std::memset(rawBuffer, 0, totalSize);
for (vector_size_t i = 0; i < numGroups; ++i) {
auto size = serializer->serialize(i, rawBuffer);
VELOX_DCHECK(!StringView::isInline(size));
StringView serialized = StringView(rawBuffer, size);
rawBuffer += size;
flatResult->setNoCopy(i, serialized);
}
}

void addRawInput(
char** groups,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /*mayPushdown*/) override {
decodedElements_.decode(*args[0], rows);
rows.applyToSelected([&](vector_size_t row) {
if (decodedElements_.isNullAt(row)) {
return;
}
auto group = groups[row];
auto tracker = trackRowSize(group);
value<ArrayAccumulator>(group)->elements.appendValue(
decodedElements_, row, allocator_);
});
}

void addIntermediateResults(
char** groups,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /*mayPushdown*/) override {
decodedIntermediate_.decode(*args[0], rows);

bool combine(
HashStringAllocator* allocator,
exec::optional_arg_type<IntermediateType> other) {
if (!other.has_value()) {
return false;
rows.applyToSelected([&](auto row) {
if (!decodedIntermediate_.isNullAt(row)) {
auto group = groups[row];
auto tracker = trackRowSize(group);
auto serialized = decodedIntermediate_.valueAt<StringView>(row);
if (!elementType_.has_value()) {
VectorPtr strVector =
row::UnsafeRowDeserializer::deserializeStructField(
{std::string_view(serialized.data())},
VARBINARY(),
kTypeIndex,
kFieldNum,
allocator_->pool());
elementType_ = Type::create(folly::parseJson(strVector->toString(0)));
}
VectorPtr dataVector =
row::UnsafeRowDeserializer::deserializeStructField(
{std::string_view(serialized.data())},
ARRAY(elementType_.value()),
kDataIndex,
kFieldNum,
allocator_->pool());
auto arrayVector = dataVector->as<ArrayVector>();
value<ArrayAccumulator>(group)->elements.appendRange(
arrayVector->elements(),
arrayVector->offsetAt(0),
arrayVector->sizeAt(0),
allocator_);
}
for (auto element : other.value()) {
elements_.appendValue(element, allocator);
});
}

void addSingleGroupRawInput(
char* group,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /* mayPushdown */) override {
auto& values = value<ArrayAccumulator>(group)->elements;

decodedElements_.decode(*args[0], rows);
auto tracker = trackRowSize(group);
rows.applyToSelected([&](vector_size_t row) {
if (decodedElements_.isNullAt(row)) {
return;
}
return true;
}
values.appendValue(decodedElements_, row, allocator_);
});
}

bool writeIntermediateResult(
bool /*nonNullGroup*/,
exec::out_type<IntermediateType>& out) {
// If the group's accumulator is null, the corresponding intermediate
// result is an empty array.
copyValueListToArrayWriter(out, elements_);
return true;
void addSingleGroupIntermediateResults(
char* group,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /* mayPushdown */) override {
decodedIntermediate_.decode(*args[0], rows);

rows.applyToSelected([&](auto row) {
if (!decodedIntermediate_.isNullAt(row)) {
auto serialized = decodedIntermediate_.valueAt<StringView>(row);
if (!elementType_.has_value()) {
VectorPtr strVector =
row::UnsafeRowDeserializer::deserializeStructField(
{std::string_view(serialized.data())},
VARBINARY(),
kTypeIndex,
kFieldNum,
allocator_->pool());
elementType_ = Type::create(folly::parseJson(strVector->toString(0)));
}
VectorPtr dataVector =
row::UnsafeRowDeserializer::deserializeStructField(
{std::string_view(serialized.data())},
ARRAY(elementType_.value()),
kDataIndex,
kFieldNum,
allocator_->pool());
auto arrayVector = dataVector->as<ArrayVector>();
value<ArrayAccumulator>(group)->elements.appendRange(
arrayVector->elements(),
arrayVector->offsetAt(0),
arrayVector->sizeAt(0),
allocator_);
}
});
}

protected:
void initializeNewGroupsInternal(
char** groups,
folly::Range<const vector_size_t*> indices) override {
for (auto index : indices) {
new (groups[index] + offset_) ArrayAccumulator();
}
}

bool writeFinalResult(
bool /*nonNullGroup*/,
exec::out_type<OutputType>& out) {
// If the group's accumulator is null, the corresponding result is an
// empty array.
copyValueListToArrayWriter(out, elements_);
return true;
void destroyInternal(folly::Range<char**> groups) override {
for (auto group : groups) {
if (isInitialized(group)) {
value<ArrayAccumulator>(group)->elements.free(allocator_);
}
}
}

void destroy(HashStringAllocator* allocator) {
elements_.free(allocator);
private:
vector_size_t countElements(char** groups, int32_t numGroups) const {
vector_size_t size = 0;
for (int32_t i = 0; i < numGroups; ++i) {
size += value<ArrayAccumulator>(groups[i])->elements.size();
}
};
return size;
}

// Reusable instance of DecodedVector for decoding input vectors.
DecodedVector decodedElements_;
DecodedVector decodedIntermediate_;
std::optional<TypePtr> elementType_;
static constexpr int32_t kDataIndex{0};
static constexpr int32_t kTypeIndex{1};
static constexpr int32_t kFieldNum{2};
};

AggregateRegistrationResult registerCollectList(
Expand All @@ -110,24 +280,32 @@ AggregateRegistrationResult registerCollectList(
exec::AggregateFunctionSignatureBuilder()
.typeVariable("E")
.returnType("array(E)")
.intermediateType("array(E)")
.intermediateType("varbinary")
.argumentType("E")
.build()};
return exec::registerAggregateFunction(
name,
std::move(signatures),
[name](
core::AggregationNode::Step /*step*/,
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_EQ(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<SimpleAggregateAdapter<CollectListAggregate>>(
resultType);
if (step == core::AggregationNode::Step::kIntermediate) {
return std::make_unique<CollectListAggregate>(
resultType, std::nullopt);
} else if (step == core::AggregationNode::Step::kFinal) {
return std::make_unique<CollectListAggregate>(resultType, resultType);
}
return std::make_unique<CollectListAggregate>(resultType, argTypes[0]);
},
withCompanionFunctions,
// Currently, Velox doesn't support automatically generating companion
// functions for aggregate functions whose result type is not resolvable
// given solely the concrete intermediate type.
false /*withCompanionFunctions*/,
overwrite);
}
} // namespace
Expand Down
Loading
Loading