Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Add unique_count algorithm #1619

Merged
merged 10 commits into from
May 7, 2022
125 changes: 125 additions & 0 deletions testing/cuda/unique.cu
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,128 @@ void TestUniqueCopyCudaStreamsNoSync()
}
DECLARE_UNITTEST(TestUniqueCopyCudaStreamsNoSync);


template<typename ExecutionPolicy, typename Iterator1, typename Iterator2>
__global__
void unique_count_kernel(ExecutionPolicy exec, Iterator1 first, Iterator1 last, Iterator2 result)
{
*result = thrust::unique_count(exec, first, last);
}


template<typename ExecutionPolicy, typename Iterator1, typename BinaryPredicate, typename Iterator2>
__global__
void unique_count_kernel(ExecutionPolicy exec, Iterator1 first, Iterator1 last, BinaryPredicate pred, Iterator2 result)
{
*result = thrust::unique_count(exec, first, last, pred);
}


template<typename ExecutionPolicy>
void TestUniqueCountDevice(ExecutionPolicy exec)
{
typedef thrust::device_vector<int> Vector;
typedef Vector::value_type T;

Vector data(10);
data[0] = 11;
data[1] = 11;
data[2] = 12;
data[3] = 20;
data[4] = 29;
data[5] = 21;
data[6] = 21;
data[7] = 31;
data[8] = 31;
data[9] = 37;

Vector output(1, -1);

unique_count_kernel<<<1,1>>>(exec, data.begin(), data.end(), output.begin());
{
cudaError_t const err = cudaDeviceSynchronize();
ASSERT_EQUAL(cudaSuccess, err);
}

ASSERT_EQUAL(output[0], 7);

unique_count_kernel<<<1,1>>>(exec, data.begin(), data.end(), is_equal_div_10_unique<T>(), output.begin());
{
cudaError_t const err = cudaDeviceSynchronize();
ASSERT_EQUAL(cudaSuccess, err);
}

ASSERT_EQUAL(output[0], 3);
}


void TestUniqueCountDeviceSeq()
{
TestUniqueCountDevice(thrust::seq);
}
DECLARE_UNITTEST(TestUniqueCountDeviceSeq);


void TestUniqueCountDeviceDevice()
{
TestUniqueCountDevice(thrust::device);
}
DECLARE_UNITTEST(TestUniqueCountDeviceDevice);


void TestUniqueCountDeviceNoSync()
{
TestUniqueCountDevice(thrust::cuda::par_nosync);
}
DECLARE_UNITTEST(TestUniqueCountDeviceNoSync);


template<typename ExecutionPolicy>
void TestUniqueCountCudaStreams(ExecutionPolicy policy)
{
typedef thrust::device_vector<int> Vector;
typedef Vector::value_type T;

Vector data(10);
data[0] = 11;
data[1] = 11;
data[2] = 12;
data[3] = 20;
data[4] = 29;
data[5] = 21;
data[6] = 21;
data[7] = 31;
data[8] = 31;
data[9] = 37;

cudaStream_t s;
cudaStreamCreate(&s);

auto streampolicy = policy.on(s);

int result = thrust::unique_count(streampolicy, data.begin(), data.end());
cudaStreamSynchronize(s);

ASSERT_EQUAL(result, 7);

result = thrust::unique_count(streampolicy, data.begin(), data.end(), is_equal_div_10_unique<T>());
cudaStreamSynchronize(s);

ASSERT_EQUAL(result, 3);

cudaStreamDestroy(s);
}

void TestUniqueCountCudaStreamsSync()
{
TestUniqueCountCudaStreams(thrust::cuda::par);
}
DECLARE_UNITTEST(TestUniqueCountCudaStreamsSync);


void TestUniqueCountCudaStreamsNoSync()
{
TestUniqueCountCudaStreams(thrust::cuda::par_nosync);
}
DECLARE_UNITTEST(TestUniqueCountCudaStreamsNoSync);

89 changes: 89 additions & 0 deletions testing/unique.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,50 @@ void TestUniqueCopyDispatchImplicit()
DECLARE_UNITTEST(TestUniqueCopyDispatchImplicit);


template <typename ForwardIterator>
typename thrust::iterator_traits<ForwardIterator>::difference_type
unique_count(my_system &system,
ForwardIterator,
ForwardIterator)
{
system.validate_dispatch();
return 0;
}

void TestUniqueCountDispatchExplicit()
{
thrust::device_vector<int> vec(1);

my_system sys(0);
thrust::unique_count(sys, vec.begin(), vec.begin());

ASSERT_EQUAL(true, sys.is_valid());
}
DECLARE_UNITTEST(TestUniqueCountDispatchExplicit);


template <typename ForwardIterator>
typename thrust::iterator_traits<ForwardIterator>::difference_type
unique_count(my_tag,
ForwardIterator,
ForwardIterator)
{
return 13;
}

void TestUniqueCountDispatchImplicit()
{
thrust::device_vector<int> vec(1);

auto result = thrust::unique_count(
thrust::retag<my_tag>(vec.begin()),
thrust::retag<my_tag>(vec.begin()));

ASSERT_EQUAL(13, result);
}
DECLARE_UNITTEST(TestUniqueCountDispatchImplicit);


template<typename T>
struct is_equal_div_10_unique
{
Expand Down Expand Up @@ -266,3 +310,48 @@ struct TestUniqueCopyToDiscardIterator
VariableUnitTest<TestUniqueCopyToDiscardIterator, IntegralTypes> TestUniqueCopyToDiscardIteratorInstance;


template <typename Vector>
void TestUniqueCountSimple(void)
{
typedef typename Vector::value_type T;

Vector data(10);
data[0] = 11;
data[1] = 11;
data[2] = 12;
data[3] = 20;
data[4] = 29;
data[5] = 21;
data[6] = 21;
data[7] = 31;
data[8] = 31;
data[9] = 37;

int count = thrust::unique_count(data.begin(), data.end());

ASSERT_EQUAL(count, 7);

int div_10_count = thrust::unique_count(data.begin(), data.end(), is_equal_div_10_unique<T>());

ASSERT_EQUAL(div_10_count, 3);
}
DECLARE_INTEGRAL_VECTOR_UNITTEST(TestUniqueCountSimple);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unique_count claims to work with forward iterators. There should be a test using forward iterators.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a nice way to wrap an iterator into a forward_iterator in Thrust? I wrote a small wrapper class and that seems to work and compile, but I suppose that problem has been solved elsewhere already?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericniebler is right, we should be testing this. Unfortunately we lack robust testing for these sorts of things.

Thanks for adding the new testing infrastructure! Please include it in this PR, ideally in the testing framework so we can reuse it from other tests later 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test to unique_copy and unique. I am not 100% sure it does what we would expect - due to the missing iterator tag, it gets executed sequentially on the CPU using device references for access.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean, forward iterators always dispatch to the CPU? @allisonvacanti can you comment on that? I mean, it seems reasonable to me.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounded odd to me, as I've never seen any logic in Thrust that would dispatch forward iterators to serial implementations. So I dug into it, and unfortunately this is due to a pretty nasty bug in Thrust's iterator traits.

The details are gory, but I've summarized them in a comment on #902.

Until that's fixed, I'm not comfortable merging this test using the forward_iterator_wrapper, since they only "do the right thing" because the iterator framework is broken.

I hate to churn on this PR even more, but I think we should remove the iterator wrappers for now and just test that the regular iterators work. We can re-introduce the wrapper tests as part of #55, after #902 is fixed and settled.

@ericniebler Can you review the two linked issues and see if you agree with my suggestion?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that forward iterators actually need to be dispatched to the sequential backend. They support multipass reading and should be usable in a parallel algorithm, so long as they're only copied and incremented. Is there something in the unique_count/count_if algorithms that would break them?

Copy link
Contributor Author

@upsj upsj May 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main issue I see with parallel execution on forward iterators is that they introduce an essentially linear dependency chain that means either every thread i starts from begin and increments it i times, or waits until one of its predecessors j is done and writes its iterator somewhere, to then increments it i - j times. Both don't really seem useful for parallel execution to me.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will require more increments, but if the work-per-element is expensive compared to the cost of the iterator increment, it may still make sense to parallelize. I'd rather let the user make that call, since they can opt-in to sequential execution by passing in the sequential exeuction policy (thrust::seq).

More importantly, the sequential implementation executes on CPU, and some types of device memory aren't accessible from the CPU's address space, so switching to seq really needs to be opt-in rather than default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification, that makes sense! I was only thinking of simple but massively parallel cases.


template <typename T>
struct TestUniqueCount
{
void operator()(const size_t n)
{
thrust::host_vector<T> h_data = unittest::random_integers<bool>(n);
thrust::device_vector<T> d_data = h_data;

int h_count{};
int d_count{};

h_count = thrust::unique_count(h_data.begin(), h_data.end());
d_count = thrust::unique_count(d_data.begin(), d_data.end());

ASSERT_EQUAL(h_count, d_count);
}
};
VariableUnitTest<TestUniqueCount, IntegralTypes> TestUniqueCountInstance;
2 changes: 1 addition & 1 deletion thrust/count.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,4 +228,4 @@ template <typename InputIterator, typename Predicate>

THRUST_NAMESPACE_END

#include <thrust/detail/count.inl>
#include <thrust/detail/count.h>
60 changes: 60 additions & 0 deletions thrust/detail/count.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright 2008-2013 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <thrust/detail/config.h>
#include <thrust/detail/execution_policy.h>

THRUST_NAMESPACE_BEGIN

template<typename DerivedPolicy,
typename InputIterator,
typename EqualityComparable>
__host__ __device__
typename thrust::iterator_traits<InputIterator>::difference_type
count(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
InputIterator first,
InputIterator last,
const EqualityComparable& value);

template<typename DerivedPolicy,
typename InputIterator,
typename Predicate>
__host__ __device__
typename thrust::iterator_traits<InputIterator>::difference_type
count_if(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
InputIterator first,
InputIterator last,
Predicate pred);

template <typename InputIterator,
typename EqualityComparable>
typename thrust::iterator_traits<InputIterator>::difference_type
count(InputIterator first,
InputIterator last,
const EqualityComparable& value);

template <typename InputIterator,
typename Predicate>
typename thrust::iterator_traits<InputIterator>::difference_type
count_if(InputIterator first,
InputIterator last,
Predicate pred);

THRUST_NAMESPACE_END

#include <thrust/detail/count.inl>
61 changes: 61 additions & 0 deletions thrust/detail/unique.inl
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,67 @@ template<typename InputIterator1,
return thrust::unique_by_key_copy(select_system(system1,system2,system3,system4), keys_first, keys_last, values_first, keys_output, values_output, binary_pred);
} // end unique_by_key_copy()

__thrust_exec_check_disable__
template <typename DerivedPolicy,
typename ForwardIterator,
typename BinaryPredicate>
__host__ __device__
typename thrust::iterator_traits<ForwardIterator>::difference_type
unique_count(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
ForwardIterator first,
ForwardIterator last,
BinaryPredicate binary_pred)
{
using thrust::system::detail::generic::unique_count;
return unique_count(thrust::detail::derived_cast(thrust::detail::strip_const(exec)), first, last, binary_pred);
} // end unique_count()

__thrust_exec_check_disable__
template <typename DerivedPolicy,
typename ForwardIterator>
__host__ __device__
typename thrust::iterator_traits<ForwardIterator>::difference_type
unique_count(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
ForwardIterator first,
ForwardIterator last)
{
using thrust::system::detail::generic::unique_count;
return unique_count(thrust::detail::derived_cast(thrust::detail::strip_const(exec)), first, last);
} // end unique_count()

__thrust_exec_check_disable__
template <typename ForwardIterator,
typename BinaryPredicate>
__host__ __device__
typename thrust::iterator_traits<ForwardIterator>::difference_type
unique_count(ForwardIterator first,
ForwardIterator last,
BinaryPredicate binary_pred)
{
using thrust::system::detail::generic::select_system;

typedef typename thrust::iterator_system<ForwardIterator>::type System;

System system;

return thrust::unique_count(select_system(system), first, last, binary_pred);
} // end unique_count()

__thrust_exec_check_disable__
template <typename ForwardIterator>
__host__ __device__
typename thrust::iterator_traits<ForwardIterator>::difference_type
unique_count(ForwardIterator first,
ForwardIterator last)
{
using thrust::system::detail::generic::select_system;

typedef typename thrust::iterator_system<ForwardIterator>::type System;

System system;

return thrust::unique_count(select_system(system), first, last);
} // end unique_count()

THRUST_NAMESPACE_END

Loading