Skip to content

Commit

Permalink
Add stream param to list explode APIs (#16317)
Browse files Browse the repository at this point in the history
Add `stream` param to list `explode*` APIs. Partially fixes #13744

Authors:
  - Jayjeet Chakraborty (https://github.com/JayjeetAtGithub)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #16317
  • Loading branch information
JayjeetAtGithub authored Jul 22, 2024
1 parent e54b82c commit e0a00c1
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 14 deletions.
8 changes: 8 additions & 0 deletions cpp/include/cudf/lists/explode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,15 @@ namespace cudf {
*
* @param input_table Table to explode.
* @param explode_column_idx Column index to explode inside the table.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory.
*
* @return A new table with explode_col exploded.
*/
std::unique_ptr<table> explode(
table_view const& input_table,
size_type explode_column_idx,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -109,6 +111,7 @@ std::unique_ptr<table> explode(
*
* @param input_table Table to explode.
* @param explode_column_idx Column index to explode inside the table.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory.
*
* @return A new table with exploded value and position. The column order of return table is
Expand All @@ -117,6 +120,7 @@ std::unique_ptr<table> explode(
std::unique_ptr<table> explode_position(
table_view const& input_table,
size_type explode_column_idx,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -152,13 +156,15 @@ std::unique_ptr<table> explode_position(
*
* @param input_table Table to explode.
* @param explode_column_idx Column index to explode inside the table.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory.
*
* @return A new table with explode_col exploded.
*/
std::unique_ptr<table> explode_outer(
table_view const& input_table,
size_type explode_column_idx,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -196,13 +202,15 @@ std::unique_ptr<table> explode_outer(
*
* @param input_table Table to explode.
* @param explode_column_idx Column index to explode inside the table.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory.
*
* @return A new table with explode_col exploded.
*/
std::unique_ptr<table> explode_outer_position(
table_view const& input_table,
size_type explode_column_idx,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/** @} */ // end of group
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/cudf/lists/set_operations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ namespace cudf::lists {
* @param nulls_equal Flag to specify whether null elements should be considered as equal, default
* to be `UNEQUAL` which means only non-null elements are checked for overlapping
* @param nans_equal Flag to specify whether floating-point NaNs should be considered as equal
* @param mr Device memory resource used to allocate the returned object
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned object
* @return A column of type BOOL containing the check results
*/
std::unique_ptr<column> have_overlap(
Expand Down
29 changes: 17 additions & 12 deletions cpp/src/lists/explode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ std::unique_ptr<table> explode_outer(table_view const& input_table,
if (null_or_empty_count == 0) {
// performance penalty to run the below loop if there are no nulls or empty lists.
// run simple explode instead
return include_position ? explode_position(input_table, explode_column_idx, stream, mr)
: explode(input_table, explode_column_idx, stream, mr);
return include_position ? detail::explode_position(input_table, explode_column_idx, stream, mr)
: detail::explode(input_table, explode_column_idx, stream, mr);
}

auto gather_map_size = sliced_child.size() + null_or_empty_count;
Expand Down Expand Up @@ -300,58 +300,63 @@ std::unique_ptr<table> explode_outer(table_view const& input_table,
} // namespace detail

/**
* @copydoc cudf::explode(table_view const&, size_type, rmm::device_async_resource_ref)
* @copydoc cudf::explode(table_view const&, size_type, rmm::cuda_stream_view,
* rmm::device_async_resource_ref)
*/
std::unique_ptr<table> explode(table_view const& input_table,
size_type explode_column_idx,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
CUDF_EXPECTS(input_table.column(explode_column_idx).type().id() == type_id::LIST,
"Unsupported non-list column");
return detail::explode(input_table, explode_column_idx, cudf::get_default_stream(), mr);
return detail::explode(input_table, explode_column_idx, stream, mr);
}

/**
* @copydoc cudf::explode_position(table_view const&, size_type, rmm::device_async_resource_ref)
* @copydoc cudf::explode_position(table_view const&, size_type, rmm::cuda_stream_view,
* rmm::device_async_resource_ref)
*/
std::unique_ptr<table> explode_position(table_view const& input_table,
size_type explode_column_idx,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
CUDF_EXPECTS(input_table.column(explode_column_idx).type().id() == type_id::LIST,
"Unsupported non-list column");
return detail::explode_position(input_table, explode_column_idx, cudf::get_default_stream(), mr);
return detail::explode_position(input_table, explode_column_idx, stream, mr);
}

/**
* @copydoc cudf::explode_outer(table_view const&, size_type, rmm::device_async_resource_ref)
* @copydoc cudf::explode_outer(table_view const&, size_type, rmm::cuda_stream_view,
* rmm::device_async_resource_ref)
*/
std::unique_ptr<table> explode_outer(table_view const& input_table,
size_type explode_column_idx,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
CUDF_EXPECTS(input_table.column(explode_column_idx).type().id() == type_id::LIST,
"Unsupported non-list column");
return detail::explode_outer(
input_table, explode_column_idx, false, cudf::get_default_stream(), mr);
return detail::explode_outer(input_table, explode_column_idx, false, stream, mr);
}

/**
* @copydoc cudf::explode_outer_position(table_view const&, size_type,
* rmm::device_async_resource_ref)
* rmm::cuda_stream_view, rmm::device_async_resource_ref)
*/
std::unique_ptr<table> explode_outer_position(table_view const& input_table,
size_type explode_column_idx,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
CUDF_EXPECTS(input_table.column(explode_column_idx).type().id() == type_id::LIST,
"Unsupported non-list column");
return detail::explode_outer(
input_table, explode_column_idx, true, cudf::get_default_stream(), mr);
return detail::explode_outer(input_table, explode_column_idx, true, stream, mr);
}

} // namespace cudf
57 changes: 56 additions & 1 deletion cpp/tests/streams/lists_test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,7 @@
#include <cudf/lists/combine.hpp>
#include <cudf/lists/contains.hpp>
#include <cudf/lists/count_elements.hpp>
#include <cudf/lists/explode.hpp>
#include <cudf/lists/extract.hpp>
#include <cudf/lists/filling.hpp>
#include <cudf/lists/gather.hpp>
Expand Down Expand Up @@ -212,3 +213,57 @@ TEST_F(ListTest, HaveOverlap)
cudf::nan_equality::ALL_EQUAL,
cudf::test::get_default_stream());
}

TEST_F(ListTest, Explode)
{
cudf::test::fixed_width_column_wrapper<int32_t> list_col_a{100, 200, 300};
cudf::test::lists_column_wrapper<int32_t> list_col_b{
cudf::test::lists_column_wrapper<int32_t>{1, 2, 7},
cudf::test::lists_column_wrapper<int32_t>{5, 6},
cudf::test::lists_column_wrapper<int32_t>{0, 3}};
cudf::test::strings_column_wrapper list_col_c{"string0", "string1", "string2"};
cudf::table_view lists_table({list_col_a, list_col_b, list_col_c});
cudf::explode(lists_table, 1, cudf::test::get_default_stream());
}

TEST_F(ListTest, ExplodePosition)
{
cudf::test::fixed_width_column_wrapper<int32_t> list_col_a{100, 200, 300};
cudf::test::lists_column_wrapper<int32_t> list_col_b{
cudf::test::lists_column_wrapper<int32_t>{1, 2, 7},
cudf::test::lists_column_wrapper<int32_t>{5, 6},
cudf::test::lists_column_wrapper<int32_t>{0, 3}};
cudf::test::strings_column_wrapper list_col_c{"string0", "string1", "string2"};
cudf::table_view lists_table({list_col_a, list_col_b, list_col_c});
cudf::explode_position(lists_table, 1, cudf::test::get_default_stream());
}

TEST_F(ListTest, ExplodeOuter)
{
constexpr auto null = 0;
auto valids =
cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i % 2 == 0; });
cudf::test::lists_column_wrapper<int32_t> list_col_a{
cudf::test::lists_column_wrapper<int32_t>({1, null, 7}, valids),
cudf::test::lists_column_wrapper<int32_t>({5, null, 0, null}, valids),
cudf::test::lists_column_wrapper<int32_t>{},
cudf::test::lists_column_wrapper<int32_t>({0, null, 8}, valids)};
cudf::test::fixed_width_column_wrapper<int32_t> list_col_b{100, 200, 300, 400};
cudf::table_view lists_table({list_col_a, list_col_b});
cudf::explode_outer(lists_table, 0, cudf::test::get_default_stream());
}

TEST_F(ListTest, ExplodeOuterPosition)
{
constexpr auto null = 0;
auto valids =
cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i % 2 == 0; });
cudf::test::lists_column_wrapper<int32_t> list_col_a{
cudf::test::lists_column_wrapper<int32_t>({1, null, 7}, valids),
cudf::test::lists_column_wrapper<int32_t>({5, null, 0, null}, valids),
cudf::test::lists_column_wrapper<int32_t>{},
cudf::test::lists_column_wrapper<int32_t>({0, null, 8}, valids)};
cudf::test::fixed_width_column_wrapper<int32_t> list_col_b{100, 200, 300, 400};
cudf::table_view lists_table({list_col_a, list_col_b});
cudf::explode_outer_position(lists_table, 0, cudf::test::get_default_stream());
}

0 comments on commit e0a00c1

Please sign in to comment.