diff --git a/lib/std/core/types.c3 b/lib/std/core/types.c3 index da0094096..ebe3a05a4 100644 --- a/lib/std/core/types.c3 +++ b/lib/std/core/types.c3 @@ -123,6 +123,16 @@ macro bool is_slice_convertable($Type) macro bool is_bool($Type) => $Type.kindof == TypeKind.BOOL; macro bool is_int($Type) => $Type.kindof == TypeKind.SIGNED_INT || $Type.kindof == TypeKind.UNSIGNED_INT; +macro bool is_indexable($Type) +{ + return $defined($Type{}[0]); +} + +macro bool is_ref_indexable($Type) +{ + return $defined(&$Type{}[0]); +} + macro bool is_intlike($Type) { $switch ($Type.kindof) diff --git a/lib/std/sort/binarysearch.c3 b/lib/std/sort/binarysearch.c3 index 64a3df7aa..52a1d2b90 100644 --- a/lib/std/sort/binarysearch.c3 +++ b/lib/std/sort/binarysearch.c3 @@ -3,8 +3,8 @@ module std::sort; /** * Perform a binary search over the sorted array and return the index * in [0, array.len) where x would be inserted or cmp(i) is true and cmp(j) is true for j in [i, array.len). - * @require $defined(list[0]) && $defined(list.len) "The list must be indexable" - * @require $or(@typeid(cmp) == void*.typeid, @is_comparer(cmp, list)) "Expected a comparison function which compares values" + * @require @is_sortable(list) "The list must be sortable" + * @require $or(@typeid(cmp) == void*.typeid, @is_cmp_fn(cmp, list)) "Expected a comparison function which compares values" **/ macro usz binarysearch(list, x, cmp = null) @builtin { diff --git a/lib/std/sort/insertionsort.c3 b/lib/std/sort/insertionsort.c3 new file mode 100644 index 000000000..5a27097c1 --- /dev/null +++ b/lib/std/sort/insertionsort.c3 @@ -0,0 +1,58 @@ +module std::sort; +import std::sort::is; + +/** + * Sort list using the quick sort algorithm. + * @require @is_sortable(list) "The list must be indexable and support .len or .len()" + * @require $or(@typeid(cmp) == void*.typeid, @is_cmp_fn(cmp, list)) "Expected a comparison function which compares values" + **/ +macro insertionsort(list, cmp = null) @builtin +{ + usz len = sort::@len_from_list(list); + is::isort(<$typeof(list), $typeof(cmp)>)(list, 0, (isz)len, cmp); +} + +module std::sort::is(); + +def ElementType = $typeof(Type{}[0]); + +fn void isort(Type list, usz low, usz high, Comparer comp) +{ + var $no_cmp = Comparer.typeid == void*.typeid; + var $cmp_by_value = $and(!$no_cmp, Comparer.params[0] == @typeid(list[0])); + var $has_get_ref = $defined(&list[0]); + assert(sort::@is_sortable(list)); + for (usz i = low; i < high; ++i) + { + usz j = i; + for (;j > low;) + { + $if $has_get_ref: + ElementType *rhs = &list[j]; + ElementType *lhs = &list[--j]; + $switch + $case $cmp_by_value: + if (comp(*rhs, *lhs) >= 0) break; + $case !$no_cmp: + if (comp(rhs, lhs) >= 0) break; + $default: + if (!less(*rhs, *lhs)) break; + $endswitch + @swap(*rhs, *lhs); + $else + usz r = j; + --j; + + $switch + $case $cmp_by_value: + if (comp(list[r], list[j]) >= 0) break; + $case !$no_cmp: + if (comp(&list[r], &list[j]) >= 0) break; + $default: + if (!less(list[r], list[j])) break; + $endswitch + @swap(list[r], list[j]); + $endif + } + } +} diff --git a/lib/std/sort/quicksort.c3 b/lib/std/sort/quicksort.c3 index 8edfa8370..f70596262 100644 --- a/lib/std/sort/quicksort.c3 +++ b/lib/std/sort/quicksort.c3 @@ -3,15 +3,13 @@ import std::sort::qs; /** * Sort list using the quick sort algorithm. - * @require $defined(list[0]) && $defined(list.len) "The list must be indexable and support .len or .len()" - * @require $or(@typeid(cmp) == void*.typeid, @is_comparer(cmp, list)) "Expected a comparison function which compares values" + * @require @is_sortable(list) "The list must be indexable and support .len or .len()" + * @require $or(@typeid(cmp) == void*.typeid, @is_cmp_fn(cmp, list)) "Expected a comparison function which compares values" **/ macro quicksort(list, cmp = null) @builtin { - var $Type = $typeof(list); - var $CmpType = $typeof(cmp); usz len = sort::@len_from_list(list); - qs::qsort(<$Type, $CmpType>)(list, 0, (isz)len - 1, cmp); + qs::qsort(<$typeof(list), $typeof(cmp)>)(list, 0, (isz)len - 1, cmp); } module std::sort::qs(); diff --git a/lib/std/sort/sort.c3 b/lib/std/sort/sort.c3 index c9f541752..5748f7c1c 100644 --- a/lib/std/sort/sort.c3 +++ b/lib/std/sort/sort.c3 @@ -10,7 +10,21 @@ macro usz @len_from_list(&list) $endif } -macro bool @is_comparer(#cmp, #list) +macro bool @is_sortable(#list) +{ + $switch + $case !$defined(#list[0]): + return false; + $case !$defined(#list.len): + return false; + $case $and($defined(&#list[0]) && !types::is_same($typeof(&#list[0]), $typeof(#list[0])*)): + return false; + $default: + return true; + $endswitch; +} + +macro bool @is_cmp_fn(#cmp, #list) { var $Type = $typeof(#cmp); $switch diff --git a/test/unit/stdlib/sort/insertionsort.c3 b/test/unit/stdlib/sort/insertionsort.c3 new file mode 100644 index 000000000..f8191548c --- /dev/null +++ b/test/unit/stdlib/sort/insertionsort.c3 @@ -0,0 +1,96 @@ +module sort_test @test; +import std::sort; +import sort::check; +import std::collections::list; + +fn void insertionsort() +{ + int[][] tcases = { + {}, + {10, 3}, + {3, 2, 1}, + {1, 2, 3}, + {2, 1, 3}, + }; + + foreach (tc : tcases) + { + sort::insertionsort(tc); + assert(check::int_ascending_sort(tc)); + } +} + +fn void insertionsort_with_ref() +{ + int[][] tcases = { + {}, + {10, 3}, + {3, 2, 1}, + {1, 2, 3}, + {2, 1, 3}, + }; + + foreach (tc : tcases) + { + sort::insertionsort(tc, &sort::cmp_int_ref); + assert(check::int_ascending_sort(tc)); + } +} + +fn void insertionsort_with_value() +{ + int[][] tcases = { + {}, + {10, 3}, + {3, 2, 1}, + {1, 2, 3}, + {2, 1, 3}, + }; + + foreach (tc : tcases) + { + sort::insertionsort(tc, &sort::cmp_int_value); + assert(check::int_ascending_sort(tc)); + } +} + +fn void insertionsort_with_lambda() +{ + int[][] tcases = { + {}, + {10, 3}, + {3, 2, 1}, + {1, 2, 3}, + {2, 1, 3}, + }; + + foreach (tc : tcases) + { + sort::insertionsort(tc, fn int(int a, int b) => a - b); + assert(check::int_ascending_sort(tc)); + } +} + +def InsertionSortTestList = List(); + +fn void insertionsort_list() +{ + InsertionSortTestList list; + list.temp_init(); + list.add_array({ 2, 1, 3}); + sort::insertionsort(list, &sort::cmp_int_value); + assert(check::int_ascending_sort(list.array_view())); +} + +module sort::check; + +fn bool int_ascending_sort(int[] list) +{ + int prev = int.min; + foreach (x : list) + { + if (prev > x) return false; + prev = x; + } + return true; +}