Skip to content

Commit

Permalink
std/lib/sort: update quicksort to use the new generics
Browse files Browse the repository at this point in the history
Signed-off-by: Pierre Curto <[email protected]>
  • Loading branch information
pierrec authored and lerno committed Jul 8, 2023
1 parent 5f71140 commit 77b3214
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 104 deletions.
102 changes: 24 additions & 78 deletions lib/std/sort/quicksort.c3
Original file line number Diff line number Diff line change
@@ -1,99 +1,45 @@
module std::sort::quicksort(<Type>);
import std::sort;
module std::sort;
import std::sort::qs;

def ElementType = $typeof(Type{}[0]);
def Comparer = fn int(ElementType, ElementType);
def ComparerRef = fn int(ElementType*, ElementType*);

const bool ELEMENT_COMPARABLE = $checks(ElementType x, greater(x, x));

fn void sort_fn(Type list, Comparer cmp)
{
usz len = sort::@len_from_list(list);
qsort_value(list, 0, (isz)len - 1, cmp);
}

fn void sort_ref_fn(Type list, ComparerRef cmp)
{
usz len = sort::@len_from_list(list);
qsort_ref(list, 0, (isz)len - 1, cmp);
}

fn void sort(Type list) @if(ELEMENT_COMPARABLE)
macro quicksort(list, cmp = null)
{
var $Type = $typeof(list);
var $CmpType = $typeof(cmp);
usz len = sort::@len_from_list(list);
qsort(list, 0, (isz)len - 1);
qs::qsort(<$Type, $CmpType>)(list, 0, (isz)len - 1, cmp);
}

fn void qsort(Type list, isz low, isz high) @local @if(ELEMENT_COMPARABLE)
{
if (low < high)
{
isz p = partition(list, low, high);
qsort(list, low, p - 1);
qsort(list, p + 1, high);
}
}
module std::sort::qs(<Type, Comparer>);

fn void qsort_value(Type list, isz low, isz high, Comparer cmp) @local
{
if (low < high)
{
isz p = partition_value(list, low, high, cmp);
qsort_value(list, low, p - 1, cmp);
qsort_value(list, p + 1, high, cmp);
}
}
def ElementType = $typeof(Type{}[0]);

fn void qsort_ref(Type list, isz low, isz high, ComparerRef cmp) @local
fn void qsort(Type list, isz low, isz high, Comparer cmp)
{
if (low < high)
{
isz p = partition_ref(list, low, high, cmp);
qsort_ref(list, low, p - 1, cmp);
qsort_ref(list, p + 1, high, cmp);
isz p = partition(list, low, high, cmp);
qsort(list, low, p - 1, cmp);
qsort(list, p + 1, high, cmp);
}
}

fn isz partition(Type list, isz low, isz high) @inline @local
fn isz partition(Type list, isz low, isz high, Comparer cmp) @inline @local
{
ElementType pivot = list[high];
isz i = low - 1;
for (isz j = low; j < high; j++)
{
if (greater(list[j], pivot)) continue;
i++;
@swap(list[i], list[j]);
}
i++;
@swap(list[i], list[high]);
return i;
}

fn isz partition_value(Type list, isz low, isz high, Comparer cmp) @inline @private
{
ElementType pivot = list[high];
isz i = low - 1;
for (isz j = low; j < high; j++)
{
if (cmp(list[j], pivot) <= 0)
{
i++;
@swap(list[i], list[j]);
}
}
i++;
@swap(list[i], list[high]);
return i;
}

fn isz partition_ref(Type list, isz low, isz high, ComparerRef cmp) @inline @private
{
ElementType* pivot = &list[high];
isz i = low - 1;
for (isz j = low; j < high; j++)
{
if (cmp(&list[j], pivot) <= 0)
$if $checks(cmp(list[0], list[0])):
int res = cmp(list[j], pivot);
$else
$if $checks(cmp(&list[0], &list[0])):
int res = cmp(&list[j], &pivot);
$else
int res;
if (greater(list[j], pivot)) continue;
$endif
$endif
if (res <= 0)
{
i++;
@swap(list[i], list[j]);
Expand Down
4 changes: 2 additions & 2 deletions test/unit/stdlib/sort/binarysearch.c3
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ fn void binarysearch()
usz idx = sort::binarysearch(tc.data, tc.x);
assert(idx == tc.index, "%s: got %d; want %d", tc.data, idx, tc.index);

usz cmp_idx = sort::binarysearch_with(tc.data, tc.x, &sort::cmp_int);
usz cmp_idx = sort::binarysearch_with(tc.data, tc.x, &sort::cmp_int_ref);
assert(cmp_idx == tc.index, "%s: got %d; want %d", tc.data, cmp_idx, tc.index);

usz cmp_idx2 = sort::binarysearch_with(tc.data, tc.x, &sort::cmp_int2);
usz cmp_idx2 = sort::binarysearch_with(tc.data, tc.x, &sort::cmp_int_value);
assert(cmp_idx2 == tc.index, "%s: got %d; want %d", tc.data, cmp_idx2, tc.index);

usz cmp_idx3 = sort::binarysearch_with(tc.data, tc.x, fn int(int a, int b) => a - b);
Expand Down
34 changes: 13 additions & 21 deletions test/unit/stdlib/sort/quicksort.c3
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
module sort_test @test;
import std::sort;
import std::sort::quicksort;

def qs_int = quicksort::sort(<int[]>);
import sort::check;

fn void quicksort()
{
Expand All @@ -16,16 +14,12 @@ fn void quicksort()

foreach (tc : tcases)
{
qs_int(tc);
assert(sort::check_int_sort(tc));
sort::quicksort(tc);
assert(check::int_sort(tc));
}
}

def Cmp = fn int(int*, int*);

def qs_int_ref = quicksort::sort_ref_fn(<int[]>);

fn void quicksort_with()
fn void quicksort_with_ref()
{
int[][] tcases = {
{},
Expand All @@ -37,14 +31,12 @@ fn void quicksort_with()

foreach (tc : tcases)
{
qs_int_ref(tc, (Cmp)&sort::cmp_int);
assert(sort::check_int_sort(tc));
sort::quicksort(tc, &sort::cmp_int_ref);
assert(check::int_sort(tc));
}
}

def qs_int_fn = quicksort::sort_fn(<int[]>);

fn void quicksort_with2()
fn void quicksort_with_value()
{
int[][] tcases = {
{},
Expand All @@ -56,8 +48,8 @@ fn void quicksort_with2()

foreach (tc : tcases)
{
qs_int_fn(tc, &sort::cmp_int2);
assert(sort::check_int_sort(tc));
sort::quicksort(tc, &sort::cmp_int_value);
assert(check::int_sort(tc));
}
}

Expand All @@ -73,14 +65,14 @@ fn void quicksort_with_lambda()

foreach (tc : tcases)
{
qs_int_fn(tc, fn int(int a, int b) => a - b);
assert(sort::check_int_sort(tc));
sort::quicksort(tc, fn int(int a, int b) => a - b);
assert(check::int_sort(tc));
}
}

module std::sort;
module sort::check;

fn bool check_int_sort(int[] list)
fn bool int_sort(int[] list)
{
int prev = int.min;
foreach (x : list)
Expand Down
6 changes: 3 additions & 3 deletions test/unit/stdlib/sort/sort.c3
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module std::sort;

fn int cmp_int(void* x, void* y) {
return *(int*)x - *(int*)y;
fn int cmp_int_ref(int* x, int* y) {
return *x - *y;
}

fn int cmp_int2(int x, int y) {
fn int cmp_int_value(int x, int y) {
return x - y;
}

0 comments on commit 77b3214

Please sign in to comment.