Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup iterator.cuh and add fixed point support for scalar_optional_accessor #10999

Merged
merged 6 commits into from
Jun 1, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 18 additions & 48 deletions cpp/include/cudf/detail/iterator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ struct null_replaced_value_accessor {
if (has_nulls) CUDF_EXPECTS(col.nullable(), "column with nulls must have a validity bitmask");
}

__device__ inline Element operator()(cudf::size_type i) const
__device__ inline Element const operator()(cudf::size_type i) const
{
return has_nulls && col.is_null_nocheck(i) ? null_replacement : col.element<Element>(i);
}
Expand Down Expand Up @@ -329,9 +329,12 @@ CUDF_HOST_DEVICE auto inline make_validity_iterator(column_device_view const& co
*
* For `p = *(iter + i)`, `p` is the validity of the scalar.
*
* @tparam bool unused. This template parameter exists to enforce the same
* template interface as @ref make_validity_iterator(column_device_view const&).
* @param scalar_value The scalar to iterate
* @return auto Iterator that returns scalar validity
*/
template <bool safe = false>
bdice marked this conversation as resolved.
Show resolved Hide resolved
auto inline make_validity_iterator(scalar const& scalar_value)
{
return thrust::make_constant_iterator(scalar_value.is_valid());
Expand All @@ -358,21 +361,7 @@ struct scalar_value_accessor {
"the data type mismatch");
}

/**
* @brief returns the value of the scalar.
*
* @throw `cudf::logic_error` if this function is called in host.
*
* @return value of the scalar.
*/
__device__ inline const Element operator()(size_type) const
{
#if defined(__CUDA_ARCH__)
return dscalar.value();
#else
CUDF_FAIL("unsupported device scalar iterator operation");
#endif
}
__device__ inline Element const operator()(size_type) const { return dscalar.value(); }
bdice marked this conversation as resolved.
Show resolved Hide resolved
};

/**
Expand Down Expand Up @@ -436,20 +425,19 @@ struct scalar_optional_accessor : public scalar_value_accessor<Element> {
{
}

/**
* @brief returns a thrust::optional<Element>.
*
* @throw `cudf::logic_error` if this function is called in host.
*
* @return a thrust::optional<Element> for the scalar value.
*/
CUDF_HOST_DEVICE inline const value_type operator()(size_type) const
__device__ inline value_type const operator()(size_type) const
{
if (has_nulls) {
return (super_t::dscalar.is_valid()) ? Element{super_t::dscalar.value()}
: value_type{thrust::nullopt};
if (has_nulls && !super_t::dscalar.is_valid()) { return value_type{thrust::nullopt}; }

if constexpr (cudf::is_fixed_point<Element>()) {
using namespace numeric;
using rep = typename Element::rep;
auto const value = super_t::dscalar.rep();
auto const scale = scale_type{super_t::dscalar.type().scale()};
return Element{scaled_integer<rep>{value, scale}};
} else {
return Element{super_t::dscalar.value()};
}
return Element{super_t::dscalar.value()};
}

Nullate has_nulls{};
Expand All @@ -469,20 +457,9 @@ struct scalar_pair_accessor : public scalar_value_accessor<Element> {
using value_type = thrust::pair<Element, bool>;
scalar_pair_accessor(scalar const& scalar_value) : scalar_value_accessor<Element>(scalar_value) {}

/**
* @brief returns a pair with value and validity of the scalar.
*
* @throw `cudf::logic_error` if this function is called in host.
*
* @return a pair with value and validity of the scalar.
*/
CUDF_HOST_DEVICE inline const value_type operator()(size_type) const
__device__ inline value_type const operator()(size_type) const
{
#if defined(__CUDA_ARCH__)
return {Element(super_t::dscalar.value()), super_t::dscalar.is_valid()};
#else
CUDF_FAIL("unsupported device scalar iterator operation");
#endif
}
};

Expand Down Expand Up @@ -520,14 +497,7 @@ struct scalar_representation_pair_accessor : public scalar_value_accessor<Elemen

scalar_representation_pair_accessor(scalar const& scalar_value) : base(scalar_value) {}

/**
* @brief returns a pair with representative value and validity of the scalar.
*
* @throw `cudf::logic_error` if this function is called in host.
*
* @return a pair with representative value and validity of the scalar.
*/
__device__ inline const value_type operator()(size_type) const
__device__ inline value_type const operator()(size_type) const
{
return {get_rep(base::dscalar), base::dscalar.is_valid()};
}
Expand Down