From 63b442e2b58c846efa49c1d70ba9dd15b1903ea6 Mon Sep 17 00:00:00 2001 From: Muhammad Haseeb <14217455+mhaseeb123@users.noreply.github.com> Date: Thu, 19 Sep 2024 15:14:30 -0700 Subject: [PATCH] Revert "Refactor mixed_semi_join using cuco::static_set (#16230)" This reverts commit e68f55c98f257bdeedeb31e68c9737264bd0b393. --- cpp/src/join/join_common_utils.hpp | 6 ++ cpp/src/join/mixed_join_common_utils.cuh | 33 --------- cpp/src/join/mixed_join_kernels_semi.cu | 35 +++++---- cpp/src/join/mixed_join_kernels_semi.cuh | 6 +- cpp/src/join/mixed_join_semi.cu | 90 +++++++++++++++++------- cpp/tests/join/mixed_join_tests.cu | 30 -------- 6 files changed, 91 insertions(+), 109 deletions(-) diff --git a/cpp/src/join/join_common_utils.hpp b/cpp/src/join/join_common_utils.hpp index 573101cefd9..86402a0e7de 100644 --- a/cpp/src/join/join_common_utils.hpp +++ b/cpp/src/join/join_common_utils.hpp @@ -22,6 +22,7 @@ #include #include +#include #include #include @@ -50,6 +51,11 @@ using mixed_multimap_type = cudf::detail::cuco_allocator, cuco::legacy::double_hashing<1, hash_type, hash_type>>; +using semi_map_type = cuco::legacy::static_map>; + using row_hash_legacy = cudf::row_hasher; diff --git a/cpp/src/join/mixed_join_common_utils.cuh b/cpp/src/join/mixed_join_common_utils.cuh index 89c13285cfe..19701816867 100644 --- a/cpp/src/join/mixed_join_common_utils.cuh +++ b/cpp/src/join/mixed_join_common_utils.cuh @@ -25,7 +25,6 @@ #include #include -#include namespace cudf { namespace detail { @@ -161,38 +160,6 @@ struct pair_expression_equality : public expression_equality { } }; -/** - * @brief Equality comparator that composes two row_equality comparators. - */ -struct double_row_equality_comparator { - row_equality const equality_comparator; - row_equality const conditional_comparator; - - __device__ bool operator()(size_type lhs_row_index, size_type rhs_row_index) const noexcept - { - using experimental::row::lhs_index_type; - using experimental::row::rhs_index_type; - - return equality_comparator(lhs_index_type{lhs_row_index}, rhs_index_type{rhs_row_index}) && - conditional_comparator(lhs_index_type{lhs_row_index}, rhs_index_type{rhs_row_index}); - } -}; - -// A CUDA Cooperative Group of 4 threads for the hash set. -auto constexpr DEFAULT_MIXED_JOIN_CG_SIZE = 4; - -// The hash set type used by mixed_semi_join with the build_table. -using hash_set_type = cuco::static_set, - cuda::thread_scope_device, - double_row_equality_comparator, - cuco::linear_probing, - cudf::detail::cuco_allocator, - cuco::storage<1>>; - -// The hash_set_ref_type used by mixed_semi_join kerenels for probing. -using hash_set_ref_type = hash_set_type::ref_type; - } // namespace detail } // namespace cudf diff --git a/cpp/src/join/mixed_join_kernels_semi.cu b/cpp/src/join/mixed_join_kernels_semi.cu index f2c5ff13638..7459ac3e99c 100644 --- a/cpp/src/join/mixed_join_kernels_semi.cu +++ b/cpp/src/join/mixed_join_kernels_semi.cu @@ -38,16 +38,12 @@ CUDF_KERNEL void __launch_bounds__(block_size) table_device_view right_table, table_device_view probe, table_device_view build, + row_hash const hash_probe, row_equality const equality_probe, - hash_set_ref_type set_ref, + cudf::detail::semi_map_type::device_view hash_table_view, cudf::device_span left_table_keep_mask, cudf::ast::detail::expression_device_view device_expression_data) { - auto constexpr cg_size = hash_set_ref_type::cg_size; - - auto const tile = - cooperative_groups::tiled_partition(cooperative_groups::this_thread_block()); - // Normally the casting of a shared memory array is used to create multiple // arrays of different types from the shared memory buffer, but here it is // used to circumvent conflicts between arrays of different types between @@ -56,24 +52,24 @@ CUDF_KERNEL void __launch_bounds__(block_size) cudf::ast::detail::IntermediateDataType* intermediate_storage = reinterpret_cast*>(raw_intermediate_storage); auto thread_intermediate_storage = - &intermediate_storage[tile.meta_group_rank() * device_expression_data.num_intermediates]; + &intermediate_storage[threadIdx.x * device_expression_data.num_intermediates]; + + cudf::size_type const left_num_rows = left_table.num_rows(); + cudf::size_type const right_num_rows = right_table.num_rows(); + auto const outer_num_rows = left_num_rows; - cudf::size_type const outer_num_rows = left_table.num_rows(); - auto const outer_row_index = cudf::detail::grid_1d::global_thread_id() / cg_size; + cudf::size_type outer_row_index = threadIdx.x + blockIdx.x * block_size; auto evaluator = cudf::ast::detail::expression_evaluator( left_table, right_table, device_expression_data); if (outer_row_index < outer_num_rows) { - // Make sure to swap_tables here as hash_set will use probe table as the left one. - auto constexpr swap_tables = true; // Figure out the number of elements for this key. auto equality = single_expression_equality{ - evaluator, thread_intermediate_storage, swap_tables, equality_probe}; + evaluator, thread_intermediate_storage, false, equality_probe}; - auto const set_ref_equality = set_ref.with_key_eq(equality); - auto const result = set_ref_equality.contains(tile, outer_row_index); - if (tile.thread_rank() == 0) left_table_keep_mask[outer_row_index] = result; + left_table_keep_mask[outer_row_index] = + hash_table_view.contains(outer_row_index, hash_probe, equality); } } @@ -82,8 +78,9 @@ void launch_mixed_join_semi(bool has_nulls, table_device_view right_table, table_device_view probe, table_device_view build, + row_hash const hash_probe, row_equality const equality_probe, - hash_set_ref_type set_ref, + cudf::detail::semi_map_type::device_view hash_table_view, cudf::device_span left_table_keep_mask, cudf::ast::detail::expression_device_view device_expression_data, detail::grid_1d const config, @@ -97,8 +94,9 @@ void launch_mixed_join_semi(bool has_nulls, right_table, probe, build, + hash_probe, equality_probe, - set_ref, + hash_table_view, left_table_keep_mask, device_expression_data); } else { @@ -108,8 +106,9 @@ void launch_mixed_join_semi(bool has_nulls, right_table, probe, build, + hash_probe, equality_probe, - set_ref, + hash_table_view, left_table_keep_mask, device_expression_data); } diff --git a/cpp/src/join/mixed_join_kernels_semi.cuh b/cpp/src/join/mixed_join_kernels_semi.cuh index b08298e64e4..43714ffb36a 100644 --- a/cpp/src/join/mixed_join_kernels_semi.cuh +++ b/cpp/src/join/mixed_join_kernels_semi.cuh @@ -45,8 +45,9 @@ namespace detail { * @param[in] right_table The right table * @param[in] probe The table with which to probe the hash table for matches. * @param[in] build The table with which the hash table was built. + * @param[in] hash_probe The hasher used for the probe table. * @param[in] equality_probe The equality comparator used when probing the hash table. - * @param[in] set_ref The hash table device view built from `build`. + * @param[in] hash_table_view The hash table built from `build`. * @param[out] left_table_keep_mask The result of the join operation with "true" element indicating * the corresponding index from left table is present in output * @param[in] device_expression_data Container of device data required to evaluate the desired @@ -57,8 +58,9 @@ void launch_mixed_join_semi(bool has_nulls, table_device_view right_table, table_device_view probe, table_device_view build, + row_hash const hash_probe, row_equality const equality_probe, - hash_set_ref_type set_ref, + cudf::detail::semi_map_type::device_view hash_table_view, cudf::device_span left_table_keep_mask, cudf::ast::detail::expression_device_view device_expression_data, detail::grid_1d const config, diff --git a/cpp/src/join/mixed_join_semi.cu b/cpp/src/join/mixed_join_semi.cu index 719b1d47105..cfb785e242c 100644 --- a/cpp/src/join/mixed_join_semi.cu +++ b/cpp/src/join/mixed_join_semi.cu @@ -46,6 +46,45 @@ namespace cudf { namespace detail { +namespace { +/** + * @brief Device functor to create a pair of hash value and index for a given row. + */ +struct make_pair_function_semi { + __device__ __forceinline__ cudf::detail::pair_type operator()(size_type i) const noexcept + { + // The value is irrelevant since we only ever use the hash map to check for + // membership of a particular row index. + return cuco::make_pair(static_cast(i), 0); + } +}; + +/** + * @brief Equality comparator that composes two row_equality comparators. + */ +class double_row_equality { + public: + double_row_equality(row_equality equality_comparator, row_equality conditional_comparator) + : _equality_comparator{equality_comparator}, _conditional_comparator{conditional_comparator} + { + } + + __device__ bool operator()(size_type lhs_row_index, size_type rhs_row_index) const noexcept + { + using experimental::row::lhs_index_type; + using experimental::row::rhs_index_type; + + return _equality_comparator(lhs_index_type{lhs_row_index}, rhs_index_type{rhs_row_index}) && + _conditional_comparator(lhs_index_type{lhs_row_index}, rhs_index_type{rhs_row_index}); + } + + private: + row_equality _equality_comparator; + row_equality _conditional_comparator; +}; + +} // namespace + std::unique_ptr> mixed_join_semi( table_view const& left_equality, table_view const& right_equality, @@ -57,7 +96,7 @@ std::unique_ptr> mixed_join_semi( rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS((join_type != join_kind::INNER_JOIN) and (join_type != join_kind::LEFT_JOIN) and + CUDF_EXPECTS((join_type != join_kind::INNER_JOIN) && (join_type != join_kind::LEFT_JOIN) && (join_type != join_kind::FULL_JOIN), "Inner, left, and full joins should use mixed_join."); @@ -98,7 +137,7 @@ std::unique_ptr> mixed_join_semi( // output column and follow the null-supporting expression evaluation code // path. auto const has_nulls = cudf::nullate::DYNAMIC{ - cudf::has_nulls(left_equality) or cudf::has_nulls(right_equality) or + cudf::has_nulls(left_equality) || cudf::has_nulls(right_equality) || binary_predicate.may_evaluate_null(left_conditional, right_conditional, stream)}; auto const parser = ast::detail::expression_parser{ @@ -117,20 +156,27 @@ std::unique_ptr> mixed_join_semi( auto right_conditional_view = table_device_view::create(right_conditional, stream); auto const preprocessed_build = - cudf::experimental::row::equality::preprocessed_table::create(build, stream); + experimental::row::equality::preprocessed_table::create(build, stream); auto const preprocessed_probe = - cudf::experimental::row::equality::preprocessed_table::create(probe, stream); + experimental::row::equality::preprocessed_table::create(probe, stream); auto const row_comparator = - cudf::experimental::row::equality::two_table_comparator{preprocessed_build, preprocessed_probe}; + cudf::experimental::row::equality::two_table_comparator{preprocessed_probe, preprocessed_build}; auto const equality_probe = row_comparator.equal_to(has_nulls, compare_nulls); + semi_map_type hash_table{ + compute_hash_table_size(build.num_rows()), + cuco::empty_key{std::numeric_limits::max()}, + cuco::empty_value{cudf::detail::JoinNoneValue}, + cudf::detail::cuco_allocator{rmm::mr::polymorphic_allocator{}, stream}, + stream.value()}; + // Create hash table containing all keys found in right table // TODO: To add support for nested columns we will need to flatten in many // places. However, this probably isn't worth adding any time soon since we // won't be able to support AST conditions for those types anyway. auto const build_nulls = cudf::nullate::DYNAMIC{cudf::has_nulls(build)}; auto const row_hash_build = cudf::experimental::row::hash::row_hasher{preprocessed_build}; - + auto const hash_build = row_hash_build.device_hasher(build_nulls); // Since we may see multiple rows that are identical in the equality tables // but differ in the conditional tables, the equality comparator used for // insertion must account for both sets of tables. An alternative solution @@ -145,28 +191,20 @@ std::unique_ptr> mixed_join_semi( auto const equality_build_equality = row_comparator_build.equal_to(build_nulls, compare_nulls); auto const preprocessed_build_condtional = - cudf::experimental::row::equality::preprocessed_table::create(right_conditional, stream); + experimental::row::equality::preprocessed_table::create(right_conditional, stream); auto const row_comparator_conditional_build = cudf::experimental::row::equality::two_table_comparator{preprocessed_build_condtional, preprocessed_build_condtional}; auto const equality_build_conditional = row_comparator_conditional_build.equal_to(build_nulls, compare_nulls); + double_row_equality equality_build{equality_build_equality, equality_build_conditional}; + make_pair_function_semi pair_func_build{}; - hash_set_type row_set{ - {compute_hash_table_size(build.num_rows())}, - cuco::empty_key{JoinNoneValue}, - {equality_build_equality, equality_build_conditional}, - {row_hash_build.device_hasher(build_nulls)}, - {}, - {}, - cudf::detail::cuco_allocator{rmm::mr::polymorphic_allocator{}, stream}, - {stream.value()}}; - - auto iter = thrust::make_counting_iterator(0); + auto iter = cudf::detail::make_counting_transform_iterator(0, pair_func_build); // skip rows that are null here. if ((compare_nulls == null_equality::EQUAL) or (not nullable(build))) { - row_set.insert(iter, iter + right_num_rows, stream.value()); + hash_table.insert(iter, iter + right_num_rows, hash_build, equality_build, stream.value()); } else { thrust::counting_iterator stencil(0); auto const [row_bitmask, _] = @@ -174,19 +212,18 @@ std::unique_ptr> mixed_join_semi( row_is_valid pred{static_cast(row_bitmask.data())}; // insert valid rows - row_set.insert_if(iter, iter + right_num_rows, stencil, pred, stream.value()); + hash_table.insert_if( + iter, iter + right_num_rows, stencil, pred, hash_build, equality_build, stream.value()); } + auto hash_table_view = hash_table.get_device_view(); + detail::grid_1d const config(outer_num_rows, DEFAULT_JOIN_BLOCK_SIZE); - auto const shmem_size_per_block = - parser.shmem_per_thread * - cuco::detail::int_div_ceil(config.num_threads_per_block, hash_set_type::cg_size); + auto const shmem_size_per_block = parser.shmem_per_thread * config.num_threads_per_block; auto const row_hash = cudf::experimental::row::hash::row_hasher{preprocessed_probe}; auto const hash_probe = row_hash.device_hasher(has_nulls); - hash_set_ref_type const row_set_ref = row_set.ref(cuco::contains).with_hash_function(hash_probe); - // Vector used to indicate indices from left/probe table which are present in output auto left_table_keep_mask = rmm::device_uvector(probe.num_rows(), stream); @@ -195,8 +232,9 @@ std::unique_ptr> mixed_join_semi( *right_conditional_view, *probe_view, *build_view, + hash_probe, equality_probe, - row_set_ref, + hash_table_view, cudf::device_span(left_table_keep_mask), parser.device_expression_data, config, diff --git a/cpp/tests/join/mixed_join_tests.cu b/cpp/tests/join/mixed_join_tests.cu index 08a0136700d..6c147c8a128 100644 --- a/cpp/tests/join/mixed_join_tests.cu +++ b/cpp/tests/join/mixed_join_tests.cu @@ -778,21 +778,6 @@ TYPED_TEST(MixedLeftSemiJoinTest, BasicEquality) {1}); } -TYPED_TEST(MixedLeftSemiJoinTest, MixedLeftSemiJoinGatherMap) -{ - auto const col_ref_left_1 = cudf::ast::column_reference(0, cudf::ast::table_reference::LEFT); - auto const col_ref_right_1 = cudf::ast::column_reference(0, cudf::ast::table_reference::RIGHT); - auto left_one_greater_right_one = - cudf::ast::operation(cudf::ast::ast_operator::GREATER, col_ref_left_1, col_ref_right_1); - - this->test({{2, 3, 9, 0, 1, 7, 4, 6, 5, 8}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}}, - {{6, 5, 9, 8, 10, 32}, {0, 1, 2, 3, 4, 5}, {7, 8, 9, 0, 1, 2}}, - {0}, - {1}, - left_one_greater_right_one, - {2, 7, 8}); -} - TYPED_TEST(MixedLeftSemiJoinTest, BasicEqualityDuplicates) { this->test({{0, 1, 2, 1}, {3, 4, 5, 6}, {10, 20, 30, 40}}, @@ -915,18 +900,3 @@ TYPED_TEST(MixedLeftAntiJoinTest, AsymmetricLeftLargerEquality) left_zero_eq_right_zero, {0, 1, 3}); } - -TYPED_TEST(MixedLeftAntiJoinTest, MixedLeftAntiJoinGatherMap) -{ - auto const col_ref_left_1 = cudf::ast::column_reference(0, cudf::ast::table_reference::LEFT); - auto const col_ref_right_1 = cudf::ast::column_reference(0, cudf::ast::table_reference::RIGHT); - auto left_one_greater_right_one = - cudf::ast::operation(cudf::ast::ast_operator::GREATER, col_ref_left_1, col_ref_right_1); - - this->test({{2, 3, 9, 0, 1, 7, 4, 6, 5, 8}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}}, - {{6, 5, 9, 8, 10, 32}, {0, 1, 2, 3, 4, 5}, {7, 8, 9, 0, 1, 2}}, - {0}, - {1}, - left_one_greater_right_one, - {0, 1, 3, 4, 5, 6, 9}); -}