Skip to content

Commit

Permalink
Fix fb_reshape_row for ArrayType equal comparison (#11169)
Browse files Browse the repository at this point in the history
Summary:

fb_reshape_row returns wrong result for array of row when "from" and "to" row
types are the same size and same type but different names.

For example:
```
SELECT
    fb_reshape_row(
        col,
        CAST(NULL AS ROW(arr ARRAY(ROW(b VARCHAR, a VARCHAR))))
    ) as col
FROM (
    SELECT CAST(ROW(x) AS ROW(arr ARRAY(ROW(a VARCHAR, b VARCHAR)))) AS col
    FROM (
    VALUES
        (ARRAY[('1', '2')])
    ) t(x)
);
```

In fb_reshape_row::reshapeRow, if it finds fromType is equal to toType, it will not
do transformation
```
if (fromVector->type()->asRow().equals(toType->asRow())) {
      return fromVector;
    }
```

For `RowType::equals` comparison, it will iterate its children and apply
`operator==` for equal comparison. However `operator==`  is not
defined for `ArrayType` and the equal logic fall back to use
ArrayType::equivalent which is weakly matched.

Change `operator==`  to use `equals` for complex type to ensure strongly
matching for equal comparison

Differential Revision: D63993116
  • Loading branch information
Ke Wang authored and facebook-github-bot committed Oct 7, 2024
1 parent 13c18db commit b330160
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 22 deletions.
57 changes: 47 additions & 10 deletions velox/type/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,17 @@ bool ArrayType::equivalent(const Type& other) const {
return child_->equivalent(*otherArray.child_);
}

bool ArrayType::equals(const Type& other) const {
if (&other == this) {
return true;
}
if (!Type::hasSameTypeId(other)) {
return false;
}
auto& otherArray = other.asArray();
return *child_ == *otherArray.child_;
}

folly::dynamic ArrayType::serialize() const {
folly::dynamic obj = folly::dynamic::object;
obj["name"] = "Type";
Expand Down Expand Up @@ -476,14 +487,6 @@ bool RowType::equals(const Type& other) const {
return true;
}

bool RowType::operator==(const Type& other) const {
return this->equals(other);
}

bool RowType::operator==(const RowType& other) const {
return this->equals(other);
}

void RowType::printChildren(std::stringstream& ss, std::string_view delimiter)
const {
bool any = false;
Expand Down Expand Up @@ -569,6 +572,17 @@ bool MapType::equivalent(const Type& other) const {
valueType_->equivalent(*otherMap.valueType_);
}

bool MapType::equals(const Type& other) const {
if (&other == this) {
return true;
}
if (!Type::hasSameTypeId(other)) {
return false;
}
auto& otherMap = other.asMap();
return *keyType_ == *otherMap.keyType_ && *valueType_ == *otherMap.valueType_;
}

FunctionType::FunctionType(
std::vector<std::shared_ptr<const Type>>&& argumentTypes,
std::shared_ptr<const Type> returnType)
Expand Down Expand Up @@ -598,6 +612,29 @@ bool FunctionType::equivalent(const Type& other) const {
return true;
}

bool FunctionType::equals(const Type& other) const {
if (&other == this) {
return true;
}

if (!Type::hasSameTypeId(other)) {
return false;
}

auto& otherTyped = *reinterpret_cast<const FunctionType*>(&other);
if (children_.size() != otherTyped.size()) {
return false;
}

for (auto i = 0; i < children_.size(); ++i) {
if (*children_.at(i) != *otherTyped.children_.at(i)) {
return false;
}
}

return true;
}

std::string FunctionType::toString() const {
std::stringstream out;
out << "FUNCTION<";
Expand Down Expand Up @@ -629,11 +666,11 @@ bool OpaqueType::equivalent(const Type& other) const {
return true;
}

bool OpaqueType::operator==(const Type& other) const {
bool OpaqueType::equals(const Type& other) const {
if (&other == this) {
return true;
}
if (!this->equivalent(other)) {
if (*this != other) {
return false;
}
auto& otherTyped = *reinterpret_cast<const OpaqueType*>(&other);
Expand Down
39 changes: 28 additions & 11 deletions velox/type/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <vector>

#include "velox/common/base/ClassName.h"
#include "velox/common/base/Exceptions.h"
#include "velox/common/serialization/Serializable.h"
#include "velox/type/HugeInt.h"
#include "velox/type/StringView.h"
Expand Down Expand Up @@ -492,12 +493,10 @@ class Type : public Tree<const TypePtr>, public velox::ISerializable {
/// equivalent if the typeKind matches, but the typeIndex could be different.
virtual bool equivalent(const Type& other) const = 0;

/// Types are strongly matched.
/// Examples: Two RowTypes are == if the children types and the children names
/// are same. Two OpaqueTypes are == if the typeKind and the typeIndex are
/// same. Same as equivalent for most types except for Row, Opaque types.
/// For Complex types (Row, Array, Map, Opaque): types are strongly matched.
/// For primitive types: same as equivalent.
virtual bool operator==(const Type& other) const {
return this->equivalent(other);
return this->equals(other);
}

inline bool operator!=(const Type& other) const {
Expand Down Expand Up @@ -566,6 +565,16 @@ class Type : public Tree<const TypePtr>, public velox::ISerializable {
return typeid(*this) == typeid(other);
}

/// For Complex types (Row, Array, Map, Opaque): types are strongly matched.
/// Examples: Two RowTypes are == if the children types and the children names
/// are same. Two OpaqueTypes are == if the typeKind and the typeIndex are
/// same.
/// For primitive types: same as equivalent.
virtual bool equals(const Type& other) const {
VELOX_CHECK(this->isPrimitiveType());
return this->equivalent(other);
}

private:
const TypeKind kind_;
const bool providesCustomComparison_;
Expand Down Expand Up @@ -914,6 +923,8 @@ class ArrayType : public TypeBase<TypeKind::ARRAY> {
}

protected:
bool equals(const Type& other) const override;

TypePtr child_;
const std::vector<TypeParameter> parameters_;
};
Expand Down Expand Up @@ -965,6 +976,9 @@ class MapType : public TypeBase<TypeKind::MAP> {
return parameters_;
}

protected:
bool equals(const Type& other) const override;

private:
TypePtr keyType_;
TypePtr valueType_;
Expand Down Expand Up @@ -1005,10 +1019,6 @@ class RowType : public TypeBase<TypeKind::ROW> {

bool equivalent(const Type& other) const override;

bool equals(const Type& other) const;
bool operator==(const Type& other) const override;
bool operator==(const RowType& other) const;

std::string toString() const override;

/// Print child names and types separated by 'delimiter'.
Expand Down Expand Up @@ -1037,6 +1047,9 @@ class RowType : public TypeBase<TypeKind::ROW> {
return *parameters;
}

protected:
bool equals(const Type& other) const override;

private:
std::unique_ptr<std::vector<TypeParameter>> makeParameters() const;

Expand Down Expand Up @@ -1090,6 +1103,9 @@ class FunctionType : public TypeBase<TypeKind::FUNCTION> {
return parameters_;
}

protected:
bool equals(const Type& other) const override;

private:
static std::vector<std::shared_ptr<const Type>> allChildren(
std::vector<std::shared_ptr<const Type>>&& argumentTypes,
Expand Down Expand Up @@ -1124,8 +1140,6 @@ class OpaqueType : public TypeBase<TypeKind::OPAQUE> {

bool equivalent(const Type& other) const override;

bool operator==(const Type& other) const override;

const std::type_index& typeIndex() const {
return typeIndex_;
}
Expand Down Expand Up @@ -1186,6 +1200,9 @@ class OpaqueType : public TypeBase<TypeKind::OPAQUE> {
deserializeTypeErased);
}

protected:
bool equals(const Type& other) const override;

private:
const std::type_index typeIndex_;

Expand Down
4 changes: 3 additions & 1 deletion velox/vector/tests/VectorMakerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,9 @@ TEST_F(VectorMakerTest, arrayOfRowVectorFromTuples) {
auto expected = maker_.arrayVector(offsets, elements);

ASSERT_EQ(expected->size(), arrayVector->size());
ASSERT_EQ(*expected->type(), *arrayVector->type());
// check equivalent because arrayVector's row type doesn't have name for each
// column ('', '' ..) whereas expected's row type have names ('c0', 'c1' ..)
ASSERT_TRUE((*expected->type()).equivalent((*arrayVector->type())));
for (auto i = 0; i < expected->size(); i++) {
ASSERT_TRUE(expected->equalValueAt(arrayVector.get(), i, i));
}
Expand Down

0 comments on commit b330160

Please sign in to comment.