Skip to content

Commit

Permalink
Java support for explode position (#7471)
Browse files Browse the repository at this point in the history
This pull request provides Java interface for `Table.explode_position`, which is required by spark-rapids plugin.

Authors:
  - Alfred Xu (@sperlingxx)

Approvers:
  - Vukasin Milovanovic (@vuule)
  - Christopher Harris (@cwharris)
  - Jason Lowe (@jlowe)
  - Mike Wilson (@hyperbolic2346)
  - Robert (Bobby) Evans (@revans2)

URL: #7471
  • Loading branch information
sperlingxx authored Mar 4, 2021
1 parent e5d0ec9 commit 7871e7a
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 47 deletions.
21 changes: 11 additions & 10 deletions cpp/include/cudf/reshape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ std::unique_ptr<table> explode(
* [[20,25], 200],
* [[30], 300],
* returns
* [5, 0, 100],
* [10, 1, 100],
* [15, 2, 100],
* [20, 0, 200],
* [25, 1, 200],
* [30, 0, 300],
* [0, 5, 100],
* [1, 10, 100],
* [2, 15, 100],
* [0, 20, 200],
* [1, 25, 200],
* [0, 30, 300],
* ```
*
* Nulls and empty lists propagate in different ways depending on what is null or empty.
Expand All @@ -164,9 +164,9 @@ std::unique_ptr<table> explode(
* [null, 200],
* [[], 300],
* returns
* [5, 0, 100],
* [null, 1, 100],
* [15, 2, 100],
* [0, 5, 100],
* [1, null, 100],
* [2, 15, 100],
* ```
* Note that null lists are not included in the resulting table, but nulls inside
* lists and empty lists will be represented with a null entry for that column in that row.
Expand All @@ -175,7 +175,8 @@ std::unique_ptr<table> explode(
* @param explode_column_idx Column index to explode inside the table.
* @param mr Device memory resource used to allocate the returned column's device memory.
*
* @return A new table with explode_col exploded.
* @return A new table with exploded value and position. The column order of return table is
* [cols before explode_input, explode_position, explode_value, cols after explode_input].
*/
std::unique_ptr<table> explode_position(
table_view const& input_table,
Expand Down
45 changes: 45 additions & 0 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,8 @@ private static native long[] repeatColumnCount(long tableHandle,

private static native long[] explode(long tableHandle, int index);

private static native long[] explodePosition(long tableHandle, int index);

private static native long createCudfTableView(long[] nativeColumnViewHandles);

private static native long[] columnViewsFromPacked(ByteBuffer metadata, long dataAddress);
Expand Down Expand Up @@ -1753,6 +1755,49 @@ public Table explode(int index) {
return new Table(explode(nativeHandle, index));
}

/**
* Explodes a list column's elements and includes a position column.
*
* Any list is exploded, which means the elements of the list in each row are expanded into new rows
* in the output. The corresponding rows for other columns in the input are duplicated. A position
* column is added that has the index inside the original list for each row. Example:
* <code>
* [[5,10,15], 100],
* [[20,25], 200],
* [[30], 300],
* returns
* [0, 5, 100],
* [1, 10, 100],
* [2, 15, 100],
* [0, 20, 200],
* [1, 25, 200],
* [0, 30, 300],
* </code>
*
* Nulls and empty lists propagate in different ways depending on what is null or empty.
* <code>
* [[5,null,15], 100],
* [null, 200],
* [[], 300],
* returns
* [0, 5, 100],
* [1, null, 100],
* [2, 15, 100],
* </code>
*
* Note that null lists are not included in the resulting table, but nulls inside
* lists and empty lists will be represented with a null entry for that column in that row.
*
* @param index Column index to explode inside the table.
* @return A new table with exploded value and position. The column order of return table is
* [cols before explode_input, explode_position, explode_value, cols after explode_input].
*/
public Table explodePosition(int index) {
assert 0 <= index && index < columns.length : "Column index is out of range";
assert columns[index].getType().equals(DType.LIST) : "Column to explode must be of type LIST";
return new Table(explodePosition(nativeHandle, index));
}

/**
* Gathers the rows of this table according to `gatherMap` such that row "i"
* in the resulting table's columns will contain row "gatherMap[i]" from this table.
Expand Down
14 changes: 14 additions & 0 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2032,4 +2032,18 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explode(JNIEnv *env, jcla
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explodePosition(JNIEnv *env, jclass,
jlong input_jtable,
jint column_index) {
JNI_NULL_CHECK(env, input_jtable, "explode: input table is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::table_view *input_table = reinterpret_cast<cudf::table_view *>(input_jtable);
cudf::size_type col_index = static_cast<cudf::size_type>(column_index);
std::unique_ptr<cudf::table> exploded = cudf::explode_position(*input_table, col_index);
return cudf::jni::convert_table_for_return(env, exploded);
}
CATCH_STD(env, 0);
}

} // extern "C"
138 changes: 101 additions & 37 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4532,51 +4532,115 @@ void testBuilderWithColumn() {
}
}

private Table[] buildExplodeTestTableWithPrimitiveTypes(boolean pos, boolean outer) {
try (Table input = new Table.TestBuilder()
.column(new ListType(true, new BasicType(true, DType.INT32)),
Arrays.asList(1, 2, 3),
Arrays.asList(4, 5),
Arrays.asList(6),
null,
Arrays.asList())
.column("s1", "s2", "s3", "s4", "s5")
.column(1, 3, 5, 7, 9)
.column(12.0, 14.0, 13.0, 11.0, 15.0)
.build()) {
Table.TestBuilder expectedBuilder = new Table.TestBuilder();
if (pos) {
Integer[] posData = outer ? new Integer[]{0, 1, 2, 0, 1, 0, 0, 0} : new Integer[]{0, 1, 2, 0, 1, 0};
expectedBuilder.column(posData);
}
List<Object[]> expectedData = new ArrayList<Object[]>(){{
if (!outer) {
this.add(new Integer[]{1, 2, 3, 4, 5, 6});
this.add(new String[]{"s1", "s1", "s1", "s2", "s2", "s3"});
this.add(new Integer[]{1, 1, 1, 3, 3, 5});
this.add(new Double[]{12.0, 12.0, 12.0, 14.0, 14.0, 13.0});
} else {
this.add(new Integer[]{1, 2, 3, 4, 5, 6, null, null});
this.add(new String[]{"s1", "s1", "s1", "s2", "s2", "s3", "s4", "s5"});
this.add(new Integer[]{1, 1, 1, 3, 3, 5, 7, 9});
this.add(new Double[]{12.0, 12.0, 12.0, 14.0, 14.0, 13.0, 11.0, 15.0});
}
}};
try (Table expected = expectedBuilder.column((Integer[]) expectedData.get(0))
.column((String[]) expectedData.get(1))
.column((Integer[]) expectedData.get(2))
.column((Double[]) expectedData.get(3))
.build()) {
return new Table[]{new Table(input.getColumns()), new Table(expected.getColumns())};
}
}
}

private Table[] buildExplodeTestTableWithNestedTypes(boolean pos) {
StructType nestedType = new StructType(true,
new BasicType(false, DType.INT32), new BasicType(false, DType.STRING));
try (Table input = new Table.TestBuilder()
.column(new ListType(false, nestedType),
Arrays.asList(struct(1, "k1"), struct(2, "k2"), struct(3, "k3")),
Arrays.asList(struct(4, "k4"), struct(5, "k5")),
Arrays.asList(struct(6, "k6")),
Arrays.asList(new HostColumnVector.StructData((List) null)),
Arrays.asList())
.column("s1", "s2", "s3", "s4", "s5")
.column(1, 3, 5, 7, 9)
.column(12.0, 14.0, 13.0, 11.0, 15.0)
.build()) {
Table.TestBuilder expectedBuilder = new Table.TestBuilder();
if (pos) {
expectedBuilder.column(0, 1, 2, 0, 1, 0, 0);
}
try (Table expected = expectedBuilder
.column(nestedType,
struct(1, "k1"), struct(2, "k2"), struct(3, "k3"),
struct(4, "k4"), struct(5, "k5"), struct(6, "k6"),
new HostColumnVector.StructData((List) null))
.column("s1", "s1", "s1", "s2", "s2", "s3", "s4")
.column(1, 1, 1, 3, 3, 5, 7)
.column(12.0, 12.0, 12.0, 14.0, 14.0, 13.0, 11.0)
.build()) {
return new Table[]{new Table(input.getColumns()), new Table(expected.getColumns())};
}
}
}

@Test
void testExplode() {
// Child is primitive type
try (Table t1 = new Table.TestBuilder()
.column(new ListType(true, new BasicType(true, DType.INT32)),
Arrays.asList(1, 2, 3),
Arrays.asList(4, 5),
Arrays.asList(6),
null)
.column("s1", "s2", "s3", "s4")
.column( 1, 3, 5, 7)
.column(12.0, 14.0, 13.0, 11.0)
.build();
Table expected = new Table.TestBuilder()
.column( 1, 2, 3, 4, 5, 6)
.column("s1", "s1", "s1", "s2", "s2", "s3")
.column( 1, 1, 1, 3, 3, 5)
.column(12.0, 12.0, 12.0, 14.0, 14.0, 13.0)
.build()) {
try (Table exploded = t1.explode(0)) {
Table[] testTables = buildExplodeTestTableWithPrimitiveTypes(false, false);
try (Table input = testTables[0];
Table expected = testTables[1]) {
try (Table exploded = input.explode(0)) {
assertTablesAreEqual(expected, exploded);
}
}

// Child is nested type
StructType nestedType = new StructType(false,
new BasicType(false, DType.INT32), new BasicType(false, DType.STRING));
try (Table t1 = new Table.TestBuilder()
.column(new ListType(false, nestedType),
Arrays.asList(struct(1, "k1"), struct(2, "k2"), struct(3, "k3")),
Arrays.asList(struct(4, "k4"), struct(5, "k5")),
Arrays.asList(struct(6, "k6")))
.column("s1", "s2", "s3")
.column( 1, 3, 5)
.column(12.0, 14.0, 13.0)
.build();
Table expected = new Table.TestBuilder()
.column(nestedType,
struct(1, "k1"), struct(2, "k2"), struct(3, "k3"),
struct(4, "k4"), struct(5, "k5"), struct(6, "k6"))
.column("s1", "s1", "s1", "s2", "s2", "s3")
.column( 1, 1, 1, 3, 3, 5)
.column(12.0, 12.0, 12.0, 14.0, 14.0, 13.0)
.build()) {
try (Table exploded = t1.explode(0)) {
Table[] testTables2 = buildExplodeTestTableWithNestedTypes(false);
try (Table input = testTables2[0];
Table expected = testTables2[1]) {
try (Table exploded = input.explode(0)) {
assertTablesAreEqual(expected, exploded);
}
}
}

@Test
void testPosExplode() {
// Child is primitive type
Table[] testTables = buildExplodeTestTableWithPrimitiveTypes(true, false);
try (Table input = testTables[0];
Table expected = testTables[1]) {
try (Table exploded = input.explodePosition(0)) {
assertTablesAreEqual(expected, exploded);
}
}

// Child is primitive type
Table[] testTables2 = buildExplodeTestTableWithNestedTypes(true);
try (Table input = testTables2[0];
Table expected = testTables2[1]) {
try (Table exploded = input.explodePosition(0)) {
assertTablesAreEqual(expected, exploded);
}
}
Expand Down

0 comments on commit 7871e7a

Please sign in to comment.