Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
lithomas1 committed Aug 1, 2024
1 parent 49a3694 commit 97e740c
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 13 deletions.
2 changes: 1 addition & 1 deletion python/cudf/cudf/_lib/pylibcudf/io/orc.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ cpdef TableWithMetadata read_orc(
list columns = *,
list stripes = *,
size_type skip_rows = *,
size_type num_rows = *,
size_type nrows = *,
bool use_index = *,
bool use_np_dtypes = *,
DataType timestamp_type = *,
Expand Down
8 changes: 4 additions & 4 deletions python/cudf/cudf/_lib/pylibcudf/io/orc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ cpdef TableWithMetadata read_orc(
list columns = None,
list stripes = None,
size_type skip_rows = 0,
size_type num_rows = -1,
size_type nrows = -1,
bool use_index = True,
bool use_np_dtypes = True,
DataType timestamp_type = DataType(type_id.EMPTY),
Expand All @@ -228,7 +228,7 @@ cpdef TableWithMetadata read_orc(
List of stripes to be read.
skip_rows : int64_t, default 0
The number of rows to skip from the start of the file.
num_rows : size_type, default -1
nrows : size_type, default -1
The number of rows to read. By default, read the entire file.
use_index : bool, default True
Whether to use the row index to speed up reading.
Expand All @@ -247,8 +247,8 @@ cpdef TableWithMetadata read_orc(
.use_index(use_index)
.build()
)
if num_rows >= 0:
opts.set_num_rows(num_rows)
if nrows >= 0:
opts.set_num_rows(nrows)
if skip_rows >= 0:
opts.set_skip_rows(skip_rows)
if stripes is not None:
Expand Down
28 changes: 22 additions & 6 deletions python/cudf/cudf/pylibcudf_tests/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pyarrow as pa
import pytest
from pyarrow.orc import write_table as orc_write_table
from pyarrow.parquet import write_table as pq_write_table

from cudf._lib import pylibcudf as plc
Expand Down Expand Up @@ -240,13 +241,20 @@ def is_nested_list(typ):
return nesting_level(typ)[0] > 1


def _convert_numeric_types_to_floating(pa_table):
def _convert_types(pa_table, input_pred, result_type):
"""
Useful little helper for testing the
dtypes option in I/O readers.
Returns a tuple containing the pylibcudf dtypes
and the new pyarrow schema
Parameters
----------
input_pred : function
Predicate that evaluates to true for types to replace
result_type : pa.DataType
The type to cast to
"""
dtypes = []
new_fields = []
Expand All @@ -255,11 +263,9 @@ def _convert_numeric_types_to_floating(pa_table):
child_types = []

plc_type = plc.interop.from_arrow(field.type)
if pa.types.is_integer(field.type) or pa.types.is_unsigned_integer(
field.type
):
plc_type = plc.interop.from_arrow(pa.float64())
field = field.with_type(pa.float64())
if input_pred(field.type):
plc_type = plc.interop.from_arrow(result_type)
field = field.with_type(result_type)

dtypes.append((field.name, plc_type, child_types))

Expand Down Expand Up @@ -330,6 +336,16 @@ def make_source(path_or_buf, pa_table, format, **kwargs):
if isinstance(path_or_buf, io.IOBase)
else path_or_buf,
)
elif format == "orc":
# The conversion to pandas is lossy (doesn't preserve
# nested types) so we
# will just use pyarrow directly to write this
orc_write_table(
pa_table,
pa.PythonFile(path_or_buf)
if isinstance(path_or_buf, io.IOBase)
else path_or_buf,
)
if isinstance(path_or_buf, io.IOBase):
path_or_buf.seek(0)
return path_or_buf
Expand Down
9 changes: 7 additions & 2 deletions python/cudf/cudf/pylibcudf_tests/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pyarrow as pa
import pytest
from utils import (
_convert_numeric_types_to_floating,
_convert_types,
assert_table_and_meta_eq,
make_source,
write_source_str,
Expand Down Expand Up @@ -149,7 +149,12 @@ def test_read_csv_dtypes(csv_table_data, source_or_sink, usecols):
if usecols is not None:
pa_table = pa_table.select(usecols)

dtypes, new_fields = _convert_numeric_types_to_floating(pa_table)
dtypes, new_fields = _convert_types(
pa_table,
lambda typ: pa.types.is_unsigned_integer(typ)
or pa.types.is_integer(typ),
pa.float64(),
)
# Extract the dtype out of the (name, type, child_types) tuple
# (read_csv doesn't support this format since it doesn't support nested columns)
dtypes = {name: dtype for name, dtype, _ in dtypes}
Expand Down
54 changes: 54 additions & 0 deletions python/cudf/cudf/pylibcudf_tests/io/test_orc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
import pyarrow as pa
import pytest
from utils import _convert_types, assert_table_and_meta_eq, make_source

import cudf._lib.pylibcudf as plc

# Shared kwargs to pass to make_source
_COMMON_ORC_SOURCE_KWARGS = {"format": "orc"}


@pytest.mark.parametrize("columns", [None, ["col_int64", "col_bool"]])
def test_read_orc_basic(
table_data, binary_source_or_sink, nrows_skiprows, columns
):
_, pa_table = table_data
nrows, skiprows = nrows_skiprows

# ORC reader doesn't support skip_rows for nested columns
if skiprows > 0:
colnames_to_drop = []
for i in range(len(pa_table.schema)):
field = pa_table.schema.field(i)

if pa.types.is_nested(field.type):
colnames_to_drop.append(field.name)
pa_table = pa_table.drop(colnames_to_drop)
# ORC doesn't support unsigned ints
# let's cast to int64
_, new_fields = _convert_types(
pa_table, pa.types.is_unsigned_integer, pa.int64()
)
pa_table = pa_table.cast(pa.schema(new_fields))

source = make_source(
binary_source_or_sink, pa_table, **_COMMON_ORC_SOURCE_KWARGS
)

res = plc.io.orc.read_orc(
plc.io.SourceInfo([source]),
nrows=nrows,
skip_rows=skiprows,
columns=columns,
)

if columns is not None:
pa_table = pa_table.select(columns)

# Adapt to nrows/skiprows
pa_table = pa_table.slice(
offset=skiprows, length=nrows if nrows != -1 else None
)

assert_table_and_meta_eq(pa_table, res, check_field_nullability=False)

0 comments on commit 97e740c

Please sign in to comment.