Skip to content

Commit

Permalink
change quickSort to not use recursion (#1531)
Browse files Browse the repository at this point in the history
this is intended to address AMReX-Astro/Castro#2818 as suggested by @WeiqunZhang
  • Loading branch information
zhichen3 authored Apr 9, 2024
1 parent e5fecb7 commit 43dc7a5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 11 deletions.
2 changes: 1 addition & 1 deletion nse_solver/nse_check.H
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ void nse_grouping(amrex::Array1D<int, 1, NumSpec>& group_ind, const burn_t& stat
// from smallest (fastest) to largest (slowest) timescale
//

quickSort_Array1D(rate_indices, reaction_timescales, 1, Rates::NumRates);
quickSort_Array1D(rate_indices, reaction_timescales);

// After the rate indices are sorted based on reaction timescales.
// Now do the grouping based on the timescale.
Expand Down
46 changes: 36 additions & 10 deletions util/microphysics_sort.H
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,16 @@ template <typename T, typename P, int l, int m>
AMREX_GPU_HOST_DEVICE AMREX_INLINE
int partition(amrex::Array1D<T, l, m>& sort_array,
amrex::Array1D<P, l, m>& metric_array,
const int low, const int high, const bool ascending=true) {
int low, int high, const bool ascending=true) {
// Helper function for quickSort
// It uses the metric array as metric, but swaps sort_array

// T pivot = sort_array(high);
// Choose pivot to be the last element

P pivot = metric_array(high);
int i = low - 1;
for (int j = low; j <= high; j++) {
// if (compare(sort_array(j), pivot, ascending)) {
if (compare(metric_array(j), pivot, ascending)) {
// if (compare(metric_array(sort_array(j)), pivot, ascending)) {
i++;
swap(sort_array(i), sort_array(j));
swap(metric_array(i), metric_array(j));
Expand All @@ -78,23 +77,50 @@ template <typename T, typename P, int l, int m>
AMREX_GPU_HOST_DEVICE AMREX_INLINE
void quickSort_Array1D(amrex::Array1D<T, l, m>& sort_array,
amrex::Array1D<P, l, m>& metric_array,
const int low, const int high,
const bool ascending=true) {
// quickSort implementation
// This implementation uses the metric array as the sorting metric
// to sort BOTH the sort_array and metric_array

if (low < high) {
// Create a stack to keep track of the low and high index to the
// left and right of the pivot

amrex::Array1D<int, l, m> stack {0};
int top = l - 1;

// Set initial values of the range.

stack(++top) = l;
stack(++top) = m;

// Keep popping from stack while it is not empty

while (top >= l) {

// Get high and low

// pi is the partition return index of pivot
int high = stack(top--);
int low = stack(top--);

// Get the pivot index, which splits the array into two parts

int pi = partition(sort_array, metric_array, low, high, ascending);

// Recursive quick sort calls.
// Record the low and high index to the left of the pivot

if (pi - 1 > low) {
stack(++top) = low;
stack(++top) = pi - 1;
}

// Record the low and high index to the right of the pivot

quickSort_Array1D(sort_array, metric_array, low, pi-1, ascending);
quickSort_Array1D(sort_array, metric_array, pi+1, high, ascending);
if (pi + 1 < high) {
stack(++top) = pi + 1;
stack(++top) = high;
}
}

}


Expand Down

0 comments on commit 43dc7a5

Please sign in to comment.