diff --git a/lib/std/sort/quicksort.c3 b/lib/std/sort/quicksort.c3 index 151e42b3a..93bdff2a0 100644 --- a/lib/std/sort/quicksort.c3 +++ b/lib/std/sort/quicksort.c3 @@ -1,66 +1,105 @@ -module std::sort; - -/** - * @require is_searchable(list) "The list must be indexable and support .len or .len()" - **/ -macro quicksort(list, $Type) -{ - (($Type)(list)).sort(null); -} - -/** - * @require is_searchable(list) "The list must be indexable and support .len or .len()" - * @require is_comparer(cmp, list) "Expected a comparison function which compares values" - **/ -macro quicksort_with(list, $Type, cmp) -{ - (($Type)(list)).sort(cmp); -} - -module std::sort::quicksort; +module std::sort::quicksort; import std::sort; -def Quicksort = distinct Type[]; +def ElementType = $typeof(Type{}[0]); +def Comparer = fn int(ElementType, ElementType); +def ComparerRef = fn int(ElementType*, ElementType*); -fn void Quicksort.sort(qs, Comparer cmp) +const bool ELEMENT_COMPARABLE = $checks(ElementType x, greater(x, x)); + +fn void sort_fn(Type list, Comparer cmp) { - usz len = sort::@len_from_list(qs); - qs.qsort(0, (isz)len - 1, cmp); + usz len = sort::@len_from_list(list); + qsort_value(list, 0, (isz)len - 1, cmp); } -fn void Quicksort.qsort(Quicksort qs, isz low, isz high, Comparer cmp) @private +fn void sort_ref_fn(Type list, ComparerRef cmp) +{ + usz len = sort::@len_from_list(list); + qsort_ref(list, 0, (isz)len - 1, cmp); +} + +fn void sort(Type list) @if(ELEMENT_COMPARABLE) +{ + usz len = sort::@len_from_list(list); + qsort(list, 0, (isz)len - 1); +} + +fn void qsort(Type list, isz low, isz high) @local @if(ELEMENT_COMPARABLE) { if (low < high) { - isz p = qs.partition(low, high, cmp); - qs.qsort(low, p - 1, cmp); - qs.qsort(p + 1, high, cmp); + isz p = partition(list, low, high); + qsort(list, low, p - 1); + qsort(list, p + 1, high); } } -fn isz Quicksort.partition(qs, isz low, isz high, Comparer cmp) @inline @private +fn void qsort_value(Type list, isz low, isz high, Comparer cmp) @local { - Type pivot = qs[high]; + if (low < high) + { + isz p = partition_value(list, low, high, cmp); + qsort_value(list, low, p - 1, cmp); + qsort_value(list, p + 1, high, cmp); + } +} + +fn void qsort_ref(Type list, isz low, isz high, ComparerRef cmp) @local +{ + if (low < high) + { + isz p = partition_ref(list, low, high, cmp); + qsort_ref(list, low, p - 1, cmp); + qsort_ref(list, p + 1, high, cmp); + } +} + +fn isz partition(Type list, isz low, isz high) @inline @local +{ + ElementType pivot = list[high]; isz i = low - 1; for (isz j = low; j < high; j++) { - $if $checks(cmp(qs[0], qs[0])): - int res = cmp(qs[j], pivot); - $else - $if $checks(cmp(&qs[0], &qs[0])): - int res = cmp(&qs[j], &pivot); - $else - int res; - if (greater(qs[j], pivot)) res = 1; - $endif - $endif - if (res <= 0) + if (greater(list[j], pivot)) continue; + i++; + @swap(list[i], list[j]); + } + i++; + @swap(list[i], list[high]); + return i; +} + +fn isz partition_value(Type list, isz low, isz high, Comparer cmp) @inline @private +{ + ElementType pivot = list[high]; + isz i = low - 1; + for (isz j = low; j < high; j++) + { + if (cmp(list[j], pivot) <= 0) { i++; - @swap(qs[i], qs[j]); + @swap(list[i], list[j]); } } i++; - @swap(qs[i], qs[high]); + @swap(list[i], list[high]); + return i; +} + +fn isz partition_ref(Type list, isz low, isz high, ComparerRef cmp) @inline @private +{ + ElementType* pivot = &list[high]; + isz i = low - 1; + for (isz j = low; j < high; j++) + { + if (cmp(&list[j], pivot) <= 0) + { + i++; + @swap(list[i], list[j]); + } + } + i++; + @swap(list[i], list[high]); return i; } \ No newline at end of file diff --git a/test/unit/stdlib/sort/quicksort.c3 b/test/unit/stdlib/sort/quicksort.c3 index ccf0e674c..f399c46ce 100644 --- a/test/unit/stdlib/sort/quicksort.c3 +++ b/test/unit/stdlib/sort/quicksort.c3 @@ -2,7 +2,7 @@ module sort_test @test; import std::sort; import std::sort::quicksort; -def QSInt = quicksort::Quicksort; +def qs_int = quicksort::sort; fn void quicksort() { @@ -16,13 +16,13 @@ fn void quicksort() foreach (tc : tcases) { - sort::quicksort(tc, QSInt); + qs_int(tc); assert(sort::check_int_sort(tc)); } } -def Cmp = fn int (void*, void*); -def QSIntCmp = quicksort::Quicksort; +def Cmp = fn int(int*, int*); +def qs_int_ref = quicksort::sort_ref_fn; fn void quicksort_with() { @@ -36,13 +36,12 @@ fn void quicksort_with() foreach (tc : tcases) { - sort::quicksort_with(tc, QSIntCmp, &sort::cmp_int); + qs_int_ref(tc, (Cmp)&sort::cmp_int); assert(sort::check_int_sort(tc)); } } -def Cmp2 = fn int (int, int); -def QSIntCmp2 = quicksort::Quicksort; +def qs_int_fn = quicksort::sort_fn; fn void quicksort_with2() { @@ -56,7 +55,7 @@ fn void quicksort_with2() foreach (tc : tcases) { - sort::quicksort_with(tc, QSIntCmp2, &sort::cmp_int2); + qs_int_fn(tc, &sort::cmp_int2); assert(sort::check_int_sort(tc)); } } @@ -73,7 +72,7 @@ fn void quicksort_with_lambda() foreach (tc : tcases) { - sort::quicksort_with(tc, QSIntCmp2, fn int(int a, int b) => a - b); + qs_int_fn(tc, fn int(int a, int b) => a - b); assert(sort::check_int_sort(tc)); } }