From 5463c398cb03b63a02b39088dccf1589804f00b2 Mon Sep 17 00:00:00 2001 From: konimarti <30975830+konimarti@users.noreply.github.com> Date: Tue, 3 Dec 2024 19:27:26 +0100 Subject: [PATCH] Add quickselect (#1654) * sort: extract partition from quicksort Extract the partition logic from quicksort into a macro. This allows to reuse the partition logic for, e.g., the quickselect algorithm. * sort: implement quickselect implement Hoare's selection algorithm (quickselect) on the basis of the already implemented quicksort. Quickselect allows to find the kth smallest element in a unordered list with an average time complexity of O(N) (worst case: O(N^2)). * add quicksort benchmark Create a top-level benchmarks folder. Add the benchmark implementation for the quicksort algorithm. Benchmarks can then be run in the same way as unit tests from the root folder with: c3c compile-benchmarks benchmarks/stdlib/sort --- benchmarks/stdlib/sort/quicksort.c3 | 69 ++++++++++++++++ lib/std/sort/quicksort.c3 | 116 +++++++++++++++++++-------- test/unit/stdlib/sort/quickselect.c3 | 63 +++++++++++++++ 3 files changed, 216 insertions(+), 32 deletions(-) create mode 100644 benchmarks/stdlib/sort/quicksort.c3 create mode 100644 test/unit/stdlib/sort/quickselect.c3 diff --git a/benchmarks/stdlib/sort/quicksort.c3 b/benchmarks/stdlib/sort/quicksort.c3 new file mode 100644 index 000000000..7a674659b --- /dev/null +++ b/benchmarks/stdlib/sort/quicksort.c3 @@ -0,0 +1,69 @@ +module sort_bench; + +import std::sort; + +fn void init() @init +{ + set_benchmark_warmup_iterations(5); + set_benchmark_max_iterations(10_000); +} + +fn void! quicksort_bench() @benchmark +{ + // test set: 500 numbers between 0 and 99; + int[] data = { + 71, 28, 2, 13, 62, 10, 54, 78, 63, 86, + 33, 65, 89, 51, 58, 0, 51, 16, 87, 30, + 89, 14, 52, 41, 88, 25, 83, 91, 56, 86, + 14, 64, 76, 18, 39, 24, 79, 62, 34, 58, + 90, 24, 56, 73, 85, 82, 79, 63, 47, 69, + 78, 29, 49, 28, 43, 47, 56, 53, 79, 56, + 19, 63, 29, 52, 71, 93, 61, 46, 30, 11, + 21, 26, 37, 86, 93, 74, 62, 0, 41, 17, + 26, 27, 34, 11, 54, 69, 72, 44, 74, 3, + 61, 62, 80, 90, 3, 82, 16, 12, 28, 1, + 2, 49, 4, 44, 57, 86, 63, 74, 33, 41, + 76, 77, 56, 57, 56, 88, 74, 71, 6, 59, + 40, 42, 94, 55, 21, 17, 17, 63, 21, 83, + 73, 19, 39, 88, 93, 74, 21, 0, 63, 45, + 69, 66, 22, 68, 86, 86, 85, 67, 8, 50, + 23, 98, 64, 80, 64, 36, 40, 30, 73, 36, + 23, 14, 1, 77, 82, 8, 18, 73, 37, 86, + 29, 70, 27, 87, 64, 81, 13, 0, 4, 83, + 90, 17, 71, 66, 38, 39, 54, 22, 86, 18, + 84, 66, 77, 25, 64, 93, 80, 91, 2, 92, + 47, 32, 90, 16, 46, 29, 56, 87, 70, 73, + 89, 41, 5, 54, 93, 63, 16, 39, 71, 84, + 74, 91, 69, 59, 49, 87, 74, 37, 75, 83, + 77, 19, 51, 44, 79, 62, 94, 20, 24, 83, + 37, 70, 57, 32, 93, 8, 29, 11, 7, 92, + 8, 23, 20, 21, 7, 70, 28, 20, 96, 6, + 50, 58, 30, 61, 66, 42, 50, 54, 64, 7, + 10, 53, 63, 44, 16, 39, 83, 73, 3, 29, + 97, 32, 36, 68, 84, 64, 73, 5, 29, 13, + 48, 3, 84, 65, 75, 68, 66, 22, 39, 33, + 39, 24, 27, 85, 18, 34, 3, 63, 32, 9, + 29, 66, 24, 90, 75, 50, 11, 95, 47, 14, + 92, 1, 76, 45, 76, 41, 55, 54, 38, 67, + 43, 40, 5, 61, 97, 11, 61, 24, 92, 24, + 76, 53, 60, 34, 78, 80, 70, 75, 30, 90, + 65, 99, 80, 61, 94, 75, 63, 67, 10, 35, + 23, 42, 31, 48, 14, 68, 84, 14, 79, 1, + 25, 94, 23, 53, 49, 69, 44, 73, 63, 51, + 44, 96, 88, 51, 94, 24, 64, 72, 59, 81, + 73, 93, 14, 35, 9, 53, 25, 48, 50, 88, + 46, 97, 67, 40, 27, 17, 2, 42, 11, 82, + 0, 46, 44, 38, 31, 88, 63, 88, 10, 82, + 77, 61, 24, 39, 27, 33, 10, 91, 69, 22, + 42, 74, 71, 13, 32, 56, 12, 46, 81, 74, + 17, 26, 45, 50, 76, 84, 76, 36, 43, 65, + 81, 64, 0, 49, 70, 11, 76, 19, 60, 55, + 15, 98, 31, 91, 56, 8, 97, 9, 3, 94, + 3, 88, 7, 2, 3, 98, 10, 51, 21, 79, + 99, 3, 8, 76, 52, 13, 40, 90, 85, 15, + 70, 77, 43, 30, 4, 89, 18, 21, 59, 17, + }; + sort::quicksort(data); +} + + diff --git a/lib/std/sort/quicksort.c3 b/lib/std/sort/quicksort.c3 index e9ea163b7..fc1b28348 100644 --- a/lib/std/sort/quicksort.c3 +++ b/lib/std/sort/quicksort.c3 @@ -13,6 +13,21 @@ macro quicksort(list, cmp = EMPTY_MACRO_SLOT, context = EMPTY_MACRO_SLOT) @built qs::qsort(<$typeof(list), $typeof(cmp), $typeof(context)>)(list, 0, (isz)len - 1, cmp, context); } +<* +Select the (k+1)th smallest element in an unordered list using Hoare's +selection algorithm (Quickselect). k should be between 0 and len-1. The data +list will be partially sorted. + + @require @is_sortable(list) "The list must be indexable and support .len or .len()" + @require @is_valid_cmp_fn(cmp, list, context) "expected a comparison function which compares values" + @require @is_valid_context(cmp, context) "Expected a valid context" +*> +macro quickselect(list, isz k, cmp = EMPTY_MACRO_SLOT, context = EMPTY_MACRO_SLOT) @builtin +{ + usz len = sort::@len_from_list(list); + return qs::qselect(<$typeof(list), $typeof(cmp), $typeof(context)>)(list, 0, (isz)len - 1, k, cmp, context); +} + module std::sort::qs(); def ElementType = $typeof(Type{}[0]); @@ -29,10 +44,6 @@ def Stack = StackElementItem[64] @private; fn void qsort(Type list, isz low, isz high, CmpFn cmp, Context context) { - var $has_cmp = @is_valid_macro_slot(cmp); - var $has_context = @is_valid_macro_slot(context); - var $cmp_by_value = $has_cmp &&& $assignable(list[0], $typefrom(CmpFn.paramsof[0].type)); - if (low >= 0 && high >= 0 && low < high) { Stack stack; @@ -48,34 +59,7 @@ fn void qsort(Type list, isz low, isz high, CmpFn cmp, Context context) if (l < h) { - ElementType pivot = list[l]; - while (l < h) - { - $switch - $case $cmp_by_value && $has_context: - while (cmp(list[h], pivot, context) >= 0 && l < h) h--; - if (l < h) list[l++] = list[h]; - while (cmp(list[l], pivot, context) <= 0 && l < h) l++; - $case $cmp_by_value: - while (cmp(list[h], pivot) >= 0 && l < h) h--; - if (l < h) list[l++] = list[h]; - while (cmp(list[l], pivot) <= 0 && l < h) l++; - $case $has_cmp && $has_context: - while (cmp(&list[h], &pivot, context) >= 0 && l < h) h--; - if (l < h) list[l++] = list[h]; - while (cmp(&list[l], &pivot, context) <= 0 && l < h) l++; - $case $has_cmp: - while (cmp(&list[h], &pivot) >= 0 && l < h) h--; - if (l < h) list[l++] = list[h]; - while (cmp(&list[l], &pivot) <= 0 && l < h) l++; - $default: - while (greater_eq(list[h], pivot) && l < h) h--; - if (l < h) list[l++] = list[h]; - while (less_eq(list[l], pivot) && l < h) l++; - $endswitch - if (l < h) list[h--] = list[l]; - } - list[l] = pivot; + l = @partition(list, l, h, cmp, context); stack[i + 1].low = l + 1; stack[i + 1].high = stack[i].high; stack[i++].high = l; @@ -91,3 +75,71 @@ fn void qsort(Type list, isz low, isz high, CmpFn cmp, Context context) } } } + +<* +@require low <= k "kth smalles element is smaller than lower bounds" +@require k <= high "kth smalles element is larger than upper bounds" +*> +fn ElementType! qselect(Type list, isz low, isz high, isz k, CmpFn cmp, Context context) +{ + if (low >= 0 && high >= 0 && low < high) + { + isz l = low; + isz h = high; + isz pivot; + + usz max_retries = 64; + while (l <= h && max_retries--) + { + pivot = @partition(list, l, h, cmp, context); + if (k == pivot) return list[k]; + if (k < pivot) + { + h = pivot - 1; + } + else + { + l = pivot + 1; + } + } + } + return SearchResult.MISSING?; +} + +macro @partition(Type list, isz l, isz h, CmpFn cmp, Context context) +{ + var $has_cmp = @is_valid_macro_slot(cmp); + var $has_context = @is_valid_macro_slot(context); + var $cmp_by_value = $has_cmp &&& $assignable(list[0], $typefrom(CmpFn.paramsof[0].type)); + + ElementType pivot = list[l]; + while (l < h) + { + $switch + $case $cmp_by_value && $has_context: + while (cmp(list[h], pivot, context) >= 0 && l < h) h--; + if (l < h) list[l++] = list[h]; + while (cmp(list[l], pivot, context) <= 0 && l < h) l++; + $case $cmp_by_value: + while (cmp(list[h], pivot) >= 0 && l < h) h--; + if (l < h) list[l++] = list[h]; + while (cmp(list[l], pivot) <= 0 && l < h) l++; + $case $has_cmp && $has_context: + while (cmp(&list[h], &pivot, context) >= 0 && l < h) h--; + if (l < h) list[l++] = list[h]; + while (cmp(&list[l], &pivot, context) <= 0 && l < h) l++; + $case $has_cmp: + while (cmp(&list[h], &pivot) >= 0 && l < h) h--; + if (l < h) list[l++] = list[h]; + while (cmp(&list[l], &pivot) <= 0 && l < h) l++; + $default: + while (greater_eq(list[h], pivot) && l < h) h--; + if (l < h) list[l++] = list[h]; + while (less_eq(list[l], pivot) && l < h) l++; + $endswitch + if (l < h) list[h--] = list[l]; + } + list[l] = pivot; + + return l; +} diff --git a/test/unit/stdlib/sort/quickselect.c3 b/test/unit/stdlib/sort/quickselect.c3 new file mode 100644 index 000000000..2d975c593 --- /dev/null +++ b/test/unit/stdlib/sort/quickselect.c3 @@ -0,0 +1,63 @@ +module sort_test @test; +import std::sort; + +struct TestCase @local +{ + int[] list; + isz k; + int want; +} + +fn void! quickselect() +{ + TestCase[] tcases = { + { + .list = {3, 4, 1}, + .k = 0, + .want = 1, + }, + { + .list = {3, 4, 1}, + .k = 1, + .want = 3, + }, + { + .list = {3, 4, 1}, + .k = 2, + .want = 4, + }, + { + .list = {3, 2, 4, 1}, + .k = 1, + .want = 2, + }, + { + .list = {3, 2, 1, 2}, + .k = 1, + .want = 2, + }, + { + .list = {3, 2, 1, 2}, + .k = 2, + .want = 2, + }, + { + .list = {3, 2, 1, 2}, + .k = 3, + .want = 3, + }, + }; + + foreach (i, tc : tcases) + { + if (try got = sort::quickselect(tc.list, tc.k)) + { + assert(got == tc.want, "got: %d, want %d", got, tc.want); + } + else + { + assert(false, "test %d failed", i); + } + } +} +