diff --git a/lib/std/collections/list.c3 b/lib/std/collections/list.c3 index 1162df7cd..060974206 100644 --- a/lib/std/collections/list.c3 +++ b/lib/std/collections/list.c3 @@ -79,7 +79,7 @@ fn void List.append(&self, Type element) /** * @require self.size > 0 - */ + **/ fn Type List.pop(&self) { return self.entries[--self.size]; @@ -92,7 +92,7 @@ fn void List.clear(&self) /** * @require self.size > 0 - */ + **/ fn Type List.pop_first(&self) { Type value = self.entries[0]; @@ -100,6 +100,9 @@ fn Type List.pop_first(&self) return value; } +/** + * @require index < self.size + **/ fn void List.remove_at(&self, usz index) { for (usz i = index + 1; i < self.size; i++) @@ -162,6 +165,9 @@ fn void List.push_front(&self, Type type) @inline self.insert_at(0, type); } +/** + * @require index < self.size + **/ fn void List.insert_at(&self, usz index, Type type) { self.ensure_capacity(); @@ -181,11 +187,17 @@ fn void List.set_at(&self, usz index, Type type) self.entries[index] = type; } +/** + * @require self.size > 0 + **/ fn void List.remove_last(&self) { self.size--; } +/** + * @require self.size > 0 + **/ fn void List.remove_first(&self) { self.remove_at(0); @@ -235,16 +247,26 @@ fn void List.swap(&self, usz i, usz j) * @return "the number of deleted elements" **/ fn usz List.remove_if(&self, ElementPredicate filter) +{ + return self._remove_if({ filter }, &Filter.same); +} + +macro usz List._remove_if(&self, Filter o, m) @private { usz size = self.size; - for (usz i = size; i > 0; i--) + usz i = size; + usz k = i; + while (k > 0) { - if (filter(&self.entries[i - 1])) continue; - for (usz j = i; j < size; j++) - { - self.entries[j - 1] = self.entries[j]; - } - self.size--; + // Find last index of item to be deleted. + while (i > 0 && m(o, &self.entries[i - 1])) i--; + // Remove the items from this index up to the one not to be deleted. + usz n = self.size - k; + self.entries[i:n] = self.entries[k:n]; + self.size -= k - i; + // Find last index of item not to be deleted. + while (i > 0 && !m(o, &self.entries[i - 1])) i--; + k = i; } return size - self.size; } @@ -255,19 +277,15 @@ fn usz List.remove_if(&self, ElementPredicate filter) **/ fn usz List.retain_if(&self, ElementPredicate selection) { - usz size = self.size; - for (usz i = size; i > 0; i--) - { - if (!selection(&self.entries[i - 1])) continue; - for (usz j = i; j < size; j++) - { - self.entries[j - 1] = self.entries[j]; - } - self.size--; - } - return size - self.size; + return self._remove_if({ selection }, &Filter.opposite); } +struct Filter @private +{ + ElementPredicate p; +} +fn bool Filter.same(self, Type* type) => self.p(type) @inline; +fn bool Filter.opposite(self, Type* type) => !self.p(type) @inline; /** * Reserve at least min_capacity diff --git a/test/unit/stdlib/collections/list.c3 b/test/unit/stdlib/collections/list.c3 index 686d36a26..2531a6dc7 100644 --- a/test/unit/stdlib/collections/list.c3 +++ b/test/unit/stdlib/collections/list.c3 @@ -4,7 +4,7 @@ import std::collections::list; def IntList = List(); def PtrList = List(); -fn void! test_delete_contains_index() +fn void! delete_contains_index() { IntList test; test.add_array({ 1, 2 }); @@ -33,7 +33,7 @@ fn void! test_delete_contains_index() assert(test.array_view() == int[]{ 2, 3 }); } -fn void! test_compact() +fn void! compact() { PtrList test; test.add_array({ null, &test }); @@ -46,7 +46,7 @@ fn void! test_compact() assert(test.compact() == 0); } -fn void! test_reverse() +fn void! reverse() { IntList test; test.reverse(); @@ -59,4 +59,50 @@ fn void! test_reverse() assert(test.array_view() == int[] { 3, 2, 1, 10 }); test.reverse(); assert(test.array_view() == int[] { 10, 1, 2, 3 }); +} + +fn void! remove_if() +{ + IntList test; + usz removed; + + test.add_array({ 1, 11, 2, 10, 20 }); + removed = test.remove_if(&filter); + assert(removed == 3); + assert(test.array_view() == int[]{1, 2}); + + test.clear(); + test.add_array({ 1, 11, 2, 10, 20 }); + removed = test.remove_if(&select); + assert(removed == 2); + assert(test.array_view() == int[]{11, 10, 20}); +} + +fn void! retain_if() +{ + IntList test; + usz removed; + + test.add_array({ 1, 11, 2, 10, 20 }); + removed = test.retain_if(&select); + assert(removed == 3); + assert(test.array_view() == int[]{1, 2}); + + test.clear(); + test.add_array({ 1, 11, 2, 10, 20 }); + removed = test.retain_if(&filter); + assert(removed == 2); + assert(test.array_view() == int[]{11, 10, 20}); +} + +module list_test; + +fn bool filter(int* i) +{ + return *i >= 10; +} + +fn bool select(int* i) +{ + return *i < 10; } \ No newline at end of file