From e790df539d84f9225c6944b6f36f2cd2f07ca655 Mon Sep 17 00:00:00 2001 From: Velikiy Kirill Date: Fri, 25 Jul 2025 20:53:39 +0300 Subject: [PATCH] Add HashSet implementation (#2322) * Add HashSet implementation Add a generic HashSet with full allocator support and standard set operations. - Basic operations: add/remove/contains/clear - Set operations:union_set/intersection/symmetric_difference/difference/is_subset - Memory management with allocator support - Iteration support - Automatic resizing with load factor control * Add "add_all" "add_all_from" "remove_all" "remove_all_from" --------- Co-authored-by: Christoffer Lerno --- lib/std/collections/hashset.c3 | 636 ++++++++++++++++++++++++++++ test/unit/stdlib/collections/set.c3 | 231 ++++++++++ 2 files changed, 867 insertions(+) create mode 100644 lib/std/collections/hashset.c3 create mode 100644 test/unit/stdlib/collections/set.c3 diff --git a/lib/std/collections/hashset.c3 b/lib/std/collections/hashset.c3 new file mode 100644 index 000000000..09bf5bc18 --- /dev/null +++ b/lib/std/collections/hashset.c3 @@ -0,0 +1,636 @@ +<* + @require $defined((Value){}.hash()) : `No .hash function found on the value` +*> +module std::collections::set {Value}; +import std::math; +import std::io @norecurse; + +const uint DEFAULT_INITIAL_CAPACITY = 16; +const uint MAXIMUM_CAPACITY = 1u << 31; +const float DEFAULT_LOAD_FACTOR = 0.75; + +const Allocator SET_HEAP_ALLOCATOR = (Allocator)&dummy; + +<* Copy the ONHEAP allocator to initialize to a set that is heap allocated *> +const HashSet ONHEAP = { .allocator = SET_HEAP_ALLOCATOR }; + +struct Entry +{ + uint hash; + Value value; + Entry* next; +} + +struct HashSet (Printable) +{ + Entry*[] table; + Allocator allocator; + usz count; // Number of elements + usz threshold; // Resize limit + float load_factor; +} + +fn int HashSet.len(&self) @operator(len) => (int) self.count; + +<* + @param [&inout] allocator : "The allocator to use" + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Set was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +fn HashSet* HashSet.init(&self, Allocator allocator, usz capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + capacity = math::next_power_of_2(capacity); + self.allocator = allocator; + self.threshold = (usz) (capacity * load_factor); + self.load_factor = load_factor; + self.table = allocator::new_array(allocator, Entry*, capacity); + return self; +} + +<* + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Set was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +fn HashSet* HashSet.tinit(&self, uint capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + return self.init(tmem, capacity, load_factor) @inline; +} + +<* + @param [&inout] allocator : "The allocator to use" + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Set was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +macro HashSet* HashSet.init_with_values(&self, Allocator allocator, ..., uint capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + self.init(allocator, capacity, load_factor); + $for var $i = 0; $i < $vacount; $i++: + self.add($vaarg[$i]); + $endfor + return self; +} + +<* + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Map was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +macro HashSet* HashSet.tinit_with_values(&self, ..., uint capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + return self.init_with_values(tmem, $vasplat, capacity: capacity, load_factor: load_factor); +} + +<* + @param [in] values : "The values for the HashSet" + @param [&inout] allocator : "The allocator to use" + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Map was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +fn HashSet* HashSet.init_from_values(&self, Allocator allocator, Value[] values, uint capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + self.init(allocator, capacity, load_factor); + foreach (v : values) self.add(v); + return self; +} + +<* + @param [in] values : "The values for the HashSet entries" + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Set was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +fn HashSet* HashSet.tinit_from_values(&self, Value[] values, uint capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + return self.init_from_values(tmem, values, capacity, load_factor); +} + +<* + Has this hash set been initialized yet? + + @param [&in] set : "The hash set we are testing" + @return "Returns true if it has been initialized, false otherwise" +*> +fn bool HashSet.is_initialized(&set) +{ + return set.allocator && set.allocator.ptr != &dummy; +} + +<* + @param [&inout] allocator : "The allocator to use" + @param [&in] other_set : "The set to copy from." + @require !self.is_initialized() : "Set was already initialized" +*> +fn HashSet* HashSet.init_from_set(&self, Allocator allocator, HashSet* other_set) +{ + self.init(allocator, other_set.table.len, other_set.load_factor); + self.put_all_for_create(other_set); + return self; +} + +<* + @param [&in] other_set : "The set to copy from." + @require !set.is_initialized() : "Set was already initialized" +*> +fn HashSet* HashSet.tinit_from_set(&set, HashSet* other_set) +{ + return set.init_from_set(tmem, other_set) @inline; +} + +<* + Check if the set is empty + + @return "true if it is empty" + @pure +*> +fn bool HashSet.is_empty(&set) @inline +{ + return !set.count; +} + +<* + Add all elements in the slice to the set. + + @param [in] list + @return "The number of new elements added" + @ensure total <= list.len +*> +fn usz HashSet.add_all(&set, Value[] list) +{ + usz total; + foreach (v : list) + { + if (set.add(v)) total++; + } + return total; +} + +<* + @param [&in] other + @return "The number of new elements added" + @ensure return <= other.count +*> +fn usz HashSet.add_all_from(&set, HashSet* other) +{ + usz total; + other.@each(;Value value) + { + if (set.add(value)) total++; + }; + return total; +} + +<* + @param value : "The value to add" + @return "true if the value didn't exist in the set" +*> +fn bool HashSet.add(&set, Value value) +{ + // If the set isn't initialized, use the defaults to initialize it. + switch (set.allocator.ptr) + { + case &dummy: + set.init(mem); + case null: + set.tinit(); + default: + break; + } + uint hash = rehash(value.hash()); + uint index = index_for(hash, set.table.len); + for (Entry *e = set.table[index]; e != null; e = e.next) + { + if (e.hash == hash && equals(value, e.value)) return false; + } + set.add_entry(hash, value, index); + return true; +} + +<* + Iterate over all the values in the set +*> +macro HashSet.@each(set; @body(value)) +{ + if (!set.count) return; + foreach (Entry* entry : set.table) + { + while (entry) + { + @body(entry.value); + entry = entry.next; + } + } +} + +<* + Check if the set contains the given value. + + @param value : "The value to check" + @return "true if it exists in the set" +*> +fn bool HashSet.contains(&set, Value value) +{ + if (!set.count) return false; + uint hash = rehash(value.hash()); + for (Entry *e = set.table[index_for(hash, set.table.len)]; e != null; e = e.next) + { + if (e.hash == hash && equals(value, e.value)) return true; + } + return false; +} + +<* + Remove a single value from the set. + + @param value : "The value to remove" + @return? NOT_FOUND : "If the entry is not found" +*> +fn void? HashSet.remove(&set, Value value) @maydiscard +{ + if (!set.remove_entry_for_value(value)) return NOT_FOUND?; +} + +fn usz HashSet.remove_all(&set, Value[] values) +{ + usz total; + foreach (v : values) + { + if (set.remove_entry_for_value(v)) total++; + } + return total; +} + +<* + @param [&in] other : "Other set" +*> +fn usz HashSet.remove_all_from(&set, HashSet* other) +{ + usz total; + other.@each(;Value val) + { + if (set.remove_entry_for_value(val)) total++; + }; + return total; +} + +<* + Free all memory allocated by the hash set. +*> +fn void HashSet.free(&set) +{ + if (!set.is_initialized()) return; + set.clear(); + set.free_internal(set.table.ptr); + *set = {}; +} + +<* + Clear all elements from the set while keeping the underlying storage + + @ensure set.count == 0 +*> +fn void HashSet.clear(&set) +{ + if (!set.count) return; + + foreach (Entry** &entry_ref : set.table) + { + Entry* entry = *entry_ref; + if (!entry) continue; + + Entry *next = entry.next; + while (next) + { + Entry *to_delete = next; + next = next.next; + set.free_entry(to_delete); + } + + set.free_entry(entry); + *entry_ref = null; + } + set.count = 0; +} + +fn void HashSet.reserve(&set, usz capacity) +{ + if (capacity > set.threshold) + { + set.resize(math::next_power_of_2(capacity)); + } +} + + + +// --- Set Operations --- + +<* + Returns the union of two sets (A | B) + + @param [&in] other : "The other set to union with" + @param [&inout] allocator : "Allocator for the new set" + @return "A new set containing the union of both sets" +*> +fn HashSet HashSet.set_union(&self, Allocator allocator, HashSet* other) +{ + usz new_capacity = math::next_power_of_2(self.count + other.count); + HashSet result; + result.init(allocator, new_capacity, self.load_factor); + result.add_all_from(self); + result.add_all_from(other); + return result; +} + +fn HashSet HashSet.tset_union(&self, HashSet* other) => self.set_union(tmem, other); + +<* + Returns the intersection of the two sets (A & B) + + @param [&in] other : "The other set to intersect with" + @param [&inout] allocator : "Allocator for the new set" + @return "A new set containing the intersection of both sets" +*> +fn HashSet HashSet.intersection(&self, Allocator allocator, HashSet* other) +{ + HashSet result; + result.init(allocator, math::min(self.table.len, other.table.len), self.load_factor); + + // Iterate through the smaller set for efficiency + HashSet* smaller = self.count <= other.count ? self : other; + HashSet* larger = self.count > other.count ? self : other; + + smaller.@each(;Value value) + { + if (larger.contains(value)) result.add(value); + }; + + return result; +} + +fn HashSet HashSet.tintersection(&self, HashSet* other) => self.intersection(tmem, other); + +<* + Return this set - other, so (A & ~B) + + @param [&in] other : "The other set to compare with" + @param [&inout] allocator : "Allocator for the new set" + @return "A new set containing elements in this set but not in the other" +*> +fn HashSet HashSet.difference(&self, Allocator allocator, HashSet* other) +{ + HashSet result; + result.init(allocator, self.table.len, self.load_factor); + self.@each(;Value value) + { + if (!other.contains(value)) + { + result.add(value); + } + }; + return result; +} + +fn HashSet HashSet.tdifference(&self, HashSet* other) => self.difference(tmem, other) @inline; + +<* + Return (A ^ B) + + @param [&in] other : "The other set to compare with" + @param [&inout] allocator : "Allocator for the new set" + @return "A new set containing elements in this set or the other, but not both" +*> +fn HashSet HashSet.symmetric_difference(&self, Allocator allocator, HashSet* other) +{ + HashSet result; + result.init(allocator, self.table.len, self.load_factor); + result.add_all_from(self); + other.@each(;Value value) + { + if (!result.add(value)) + { + result.remove(value); + } + }; + return result; +} + +fn HashSet HashSet.tsymmetric_difference(&self, HashSet* other) => self.symmetric_difference(tmem, other) @inline; + +<* + Check if this hash set is a subset of another set. + + @param [&in] other : "The other set to check against" + @return "True if all elements of this set are in the other set" +*> +fn bool HashSet.is_subset(&self, HashSet* other) +{ + if (self.count == 0) return true; + if (self.count > other.count) return false; + + self.@each(;Value value) + { + if (!other.contains(value)) return false; + }; + return true; +} + + +// --- private methods + +fn void HashSet.add_entry(&set, uint hash, Value value, uint bucket_index) @private +{ + Entry* entry = allocator::new(set.allocator, Entry, { .hash = hash, .value = value, .next = set.table[bucket_index] }); + set.table[bucket_index] = entry; + if (set.count++ >= set.threshold) + { + set.resize(set.table.len * 2); + } +} + +fn void HashSet.resize(&self, usz new_capacity) @private +{ + Entry*[] old_table = self.table; + usz old_capacity = old_table.len; + if (old_capacity == MAXIMUM_CAPACITY) + { + self.threshold = uint.max; + return; + } + Entry*[] new_table = allocator::new_array(self.allocator, Entry*, new_capacity); + self.transfer(new_table); + self.table = new_table; + self.free_internal(old_table.ptr); + self.threshold = (uint)(new_capacity * self.load_factor); +} + +fn usz? HashSet.to_format(&self, Formatter* f) @dynamic +{ + usz len; + len += f.print("{ ")!; + self.@each(; Value value) + { + if (len > 2) len += f.print(", ")!; + len += f.printf("%s", value)!; + }; + return len + f.print(" }"); +} + +fn void HashSet.transfer(&self, Entry*[] new_table) @private +{ + Entry*[] src = self.table; + uint new_capacity = new_table.len; + foreach (uint j, Entry *e : src) + { + if (!e) continue; + do + { + Entry* next = e.next; + uint i = index_for(e.hash, new_capacity); + e.next = new_table[i]; + new_table[i] = e; + e = next; + } + while (e); + } +} + +fn void HashSet.put_all_for_create(&set, HashSet* other_set) @private +{ + if (!other_set.count) return; + foreach (Entry *e : other_set.table) + { + while (e) + { + set.put_for_create(e.value); + e = e.next; + } + } +} + +fn void HashSet.put_for_create(&set, Value value) @private +{ + uint hash = rehash(value.hash()); + uint i = index_for(hash, set.table.len); + for (Entry *e = set.table[i]; e != null; e = e.next) + { + if (e.hash == hash && equals(value, e.value)) + { + // Value already exists, no need to do anything + return; + } + } + set.create_entry(hash, value, i); +} + +fn void HashSet.free_internal(&self, void* ptr) @inline @private +{ + allocator::free(self.allocator, ptr); +} + +fn void HashSet.create_entry(&set, uint hash, Value value, int bucket_index) @private +{ + Entry* entry = allocator::new(set.allocator, Entry, { + .hash = hash, + .value = value, + .next = set.table[bucket_index] + }); + set.table[bucket_index] = entry; + set.count++; +} + +<* + Removes the entry for the specified value if present + @return "true if found and removed, false otherwise" +*> +fn bool HashSet.remove_entry_for_value(&set, Value value) @private +{ + if (!set.count) return false; + uint hash = rehash(value.hash()); + uint i = index_for(hash, set.table.len); + Entry* prev = set.table[i]; + Entry* e = prev; + while (e) + { + Entry *next = e.next; + if (e.hash == hash && equals(value, e.value)) + { + set.count--; + if (prev == e) + { + set.table[i] = next; + } + else + { + prev.next = next; + } + set.free_entry(e); + return true; + } + prev = e; + e = next; + } + + return false; +} + +fn void HashSet.free_entry(&set, Entry *entry) @private +{ + allocator::free(set.allocator, entry); +} + +struct HashSetIterator +{ + HashSet* set; + usz bucket_index; + Entry* current; +} + +fn HashSetIterator HashSet.iter(&set) => { .set = set, .bucket_index = 0, .current = null }; + +fn Value? HashSetIterator.next(&self) +{ + if (self.current) + { + Value value = self.current.value; + self.current = self.current.next; + return value; + } + + while (self.bucket_index < self.set.table.len) + { + self.current = self.set.table[self.bucket_index++]; + if (self.current) + { + Value value = self.current.value; + self.current = self.current.next; + return value; + } + } + + return NOT_FOUND?; +} + +fn usz HashSetIterator.len(&self) @operator(len) +{ + return self.set.count; +} + +<* @pure *> +fn uint rehash(uint hash) @inline @private +{ + hash ^= (hash >> 20) ^ (hash >> 12); + return hash ^ ((hash >> 7) ^ (hash >> 4)); +} + +macro uint index_for(uint hash, uint capacity) @private => hash & (capacity - 1); + +int dummy @local; diff --git a/test/unit/stdlib/collections/set.c3 b/test/unit/stdlib/collections/set.c3 new file mode 100644 index 000000000..dc3306376 --- /dev/null +++ b/test/unit/stdlib/collections/set.c3 @@ -0,0 +1,231 @@ +module set_test @test; +import std::collections::set; + +alias IntSet = HashSet{int}; + +fn void basic_operations() +{ + IntSet set; + defer set.free(); + + assert(set.is_empty()); + assert(!set.contains(1)); + + assert(set.add(1)); + assert(set.contains(1)); + assert(!set.is_empty()); + assert(set.len() == 1); + + assert(!set.add(1)); + assert(set.len() == 1); + + assert(set.add(2)); + assert(set.add(3)); + assert(set.len() == 3); + assert(set.contains(2)); + assert(set.contains(3)); + + set.remove(2); + assert(!set.contains(2)); + assert(set.len() == 2); + + set.clear(); + assert(set.is_empty()); + assert(!set.contains(1)); +} + +fn void initialization_methods() +{ + IntSet set1; + set1.tinit(); + defer set1.free(); + assert(set1.is_initialized()); + + IntSet set2; + set2.tinit_with_values(1, 2, 3); + defer set2.free(); + assert(set2.contains(1)); + assert(set2.contains(2)); + assert(set2.contains(3)); + assert(set2.len() == 3); + + int[] values = {4, 5, 6}; + IntSet set3; + set3.tinit_from_values(values); + defer set3.free(); + assert(set3.contains(4)); + assert(set3.contains(5)); + assert(set3.contains(6)); + assert(set3.len() == 3); + + IntSet set4; + set4.tinit_from_set(&set3); + defer set4.free(); + assert(set4.contains(4)); + assert(set4.contains(5)); + assert(set4.contains(6)); + assert(set4.len() == 3); +} + +fn void set_operations() +{ + IntSet set1; + set1.tinit_with_values(1, 2, 3); + defer set1.free(); + + IntSet set2; + set2.tinit_with_values(2, 3, 4); + defer set2.free(); + + IntSet union_set = set1.tset_union(&set2); + defer union_set.free(); + assert(union_set.contains(1)); + assert(union_set.contains(2)); + assert(union_set.contains(3)); + assert(union_set.contains(4)); + assert(union_set.len() == 4); + + IntSet intersect_set = set1.tintersection(&set2); + defer intersect_set.free(); + assert(intersect_set.contains(2)); + assert(intersect_set.contains(3)); + assert(!intersect_set.contains(1)); + assert(!intersect_set.contains(4)); + assert(intersect_set.len() == 2); + + IntSet diff_set = set1.tdifference(&set2); + assert(diff_set.contains(1)); + assert(!diff_set.contains(2)); + assert(!diff_set.contains(3)); + assert(!diff_set.contains(4)); + assert(diff_set.len() == 1); + + IntSet sdiff_set = set1.tsymmetric_difference(&set2); + assert(sdiff_set.contains(1)); + assert(!sdiff_set.contains(2)); + assert(!sdiff_set.contains(3)); + assert(sdiff_set.contains(4)); + assert(sdiff_set.len() == 2); + + IntSet subset; + subset.tinit_with_values(2, 3); + defer subset.free(); + assert(subset.is_subset(&set1)); + assert(!set1.is_subset(&subset)); +} + +fn void iterator_test() +{ + IntSet set; + set.tinit_with_values(1, 2, 3); + defer set.free(); + + int count = 0; + bool found1 = false; bool found2 = false; bool found3 = false; + + set.@each(; int value) + { + count++; + switch (value) + { + case 1: found1 = true; + case 2: found2 = true; + case 3: found3 = true; + } + }; + + assert(count == 3); + assert(found1 && found2 && found3); + + HashSetIterator {int} iter = set.iter(); + count = 0; + while (@ok(iter.next())) + { + count++; + } + assert(count == 3); +} + +fn void edge_cases() +{ + IntSet empty; + empty.tinit(); + defer empty.free(); + + assert(empty.is_empty()); + assert(!empty.contains(0)); + empty.remove(0); // Shouldn't crash + + IntSet large; + large.tinit(); + defer large.free(); + + for (int i = 0; i < 1000; i++) + { + large.add(i); + } + assert(large.len() == 1000); + for (int i = 0; i < 1000; i++) + { + assert(large.contains(i)); + } + + assert(@catch(large.remove(1001))); + assert(large.len() == 1000); + + large.clear(); + assert(large.is_empty()); + for (int i = 0; i < 1000; i++) + { + assert(!large.contains(i)); + } +} + +alias StringSet = HashSet{String}; + +fn void string_set_test() +{ + StringSet set; + set.tinit(); + defer set.free(); + + assert(set.add("hello")); + assert(set.add("world")); + assert(!set.add("hello")); + + assert(set.contains("hello")); + assert(set.contains("world")); + assert(!set.contains("foo")); + + set.remove("hello"); + assert(!set.contains("hello")); + assert(set.len() == 1); +} + +fn void add_all_test() +{ + StringSet set; + set.init(mem); + defer set.free(); + + String[] list = { "hello", "world", "hello" }; + usz total = set.add_all(list); + assert(total == 2); + + assert(set.contains("hello")); + assert(set.contains("world")); + assert(!set.contains("foo")); + + set.remove("hello"); + assert(!set.contains("hello")); + assert(set.len() == 1); +} + +fn void is_initialized_test() +{ + IntSet test; + assert(!test.is_initialized()); + test.tinit(); + assert(test.is_initialized()); + test.free(); +}