diff --git a/cpp/src/io/parquet/bloom_filter_reader.cu b/cpp/src/io/parquet/bloom_filter_reader.cu index a883981a467..87024719d87 100644 --- a/cpp/src/io/parquet/bloom_filter_reader.cu +++ b/cpp/src/io/parquet/bloom_filter_reader.cu @@ -32,7 +32,6 @@ #include #include #include -#include #include #include @@ -163,108 +162,6 @@ struct bloom_filter_caster { } }; -/** - * @brief Collects lists of equality predicate literals in the AST expression, one list per input - * table column. This is used in row group filtering based on bloom filters. - */ -class equality_literals_collector : public ast::detail::expression_transformer { - public: - equality_literals_collector() = default; - - equality_literals_collector(ast::expression const& expr, cudf::size_type num_input_columns) - : _num_input_columns{num_input_columns} - { - _equality_literals.resize(_num_input_columns); - expr.accept(*this); - } - - /** - * @copydoc ast::detail::expression_transformer::visit(ast::literal const& ) - */ - std::reference_wrapper visit(ast::literal const& expr) override - { - return expr; - } - - /** - * @copydoc ast::detail::expression_transformer::visit(ast::column_reference const& ) - */ - std::reference_wrapper visit(ast::column_reference const& expr) override - { - CUDF_EXPECTS(expr.get_table_source() == ast::table_reference::LEFT, - "BloomfilterAST supports only left table"); - CUDF_EXPECTS(expr.get_column_index() < _num_input_columns, - "Column index cannot be more than number of columns in the table"); - return expr; - } - - /** - * @copydoc ast::detail::expression_transformer::visit(ast::column_name_reference const& ) - */ - std::reference_wrapper visit( - ast::column_name_reference const& expr) override - { - CUDF_FAIL("Column name reference is not supported in BloomfilterAST"); - } - - /** - * @copydoc ast::detail::expression_transformer::visit(ast::operation const& ) - */ - std::reference_wrapper visit(ast::operation const& expr) override - { - using cudf::ast::ast_operator; - auto const operands = expr.get_operands(); - auto const op = expr.get_operator(); - - if (auto* v = dynamic_cast(&operands[0].get())) { - // First operand should be column reference, second should be literal. - CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 2, - "Only binary operations are supported on column reference"); - auto const literal_ptr = dynamic_cast(&operands[1].get()); - CUDF_EXPECTS(literal_ptr != nullptr, - "Second operand of binary operation with column reference must be a literal"); - v->accept(*this); - - // Push to the corresponding column's literals list iff equality predicate is seen - if (op == ast_operator::EQUAL) { - auto const col_idx = v->get_column_index(); - _equality_literals[col_idx].emplace_back(const_cast(literal_ptr)); - } - } else { - // Just visit the operands and ignore any output - std::ignore = visit_operands(operands); - } - - return expr; - } - - /** - * @brief Vectors of equality literals in the AST expression, one per input table column - * - * @return Vectors of equality literals, one per input table column - */ - [[nodiscard]] std::vector> get_equality_literals() && - { - return std::move(_equality_literals); - } - - private: - std::vector> _equality_literals; - - protected: - std::vector> visit_operands( - cudf::host_span const> operands) - { - std::vector> transformed_operands; - for (auto const& operand : operands) { - auto const new_operand = operand.get().accept(*this); - transformed_operands.push_back(new_operand); - } - return transformed_operands; - } - size_type _num_input_columns; -}; - /** * @brief Converts AST expression to bloom filter membership (BloomfilterAST) expression. * This is used in row group filtering based on equality predicate. @@ -502,6 +399,17 @@ void read_bloom_filter_data(host_span const> sources } // namespace +size_t aggregate_reader_metadata::get_bloom_filter_alignment() const +{ + // Required alignment: + // https://github.com/NVIDIA/cuCollections/blob/deab5799f3e4226cb8a49acf2199c03b14941ee4/include/cuco/detail/bloom_filter/bloom_filter_impl.cuh#L55-L67 + using policy_type = cuco::arrow_filter_policy; + return alignof(cuco::bloom_filter_ref, + cuco::thread_scope_thread, + policy_type>::filter_block_type); +} + std::vector aggregate_reader_metadata::read_bloom_filters( host_span const> sources, host_span const> row_group_indices, @@ -599,55 +507,19 @@ std::vector aggregate_reader_metadata::get_parquet_types( return parquet_types; } -std::pair>>, bool> -aggregate_reader_metadata::apply_bloom_filters( - host_span const> sources, +std::optional>> aggregate_reader_metadata::apply_bloom_filters( + std::vector& bloom_filter_data, host_span const> input_row_group_indices, + host_span const> literals, size_type total_row_groups, host_span output_dtypes, - host_span output_column_schemas, + host_span equality_col_schemas, std::reference_wrapper filter, rmm::cuda_stream_view stream) const { // Number of input table columns auto const num_input_columns = static_cast(output_dtypes.size()); - // Collect equality literals for each input table column - auto const equality_literals = - equality_literals_collector{filter.get(), num_input_columns}.get_equality_literals(); - - // Collect schema indices of columns with equality predicate(s) - std::vector equality_col_schemas; - thrust::copy_if(thrust::host, - output_column_schemas.begin(), - output_column_schemas.end(), - equality_literals.begin(), - std::back_inserter(equality_col_schemas), - [](auto& eq_literals) { return not eq_literals.empty(); }); - - // Return early if no column with equality predicate(s) - if (equality_col_schemas.empty()) { return {std::nullopt, false}; } - - // Required alignment: - // https://github.com/NVIDIA/cuCollections/blob/deab5799f3e4226cb8a49acf2199c03b14941ee4/include/cuco/detail/bloom_filter/bloom_filter_impl.cuh#L55-L67 - using policy_type = cuco::arrow_filter_policy; - auto constexpr alignment = alignof(cuco::bloom_filter_ref, - cuco::thread_scope_thread, - policy_type>::filter_block_type); - - // Aligned resource adaptor to allocate bloom filter buffers with - auto aligned_mr = - rmm::mr::aligned_resource_adaptor(cudf::get_current_device_resource(), alignment); - - // Read a vector of bloom filter bitset device buffers for all columns with equality - // predicate(s) across all row groups - auto bloom_filter_data = read_bloom_filters( - sources, input_row_group_indices, equality_col_schemas, total_row_groups, stream, aligned_mr); - - // No bloom filter buffers, return early - if (bloom_filter_data.empty()) { return {std::nullopt, false}; } - // Get parquet types for the predicate columns auto const parquet_types = get_parquet_types(input_row_group_indices, equality_col_schemas); @@ -684,13 +556,13 @@ aggregate_reader_metadata::apply_bloom_filters( auto const& dtype = output_dtypes[input_col_idx]; // Skip if no equality literals for this column - if (equality_literals[input_col_idx].empty()) { return; } + if (literals[input_col_idx].empty()) { return; } // Skip if non-comparable (compound) type except string if (cudf::is_compound(dtype) and dtype.id() != cudf::type_id::STRING) { return; } // Add a column for all literals associated with an equality column - for (auto const& literal : equality_literals[input_col_idx]) { + for (auto const& literal : literals[input_col_idx]) { bloom_filter_membership_columns.emplace_back(cudf::type_dispatcher( dtype, bloom_filter_col, equality_col_idx, dtype, literal, stream)); } @@ -702,16 +574,92 @@ aggregate_reader_metadata::apply_bloom_filters( // Convert AST to BloomfilterAST expression with reference to bloom filter membership // in above `bloom_filter_membership_table` - bloom_filter_expression_converter bloom_filter_expr{ - filter.get(), num_input_columns, {equality_literals}}; + bloom_filter_expression_converter bloom_filter_expr{filter.get(), num_input_columns, {literals}}; // Filter bloom filter membership table with the BloomfilterAST expression and collect // filtered row group indices - return {collect_filtered_row_group_indices(bloom_filter_membership_table, - bloom_filter_expr.get_bloom_filter_expr(), - input_row_group_indices, - stream), - true}; + return collect_filtered_row_group_indices(bloom_filter_membership_table, + bloom_filter_expr.get_bloom_filter_expr(), + input_row_group_indices, + stream); +} + +equality_literals_collector::equality_literals_collector() = default; + +equality_literals_collector::equality_literals_collector(ast::expression const& expr, + cudf::size_type num_input_columns) + : _num_input_columns{num_input_columns} +{ + _literals.resize(_num_input_columns); + expr.accept(*this); +} + +std::reference_wrapper equality_literals_collector::visit( + ast::literal const& expr) +{ + return expr; +} + +std::reference_wrapper equality_literals_collector::visit( + ast::column_reference const& expr) +{ + CUDF_EXPECTS(expr.get_table_source() == ast::table_reference::LEFT, + "BloomfilterAST supports only left table"); + CUDF_EXPECTS(expr.get_column_index() < _num_input_columns, + "Column index cannot be more than number of columns in the table"); + return expr; +} + +std::reference_wrapper equality_literals_collector::visit( + ast::column_name_reference const& expr) +{ + CUDF_FAIL("Column name reference is not supported in BloomfilterAST"); +} + +std::reference_wrapper equality_literals_collector::visit( + ast::operation const& expr) +{ + using cudf::ast::ast_operator; + auto const operands = expr.get_operands(); + auto const op = expr.get_operator(); + + if (auto* v = dynamic_cast(&operands[0].get())) { + // First operand should be column reference, second should be literal. + CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 2, + "Only binary operations are supported on column reference"); + auto const literal_ptr = dynamic_cast(&operands[1].get()); + CUDF_EXPECTS(literal_ptr != nullptr, + "Second operand of binary operation with column reference must be a literal"); + v->accept(*this); + + // Push to the corresponding column's literals list iff equality predicate is seen + if (op == ast_operator::EQUAL) { + auto const col_idx = v->get_column_index(); + _literals[col_idx].emplace_back(const_cast(literal_ptr)); + } + } else { + // Just visit the operands and ignore any output + std::ignore = visit_operands(operands); + } + + return expr; +} + +std::vector> equality_literals_collector::get_literals() && +{ + return std::move(_literals); +} + +std::vector> +equality_literals_collector::visit_operands( + cudf::host_span const> operands) +{ + std::vector> transformed_operands; + for (auto const& operand : operands) { + auto const new_operand = operand.get().accept(*this); + transformed_operands.push_back(new_operand); + } + return transformed_operands; } } // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/predicate_pushdown.cpp b/cpp/src/io/parquet/predicate_pushdown.cpp index 1508b7eef8b..e1d7dbb03b3 100644 --- a/cpp/src/io/parquet/predicate_pushdown.cpp +++ b/cpp/src/io/parquet/predicate_pushdown.cpp @@ -29,6 +29,8 @@ #include #include +#include + #include #include @@ -388,9 +390,7 @@ class stats_expression_converter : public ast::detail::expression_transformer { }; } // namespace -std::pair>>, surviving_row_group_metrics> -aggregate_reader_metadata::filter_row_groups( - host_span const> sources, +std::optional>> aggregate_reader_metadata::apply_stats_filters( host_span const> input_row_group_indices, size_type total_row_groups, host_span output_dtypes, @@ -430,14 +430,33 @@ aggregate_reader_metadata::filter_row_groups( static_cast(output_dtypes.size())}; // Filter stats table with StatsAST expression and collect filtered row group indices - auto const filtered_row_group_indices = collect_filtered_row_group_indices( + return collect_filtered_row_group_indices( stats_table, stats_expr.get_stats_expr(), input_row_group_indices, stream); +} + +std::pair>>, surviving_row_group_metrics> +aggregate_reader_metadata::filter_row_groups( + host_span const> sources, + host_span const> input_row_group_indices, + size_type total_row_groups, + host_span output_dtypes, + host_span output_column_schemas, + std::reference_wrapper filter, + rmm::cuda_stream_view stream) const +{ + // Apply stats filtering on input row groups + auto const stats_filtered_row_groups = apply_stats_filters(input_row_group_indices, + total_row_groups, + output_dtypes, + output_column_schemas, + filter, + stream); // Number of surviving row groups after applying stats filter auto const num_stats_filtered_row_groups = - filtered_row_group_indices.has_value() - ? std::accumulate(filtered_row_group_indices.value().cbegin(), - filtered_row_group_indices.value().cend(), + stats_filtered_row_groups.has_value() + ? std::accumulate(stats_filtered_row_groups.value().cbegin(), + stats_filtered_row_groups.value().cend(), size_type{0}, [](auto& sum, auto const& per_file_row_groups) { return sum + per_file_row_groups.size(); @@ -446,37 +465,75 @@ aggregate_reader_metadata::filter_row_groups( // Span of row groups to apply bloom filtering on. auto const bloom_filter_input_row_groups = - filtered_row_group_indices.has_value() - ? host_span const>(filtered_row_group_indices.value()) + stats_filtered_row_groups.has_value() + ? host_span const>(stats_filtered_row_groups.value()) : input_row_group_indices; - // Apply bloom filtering on the bloom filter input row groups - auto const [bloom_filtered_row_groups, bloom_filters_exist] = - apply_bloom_filters(sources, - bloom_filter_input_row_groups, - num_stats_filtered_row_groups, - output_dtypes, - output_column_schemas, - filter, - stream); + // Collect equality literals for each input table column for bloom filtering + auto const equality_literals = + equality_literals_collector{filter.get(), static_cast(output_dtypes.size())} + .get_literals(); + + // Collect schema indices of columns with equality predicate(s) + std::vector equality_col_schemas; + thrust::copy_if(thrust::host, + output_column_schemas.begin(), + output_column_schemas.end(), + equality_literals.begin(), + std::back_inserter(equality_col_schemas), + [](auto& eq_literals) { return not eq_literals.empty(); }); + + // Return early if no column with equality predicate(s) + if (equality_col_schemas.empty()) { + return {stats_filtered_row_groups, + {std::make_optional(num_stats_filtered_row_groups), std::nullopt}}; + } + + // Aligned resource adaptor to allocate bloom filter buffers with + auto aligned_mr = rmm::mr::aligned_resource_adaptor(cudf::get_current_device_resource(), + get_bloom_filter_alignment()); + + // Read a vector of bloom filter bitset device buffers for all columns with equality + // predicate(s) across all row groups + auto bloom_filter_data = read_bloom_filters(sources, + bloom_filter_input_row_groups, + equality_col_schemas, + num_stats_filtered_row_groups, + stream, + aligned_mr); + + // No bloom filter buffers, return early + if (bloom_filter_data.empty()) { + return {stats_filtered_row_groups, + {std::make_optional(num_stats_filtered_row_groups), std::nullopt}}; + } + + // Apply bloom filtering on the output row groups from stats filter + auto const bloom_filtered_row_groups = apply_bloom_filters(bloom_filter_data, + bloom_filter_input_row_groups, + equality_literals, + num_stats_filtered_row_groups, + output_dtypes, + equality_col_schemas, + filter, + stream); // Number of surviving row groups after applying bloom filter auto const num_bloom_filtered_row_groups = - bloom_filters_exist - ? (bloom_filtered_row_groups.has_value() - ? std::make_optional(std::accumulate(bloom_filtered_row_groups.value().cbegin(), - bloom_filtered_row_groups.value().cend(), - size_type{0}, - [](auto& sum, auto const& per_file_row_groups) { - return sum + per_file_row_groups.size(); - })) - : std::make_optional(num_stats_filtered_row_groups)) - : std::nullopt; + bloom_filtered_row_groups.has_value() + ? std::accumulate(bloom_filtered_row_groups.value().cbegin(), + bloom_filtered_row_groups.value().cend(), + size_type{0}, + [](auto& sum, auto const& per_file_row_groups) { + return sum + per_file_row_groups.size(); + }) + : num_stats_filtered_row_groups; // Return bloom filtered row group indices iff collected return { - bloom_filtered_row_groups.has_value() ? bloom_filtered_row_groups : filtered_row_group_indices, - {std::make_optional(num_stats_filtered_row_groups), num_bloom_filtered_row_groups}}; + bloom_filtered_row_groups.has_value() ? bloom_filtered_row_groups : stats_filtered_row_groups, + {std::make_optional(num_stats_filtered_row_groups), + std::make_optional(num_bloom_filtered_row_groups)}}; } // convert column named expression to column index reference expression diff --git a/cpp/src/io/parquet/reader_impl_helpers.hpp b/cpp/src/io/parquet/reader_impl_helpers.hpp index c4372b2c1ff..f08ba5f8b85 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.hpp +++ b/cpp/src/io/parquet/reader_impl_helpers.hpp @@ -203,6 +203,11 @@ class aggregate_reader_metadata { */ void column_info_for_row_group(row_group_info& rg_info, size_type chunk_start_row) const; + /** + * @brief Returns the required alignment for bloom filter buffers + */ + [[nodiscard]] size_t get_bloom_filter_alignment() const; + /** * @brief Reads bloom filter bitsets for the specified columns from the given lists of row * groups. @@ -237,6 +242,50 @@ class aggregate_reader_metadata { host_span const> row_group_indices, host_span column_schemas) const; + /** + * @brief Filters the row groups using stats filter + * + * @param input_row_group_indices Lists of input row groups, one per source + * @param total_row_groups Total number of row groups in `input_row_group_indices` + * @param output_dtypes Datatypes of output columns + * @param output_column_schemas schema indices of output columns + * @param filter AST expression to filter row groups based on bloom filter membership + * @param stream CUDA stream used for device memory operations and kernel launches + * + * @return Filtered row group indices if any is filtered + */ + [[nodiscard]] std::optional>> apply_stats_filters( + host_span const> input_row_group_indices, + size_type total_row_groups, + host_span output_dtypes, + host_span output_column_schemas, + std::reference_wrapper filter, + rmm::cuda_stream_view stream) const; + + /** + * @brief Filters the row groups using bloom filters + * + * @param bloom_filter_data Bloom filter data device buffers for each input row group + * @param input_row_group_indices Lists of input row groups, one per source + * @param literals Lists of equality literals, one per each input row group + * @param total_row_groups Total number of row groups in `input_row_group_indices` + * @param output_dtypes Datatypes of output columns + * @param equality_col_schemas schema indices of equality columns only + * @param filter AST expression to filter row groups based on bloom filter membership + * @param stream CUDA stream used for device memory operations and kernel launches + * + * @return Filtered row group indices if any is filtered + */ + [[nodiscard]] std::optional>> apply_bloom_filters( + std::vector& bloom_filter_data, + host_span const> input_row_group_indices, + host_span const> literals, + size_type total_row_groups, + host_span output_dtypes, + host_span equality_col_schemas, + std::reference_wrapper filter, + rmm::cuda_stream_view stream) const; + public: aggregate_reader_metadata(host_span const> sources, bool use_arrow_schema, @@ -363,7 +412,7 @@ class aggregate_reader_metadata { [[nodiscard]] std::vector get_pandas_index_names() const; /** - * @brief Filters the row groups based on predicate filter + * @brief Filters the row groups using stats and bloom filters based on predicate filter * * @param sources Lists of input datasources * @param input_row_group_indices Lists of input row groups, one per source @@ -385,29 +434,6 @@ class aggregate_reader_metadata { std::reference_wrapper filter, rmm::cuda_stream_view stream) const; - /** - * @brief Filters the row groups using bloom filters - * - * @param sources Dataset sources - * @param input_row_group_indices Lists of input row groups, one per source - * @param total_row_groups Total number of row groups in `input_row_group_indices` - * @param output_dtypes Datatypes of output columns - * @param output_column_schemas schema indices of output columns - * @param filter AST expression to filter row groups based on bloom filter membership - * @param stream CUDA stream used for device memory operations and kernel launches - * - * @return A pair of filtered row group indices if any is filtered, and a boolean indicating if - * bloom filtering was applied - */ - [[nodiscard]] std::pair>>, bool> - apply_bloom_filters(host_span const> sources, - host_span const> input_row_group_indices, - size_type total_row_groups, - host_span output_dtypes, - host_span output_column_schemas, - std::reference_wrapper filter, - rmm::cuda_stream_view stream) const; - /** * @brief Filters and reduces down to a selection of row groups * @@ -513,6 +539,54 @@ class named_to_reference_converter : public ast::detail::expression_transformer std::list _operators; }; +/** + * @brief Collects lists of equality predicate literals in the AST expression, one list per input + * table column. This is used in row group filtering based on bloom filters. + */ +class equality_literals_collector : public ast::detail::expression_transformer { + public: + equality_literals_collector(); + + equality_literals_collector(ast::expression const& expr, cudf::size_type num_input_columns); + + /** + * @copydoc ast::detail::expression_transformer::visit(ast::literal const& ) + */ + std::reference_wrapper visit(ast::literal const& expr) override; + + /** + * @copydoc ast::detail::expression_transformer::visit(ast::column_reference const& ) + */ + std::reference_wrapper visit(ast::column_reference const& expr) override; + + /** + * @copydoc ast::detail::expression_transformer::visit(ast::column_name_reference const& ) + */ + std::reference_wrapper visit( + ast::column_name_reference const& expr) override; + + /** + * @copydoc ast::detail::expression_transformer::visit(ast::operation const& ) + */ + std::reference_wrapper visit(ast::operation const& expr) override; + + /** + * @brief Vectors of equality literals in the AST expression, one per input table column + * + * @return Vectors of equality literals, one per input table column + */ + [[nodiscard]] std::vector> get_literals() &&; + + protected: + std::vector> visit_operands( + cudf::host_span const> operands); + + size_type _num_input_columns; + + private: + std::vector> _literals; +}; + /** * @brief Get the column names in expression object *