diff --git a/cpp/include/cudf/table/experimental/row_operators.cuh b/cpp/include/cudf/table/experimental/row_operators.cuh index 336420ed840..6b7350eda32 100644 --- a/cpp/include/cudf/table/experimental/row_operators.cuh +++ b/cpp/include/cudf/table/experimental/row_operators.cuh @@ -32,6 +32,8 @@ #include #include +#include +#include #include #include #include @@ -69,6 +71,48 @@ struct dispatch_void_if_nested { }; namespace row { + +enum class lhs_index_type : size_type {}; +enum class rhs_index_type : size_type {}; + +template > +struct strong_index_iterator : public thrust::iterator_facade, + Index, + thrust::use_default, + thrust::random_access_traversal_tag, + Index, + Underlying> { + using super_t = thrust::iterator_adaptor, Index>; + + explicit constexpr strong_index_iterator(Underlying n) : begin{n} {} + + friend class thrust::iterator_core_access; + + private: + __device__ constexpr void increment() { ++begin; } + __device__ constexpr void decrement() { --begin; } + + __device__ constexpr void advance(Underlying n) { begin += n; } + + __device__ constexpr bool equal(strong_index_iterator const& other) const noexcept + { + return begin == other.begin; + } + + __device__ constexpr Index dereference() const noexcept { return static_cast(begin); } + + __device__ constexpr Underlying distance_to( + strong_index_iterator const& other) const noexcept + { + return other.begin - begin; + } + + Underlying begin{}; +}; + +using lhs_iterator = strong_index_iterator; +using rhs_iterator = strong_index_iterator; + namespace lexicographic { /** @@ -91,6 +135,8 @@ namespace lexicographic { template class device_row_comparator { friend class self_comparator; + friend class two_table_comparator; + /** * @brief Construct a function object for performing a lexicographic * comparison between the rows of two tables. @@ -183,9 +229,9 @@ class device_row_comparator { template () and - not std::is_same_v), - typename... Args> - __device__ cuda::std::pair operator()(Args...) const noexcept + not std::is_same_v)> + __device__ cuda::std::pair operator()(size_type const, + size_type const) const noexcept { CUDF_UNREACHABLE("Attempted to compare elements of uncomparable types."); } @@ -234,12 +280,13 @@ class device_row_comparator { * @brief Checks whether the row at `lhs_index` in the `lhs` table compares * lexicographically less, greater, or equivalent to the row at `rhs_index` in the `rhs` table. * - * @param lhs_index The index of row in the `lhs` table to examine + * @param lhs_index The index of the row in the `lhs` table to examine * @param rhs_index The index of the row in the `rhs` table to examine * @return weak ordering comparison of the row in the `lhs` table relative to the row in the `rhs` * table */ - __device__ weak_ordering operator()(size_type lhs_index, size_type rhs_index) const noexcept + __device__ weak_ordering operator()(size_type const lhs_index, + size_type const rhs_index) const noexcept { int last_null_depth = std::numeric_limits::max(); for (size_type i = 0; i < _lhs.num_columns(); ++i) { @@ -288,12 +335,14 @@ class device_row_comparator { */ template struct weak_ordering_comparator_impl { - __device__ bool operator()(size_type const lhs, size_type const rhs) const noexcept + template + __device__ constexpr bool operator()(LhsType const lhs_index, + RhsType const rhs_index) const noexcept { - weak_ordering const result = comparator(lhs, rhs); + weak_ordering const result = comparator(lhs_index, rhs_index); return ((result == values) || ...); } - Comparator comparator; + Comparator const comparator; }; /** @@ -302,14 +351,12 @@ struct weak_ordering_comparator_impl { * * @tparam Nullate A cudf::nullate type describing whether to check for nulls. */ -template -using less_comparator = - weak_ordering_comparator_impl, weak_ordering::LESS>; +template +using less_comparator = weak_ordering_comparator_impl; -template -using less_equivalent_comparator = weak_ordering_comparator_impl, - weak_ordering::LESS, - weak_ordering::EQUIVALENT>; +template +using less_equivalent_comparator = + weak_ordering_comparator_impl; struct preprocessed_table { using table_device_view_owner = @@ -319,7 +366,7 @@ struct preprocessed_table { * @brief Preprocess table for use with lexicographical comparison * * Sets up the table for use with lexicographical comparison. The resulting preprocessed table can - * be passed to the constructor of `lex::self_comparator` to avoid preprocessing again. + * be passed to the constructor of `lexicographic::self_comparator` to avoid preprocessing again. * * @param table The table to preprocess * @param column_order Optional, host array the same length as a row that indicates the desired @@ -337,6 +384,7 @@ struct preprocessed_table { private: friend class self_comparator; + friend class two_table_comparator; preprocessed_table(table_device_view_owner&& table, rmm::device_uvector&& column_order, @@ -395,10 +443,10 @@ struct preprocessed_table { } private: - table_device_view_owner _t; - rmm::device_uvector _column_order; - rmm::device_uvector _null_precedence; - rmm::device_uvector _depths; + table_device_view_owner const _t; + rmm::device_uvector const _column_order; + rmm::device_uvector const _null_precedence; + rmm::device_uvector const _depths; }; /** @@ -459,9 +507,9 @@ class self_comparator { * @tparam Nullate A cudf::nullate type describing whether to check for nulls. */ template - less_comparator device_comparator(Nullate nullate = {}) const + less_comparator> device_comparator(Nullate nullate = {}) const { - return less_comparator{device_row_comparator( + return less_comparator>{device_row_comparator( nullate, *d_t, *d_t, d_t->depths(), d_t->column_order(), d_t->null_precedence())}; } @@ -469,6 +517,124 @@ class self_comparator { std::shared_ptr d_t; }; +template +struct strong_index_comparator_adapter { + __device__ constexpr weak_ordering operator()(lhs_index_type const lhs_index, + rhs_index_type const rhs_index) const noexcept + { + return comparator(static_cast(lhs_index), + static_cast(rhs_index)); + } + + __device__ constexpr weak_ordering operator()(rhs_index_type const rhs_index, + lhs_index_type const lhs_index) const noexcept + { + auto const left_right_ordering = + comparator(static_cast(lhs_index), static_cast(rhs_index)); + + // Invert less/greater values to reflect right to left ordering + if (left_right_ordering == weak_ordering::LESS) { + return weak_ordering::GREATER; + } else if (left_right_ordering == weak_ordering::GREATER) { + return weak_ordering::LESS; + } + return weak_ordering::EQUIVALENT; + } + + Comparator const comparator; +}; + +/** + * @brief An owning object that can be used to lexicographically compare rows of two different + * tables + * + * This class takes two table_views and preprocesses certain columns to allow for lexicographical + * comparison. The preprocessed table and temporary data required for the comparison are created and + * owned by this class. + * + * Alternatively, `two_table_comparator` can be constructed from two existing + * `shared_ptr`s when sharing the same tables among multiple comparators. + * + * This class can then provide a functor object that can used on the device. + * The object of this class must outlive the usage of the device functor. + */ +class two_table_comparator { + public: + /** + * @brief Construct an owning object for performing a lexicographic comparison between rows of + * two different tables. + * + * The left and right table are expected to have the same number of columns + * and data types for each column. + * + * @param left The left table to compare + * @param right The right table to compare + * @param column_order Optional, host array the same length as a row that indicates the desired + * ascending/descending order of each column in a row. If empty, it is assumed all columns are + * sorted in ascending order. + * @param null_precedence Optional, device array the same length as a row and indicates how null + * values compare to all other for every column. If empty, then null precedence would be + * `null_order::BEFORE` for all columns. + * @param stream The stream to construct this object on. Not the stream that will be used for + * comparisons using this object. + */ + two_table_comparator(table_view const& left, + table_view const& right, + host_span column_order = {}, + host_span null_precedence = {}, + rmm::cuda_stream_view stream = rmm::cuda_stream_default); + + /** + * @brief Construct an owning object for performing a lexicographic comparison between two rows of + * the same preprocessed table. + * + * This constructor allows independently constructing a `preprocessed_table` and sharing it among + * multiple comparators. + * + * @param left A table preprocessed for lexicographic comparison + * @param right A table preprocessed for lexicographic comparison + */ + two_table_comparator(std::shared_ptr left, + std::shared_ptr right) + : d_left_table{std::move(left)}, d_right_table{std::move(right)} + { + } + + /** + * @brief Return the binary operator for comparing rows in the table. + * + * Returns a binary callable, `F`, with signatures + * `bool F(lhs_index_type, rhs_index_type)` and + * `bool F(rhs_index_type, lhs_index_type)`. + * + * `F(lhs_index_type i, rhs_index_type j)` returns true if and only if row + * `i` of the left table compares lexicographically less than row `j` of the + * right table. + * + * Similarly, `F(rhs_index_type i, lhs_index_type j)` returns true if and + * only if row `i` of the right table compares lexicographically less than row + * `j` of the left table. + * + * @tparam Nullate A cudf::nullate type describing whether to check for nulls. + */ + template + less_comparator>> + device_comparator(Nullate nullate = {}) const + { + return less_comparator>>{ + device_row_comparator(nullate, + *d_left_table, + *d_right_table, + d_left_table->depths(), + d_left_table->column_order(), + d_left_table->null_precedence())}; + } + + private: + std::shared_ptr d_left_table; + std::shared_ptr d_right_table; +}; + } // namespace lexicographic namespace hash { diff --git a/cpp/include/cudf/table/row_operators.cuh b/cpp/include/cudf/table/row_operators.cuh index 4d503cd53b8..a181e9bae63 100644 --- a/cpp/include/cudf/table/row_operators.cuh +++ b/cpp/include/cudf/table/row_operators.cuh @@ -389,7 +389,7 @@ class row_lexicographic_comparator { * @brief Checks whether the row at `lhs_index` in the `lhs` table compares * lexicographically less than the row at `rhs_index` in the `rhs` table. * - * @param lhs_index The index of row in the `lhs` table to examine + * @param lhs_index The index of the row in the `lhs` table to examine * @param rhs_index The index of the row in the `rhs` table to examine * @return `true` if row from the `lhs` table compares less than row in the * `rhs` table diff --git a/cpp/src/search/contains.cu b/cpp/src/search/contains.cu index 2748dc18676..e6cc5d75e6c 100644 --- a/cpp/src/search/contains.cu +++ b/cpp/src/search/contains.cu @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/src/search/search_ordered.cu b/cpp/src/search/search_ordered.cu index 7188d328689..d3feba0aef2 100644 --- a/cpp/src/search/search_ordered.cu +++ b/cpp/src/search/search_ordered.cu @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -63,53 +64,34 @@ std::unique_ptr search_ordered(table_view const& haystack, // This utility will ensure all corresponding dictionary columns have matching keys. // It will return any new dictionary columns created as well as updated table_views. auto const matched = dictionary::detail::match_dictionaries({haystack, needles}, stream); + auto const& matched_haystack = matched.second.front(); + auto const& matched_needles = matched.second.back(); - // Prepare to flatten the structs column - auto const has_null_elements = has_nested_nulls(haystack) or has_nested_nulls(needles); - auto const flatten_nullability = has_null_elements - ? structs::detail::column_nullability::FORCE - : structs::detail::column_nullability::MATCH_INCOMING; - - // 0-table_view, 1-column_order, 2-null_precedence, 3-validity_columns - auto const t_flattened = structs::detail::flatten_nested_columns( - matched.second.front(), column_order, null_precedence, flatten_nullability); - auto const values_flattened = - structs::detail::flatten_nested_columns(matched.second.back(), {}, {}, flatten_nullability); - - auto const t_d = table_device_view::create(t_flattened, stream); - auto const values_d = table_device_view::create(values_flattened, stream); - auto const& lhs = find_first ? *t_d : *values_d; - auto const& rhs = find_first ? *values_d : *t_d; - - auto const& column_order_flattened = t_flattened.orders(); - auto const& null_precedence_flattened = t_flattened.null_orders(); - auto const column_order_dv = detail::make_device_uvector_async(column_order_flattened, stream); - auto const null_precedence_dv = - detail::make_device_uvector_async(null_precedence_flattened, stream); - - auto const count_it = thrust::make_counting_iterator(0); - auto const comp = row_lexicographic_comparator(nullate::DYNAMIC{has_null_elements}, - lhs, - rhs, - column_order_dv.data(), - null_precedence_dv.data()); + auto const comparator = cudf::experimental::row::lexicographic::two_table_comparator( + matched_haystack, matched_needles, column_order, null_precedence, stream); + auto const has_null_elements = + has_nested_nulls(matched_haystack) or has_nested_nulls(matched_needles); + auto const d_comparator = comparator.device_comparator(nullate::DYNAMIC{has_null_elements}); + + auto const haystack_it = cudf::experimental::row::lhs_iterator(0); + auto const needles_it = cudf::experimental::row::rhs_iterator(0); if (find_first) { thrust::lower_bound(rmm::exec_policy(stream), - count_it, - count_it + haystack.num_rows(), - count_it, - count_it + needles.num_rows(), + haystack_it, + haystack_it + haystack.num_rows(), + needles_it, + needles_it + needles.num_rows(), out_it, - comp); + d_comparator); } else { thrust::upper_bound(rmm::exec_policy(stream), - count_it, - count_it + haystack.num_rows(), - count_it, - count_it + needles.num_rows(), + haystack_it, + haystack_it + haystack.num_rows(), + needles_it, + needles_it + needles.num_rows(), out_it, - comp); + d_comparator); } return result; } diff --git a/cpp/src/table/row_operators.cu b/cpp/src/table/row_operators.cu index 3c51ae22418..b48566fe837 100644 --- a/cpp/src/table/row_operators.cu +++ b/cpp/src/table/row_operators.cu @@ -22,6 +22,7 @@ #include #include #include +#include #include @@ -301,6 +302,16 @@ void check_eq_compatibility(table_view const& input) } } +void check_shape_compatibility(table_view const& lhs, table_view const& rhs) +{ + CUDF_EXPECTS(lhs.num_columns() == rhs.num_columns(), + "Cannot compare tables with different number of columns"); + for (size_type i = 0; i < lhs.num_columns(); ++i) { + CUDF_EXPECTS(column_types_equal(lhs.column(i), rhs.column(i)), + "Cannot compare tables with different column types"); + } +} + } // namespace namespace row { @@ -327,6 +338,17 @@ std::shared_ptr preprocessed_table::create( std::move(d_t), std::move(d_column_order), std::move(d_null_precedence), std::move(d_depths))); } +two_table_comparator::two_table_comparator(table_view const& left, + table_view const& right, + host_span column_order, + host_span null_precedence, + rmm::cuda_stream_view stream) + : d_left_table{preprocessed_table::create(left, column_order, null_precedence, stream)}, + d_right_table{preprocessed_table::create(right, column_order, null_precedence, stream)} +{ + check_shape_compatibility(left, right); +} + } // namespace lexicographic namespace equality {