diff --git a/lib/std/collections/list.c3 b/lib/std/collections/list.c3 index 63514dfa7..4610a95e1 100644 --- a/lib/std/collections/list.c3 +++ b/lib/std/collections/list.c3 @@ -6,6 +6,7 @@ import std::io; import std::math; def ElementPredicate = fn bool(Type *type); +def ElementTest = fn bool(Type *type, any context); const ELEMENT_IS_EQUATABLE = types::is_equatable_type(Type); const ELEMENT_IS_POINTER = Type.kindof == POINTER; @@ -251,6 +252,15 @@ fn usz List.remove_if(&self, ElementPredicate filter) return self._remove_if(filter, false); } +/** + * @param selection "The function to determine if it should be kept or not" + * @return "the number of deleted elements" + **/ +fn usz List.retain_if(&self, ElementPredicate selection) +{ + return self._remove_if(selection, true); +} + macro usz List._remove_if(&self, ElementPredicate filter, bool $invert) @local { usz size = self.size; @@ -264,8 +274,7 @@ macro usz List._remove_if(&self, ElementPredicate filter, bool $invert) @local $endif // Remove the items from this index up to the one not to be deleted. usz n = self.size - k; - // Do explicit copy - copying between the same slice is not well defined. - for (usz j = 0; j < n; j++) self.entries[i + j] = self.entries[k + j]; + self.entries[i:n] = self.entries[k:n]; self.size -= k - i; // Find last index of item not to be deleted. $if $invert: @@ -277,13 +286,39 @@ macro usz List._remove_if(&self, ElementPredicate filter, bool $invert) @local return size - self.size; } -/** - * @param selection "The function to determine if it should be kept or not" - * @return "the number of deleted elements" - **/ -fn usz List.retain_if(&self, ElementPredicate selection) +fn usz List.remove_using_test(&self, ElementTest filter, any context) { - return self._remove_if(selection, true); + return self._remove_using_test(filter, false, context); +} + +fn usz List.retain_using_test(&self, ElementTest filter, any context) +{ + return self._remove_using_test(filter, true, context); +} + +macro usz List._remove_using_test(&self, ElementTest filter, bool $invert, ctx) @local +{ + usz size = self.size; + for (usz i = size, usz k = size; k > 0; k = i) + { + // Find last index of item to be deleted. + $if $invert: + while (i > 0 && !filter(&self.entries[i - 1], ctx)) i--; + $else + while (i > 0 && filter(&self.entries[i - 1], ctx)) i--; + $endif + // Remove the items from this index up to the one not to be deleted. + usz n = self.size - k; + self.entries[i:n] = self.entries[k:n]; + self.size -= k - i; + // Find last index of item not to be deleted. + $if $invert: + while (i > 0 && filter(&self.entries[i - 1], ctx)) i--; + $else + while (i > 0 && !filter(&self.entries[i - 1], ctx)) i--; + $endif + } + return size - self.size; } /** diff --git a/test/unit/stdlib/collections/list.c3 b/test/unit/stdlib/collections/list.c3 index 2531a6dc7..3b34ab57e 100644 --- a/test/unit/stdlib/collections/list.c3 +++ b/test/unit/stdlib/collections/list.c3 @@ -78,6 +78,23 @@ fn void! remove_if() assert(test.array_view() == int[]{11, 10, 20}); } +fn void! remove_using_test() +{ + IntList test; + usz removed; + + test.add_array({ 1, 11, 2, 10, 20 }); + removed = test.remove_using_test(fn bool(i, ctx) => *i >= *(int*)ctx, &&10); + assert(removed == 3); + assert(test.array_view() == int[]{1, 2}); + + test.clear(); + test.add_array({ 1, 11, 2, 10, 20 }); + removed = test.remove_using_test(fn bool(i, ctx) => *i < *(int*)ctx, &&10); + assert(removed == 2); + assert(test.array_view() == int[]{11, 10, 20}); +} + fn void! retain_if() { IntList test; @@ -95,6 +112,23 @@ fn void! retain_if() assert(test.array_view() == int[]{11, 10, 20}); } +fn void! retain_using_test() +{ + IntList test; + usz removed; + + test.add_array({ 1, 11, 2, 10, 20 }); + removed = test.remove_using_test(fn bool(i, ctx) => *i >= *(int*)ctx, &&10); + assert(removed == 3); + assert(test.array_view() == int[]{1, 2}); + + test.clear(); + test.add_array({ 1, 11, 2, 10, 20 }); + removed = test.remove_using_test(fn bool(i, ctx) => *i < *(int*)ctx, &&10); + assert(removed == 2); + assert(test.array_view() == int[]{11, 10, 20}); +} + module list_test; fn bool filter(int* i)