From 77b32147461a4b87f328bde3749eddcc31b3527c Mon Sep 17 00:00:00 2001 From: Pierre Curto Date: Sat, 8 Jul 2023 11:40:51 +0200 Subject: [PATCH] std/lib/sort: update quicksort to use the new generics Signed-off-by: Pierre Curto --- lib/std/sort/quicksort.c3 | 110 +++++++------------------- test/unit/stdlib/sort/binarysearch.c3 | 4 +- test/unit/stdlib/sort/quicksort.c3 | 34 +++----- test/unit/stdlib/sort/sort.c3 | 6 +- 4 files changed, 46 insertions(+), 108 deletions(-) diff --git a/lib/std/sort/quicksort.c3 b/lib/std/sort/quicksort.c3 index 0b40d6a07..6fa41177f 100644 --- a/lib/std/sort/quicksort.c3 +++ b/lib/std/sort/quicksort.c3 @@ -1,99 +1,45 @@ -module std::sort::quicksort(); -import std::sort; +module std::sort; +import std::sort::qs; + +macro quicksort(list, cmp = null) +{ + var $Type = $typeof(list); + var $CmpType = $typeof(cmp); + usz len = sort::@len_from_list(list); + qs::qsort(<$Type, $CmpType>)(list, 0, (isz)len - 1, cmp); +} + +module std::sort::qs(); def ElementType = $typeof(Type{}[0]); -def Comparer = fn int(ElementType, ElementType); -def ComparerRef = fn int(ElementType*, ElementType*); -const bool ELEMENT_COMPARABLE = $checks(ElementType x, greater(x, x)); - -fn void sort_fn(Type list, Comparer cmp) -{ - usz len = sort::@len_from_list(list); - qsort_value(list, 0, (isz)len - 1, cmp); -} - -fn void sort_ref_fn(Type list, ComparerRef cmp) -{ - usz len = sort::@len_from_list(list); - qsort_ref(list, 0, (isz)len - 1, cmp); -} - -fn void sort(Type list) @if(ELEMENT_COMPARABLE) -{ - usz len = sort::@len_from_list(list); - qsort(list, 0, (isz)len - 1); -} - -fn void qsort(Type list, isz low, isz high) @local @if(ELEMENT_COMPARABLE) +fn void qsort(Type list, isz low, isz high, Comparer cmp) { if (low < high) { - isz p = partition(list, low, high); - qsort(list, low, p - 1); - qsort(list, p + 1, high); + isz p = partition(list, low, high, cmp); + qsort(list, low, p - 1, cmp); + qsort(list, p + 1, high, cmp); } } -fn void qsort_value(Type list, isz low, isz high, Comparer cmp) @local -{ - if (low < high) - { - isz p = partition_value(list, low, high, cmp); - qsort_value(list, low, p - 1, cmp); - qsort_value(list, p + 1, high, cmp); - } -} - -fn void qsort_ref(Type list, isz low, isz high, ComparerRef cmp) @local -{ - if (low < high) - { - isz p = partition_ref(list, low, high, cmp); - qsort_ref(list, low, p - 1, cmp); - qsort_ref(list, p + 1, high, cmp); - } -} - -fn isz partition(Type list, isz low, isz high) @inline @local +fn isz partition(Type list, isz low, isz high, Comparer cmp) @inline @local { ElementType pivot = list[high]; isz i = low - 1; for (isz j = low; j < high; j++) { - if (greater(list[j], pivot)) continue; - i++; - @swap(list[i], list[j]); - } - i++; - @swap(list[i], list[high]); - return i; -} - -fn isz partition_value(Type list, isz low, isz high, Comparer cmp) @inline @private -{ - ElementType pivot = list[high]; - isz i = low - 1; - for (isz j = low; j < high; j++) - { - if (cmp(list[j], pivot) <= 0) - { - i++; - @swap(list[i], list[j]); - } - } - i++; - @swap(list[i], list[high]); - return i; -} - -fn isz partition_ref(Type list, isz low, isz high, ComparerRef cmp) @inline @private -{ - ElementType* pivot = &list[high]; - isz i = low - 1; - for (isz j = low; j < high; j++) - { - if (cmp(&list[j], pivot) <= 0) + $if $checks(cmp(list[0], list[0])): + int res = cmp(list[j], pivot); + $else + $if $checks(cmp(&list[0], &list[0])): + int res = cmp(&list[j], &pivot); + $else + int res; + if (greater(list[j], pivot)) continue; + $endif + $endif + if (res <= 0) { i++; @swap(list[i], list[j]); diff --git a/test/unit/stdlib/sort/binarysearch.c3 b/test/unit/stdlib/sort/binarysearch.c3 index 11033c345..e1683823c 100644 --- a/test/unit/stdlib/sort/binarysearch.c3 +++ b/test/unit/stdlib/sort/binarysearch.c3 @@ -25,10 +25,10 @@ fn void binarysearch() usz idx = sort::binarysearch(tc.data, tc.x); assert(idx == tc.index, "%s: got %d; want %d", tc.data, idx, tc.index); - usz cmp_idx = sort::binarysearch_with(tc.data, tc.x, &sort::cmp_int); + usz cmp_idx = sort::binarysearch_with(tc.data, tc.x, &sort::cmp_int_ref); assert(cmp_idx == tc.index, "%s: got %d; want %d", tc.data, cmp_idx, tc.index); - usz cmp_idx2 = sort::binarysearch_with(tc.data, tc.x, &sort::cmp_int2); + usz cmp_idx2 = sort::binarysearch_with(tc.data, tc.x, &sort::cmp_int_value); assert(cmp_idx2 == tc.index, "%s: got %d; want %d", tc.data, cmp_idx2, tc.index); usz cmp_idx3 = sort::binarysearch_with(tc.data, tc.x, fn int(int a, int b) => a - b); diff --git a/test/unit/stdlib/sort/quicksort.c3 b/test/unit/stdlib/sort/quicksort.c3 index 7c47c1059..a010b41ae 100644 --- a/test/unit/stdlib/sort/quicksort.c3 +++ b/test/unit/stdlib/sort/quicksort.c3 @@ -1,8 +1,6 @@ module sort_test @test; import std::sort; -import std::sort::quicksort; - -def qs_int = quicksort::sort(); +import sort::check; fn void quicksort() { @@ -16,16 +14,12 @@ fn void quicksort() foreach (tc : tcases) { - qs_int(tc); - assert(sort::check_int_sort(tc)); + sort::quicksort(tc); + assert(check::int_sort(tc)); } } -def Cmp = fn int(int*, int*); - -def qs_int_ref = quicksort::sort_ref_fn(); - -fn void quicksort_with() +fn void quicksort_with_ref() { int[][] tcases = { {}, @@ -37,14 +31,12 @@ fn void quicksort_with() foreach (tc : tcases) { - qs_int_ref(tc, (Cmp)&sort::cmp_int); - assert(sort::check_int_sort(tc)); + sort::quicksort(tc, &sort::cmp_int_ref); + assert(check::int_sort(tc)); } } -def qs_int_fn = quicksort::sort_fn(); - -fn void quicksort_with2() +fn void quicksort_with_value() { int[][] tcases = { {}, @@ -56,8 +48,8 @@ fn void quicksort_with2() foreach (tc : tcases) { - qs_int_fn(tc, &sort::cmp_int2); - assert(sort::check_int_sort(tc)); + sort::quicksort(tc, &sort::cmp_int_value); + assert(check::int_sort(tc)); } } @@ -73,14 +65,14 @@ fn void quicksort_with_lambda() foreach (tc : tcases) { - qs_int_fn(tc, fn int(int a, int b) => a - b); - assert(sort::check_int_sort(tc)); + sort::quicksort(tc, fn int(int a, int b) => a - b); + assert(check::int_sort(tc)); } } -module std::sort; +module sort::check; -fn bool check_int_sort(int[] list) +fn bool int_sort(int[] list) { int prev = int.min; foreach (x : list) diff --git a/test/unit/stdlib/sort/sort.c3 b/test/unit/stdlib/sort/sort.c3 index 90ab431ae..538a8b203 100644 --- a/test/unit/stdlib/sort/sort.c3 +++ b/test/unit/stdlib/sort/sort.c3 @@ -1,9 +1,9 @@ module std::sort; -fn int cmp_int(void* x, void* y) { - return *(int*)x - *(int*)y; +fn int cmp_int_ref(int* x, int* y) { + return *x - *y; } -fn int cmp_int2(int x, int y) { +fn int cmp_int_value(int x, int y) { return x - y; } \ No newline at end of file