From ed00a4e0f049f4607301e7c9ad08f5eae25387b8 Mon Sep 17 00:00:00 2001 From: Guy Blelloch Date: Tue, 27 Feb 2024 09:55:29 -0500 Subject: [PATCH] updated counting sort --- examples/counting_sort.cpp | 9 +++--- examples/counting_sort.h | 64 ++++++++++++++++++-------------------- 2 files changed, 35 insertions(+), 38 deletions(-) diff --git a/examples/counting_sort.cpp b/examples/counting_sort.cpp index 6a33807..3f8cb98 100644 --- a/examples/counting_sort.cpp +++ b/examples/counting_sort.cpp @@ -21,7 +21,6 @@ int main(int argc, char* argv[]) { catch (...) { std::cout << usage << std::endl; return 1; } long num_buckets = 256; - long num_partitions = std::min(1000l, n / (256 * 16) + 1); parlay::random_generator gen; std::uniform_int_distribution dis(0, num_buckets - 1); @@ -32,11 +31,13 @@ int main(int argc, char* argv[]) { return dis(r);}); parlay::internal::timer t("Time"); - parlay::sequence result; + parlay::sequence result(n); for (int i=0; i < 5; i++) { - result = data; t.start(); - result = counting_sort(data, data, num_buckets, num_partitions); + counting_sort(data.begin(), data.end(), + result.begin(), + data.begin(), + num_buckets); t.next("counting_sort"); } diff --git a/examples/counting_sort.h b/examples/counting_sort.h index e3d5a84..32e7e12 100644 --- a/examples/counting_sort.h +++ b/examples/counting_sort.h @@ -5,57 +5,53 @@ #include #include -// ************************************************************** -// A parallel counting sort -// Works well for a smallish (e.g. up to 256 or perhaps 1000) buckets. -// Input is a sequence of values, and a range of keys (of equal length). -// They could be the same, or the keys could be a field from the values. -// Must also specify the number of buckets and the number of paritions. -// ************************************************************** - -template -parlay::sequence counting_sort(const parlay::sequence& in, const Keys& keys, - long num_buckets, long num_parts) { - long n = in.size(); +template +parlay::sequence +counting_sort(const InIt& begin, const InIt& end, + OutIt out, const KeyIt& keys, + long num_buckets) { + long n = end - begin; + long num_parts = n / (num_buckets * 64) + 1; long part_size = (n - 1)/num_parts + 1; - // For each partition count number of each of the key values - auto all_counts = parlay::tabulate(num_parts, [&] (long i) { + // first count buckets within each partition + auto counts = parlay::sequence::uninitialized(num_buckets * num_parts); + parlay::parallel_for(0, num_parts, [&] (long i) { long start = i * part_size; long end = std::min(start + part_size, n); - parlay::sequence local_counts(num_buckets, 0); - for (size_t j = start; j < end; j++) local_counts[keys[j]]++; - return local_counts;}, 1); - - // need to transpose the counts for the scan - auto counts = parlay::sequence::uninitialized(num_buckets * num_parts); - parlay::parallel_for(0, num_buckets, [&] (long i) { - for (size_t j = 0; j < num_parts; j++) - counts[i* num_parts + j] = all_counts[j][i];}, 1); - all_counts.clear(); + for (int j = 0; j < num_buckets; j++) counts[i*num_buckets + j] = 0; + for (size_t j = start; j < end; j++) counts[i*num_buckets + keys[j]]++; + }, 1); + + // transpose the counts if more than one part + parlay::sequence trans_counts; + if (num_parts > 1) { + trans_counts = parlay::sequence::uninitialized(num_buckets * num_parts); + parlay::parallel_for(0, num_buckets, [&] (long i) { + for (size_t j = 0; j < num_parts; j++) + trans_counts[i* num_parts + j] = counts[j * num_buckets + i];}, 1); + } else trans_counts = std::move(counts); // scan for offsets for all buckets - parlay::scan_inplace(counts); + parlay::scan_inplace(trans_counts); - // the ouput sequence - auto out = parlay::sequence::uninitialized(n); - - // go back over partitions to place the input in final location + // go back over partitions to place in final location parlay::parallel_for(0, num_parts, [&] (long i) { long start = i * part_size; long end = std::min(start + part_size, n); - parlay::sequence local_offsets(num_buckets, 0); + int local_offsets[num_buckets]; // transpose back for (int j = 0; j < num_buckets; j++) - local_offsets[j] = counts[num_parts * j + i]; + local_offsets[j] = trans_counts[num_parts * j + i]; // copy to output for (size_t j = start; j < end; j++) { int k = local_offsets[keys[j]]++; - // the following line helps performance __builtin_prefetch (((char*) &out[k]) + 64); - out[k] = in[j]; + out[k] = begin[j]; }}, 1); - return out; + + return parlay::tabulate(num_buckets, [&] (long i) { + return trans_counts[i * num_parts];}); }