diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index c0433145dd1d0..245341a6c054f 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -26,16 +26,25 @@ #include "arrow/compute/exec.h" #include "arrow/compute/exec_internal.h" #include "arrow/compute/function_internal.h" +#include "arrow/compute/kernel.h" #include "arrow/compute/kernels/common_internal.h" #include "arrow/compute/registry.h" #include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/type.h" +#include "arrow/util/bit_util.h" #include "arrow/util/cpu_info.h" #include "arrow/util/logging.h" +#include "arrow/util/string.h" +#include "arrow/util/string_builder.h" #include "arrow/util/tracing_internal.h" +#include "arrow/util/vector.h" namespace arrow { using internal::checked_cast; +using internal::JoinStrings; +using internal::MapVector; namespace compute { Result> FunctionOptionsType::Serialize( @@ -117,6 +126,21 @@ Status NoMatchingKernel(const Function* func, const std::vector& typ TypeHolder::ToString(types)); } +Status AmbiguousCall(const Function* func, const std::vector& types, + const std::vector& kernels) { + return Status::TypeError( + "Call of function '", func->name(), + "' is ambiguous with these input types: ", TypeHolder::ToString(types), + ". Candidates are: ", + JoinStrings(MapVector( + [&](const Kernel* kernel) { + return util::StringBuilder(func->name(), "(", + kernel->signature->ToString(), ")"); + }, + kernels), + ", ")); +} + template const KernelType* DispatchExactImpl(const std::vector& kernels, const std::vector& values) { @@ -306,9 +330,79 @@ Result Function::DispatchExact( return detail::NoMatchingKernel(this, values); } +Result Function::DispatchWithExtensionCast( + std::vector* values) const { + if (kind_ == Function::META) { + return Status::NotImplemented("Dispatch for a MetaFunction's Kernels"); + } + RETURN_NOT_OK(CheckArity(values->size())); + + std::vector extension_indices; + for (size_t i = 0; i < values->size(); ++i) { + if ((*values)[i].id() == Type::EXTENSION) { + extension_indices.push_back(static_cast(i)); + } + } + + if (extension_indices.empty()) { + return detail::NoMatchingKernel(this, *values); + } + + // Enumerate all possible combinations of extensions to cast, in increasing number + // of replacements. Try DispatchExact for each combination. + for (size_t num_replacement = 1; num_replacement <= extension_indices.size(); + ++num_replacement) { + // create bitmasks with num_replacement 1s + uint32_t mask = (1 << num_replacement) - 1; + const Kernel* matched_kernel = nullptr; + std::vector matched_values; + while (mask < (1ULL << extension_indices.size())) { + std::vector replaced_values = *values; + for (size_t i = 0; i < extension_indices.size(); ++i) { + if (mask & (1 << i)) { + replaced_values[extension_indices[i]] = + static_cast(*replaced_values[extension_indices[i]]) + .storage_type(); + } + } + + if (auto kernel = detail::DispatchExactImpl(this, replaced_values)) { + if (matched_kernel) { // If there are multiple matches, the call is ambiguous + return detail::AmbiguousCall(this, *values, {matched_kernel, kernel}); + } + matched_kernel = kernel; + matched_values = std::move(replaced_values); + } else { + } + // next lexicographical permutation of mask + mask = bit_util::NextBitPermutation(mask); + } + if (matched_kernel) { + *values = std::move(matched_values); + return matched_kernel; + } + } + return detail::NoMatchingKernel(this, *values); +} + Result Function::DispatchBest(std::vector* values) const { // TODO(ARROW-11508) permit generic conversions here - return DispatchExact(*values); + auto exact_result = DispatchExact(*values); + if (exact_result.ok()) { + return exact_result; + } + + // Try to cast extension types to their storage types + auto extension_result = DispatchWithExtensionCast(values); + if (extension_result.ok()) { + return extension_result; + } else if (extension_result.status().IsTypeError()) { + // If DispatchWithExtensionCast returns an ambiguous error, return it + return extension_result; + } + + // Otherwise returns the error from DispatchExact + return exact_result; } Result> Function::GetBestExecutor( diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h index 333c9a65c56c4..b8f6fdbf46d43 100644 --- a/cpp/src/arrow/compute/function.h +++ b/cpp/src/arrow/compute/function.h @@ -284,6 +284,18 @@ class ARROW_EXPORT Function { Status CheckArity(size_t num_args) const; + /// \brief Return a best-match kernel that can execute the function that enables casting + /// ExtensionTypes to their storage types. If multiple kernels match, the one with the + /// minimum number of casts is returned. If there are multiple kernels with the minimum + /// number of casts, the first one is returned. + /// + /// + /// \param[in,out] values Argument types. An element may be modified to + /// indicate that the returned kernel only approximately matches the input + /// value descriptors; callers are responsible for casting inputs to the type + /// required by the kernel. + Result DispatchWithExtensionCast(std::vector* values) const; + std::string name_; Function::Kind kind_; Arity arity_; diff --git a/cpp/src/arrow/compute/function_test.cc b/cpp/src/arrow/compute/function_test.cc index 66d38ecd64d49..5595753629768 100644 --- a/cpp/src/arrow/compute/function_test.cc +++ b/cpp/src/arrow/compute/function_test.cc @@ -32,6 +32,7 @@ #include "arrow/compute/kernel.h" #include "arrow/datum.h" #include "arrow/status.h" +#include "arrow/testing/extension_type.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/type.h" @@ -456,5 +457,39 @@ TEST(FunctionExecutor, Basics) { } } +TEST(DispatchBest, ImplicitExtensionToStorage) { + auto add = std::make_shared("add", Arity::Binary(), + /*doc=*/FunctionDoc::Empty()); + auto assert_kernel_match = [&add](const std::vector& expected_types) { + std::vector types = {smallint(), tinyint()}; + ASSERT_OK_AND_ASSIGN(auto kernel, add->DispatchBest(&types)); + ASSERT_TRUE(kernel->signature->MatchesInputs(expected_types)); + }; + + auto assert_fail = [&add](StatusCode code, std::string_view msg) { + std::vector types = {smallint(), tinyint()}; + EXPECT_RAISES_WITH_CODE_AND_MESSAGE_THAT(code, ::testing::HasSubstr(msg), + add->DispatchBest(&types)); + }; + + // no kernel, fail + assert_fail(StatusCode::NotImplemented, "no kernel matching input types"); + + // should match kernel with two casts + ASSERT_OK(add->AddKernel({int16(), int8()}, int16(), ExecNYI)); + assert_kernel_match({int16(), int8()}); + + // should perfer kernel with one cast + ASSERT_OK(add->AddKernel({smallint(), int8()}, smallint(), ExecNYI)); + assert_kernel_match({smallint(), int8()}); + + // two kernels with one cast, ambigous + ASSERT_OK(add->AddKernel({int16(), tinyint()}, int16(), ExecNYI)); + assert_fail(StatusCode::TypeError, "is ambiguous"); + + // should prefer kernel with no casts + ASSERT_OK(add->AddKernel({smallint(), tinyint()}, smallint(), ExecNYI)); + assert_kernel_match({smallint(), tinyint()}); +} } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index 00a833742f957..a4e10eabc65e6 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -138,6 +138,19 @@ void ReplaceTypes(const TypeHolder& replacement, TypeHolder* begin, size_t count } } +void EnsureExtensionToStorage(std::vector* types) { + EnsureExtensionToStorage(types->data(), types->size()); +} + +void EnsureExtensionToStorage(TypeHolder* begin, size_t count) { + auto* end = begin + count; + for (auto* it = begin; it != end; it++) { + if (it->type->id() == Type::EXTENSION) { + *it = checked_cast(*it->type).storage_type(); + } + } +} + TypeHolder CommonNumeric(const std::vector& types) { return CommonNumeric(types.data(), types.size()); } diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 72b29057b82e0..bd0d38be4c6a1 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1365,6 +1365,12 @@ void ReplaceTypes(const TypeHolder& replacement, TypeHolder* types, size_t count ARROW_EXPORT void ReplaceTemporalTypes(TimeUnit::type unit, std::vector* types); +ARROW_EXPORT +void EnsureExtensionToStorage(std::vector* types); + +ARROW_EXPORT +void EnsureExtensionToStorage(TypeHolder* begin, size_t count); + ARROW_EXPORT TypeHolder CommonNumeric(const std::vector& types); diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index ad33d7f8951f4..1fc7abd3a84ad 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -630,6 +630,7 @@ struct ArithmeticFunction : ScalarFunction { if (auto kernel = DispatchExactImpl(this, *types)) return kernel; EnsureDictionaryDecoded(types); + EnsureExtensionToStorage(types); // Only promote types for binary functions if (types->size() == 2) { @@ -685,6 +686,7 @@ struct ArithmeticDecimalToFloatingPointFunction : public ArithmeticFunction { if (auto kernel = DispatchExactImpl(this, *types)) return kernel; EnsureDictionaryDecoded(types); + EnsureExtensionToStorage(types); if (types->size() == 2) { ReplaceNullWithOtherType(types); @@ -717,6 +719,7 @@ struct ArithmeticIntegerToFloatingPointFunction : public ArithmeticFunction { if (auto kernel = DispatchExactImpl(this, *types)) return kernel; EnsureDictionaryDecoded(types); + EnsureExtensionToStorage(types); if (types->size() == 2) { ReplaceNullWithOtherType(types); @@ -748,6 +751,7 @@ struct ArithmeticFloatingPointFunction : public ArithmeticFunction { if (auto kernel = DispatchExactImpl(this, *types)) return kernel; EnsureDictionaryDecoded(types); + EnsureExtensionToStorage(types); if (types->size() == 2) { ReplaceNullWithOtherType(types); diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index 756b3028c4a59..caf8f0f9daee8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -36,6 +36,7 @@ #include "arrow/util/string.h" #include "arrow/testing/builder.h" +#include "arrow/testing/extension_type.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" @@ -2753,6 +2754,11 @@ TYPED_TEST(TestUnaryArithmeticFloating, Sign) { this->AssertUnaryOp(sign, this->MakeScalar(max), this->MakeScalar(1)); } +TEST(TestArithmeticExtension, Extension) { + // Allow extension types to be implicitly cast to their storage types + ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(int16(), "[-32640, null, 0, 0, 0, 0, 32640]"), + *Subtract(ExampleSmallint(), ExampleTinyint())->make_array()); +} } // namespace } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index aad648ca275c3..e12298a24829a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -348,6 +348,7 @@ struct CompareFunction : ScalarFunction { EnsureDictionaryDecoded(types); ReplaceNullWithOtherType(types); + EnsureExtensionToStorage(types); if (auto type = CommonNumeric(*types)) { ReplaceTypes(type, types); @@ -372,6 +373,7 @@ struct VarArgsCompareFunction : ScalarFunction { if (auto kernel = DispatchExactImpl(this, *types)) return kernel; EnsureDictionaryDecoded(types); + EnsureExtensionToStorage(types); if (auto type = CommonNumeric(*types)) { ReplaceTypes(type, types); diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 48fa780b03104..4b3051438abb2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -28,6 +28,7 @@ #include "arrow/compute/api.h" #include "arrow/compute/kernels/test_util.h" #include "arrow/testing/builder.h" +#include "arrow/testing/extension_type.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/testing/random.h" @@ -2123,5 +2124,11 @@ TEST(TestMaxElementWiseMinElementWise, CommonTemporal) { ResultWith(ScalarFromJSON(date64(), "86400000"))); } +TEST(TestCompareExtension, Extension) { + // Allow extension types to be implicitly cast to their storage types + ASSERT_ARRAYS_EQUAL( + *ArrayFromJSON(int16(), "[-32768, null, 1, 2, 3, 4, 127]"), + *MinElementWise({ExampleSmallint(), ExampleTinyint()})->make_array()); +} } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index ee181c053c053..d57f0c91976da 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -1226,6 +1226,7 @@ struct IfElseFunction : ScalarFunction { } internal::EnsureDictionaryDecoded(left_arg, num_args); + internal::EnsureExtensionToStorage(types); if (auto type = internal::CommonNumeric(left_arg, num_args)) { internal::ReplaceTypes(type, left_arg, num_args); @@ -1431,6 +1432,7 @@ struct CaseWhenFunction : ScalarFunction { } EnsureDictionaryDecoded(types); + EnsureExtensionToStorage(types); TypeHolder* first_arg = &(*types)[1]; const size_t num_args = types->size() - 1; if (auto type = CommonNumeric(first_arg, num_args)) { @@ -1964,6 +1966,7 @@ struct CoalesceFunction : ScalarFunction { // Do not DispatchExact here since we want to rescale decimals if necessary EnsureDictionaryDecoded(types); + EnsureExtensionToStorage(types); if (auto type = CommonNumeric(types->data(), types->size())) { ReplaceTypes(type, types); } @@ -2663,6 +2666,7 @@ struct ChooseFunction : ScalarFunction { // based on the type of the rest of the arguments. RETURN_NOT_OK(CheckArity(types->size())); EnsureDictionaryDecoded(types); + EnsureExtensionToStorage(types); if (types->front().id() != Type::INT64) { (*types)[0] = int64(); } diff --git a/cpp/src/arrow/compute/kernels/scalar_round.cc b/cpp/src/arrow/compute/kernels/scalar_round.cc index 36e59c8c1deac..7234d410218cb 100644 --- a/cpp/src/arrow/compute/kernels/scalar_round.cc +++ b/cpp/src/arrow/compute/kernels/scalar_round.cc @@ -1053,6 +1053,7 @@ struct RoundFunction : ScalarFunction { if (auto kernel = DispatchExactImpl(this, *types)) return kernel; EnsureDictionaryDecoded(types); + EnsureExtensionToStorage(types); // for binary round functions, the second scalar must be int32 if (types->size() == 2 && (*types)[1].id() != Type::INT32) { @@ -1074,6 +1075,7 @@ struct RoundDecimalToFloatingPointFunction : public RoundFunction { if (auto kernel = DispatchExactImpl(this, *types)) return kernel; EnsureDictionaryDecoded(types); + EnsureExtensionToStorage(types); // Size of types is checked above. const auto originalType = (*types)[0]; @@ -1099,6 +1101,7 @@ struct RoundIntegerToFloatingPointFunction : public RoundFunction { if (auto kernel = DispatchExactImpl(this, *types)) return kernel; EnsureDictionaryDecoded(types); + EnsureExtensionToStorage(types); // Size of types is checked above. const auto originalType = (*types)[0]; @@ -1124,6 +1127,7 @@ struct RoundFloatingPointFunction : public RoundFunction { if (auto kernel = DispatchExactImpl(this, *types)) return kernel; EnsureDictionaryDecoded(types); + EnsureExtensionToStorage(types); // Size of types is checked above. const auto originalType = (*types)[0]; diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc index e2d5583e36e6b..0cc1c1878a5e2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc @@ -636,6 +636,7 @@ struct SetLookupFunction : ScalarFunction { Result DispatchBest(std::vector* values) const override { EnsureDictionaryDecoded(values); + EnsureExtensionToStorage(values); return DispatchExact(*values); } }; diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc index 6764845dfca81..9a19a24cda025 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc @@ -3258,6 +3258,7 @@ struct ScalarCTypeToInt64Function : public ScalarFunction { if (auto kernel = DispatchExactImpl(this, *types)) return kernel; EnsureDictionaryDecoded(types); + EnsureExtensionToStorage(types); for (auto it = types->begin(); it < types->end(); ++it) { if (is_integer(it->id())) { diff --git a/cpp/src/arrow/util/bit_util.h b/cpp/src/arrow/util/bit_util.h index 04ab07af1d779..d3f9d448a397d 100644 --- a/cpp/src/arrow/util/bit_util.h +++ b/cpp/src/arrow/util/bit_util.h @@ -317,6 +317,18 @@ static inline void SetBitTo(uint8_t* bits, int64_t i, bool bit_is_set) { kBitmask[i % 8]; } +// Inputs a pattern of N bits set to 1 in an integer and outputs the next permutation of N +// 1 bits in a lexicographical sense. For example, if N is 3 and the bit pattern is +// 00010011, the next patterns would be 00010101, 00010110, 00011001,00011010, 00011100, +// 00100011, and so forth. The following is a fast way to compute the next permutation. +// https://graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation +static inline uint32_t NextBitPermutation(uint32_t v) { + uint32_t t = v | (v - 1); // t gets v's least significant 0 bits set to 1 + // Next set to 1 the most significant bit to change, + // set to 0 the least significant ones, and add the necessary 1 bits. + return (t + 1) | (((~t & (0 - ~t)) - 1) >> (CountTrailingZeros(v) + 1)); +} + /// \brief set or clear a range of bits quickly ARROW_EXPORT void SetBitsTo(uint8_t* bits, int64_t start_offset, int64_t length, bool bits_are_set);