Skip to content

Commit

Permalink
Migrate lists/set_operations to pylibcudf (#16190)
Browse files Browse the repository at this point in the history
Apart of #15162

Authors:
  - Matthew Murray (https://github.com/Matt711)

Approvers:
  - Thomas Li (https://github.com/lithomas1)

URL: #16190
  • Loading branch information
Matt711 authored Jul 24, 2024
1 parent 73937fb commit 8bba6df
Show file tree
Hide file tree
Showing 4 changed files with 339 additions and 1 deletion.
39 changes: 39 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/libcudf/lists/set_operations.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2021-2024, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr

from cudf._lib.pylibcudf.libcudf.column.column cimport column
from cudf._lib.pylibcudf.libcudf.lists.lists_column_view cimport (
lists_column_view,
)
from cudf._lib.pylibcudf.libcudf.types cimport nan_equality, null_equality


cdef extern from "cudf/lists/set_operations.hpp" namespace "cudf::lists" nogil:
cdef unique_ptr[column] difference_distinct(
const lists_column_view& lhs,
const lists_column_view& rhs,
null_equality nulls_equal,
nan_equality nans_equal
) except +

cdef unique_ptr[column] have_overlap(
const lists_column_view& lhs,
const lists_column_view& rhs,
null_equality nulls_equal,
nan_equality nans_equal
) except +

cdef unique_ptr[column] intersect_distinct(
const lists_column_view& lhs,
const lists_column_view& rhs,
null_equality nulls_equal,
nan_equality nans_equal
) except +

cdef unique_ptr[column] union_distinct(
const lists_column_view& lhs,
const lists_column_view& rhs,
null_equality nulls_equal,
nan_equality nans_equal
) except +
8 changes: 8 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/lists.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,11 @@ cpdef Column count_elements(Column)
cpdef Column sequences(Column, Column, Column steps = *)

cpdef Column sort_lists(Column, bool, null_order, bool stable = *)

cpdef Column difference_distinct(Column, Column, bool nulls_equal=*, bool nans_equal=*)

cpdef Column have_overlap(Column, Column, bool nulls_equal=*, bool nans_equal=*)

cpdef Column intersect_distinct(Column, Column, bool nulls_equal=*, bool nans_equal=*)

cpdef Column union_distinct(Column, Column, bool nulls_equal=*, bool nans_equal=*)
203 changes: 202 additions & 1 deletion python/cudf/cudf/_lib/pylibcudf/lists.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ from cudf._lib.pylibcudf.libcudf.lists cimport (
filling as cpp_filling,
gather as cpp_gather,
reverse as cpp_reverse,
set_operations as cpp_set_operations,
)
from cudf._lib.pylibcudf.libcudf.lists.combine cimport (
concatenate_list_elements as cpp_concatenate_list_elements,
Expand All @@ -29,7 +30,13 @@ from cudf._lib.pylibcudf.libcudf.lists.sorting cimport (
stable_sort_lists as cpp_stable_sort_lists,
)
from cudf._lib.pylibcudf.libcudf.table.table cimport table
from cudf._lib.pylibcudf.libcudf.types cimport null_order, order, size_type
from cudf._lib.pylibcudf.libcudf.types cimport (
nan_equality,
null_equality,
null_order,
order,
size_type,
)
from cudf._lib.pylibcudf.lists cimport ColumnOrScalar, ColumnOrSizeType

from .column cimport Column, ListColumnView
Expand Down Expand Up @@ -413,3 +420,197 @@ cpdef Column sort_lists(
na_position,
))
return Column.from_libcudf(move(c_result))


cpdef Column difference_distinct(
Column lhs,
Column rhs,
bool nulls_equal=True,
bool nans_equal=True
):
"""Create a column of index values indicating the position of a search
key row within the corresponding list row in the lists column.
For details, see :cpp:func:`difference_distinct`.
Parameters
----------
lhs : Column
The input lists column of elements that may be included.
rhs : Column
The input lists column of elements to exclude.
nulls_equal : bool, default True
If true, null elements are considered equal. Otherwise, unequal.
nans_equal : bool, default True
If true, libcudf will treat nan elements from {-nan, +nan}
as equal. Otherwise, unequal. Otherwise, unequal.
Returns
-------
Column
A lists column containing the difference results.
"""
cdef unique_ptr[column] c_result
cdef ListColumnView lhs_view = lhs.list_view()
cdef ListColumnView rhs_view = rhs.list_view()

cdef null_equality c_nulls_equal = (
null_equality.EQUAL if nulls_equal else null_equality.UNEQUAL
)
cdef nan_equality c_nans_equal = (
nan_equality.ALL_EQUAL if nans_equal else nan_equality.UNEQUAL
)

with nogil:
c_result = move(cpp_set_operations.difference_distinct(
lhs_view.view(),
rhs_view.view(),
c_nulls_equal,
c_nans_equal,
))
return Column.from_libcudf(move(c_result))


cpdef Column have_overlap(
Column lhs,
Column rhs,
bool nulls_equal=True,
bool nans_equal=True
):
"""Check if lists at each row of the given lists columns overlap.
For details, see :cpp:func:`have_overlap`.
Parameters
----------
lhs : Column
The input lists column for one side.
rhs : Column
The input lists column for the other side.
nulls_equal : bool, default True
If true, null elements are considered equal. Otherwise, unequal.
nans_equal : bool, default True
If true, libcudf will treat nan elements from {-nan, +nan}
as equal. Otherwise, unequal. Otherwise, unequal.
Returns
-------
Column
A column containing the check results.
"""
cdef unique_ptr[column] c_result
cdef ListColumnView lhs_view = lhs.list_view()
cdef ListColumnView rhs_view = rhs.list_view()

cdef null_equality c_nulls_equal = (
null_equality.EQUAL if nulls_equal else null_equality.UNEQUAL
)
cdef nan_equality c_nans_equal = (
nan_equality.ALL_EQUAL if nans_equal else nan_equality.UNEQUAL
)

with nogil:
c_result = move(cpp_set_operations.have_overlap(
lhs_view.view(),
rhs_view.view(),
c_nulls_equal,
c_nans_equal,
))
return Column.from_libcudf(move(c_result))


cpdef Column intersect_distinct(
Column lhs,
Column rhs,
bool nulls_equal=True,
bool nans_equal=True
):
"""Create a lists column of distinct elements common to two input lists columns.
For details, see :cpp:func:`intersect_distinct`.
Parameters
----------
lhs : Column
The input lists column of elements that may be included.
rhs : Column
The input lists column of elements to exclude.
nulls_equal : bool, default True
If true, null elements are considered equal. Otherwise, unequal.
nans_equal : bool, default True
If true, libcudf will treat nan elements from {-nan, +nan}
as equal. Otherwise, unequal. Otherwise, unequal.
Returns
-------
Column
A lists column containing the intersection results.
"""
cdef unique_ptr[column] c_result
cdef ListColumnView lhs_view = lhs.list_view()
cdef ListColumnView rhs_view = rhs.list_view()

cdef null_equality c_nulls_equal = (
null_equality.EQUAL if nulls_equal else null_equality.UNEQUAL
)
cdef nan_equality c_nans_equal = (
nan_equality.ALL_EQUAL if nans_equal else nan_equality.UNEQUAL
)

with nogil:
c_result = move(cpp_set_operations.intersect_distinct(
lhs_view.view(),
rhs_view.view(),
c_nulls_equal,
c_nans_equal,
))
return Column.from_libcudf(move(c_result))


cpdef Column union_distinct(
Column lhs,
Column rhs,
bool nulls_equal=True,
bool nans_equal=True
):
"""Create a lists column of distinct elements found in
either of two input lists columns.
For details, see :cpp:func:`union_distinct`.
Parameters
----------
lhs : Column
The input lists column of elements that may be included.
rhs : Column
The input lists column of elements to exclude.
nulls_equal : bool, default True
If true, null elements are considered equal. Otherwise, unequal.
nans_equal : bool, default True
If true, libcudf will treat nan elements from {-nan, +nan}
as equal. Otherwise, unequal. Otherwise, unequal.
Returns
-------
Column
A lists column containing the union results.
"""
cdef unique_ptr[column] c_result
cdef ListColumnView lhs_view = lhs.list_view()
cdef ListColumnView rhs_view = rhs.list_view()

cdef null_equality c_nulls_equal = (
null_equality.EQUAL if nulls_equal else null_equality.UNEQUAL
)
cdef nan_equality c_nans_equal = (
nan_equality.ALL_EQUAL if nans_equal else nan_equality.UNEQUAL
)

with nogil:
c_result = move(cpp_set_operations.union_distinct(
lhs_view.view(),
rhs_view.view(),
c_nulls_equal,
c_nans_equal,
))
return Column.from_libcudf(move(c_result))
90 changes: 90 additions & 0 deletions python/cudf/cudf/pylibcudf_tests/test_lists.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import numpy as np
import pyarrow as pa
import pytest
from utils import assert_column_eq
Expand All @@ -22,6 +23,13 @@ def column():
return pa.array([3, 2, 5, 6]), pa.array([-1, 0, 0, 0], type=pa.int32())


@pytest.fixture
def set_lists_column():
lhs = [[np.nan, np.nan, 2, 1, 2], [1, 2, 3], None, [4, None, 5]]
rhs = [[np.nan, 1, 2, 3], [4, 5], [None, 7, 8], [None, None]]
return lhs, rhs


@pytest.fixture
def lists_column():
return [[4, 2, 3, 1], [1, 2, None, 4], [-10, 10, 10, 0]]
Expand Down Expand Up @@ -253,3 +261,85 @@ def test_sort_lists(lists_column, ascending, na_position, expected):

assert_column_eq(expect, res)
assert_column_eq(expect, res_stable)


@pytest.mark.parametrize(
"set_operation,nans_equal,nulls_equal,expected",
[
(
plc.lists.difference_distinct,
True,
True,
[[], [1, 2, 3], None, [4, 5]],
),
(
plc.lists.difference_distinct,
False,
True,
[[], [1, 2, 3], None, [4, None, 5]],
),
(
plc.lists.have_overlap,
True,
True,
[True, False, None, True],
),
(
plc.lists.have_overlap,
False,
False,
[True, False, None, False],
),
(
plc.lists.intersect_distinct,
True,
True,
[[np.nan, 1, 2], [], None, [None]],
),
(
plc.lists.intersect_distinct,
True,
False,
[[1, 2], [], None, [None]],
),
(
plc.lists.union_distinct,
False,
True,
[
[np.nan, 2, 1, 3],
[1, 2, 3, 4, 5],
None,
[4, None, 5, None, None],
],
),
(
plc.lists.union_distinct,
False,
False,
[
[np.nan, np.nan, 2, 1, np.nan, 3],
[1, 2, 3, 4, 5],
None,
[4, None, 5, None, None],
],
),
],
)
def test_set_operations(
set_lists_column, set_operation, nans_equal, nulls_equal, expected
):
lhs, rhs = set_lists_column

res = set_operation(
plc.interop.from_arrow(pa.array(lhs)),
plc.interop.from_arrow(pa.array(rhs)),
nans_equal,
nulls_equal,
)

if set_operation != plc.lists.have_overlap:
expect = pa.array(expected, type=pa.list_(pa.float64()))
else:
expect = pa.array(expected)
assert_column_eq(expect, res)

0 comments on commit 8bba6df

Please sign in to comment.