From 35bffdadc235735cf1e7c166631a36d144673127 Mon Sep 17 00:00:00 2001 From: Pierre Curto Date: Sat, 15 Jul 2023 19:08:54 +0200 Subject: [PATCH] improve the sort and collections libs (#853) * lib/std/sort: unify binarysearch and binarysearch_with; add comments to quicksort Signed-off-by: Pierre Curto * lib/std/collections: mark List.{len, is_empty, get} with @inline Signed-off-by: Pierre Curto * lib/std/collections: add PriorityQueueMax; add tests for PriorityQueue and PriorityQueueMax Signed-off-by: Pierre Curto --------- Signed-off-by: Pierre Curto --- lib/std/collections/list.c3 | 6 +- lib/std/collections/priorityqueue.c3 | 80 ++++++++++++------- lib/std/sort/binarysearch.c3 | 59 ++++++-------- lib/std/sort/quicksort.c3 | 5 ++ test/test_suite/stdlib/priorityqueue.c3t | 5 +- test/unit/stdlib/collections/priorityqueue.c3 | 59 ++++++++++++++ test/unit/stdlib/sort/binarysearch.c3 | 7 +- 7 files changed, 145 insertions(+), 76 deletions(-) create mode 100644 test/unit/stdlib/collections/priorityqueue.c3 diff --git a/lib/std/collections/list.c3 b/lib/std/collections/list.c3 index 0aac53f65..9eb16b582 100644 --- a/lib/std/collections/list.c3 +++ b/lib/std/collections/list.c3 @@ -199,17 +199,17 @@ fn Type* List.last(&self) return self.size ? &self.entries[self.size - 1] : null; } -fn bool List.is_empty(&self) +fn bool List.is_empty(&self) @inline { return !self.size; } -fn usz List.len(&self) @operator(len) +fn usz List.len(&self) @operator(len) @inline { return self.size; } -fn Type List.get(&self, usz index) +fn Type List.get(&self, usz index) @inline { return self.entries[index]; } diff --git a/lib/std/collections/priorityqueue.c3 b/lib/std/collections/priorityqueue.c3 index af7d59cd5..8678f89ff 100644 --- a/lib/std/collections/priorityqueue.c3 +++ b/lib/std/collections/priorityqueue.c3 @@ -21,94 +21,114 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. module std::collections::priorityqueue(); +import std::collections::priorityqueue::private; + +def PriorityQueue = distinct inline PrivatePriorityQueue(); +def PriorityQueueMax = distinct inline PrivatePriorityQueue(); + +module std::collections::priorityqueue::private(); import std::collections::list; def Heap = List(); -struct PriorityQueue +struct PrivatePriorityQueue { Heap heap; - bool max; // true if max-heap, false if min-heap } -fn void PriorityQueue.push(&self, Type element) +fn void PrivatePriorityQueue.push(&self, Type element) { self.heap.push(element); usz i = self.heap.len() - 1; while (i > 0) { usz parent = (i - 1) / 2; - if ((self.max && greater(self.heap.get(i), self.heap.get(parent))) || (!self.max && less(self.heap.get(i), self.heap.get(parent)))) - { - self.heap.swap(i, parent); - i = parent; - continue; - } - break; + Type item = self.heap[i]; + Type parent_item = self.heap[parent]; + $if MAX: + bool ok = greater(item, parent_item); + $else + bool ok = less(item, parent_item); + $endif + if (!ok) break; + self.heap.swap(i, parent); + i = parent; } } /** * @require self != null */ -fn Type! PriorityQueue.pop(&self) +fn Type! PrivatePriorityQueue.pop(&self) { usz i = 0; - usz len = self.heap.len() @inline; + usz len = self.heap.len(); if (!len) return IteratorResult.NO_MORE_ELEMENT?; usz newCount = len - 1; self.heap.swap(0, newCount); while ((2 * i + 1) < newCount) { usz j = 2 * i + 1; - if (((j + 1) < newCount) && - ((self.max && greater(self.heap.get(j + 1), self.heap[j])) - || (!self.max && less(self.heap.get(j + 1), self.heap.get(j))))) + Type itemj = self.heap[j]; + if ((j + 1) < newCount) { - j++; + Type nextj = self.heap[j + 1]; + $if MAX: + bool ok = greater(nextj, itemj); + $else + bool ok = less(nextj, itemj); + $endif + if (ok) j++; } - if ((self.max && less(self.heap.get(i), self.heap.get(j))) || (!self.max && greater(self.heap.get(i), self.heap.get(j)))) - { - self.heap.swap(i, j); - i = j; - continue; - } - break; + Type item = self.heap[i]; + $if MAX: + bool ok = less(item, itemj); + $else + bool ok = greater(item, itemj); + $endif + if (!ok) break; + self.heap.swap(i, j); + i = j; } return self.heap.pop(); } -fn Type! PriorityQueue.peek(&self) +fn Type! PrivatePriorityQueue.peek(&self) { if (!self.len()) return IteratorResult.NO_MORE_ELEMENT?; return self.heap.get(0); } -fn void PriorityQueue.free(&self) +fn void PrivatePriorityQueue.free(&self) { self.heap.free(); } -fn usz PriorityQueue.len(&self) @operator(len) +fn usz PrivatePriorityQueue.len(&self) @operator(len) { return self.heap.len(); } +fn bool PrivatePriorityQueue.is_empty(&self) +{ + return self.heap.is_empty(); +} + /** * @require index < self.len() */ -fn Type PriorityQueue.peek_at(&self, usz index) @operator([]) +fn Type PrivatePriorityQueue.peek_at(&self, usz index) @operator([]) { return self.heap[index]; } -fn void! PriorityQueue.to_format(&self, Formatter* formatter) @dynamic +fn void! PrivatePriorityQueue.to_format(&self, Formatter* formatter) @dynamic { return self.heap.to_format(formatter); } -fn String PriorityQueue.to_string(&self, Allocator* using = mem::heap()) @dynamic +fn String PrivatePriorityQueue.to_string(&self, Allocator* using = mem::heap()) @dynamic { return self.heap.to_string(using); -} +} \ No newline at end of file diff --git a/lib/std/sort/binarysearch.c3 b/lib/std/sort/binarysearch.c3 index e31a24058..c3e2da9a5 100644 --- a/lib/std/sort/binarysearch.c3 +++ b/lib/std/sort/binarysearch.c3 @@ -1,52 +1,39 @@ module std::sort; /** - * Perform a binary search over the sorted array and return the smallest index - * in [0, array.len) where cmp(i) is true and cmp(j) is true for j in [i, array.len). + * Perform a binary search over the sorted array and return the index + * in [0, array.len) where x would be inserted or cmp(i) is true and cmp(j) is true for j in [i, array.len). * @require is_searchable(list) "The list must be indexable and support .len or .len()" - * @require is_comparer(cmp, list) "Expected a comparison function which compares values" + * @require !cmp || is_comparer(cmp, list) "Expected a comparison function which compares values" **/ -macro usz binarysearch_with(list, x, cmp) +macro usz binarysearch(list, x, cmp = null) @builtin { usz i; usz len = @len_from_list(list); for (usz j = len; i < j;) { usz half = i + (j - i) / 2; - $if $checks(cmp(list[0], list[0])): - int res = cmp(list[half], x); + $if $checks(!cmp): + switch + { + case greater(list[half], x): j = half; + case less(list[half], x): i = half + 1; + default: return half; + } $else - int res = cmp(&list[half], &x); + $switch + $case $checks(cmp(list[0], list[0])): + int res = cmp(list[half], x); + $case $checks(cmp(&list[0], &list[0])): + int res = cmp(&list[half], &x); + $endswitch + switch + { + case res > 0: j = half; + case res < 0: i = half + 1; + default: return half; + } $endif - switch - { - case res > 0: j = half; - case res < 0: i = half + 1; - default: return half; - } - } - return i; -} - -/** - * Perform a binary search over the sorted array and return the index - * in [0, array.len) where x would be inserted. - * @require is_searchable(list) "The list must be indexable and support .len or .len()" - * @checked less(list[0], x) "The values must be comparable" - **/ -macro usz binarysearch(list, x) @builtin -{ - usz i; - usz len = @len_from_list(list); - for (usz j = len; i < j;) - { - usz half = (i + j) / 2; - switch - { - case greater(list[half], x): j = half; - case less(list[half], x): i = half + 1; - default: return half; - } } return i; } \ No newline at end of file diff --git a/lib/std/sort/quicksort.c3 b/lib/std/sort/quicksort.c3 index c8c53e5e3..9245efe08 100644 --- a/lib/std/sort/quicksort.c3 +++ b/lib/std/sort/quicksort.c3 @@ -1,6 +1,11 @@ module std::sort; import std::sort::qs; +/** + * Sort list using the quick sort algorithm. + * @require is_searchable(list) "The list must be indexable and support .len or .len()" + * @require !cmp || is_comparer(cmp, list) "Expected a comparison function which compares values" + **/ macro quicksort(list, cmp = null) @builtin { var $Type = $typeof(list); diff --git a/test/test_suite/stdlib/priorityqueue.c3t b/test/test_suite/stdlib/priorityqueue.c3t index cbadfad4c..e677e3ad6 100644 --- a/test/test_suite/stdlib/priorityqueue.c3t +++ b/test/test_suite/stdlib/priorityqueue.c3t @@ -4,13 +4,12 @@ import std::io; import std::math; import std::collections::priorityqueue; -def FooPriorityQueue = PriorityQueue(); +def FooPriorityQueue = PriorityQueueMax(); fn void main() { FooPriorityQueue agh; - agh.max = true; agh.push(Foo { 3 }); agh.push(Foo { 101 }); agh.push(Foo { 10 }); @@ -29,4 +28,4 @@ fn bool Foo.less(Foo* x, Foo y) @inline /* #expect: test.ll -%PriorityQueue = type { %List, i8 } +%PrivatePriorityQueue = type { %List } diff --git a/test/unit/stdlib/collections/priorityqueue.c3 b/test/unit/stdlib/collections/priorityqueue.c3 new file mode 100644 index 000000000..ae5553374 --- /dev/null +++ b/test/unit/stdlib/collections/priorityqueue.c3 @@ -0,0 +1,59 @@ +module priorityqueue_test @test; +import std::collections; +import std::collections::priorityqueue; + +def Queue = PriorityQueue(); + +fn void! priorityqueue() +{ + Queue q; + assert(q.is_empty()); + + q.push(1); + q.push(2); + assert(q.len() == 2); + + int x; + x = q.pop()!; + assert(x == 1, "got %d; want %d", x, 1); + x = q.pop()!; + assert(x == 2, "got %d; want %d", x, 2); + + q.push(3); + q.push(2); + q.push(1); + x = q.pop()!; + assert(x == 1, "got %d; want %d", x, 1); + x = q.pop()!; + assert(x == 2, "got %d; want %d", x, 2); + x = q.pop()!; + assert(x == 3, "got %d; want %d", x, 3); +} + +def QueueMax = PriorityQueueMax(); + +fn void! priorityqueue_max() +{ + QueueMax q; + assert(q.is_empty()); + + q.push(1); + q.push(2); + assert(q.len() == 2); + + int x; + x = q.pop()!; + assert(x == 2, "got %d; want %d", x, 2); + x = q.pop()!; + assert(x == 1, "got %d; want %d", x, 1); + + q.push(3); + q.push(2); + q.push(1); + x = q.pop()!; + assert(x == 3, "got %d; want %d", x, 3); + x = q.pop()!; + assert(x == 2, "got %d; want %d", x, 2); + x = q.pop()!; + assert(x == 1, "got %d; want %d", x, 1); +} \ No newline at end of file diff --git a/test/unit/stdlib/sort/binarysearch.c3 b/test/unit/stdlib/sort/binarysearch.c3 index e1683823c..35150222a 100644 --- a/test/unit/stdlib/sort/binarysearch.c3 +++ b/test/unit/stdlib/sort/binarysearch.c3 @@ -25,14 +25,13 @@ fn void binarysearch() usz idx = sort::binarysearch(tc.data, tc.x); assert(idx == tc.index, "%s: got %d; want %d", tc.data, idx, tc.index); - usz cmp_idx = sort::binarysearch_with(tc.data, tc.x, &sort::cmp_int_ref); + usz cmp_idx = sort::binarysearch(tc.data, tc.x, &sort::cmp_int_ref); assert(cmp_idx == tc.index, "%s: got %d; want %d", tc.data, cmp_idx, tc.index); - usz cmp_idx2 = sort::binarysearch_with(tc.data, tc.x, &sort::cmp_int_value); + usz cmp_idx2 = sort::binarysearch(tc.data, tc.x, &sort::cmp_int_value); assert(cmp_idx2 == tc.index, "%s: got %d; want %d", tc.data, cmp_idx2, tc.index); - usz cmp_idx3 = sort::binarysearch_with(tc.data, tc.x, fn int(int a, int b) => a - b); + usz cmp_idx3 = sort::binarysearch(tc.data, tc.x, fn int(int a, int b) => a - b); assert(cmp_idx3 == tc.index, "%s: got %d; want %d", tc.data, cmp_idx2, tc.index); } - } \ No newline at end of file