From ede70ee321adf8d00d63710f7d2b0b4b84185d78 Mon Sep 17 00:00:00 2001 From: NEUpanning Date: Mon, 20 Jan 2025 15:13:46 +0800 Subject: [PATCH 1/2] first commit --- .../aggregates/CollectListAggregate.cpp | 293 ++++++++++++++---- .../tests/CollectListAggregateTest.cpp | 4 + velox/row/UnsafeRowDeserializers.h | 121 ++++++++ .../tests/UnsafeRowSerializerTest.cpp | 14 + 4 files changed, 373 insertions(+), 59 deletions(-) diff --git a/velox/functions/sparksql/aggregates/CollectListAggregate.cpp b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp index e2c14cfa7969..88a6bee0c68d 100644 --- a/velox/functions/sparksql/aggregates/CollectListAggregate.cpp +++ b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp @@ -18,88 +18,258 @@ #include "velox/exec/SimpleAggregateAdapter.h" #include "velox/functions/lib/aggregates/ValueList.h" +#include "velox/row/UnsafeRowDeserializers.h" +#include "velox/row/UnsafeRowFast.h" using namespace facebook::velox::aggregate; using namespace facebook::velox::exec; namespace facebook::velox::functions::aggregate::sparksql { namespace { -class CollectListAggregate { +struct ArrayAccumulator { + ValueList elements; +}; + +class CollectListAggregate : public exec::Aggregate { public: - using InputType = Row>; + CollectListAggregate(TypePtr resultType, std::optional elementType) + : Aggregate(resultType), elementType_(elementType) {} - using IntermediateType = Array>; + int32_t accumulatorFixedWidthSize() const override { + return sizeof(ArrayAccumulator); + } - using OutputType = Array>; + bool isFixedSize() const override { + return false; + } - /// In Spark, when all inputs are null, the output is an empty array instead - /// of null. Therefore, in the writeIntermediateResult and writeFinalResult, - /// we still need to output the empty element_ when the group is null. This - /// behavior can only be achieved when the default-null behavior is disabled. - static constexpr bool default_null_behavior_ = false; + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + auto vector = (*result)->as(); + VELOX_CHECK(vector); + vector->resize(numGroups); - static bool toIntermediate( - exec::out_type>>& out, - exec::optional_arg_type> in) { - if (in.has_value()) { - out.add_item().copy_from(in.value()); - return true; + auto elements = vector->elements(); + elements->resize(countElements(groups, numGroups)); + + uint64_t* rawNulls = getRawNulls(vector); + vector_size_t offset = 0; + for (int32_t i = 0; i < numGroups; ++i) { + auto& values = value(groups[i])->elements; + auto arraySize = values.size(); + // If the group's accumulator is null, the corresponding result is an + // empty array rather than null. + clearNull(rawNulls, i); + if (arraySize) { + ValueListReader reader(values); + for (auto index = 0; index < arraySize; ++index) { + reader.next(*elements, offset + index); + } + } + vector->setOffsetAndSize(i, offset, arraySize); + offset += arraySize; } - return false; } - struct AccumulatorType { - ValueList elements_; + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + VELOX_CHECK(result); + auto vector = BaseVector::create( + ROW({"data", "type"}, {ARRAY(elementType_.value()), VARCHAR()}), + numGroups, + allocator_->pool()); + auto rowVector = vector->asChecked(); + auto arrayVector = rowVector->childAt(0)->asChecked(); + auto flatVector = + rowVector->childAt(1)->asChecked>(); + auto elements = arrayVector->elements(); + elements->resize(countElements(groups, numGroups)); - explicit AccumulatorType(HashStringAllocator* /*allocator*/) - : elements_{} {} + auto typeStr = folly::toJson(elementType_.value()->serialize()); + vector_size_t offset = 0; + for (int32_t i = 0; i < numGroups; ++i) { + flatVector->set(i, StringView(typeStr)); - static constexpr bool is_fixed_size_ = false; - - bool addInput( - HashStringAllocator* allocator, - exec::optional_arg_type> data) { - if (data.has_value()) { - elements_.appendValue(data, allocator); - return true; + // To align with Spark's intermediate data, if the group's accumulator is + // null, the corresponding result is an empty array. + auto& values = value(groups[i])->elements; + auto arraySize = values.size(); + arrayVector->setNull(i, false); + if (arraySize) { + ValueListReader reader(values); + for (auto index = 0; index < arraySize; ++index) { + reader.next(*elements, offset + index); + } } - return false; + arrayVector->setOffsetAndSize(i, offset, arraySize); + offset += arraySize; + } + + auto flatResult = (*result)->asUnchecked>(); + flatResult->resize(numGroups); + RowVectorPtr rowVectorPtr = std::dynamic_pointer_cast(vector); + auto serializer = std::make_unique(rowVectorPtr); + size_t totalSize = 0; + for (vector_size_t i = 0; i < numGroups; ++i) { + int32_t rowSize = serializer->rowSize(i); + totalSize += rowSize; + } + + char* rawBuffer = flatResult->getRawStringBufferWithSpace(totalSize); + // RawBuffer must be set to all zeros. + std::memset(rawBuffer, 0, totalSize); + for (vector_size_t i = 0; i < numGroups; ++i) { + auto size = serializer->serialize(i, rawBuffer); + VELOX_DCHECK(!StringView::isInline(size)); + StringView serialized = StringView(rawBuffer, size); + rawBuffer += size; + flatResult->setNoCopy(i, serialized); } + } + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decodedElements_.decode(*args[0], rows); + rows.applyToSelected([&](vector_size_t row) { + if (decodedElements_.isNullAt(row)) { + return; + } + auto group = groups[row]; + auto tracker = trackRowSize(group); + value(group)->elements.appendValue( + decodedElements_, row, allocator_); + }); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decodedIntermediate_.decode(*args[0], rows); - bool combine( - HashStringAllocator* allocator, - exec::optional_arg_type other) { - if (!other.has_value()) { - return false; + rows.applyToSelected([&](auto row) { + if (!decodedIntermediate_.isNullAt(row)) { + auto group = groups[row]; + auto tracker = trackRowSize(group); + auto serialized = decodedIntermediate_.valueAt(row); + if (!elementType_.has_value()) { + VectorPtr strVector = + row::UnsafeRowDeserializer::deserializeStructField( + {std::string_view(serialized.data())}, + VARBINARY(), + kTypeIndex, + kFieldNum, + allocator_->pool()); + elementType_ = Type::create(folly::parseJson(strVector->toString(0))); + } + VectorPtr dataVector = + row::UnsafeRowDeserializer::deserializeStructField( + {std::string_view(serialized.data())}, + ARRAY(elementType_.value()), + kDataIndex, + kFieldNum, + allocator_->pool()); + auto arrayVector = dataVector->as(); + value(group)->elements.appendRange( + arrayVector->elements(), + arrayVector->offsetAt(0), + arrayVector->sizeAt(0), + allocator_); } - for (auto element : other.value()) { - elements_.appendValue(element, allocator); + }); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + auto& values = value(group)->elements; + + decodedElements_.decode(*args[0], rows); + auto tracker = trackRowSize(group); + rows.applyToSelected([&](vector_size_t row) { + if (decodedElements_.isNullAt(row)) { + return; } - return true; - } + values.appendValue(decodedElements_, row, allocator_); + }); + } - bool writeIntermediateResult( - bool /*nonNullGroup*/, - exec::out_type& out) { - // If the group's accumulator is null, the corresponding intermediate - // result is an empty array. - copyValueListToArrayWriter(out, elements_); - return true; + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedIntermediate_.decode(*args[0], rows); + + rows.applyToSelected([&](auto row) { + if (!decodedIntermediate_.isNullAt(row)) { + auto serialized = decodedIntermediate_.valueAt(row); + if (!elementType_.has_value()) { + VectorPtr strVector = + row::UnsafeRowDeserializer::deserializeStructField( + {std::string_view(serialized.data())}, + VARBINARY(), + kTypeIndex, + kFieldNum, + allocator_->pool()); + elementType_ = Type::create(folly::parseJson(strVector->toString(0))); + } + VectorPtr dataVector = + row::UnsafeRowDeserializer::deserializeStructField( + {std::string_view(serialized.data())}, + ARRAY(elementType_.value()), + kDataIndex, + kFieldNum, + allocator_->pool()); + auto arrayVector = dataVector->as(); + value(group)->elements.appendRange( + arrayVector->elements(), + arrayVector->offsetAt(0), + arrayVector->sizeAt(0), + allocator_); + } + }); + } + + protected: + void initializeNewGroupsInternal( + char** groups, + folly::Range indices) override { + for (auto index : indices) { + new (groups[index] + offset_) ArrayAccumulator(); } + } - bool writeFinalResult( - bool /*nonNullGroup*/, - exec::out_type& out) { - // If the group's accumulator is null, the corresponding result is an - // empty array. - copyValueListToArrayWriter(out, elements_); - return true; + void destroyInternal(folly::Range groups) override { + for (auto group : groups) { + if (isInitialized(group)) { + value(group)->elements.free(allocator_); + } } + } - void destroy(HashStringAllocator* allocator) { - elements_.free(allocator); + private: + vector_size_t countElements(char** groups, int32_t numGroups) const { + vector_size_t size = 0; + for (int32_t i = 0; i < numGroups; ++i) { + size += value(groups[i])->elements.size(); } - }; + return size; + } + + // Reusable instance of DecodedVector for decoding input vectors. + DecodedVector decodedElements_; + DecodedVector decodedIntermediate_; + std::optional elementType_; + static constexpr int32_t kDataIndex{0}; + static constexpr int32_t kTypeIndex{1}; + static constexpr int32_t kFieldNum{2}; }; AggregateRegistrationResult registerCollectList( @@ -110,22 +280,27 @@ AggregateRegistrationResult registerCollectList( exec::AggregateFunctionSignatureBuilder() .typeVariable("E") .returnType("array(E)") - .intermediateType("array(E)") + .intermediateType("varbinary") .argumentType("E") .build()}; return exec::registerAggregateFunction( name, std::move(signatures), [name]( - core::AggregationNode::Step /*step*/, + core::AggregationNode::Step step, const std::vector& argTypes, const TypePtr& resultType, const core::QueryConfig& /*config*/) -> std::unique_ptr { VELOX_CHECK_EQ( argTypes.size(), 1, "{} takes at most one argument", name); - return std::make_unique>( - resultType); + if (step == core::AggregationNode::Step::kIntermediate) { + return std::make_unique( + resultType, std::nullopt); + } else if (step == core::AggregationNode::Step::kFinal) { + return std::make_unique(resultType, resultType); + } + return std::make_unique(resultType, argTypes[0]); }, withCompanionFunctions, overwrite); diff --git a/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp b/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp index 73088a47b620..bb19c360909c 100644 --- a/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp +++ b/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp @@ -28,6 +28,10 @@ class CollectListAggregateTest : public AggregationTestBase { void SetUp() override { AggregationTestBase::SetUp(); registerAggregateFunctions("spark_"); + // The intermediate data of spark_collect_list includes the json string of + // type, but it is non-deterministic. Therefore, it cannot be tested + // incrementally. + disableTestIncremental(); } }; diff --git a/velox/row/UnsafeRowDeserializers.h b/velox/row/UnsafeRowDeserializers.h index bcda23dc60a8..852389e00943 100644 --- a/velox/row/UnsafeRowDeserializers.h +++ b/velox/row/UnsafeRowDeserializers.h @@ -169,6 +169,82 @@ struct PrimitiveBatchIterator : UnsafeRowDataBatchIterator { size_t currentRow_ = 0; }; +struct StructFieldBatchIterator : UnsafeRowDataBatchIterator { + public: + /** + * StructFieldBatchIterator constructor. + * @param data the UnsafeRow data + * @param type the field type of struct + * @param fieldIdx the field index + * @param numElements the number of field in struct + */ + StructFieldBatchIterator( + const std::vector>& data, + const TypePtr& type, + size_t fieldIdx, + size_t numElements) + : UnsafeRowDataBatchIterator(data, type), + idx_(fieldIdx), + numElements_(numElements) { + columnData_.resize(numRows_); + }; + + /** + * @return return the field column batch. + */ + const std::vector>& columnBatch() { + std::size_t fixedSize = + isFixedWidth(type()) ? serializedSizeInBytes(type()) : 0; + std::size_t fieldOffset = UnsafeRow::getNullLength(numElements_) + + idx_ * UnsafeRow::kFieldWidthBytes; + + for (int32_t i = 0; i < numRows_; ++i) { + // if null + if (!data_[i] || bits::isBitSet(data_[i].value().data(), idx_)) { + columnData_[i] = std::nullopt; + continue; + } + + const char* rawData = data_[i]->data(); + const char* fieldData = rawData + fieldOffset; + + // Fixed length field + if (fixedSize > 0) { + columnData_[i] = std::string_view(fieldData, fixedSize); + continue; + } + + auto [size, offset] = readDataPointer(fieldData); + columnData_[i] = std::string_view(rawData + offset, size); + } + + return columnData_; + } + + std::string toString(size_t idx) const override { + std::stringstream str; + str << "Struct field data iterator of type " << type()->toString() + << " of size " << size(idx); + return str.str(); + } + + private: + /** + * The field element index. + */ + size_t idx_; + + /** + * The number of elements in the struct. + */ + const size_t numElements_; + + /** + * The column data of the field. + */ + std::vector> columnData_; +}; + /** * Iterator representation of an UnsafeRow Struct object. UnsafeRow Struct representation: @@ -711,6 +787,28 @@ struct UnsafeRowDeserializer { pool, type, nulls, numStructs, std::move(columnVectors)); } + /** + * Converts a list of structFieldBatchIterator to Vectors. + * @param dataIterator iterator that points to whole column batch of data. + * @param pool + * @return a VectorPtr + */ + static VectorPtr convertFieldStructIteratorsToVectors( + const DataBatchIteratorPtr& dataIterator, + memory::MemoryPool* pool) { + const TypePtr& type = dataIterator->type(); + auto* structFieldBatchIterator = + static_cast(dataIterator.get()); + size_t numRows = structFieldBatchIterator->numRows(); + + auto nulls = populateNulls(dataIterator, pool, numRows); + + return deserialize( + structFieldBatchIterator->columnBatch(), + structFieldBatchIterator->type(), + pool); + } + /** * Converts a list of PrimitiveBatchIterators to a FlatVector * @tparam Kind the element's type kind. @@ -855,6 +953,29 @@ struct UnsafeRowDeserializer { memory::MemoryPool* pool) { return convertToVectors(getBatchIteratorPtr(data, type), pool); } + + /** + * Deserializes a field of a struct type to its Vector representation. + * @param data A vector of string_view over a given element in the + *UnsafeRow. + * @param type the field element type. + * @param field the field index + * @param fieldsNum the number of fields + * @param pool the memory pool to allocate Vectors from + *data to a array. + * @return a VectorPtr + */ + static VectorPtr deserializeStructField( + const std::vector>& data, + const TypePtr& type, + size_t field, + size_t fieldsNum, + memory::MemoryPool* pool) { + return convertFieldStructIteratorsToVectors( + std::make_shared( + data, type, field, fieldsNum), + pool); + } }; } // namespace facebook::velox::row diff --git a/velox/serializers/tests/UnsafeRowSerializerTest.cpp b/velox/serializers/tests/UnsafeRowSerializerTest.cpp index 050c104e33d2..345aec381149 100644 --- a/velox/serializers/tests/UnsafeRowSerializerTest.cpp +++ b/velox/serializers/tests/UnsafeRowSerializerTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "velox/serializers/UnsafeRowSerializer.h" +#include "velox/row/UnsafeRowDeserializers.h" #include #include "velox/common/base/tests/GTestUtils.h" #include "velox/row/UnsafeRowFast.h" @@ -258,6 +259,19 @@ TEST_P(UnsafeRowSerializerTest, manyRows) { testDeserialize(data, 140, expected); } +TEST_P(UnsafeRowSerializerTest, structField) { + int8_t data[32] = {0, 0, 0, 0, 0, 0, 0, 0, -24, 3, 0, 0, 0, 0, 0, 0, + 5, 0, 0, 0, 24, 0, 0, 0, 97, 112, 112, 108, 101, 0, 0, 0}; + auto expected = makeFlatVector(std::vector{"apple"}); + VectorPtr results = row::UnsafeRowDeserializer::deserializeStructField( + {std::string_view(reinterpret_cast(data), 32)}, + VARCHAR(), + 1, + 2, + pool()); + test::assertEqualVectors(expected, results); +} + TEST_P(UnsafeRowSerializerTest, splitRow) { int8_t data[20] = {0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, 0, 62, 28, -36, -33, 2, 0, 0, 0}; From 81cc0b53fac9c3dad1b74c53bb04aa75cce8e6a6 Mon Sep 17 00:00:00 2001 From: NEUpanning Date: Mon, 20 Jan 2025 15:23:46 +0800 Subject: [PATCH 2/2] remove testAggregationsWithCompanion --- .../aggregates/CollectListAggregate.cpp | 5 ++- .../tests/CollectListAggregateTest.cpp | 35 ------------------- 2 files changed, 4 insertions(+), 36 deletions(-) diff --git a/velox/functions/sparksql/aggregates/CollectListAggregate.cpp b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp index 88a6bee0c68d..4e543ce6ddb4 100644 --- a/velox/functions/sparksql/aggregates/CollectListAggregate.cpp +++ b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp @@ -302,7 +302,10 @@ AggregateRegistrationResult registerCollectList( } return std::make_unique(resultType, argTypes[0]); }, - withCompanionFunctions, + // Currently, Velox doesn't support automatically generating companion + // functions for aggregate functions whose result type is not resolvable + // given solely the concrete intermediate type. + false /*withCompanionFunctions*/, overwrite); } } // namespace diff --git a/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp b/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp index bb19c360909c..a9a8fe0891c3 100644 --- a/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp +++ b/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp @@ -64,15 +64,6 @@ TEST_F(CollectListAggregateTest, groupBy) { {"spark_collect_list(c1)"}, {"c0", "array_sort(a0)"}, {expected}); - testAggregationsWithCompanion( - batches, - [](auto& /*builder*/) {}, - {"c0"}, - {"spark_collect_list(c1)"}, - {{BIGINT()}}, - {"c0", "array_sort(a0)"}, - {expected}, - {}); } TEST_F(CollectListAggregateTest, global) { @@ -83,14 +74,6 @@ TEST_F(CollectListAggregateTest, global) { testAggregations( {data}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected}); - testAggregationsWithCompanion( - {data}, - [](auto& /*builder*/) {}, - {}, - {"spark_collect_list(c0)"}, - {{INTEGER()}}, - {"array_sort(a0)"}, - {expected}); } TEST_F(CollectListAggregateTest, ignoreNulls) { @@ -101,15 +84,6 @@ TEST_F(CollectListAggregateTest, ignoreNulls) { makeRowVector({makeArrayVectorFromJson({"[1, 2, 4, 6]"})}); testAggregations( {input}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected}); - testAggregationsWithCompanion( - {input}, - [](auto& /*builder*/) {}, - {}, - {"spark_collect_list(c0)"}, - {{INTEGER()}}, - {"array_sort(a0)"}, - {expected}, - {}); } TEST_F(CollectListAggregateTest, allNullsInput) { @@ -117,15 +91,6 @@ TEST_F(CollectListAggregateTest, allNullsInput) { // If all input data is null, Spark will output an empty array. auto expected = makeRowVector({makeArrayVectorFromJson({"[]"})}); testAggregations({input}, {}, {"spark_collect_list(c0)"}, {expected}); - testAggregationsWithCompanion( - {input}, - [](auto& /*builder*/) {}, - {}, - {"spark_collect_list(c0)"}, - {{BIGINT()}}, - {}, - {expected}, - {}); } } // namespace } // namespace facebook::velox::functions::aggregate::sparksql::test