Skip to content

Commit

Permalink
Refactor collect_set to use cudf::distinct and `cudf::lists::dist…
Browse files Browse the repository at this point in the history
…inct` (#11228)

The current groupby/reducttion `collect_set` aggregations use `lists::drop_list_duplicates` to generate set(s) of distinct elements. This PR changes that to use `cudf::distinct` and `cudf::lists::distinct` instead, which have some advantages including:
 * Fully supporting nested types, and:
 * Achieving better performance (`O(n)` instead of `O(nlogn)`) by internally using hash table instead of segmented sort.

This also enables nested types support for `collect_set` in spark-rapids (issue NVIDIA/spark-rapids#5508).

The changes in Java code here are only to fix unit tests. Previously, they were implemented with the assumption that the `collect_set` results are sorted, now they fail when the results are no longer sorted.

Authors:
  - Nghia Truong (https://github.com/ttnghia)

Approvers:
  - Jason Lowe (https://github.com/jlowe)
  - David Wendt (https://github.com/davidwendt)
  - MithunR (https://github.com/mythrocks)

URL: #11228
  • Loading branch information
ttnghia authored Jul 15, 2022
1 parent 4528d8e commit b654597
Show file tree
Hide file tree
Showing 9 changed files with 701 additions and 431 deletions.
62 changes: 31 additions & 31 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include <cudf/detail/unary.hpp>
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/groupby.hpp>
#include <cudf/lists/detail/drop_list_duplicates.hpp>
#include <cudf/lists/detail/stream_compaction.hpp>
#include <cudf/table/table.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
Expand Down Expand Up @@ -99,7 +99,7 @@ void aggregate_result_functor::operator()<aggregation::SUM>(aggregation const& a
agg,
detail::group_sum(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::PRODUCT>(aggregation const& agg)
Expand All @@ -111,7 +111,7 @@ void aggregate_result_functor::operator()<aggregation::PRODUCT>(aggregation cons
agg,
detail::group_product(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::ARGMAX>(aggregation const& agg)
Expand All @@ -126,7 +126,7 @@ void aggregate_result_functor::operator()<aggregation::ARGMAX>(aggregation const
helper.key_sort_order(stream),
stream,
mr));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::ARGMIN>(aggregation const& agg)
Expand All @@ -141,7 +141,7 @@ void aggregate_result_functor::operator()<aggregation::ARGMIN>(aggregation const
helper.key_sort_order(stream),
stream,
mr));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::MIN>(aggregation const& agg)
Expand Down Expand Up @@ -181,7 +181,7 @@ void aggregate_result_functor::operator()<aggregation::MIN>(aggregation const& a
}();

cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::MAX>(aggregation const& agg)
Expand Down Expand Up @@ -221,7 +221,7 @@ void aggregate_result_functor::operator()<aggregation::MAX>(aggregation const& a
}();

cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::MEAN>(aggregation const& agg)
Expand All @@ -248,7 +248,7 @@ void aggregate_result_functor::operator()<aggregation::MEAN>(aggregation const&
stream,
mr);
cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::M2>(aggregation const& agg)
Expand All @@ -263,7 +263,7 @@ void aggregate_result_functor::operator()<aggregation::M2>(aggregation const& ag
values,
agg,
detail::group_m2(get_grouped_values(), mean_result, helper.group_labels(stream), stream, mr));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::VARIANCE>(aggregation const& agg)
Expand All @@ -286,7 +286,7 @@ void aggregate_result_functor::operator()<aggregation::VARIANCE>(aggregation con
stream,
mr);
cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::STD>(aggregation const& agg)
Expand All @@ -300,7 +300,7 @@ void aggregate_result_functor::operator()<aggregation::STD>(aggregation const& a

auto result = cudf::detail::unary_operation(var_result, unary_operator::SQRT, stream, mr);
cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::QUANTILE>(aggregation const& agg)
Expand All @@ -321,7 +321,7 @@ void aggregate_result_functor::operator()<aggregation::QUANTILE>(aggregation con
stream,
mr);
cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::MEDIAN>(aggregation const& agg)
Expand All @@ -341,7 +341,7 @@ void aggregate_result_functor::operator()<aggregation::MEDIAN>(aggregation const
stream,
mr);
cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::NUNIQUE>(aggregation const& agg)
Expand All @@ -358,7 +358,7 @@ void aggregate_result_functor::operator()<aggregation::NUNIQUE>(aggregation cons
stream,
mr);
cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::NTH_ELEMENT>(aggregation const& agg)
Expand Down Expand Up @@ -404,7 +404,7 @@ void aggregate_result_functor::operator()<aggregation::COLLECT_LIST>(aggregation
stream,
mr);
cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation const& agg)
Expand All @@ -426,9 +426,9 @@ void aggregate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation
cache.add_result(
values,
agg,
lists::detail::drop_list_duplicates(
lists_column_view(collect_result->view()), nulls_equal, nans_equal, stream, mr));
};
lists::detail::distinct(
lists_column_view{collect_result->view()}, nulls_equal, nans_equal, stream, mr));
}

/**
* @brief Perform merging for the lists that correspond to the same key value.
Expand All @@ -455,7 +455,7 @@ void aggregate_result_functor::operator()<aggregation::MERGE_LISTS>(aggregation
agg,
detail::group_merge_lists(
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr));
};
}

/**
* @brief Perform merging for the lists corresponding to the same key value, then dropping duplicate
Expand All @@ -473,13 +473,13 @@ void aggregate_result_functor::operator()<aggregation::MERGE_LISTS>(aggregation
* column for this aggregation.
*
* Firstly, this aggregation performs `MERGE_LISTS` to concatenate the input lists (corresponding to
* the same key) into intermediate lists, then it calls `lists::drop_list_duplicates` on them to
* the same key) into intermediate lists, then it calls `lists::distinct` on them to
* remove duplicate list entries. As such, the input (partial results) to this aggregation should be
* generated by (distributed) `COLLECT_LIST` aggregations, not `COLLECT_SET`, to avoid unnecessarily
* removing duplicate entries for the partial results.
*
* Since duplicate list entries will be removed, the parameters `null_equality` and `nan_equality`
* are needed for calling to `lists::drop_list_duplicates`.
* are needed for calling `lists::distinct`.
*/
template <>
void aggregate_result_functor::operator()<aggregation::MERGE_SETS>(aggregation const& agg)
Expand All @@ -494,12 +494,12 @@ void aggregate_result_functor::operator()<aggregation::MERGE_SETS>(aggregation c
auto const& merge_sets_agg = dynamic_cast<cudf::detail::merge_sets_aggregation const&>(agg);
cache.add_result(values,
agg,
lists::detail::drop_list_duplicates(lists_column_view(merged_result->view()),
merge_sets_agg._nulls_equal,
merge_sets_agg._nans_equal,
stream,
mr));
};
lists::detail::distinct(lists_column_view{merged_result->view()},
merge_sets_agg._nulls_equal,
merge_sets_agg._nans_equal,
stream,
mr));
}

/**
* @brief Perform merging for the M2 values that correspond to the same key value.
Expand Down Expand Up @@ -528,7 +528,7 @@ void aggregate_result_functor::operator()<aggregation::MERGE_M2>(aggregation con
agg,
detail::group_merge_m2(
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr));
};
}

/**
* @brief Creates column views with only valid elements in both input column views
Expand Down Expand Up @@ -600,7 +600,7 @@ void aggregate_result_functor::operator()<aggregation::COVARIANCE>(aggregation c
cov_agg._ddof,
stream,
mr));
};
}

/**
* @brief Perform correlation between two child columns of non-nullable struct column.
Expand Down Expand Up @@ -710,7 +710,7 @@ void aggregate_result_functor::operator()<aggregation::TDIGEST>(aggregation cons
max_centroids,
stream,
mr));
};
}

/**
* @brief Generate a merged tdigest column from a grouped set of input tdigest columns.
Expand Down Expand Up @@ -752,7 +752,7 @@ void aggregate_result_functor::operator()<aggregation::MERGE_TDIGEST>(aggregatio
max_centroids,
stream,
mr));
};
}

} // namespace detail

Expand Down
64 changes: 45 additions & 19 deletions cpp/src/reductions/collect_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,36 @@
#include <cudf/detail/copy_if.cuh>
#include <cudf/detail/iterator.cuh>
#include <cudf/detail/reduction_functions.hpp>
#include <cudf/lists/drop_list_duplicates.hpp>
#include <cudf/lists/lists_column_factories.hpp>
#include <cudf/detail/stream_compaction.hpp>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/scalar/scalar_factories.hpp>

namespace cudf {
namespace reduction {

std::unique_ptr<scalar> drop_duplicates(list_scalar const& scalar,
null_equality nulls_equal,
nan_equality nans_equal,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
namespace {

/**
* @brief Check if we need to handle nulls in the input column.
*
* @param input The input column
* @param null_handling The null handling policy
* @return A boolean value indicating if we need to handle nulls
*/
bool need_handle_nulls(column_view const& input, null_policy null_handling)
{
auto list_wrapper = lists::detail::make_lists_column_from_scalar(scalar, 1, stream, mr);
auto lcw = lists_column_view(list_wrapper->view());
auto no_dup_wrapper = lists::drop_list_duplicates(lcw, nulls_equal, nans_equal, mr);
auto no_dup = lists_column_view(no_dup_wrapper->view()).get_sliced_child(stream);
return make_list_scalar(no_dup, stream, mr);
return null_handling == null_policy::EXCLUDE && input.has_nulls();
}

} // namespace

std::unique_ptr<scalar> collect_list(column_view const& col,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
if (null_handling == null_policy::EXCLUDE && col.has_nulls()) {
if (need_handle_nulls(col, null_handling)) {
auto d_view = column_device_view::create(col, stream);
auto filter = detail::validity_accessor(*d_view);
auto null_purged_table = detail::copy_if(table_view{{col}}, filter, stream, mr);
Expand All @@ -72,9 +74,27 @@ std::unique_ptr<scalar> collect_set(column_view const& col,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto scalar = collect_list(col, null_handling, stream, mr);
auto ls = dynamic_cast<list_scalar*>(scalar.get());
return drop_duplicates(*ls, nulls_equal, nans_equal, stream, mr);
// `input_as_collect_list` is the result of the input column that has been processed to obey
// the given null handling behavior.
[[maybe_unused]] auto const [input_as_collect_list, unused_scalar] = [&] {
if (need_handle_nulls(col, null_handling)) {
// Only call `collect_list` when we need to handle nulls.
auto scalar = collect_list(col, null_handling, stream, mr);
return std::pair(static_cast<list_scalar*>(scalar.get())->view(), std::move(scalar));
}

return std::pair(col, std::unique_ptr<scalar>(nullptr));
}();

auto distinct_table = detail::distinct(table_view{{input_as_collect_list}},
std::vector<size_type>{0},
duplicate_keep_option::KEEP_ANY,
nulls_equal,
nans_equal,
stream,
mr);

return std::make_unique<list_scalar>(std::move(distinct_table->get_column(0)), true, stream, mr);
}

std::unique_ptr<scalar> merge_sets(lists_column_view const& col,
Expand All @@ -83,9 +103,15 @@ std::unique_ptr<scalar> merge_sets(lists_column_view const& col,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto flatten_col = col.get_sliced_child(stream);
auto scalar = std::make_unique<list_scalar>(flatten_col, true, stream, mr);
return drop_duplicates(*scalar, nulls_equal, nans_equal, stream, mr);
auto flatten_col = col.get_sliced_child(stream);
auto distinct_table = detail::distinct(table_view{{flatten_col}},
std::vector<size_type>{0},
duplicate_keep_option::KEEP_ANY,
nulls_equal,
nans_equal,
stream,
mr);
return std::make_unique<list_scalar>(std::move(distinct_table->get_column(0)), true, stream, mr);
}

} // namespace reduction
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/rolling/detail/rolling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#include <cudf/detail/utilities/device_operators.cuh>
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/dictionary/dictionary_factories.hpp>
#include <cudf/lists/detail/drop_list_duplicates.hpp>
#include <cudf/lists/detail/stream_compaction.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/bit.hpp>
#include <cudf/utilities/error.hpp>
Expand Down Expand Up @@ -928,8 +928,8 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation
stream,
rmm::mr::get_current_device_resource());

result = lists::detail::drop_list_duplicates(
lists_column_view(collected_list->view()), agg._nulls_equal, agg._nans_equal, stream, mr);
result = lists::detail::distinct(
lists_column_view{collected_list->view()}, agg._nulls_equal, agg._nans_equal, stream, mr);
}

// perform the element-wise square root operation on result of VARIANCE
Expand Down
Loading

0 comments on commit b654597

Please sign in to comment.