diff --git a/lib/std/collections/linked_hashmap.c3 b/lib/std/collections/linked_hashmap.c3 new file mode 100644 index 000000000..d5891e0cb --- /dev/null +++ b/lib/std/collections/linked_hashmap.c3 @@ -0,0 +1,650 @@ +// Copyright (c) 2023 Christoffer Lerno. All rights reserved. +// Use of this source code is governed by the MIT license +// a copy of which can be found in the LICENSE_STDLIB file. +<* + @require $defined((Key){}.hash()) : `No .hash function found on the key` +*> +module std::collections::map{Key, Value}; +import std::math; +import std::io @norecurse; + +const LinkedHashMap LINKEDONHEAP = { .allocator = MAP_HEAP_ALLOCATOR }; + +struct LinkedEntry +{ + uint hash; + Key key; + Value value; + LinkedEntry* next; // For bucket chain + LinkedEntry* before; // Previous in insertion order + LinkedEntry* after; // Next in insertion order +} + +struct LinkedHashMap (Printable) +{ + LinkedEntry*[] table; + Allocator allocator; + usz count; + usz threshold; + float load_factor; + LinkedEntry* head; // First inserted LinkedEntry + LinkedEntry* tail; // Last inserted LinkedEntry +} + + +<* + @param [&inout] allocator : "The allocator to use" + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Map was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +fn LinkedHashMap* LinkedHashMap.init(&self, Allocator allocator, usz capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + capacity = math::next_power_of_2(capacity); + self.allocator = allocator; + self.load_factor = load_factor; + self.threshold = (usz)(capacity * load_factor); + self.table = allocator::new_array(allocator, LinkedEntry*, capacity); + self.head = null; + self.tail = null; + return self; +} + +<* + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Map was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +fn LinkedHashMap* LinkedHashMap.tinit(&self, usz capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + return self.init(tmem, capacity, load_factor) @inline; +} + +<* + @param [&inout] allocator : "The allocator to use" + @require $vacount % 2 == 0 : "There must be an even number of arguments provided for keys and values" + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Map was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +macro LinkedHashMap* LinkedHashMap.init_with_key_values(&self, Allocator allocator, ..., uint capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + self.init(allocator, capacity, load_factor); + $for var $i = 0; $i < $vacount; $i += 2: + self.set($vaarg[$i], $vaarg[$i + 1]); + $endfor + return self; +} + +<* + @require $vacount % 2 == 0 : "There must be an even number of arguments provided for keys and values" + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Map was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +macro LinkedHashMap* LinkedHashMap.tinit_with_key_values(&self, ..., uint capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + return self.init_with_key_values(tmem, $vasplat, capacity: capacity, load_factor: load_factor); +} + +<* + @param [in] keys : "The keys for the LinkedHashMap entries" + @param [in] values : "The values for the LinkedHashMap entries" + @param [&inout] allocator : "The allocator to use" + @require keys.len == values.len : "Both keys and values arrays must be the same length" + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Map was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +fn LinkedHashMap* LinkedHashMap.init_from_keys_and_values(&self, Allocator allocator, Key[] keys, Value[] values, uint capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + assert(keys.len == values.len); + self.init(allocator, capacity, load_factor); + for (usz i = 0; i < keys.len; i++) + { + self.set(keys[i], values[i]); + } + return self; +} + +<* + @param [in] keys : "The keys for the LinkedHashMap entries" + @param [in] values : "The values for the LinkedHashMap entries" + @require keys.len == values.len : "Both keys and values arrays must be the same length" + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Map was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +fn LinkedHashMap* LinkedHashMap.tinit_from_keys_and_values(&self, Key[] keys, Value[] values, uint capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + return self.init_from_keys_and_values(tmem, keys, values, capacity, load_factor); +} + +<* + Has this hash map been initialized yet? + + @param [&in] map : "The hash map we are testing" + @return "Returns true if it has been initialized, false otherwise" +*> +fn bool LinkedHashMap.is_initialized(&map) +{ + return map.allocator && map.allocator.ptr != &dummy; +} + +<* + @param [&inout] allocator : "The allocator to use" + @param [&in] other_map : "The map to copy from." + @require !self.is_initialized() : "Map was already initialized" +*> +fn LinkedHashMap* LinkedHashMap.init_from_map(&self, Allocator allocator, LinkedHashMap* other_map) +{ + self.init(allocator, other_map.table.len, other_map.load_factor); + self.put_all_for_create(other_map); + return self; +} + +<* + @param [&in] other_map : "The map to copy from." + @require !map.is_initialized() : "Map was already initialized" +*> +fn LinkedHashMap* LinkedHashMap.tinit_from_map(&map, LinkedHashMap* other_map) +{ + return map.init_from_map(tmem, other_map) @inline; +} + +fn bool LinkedHashMap.is_empty(&map) @inline +{ + return !map.count; +} + +fn usz LinkedHashMap.len(&map) @inline => map.count; + +fn Value*? LinkedHashMap.get_ref(&map, Key key) +{ + if (!map.count) return NOT_FOUND?; + uint hash = rehash(key.hash()); + for (LinkedEntry *e = map.table[index_for(hash, map.table.len)]; e != null; e = e.next) + { + if (e.hash == hash && equals(key, e.key)) return &e.value; + } + return NOT_FOUND?; +} + +fn LinkedEntry*? LinkedHashMap.get_entry(&map, Key key) +{ + if (!map.count) return NOT_FOUND?; + uint hash = rehash(key.hash()); + for (LinkedEntry *e = map.table[index_for(hash, map.table.len)]; e != null; e = e.next) + { + if (e.hash == hash && equals(key, e.key)) return e; + } + return NOT_FOUND?; +} + +<* + Get the value or update and + @require @assignable_to(#expr, Value) +*> +macro Value LinkedHashMap.@get_or_set(&map, Key key, Value #expr) +{ + if (!map.count) + { + Value val = #expr; + map.set(key, val); + return val; + } + uint hash = rehash(key.hash()); + uint index = index_for(hash, map.table.len); + for (LinkedEntry *e = map.table[index]; e != null; e = e.next) + { + if (e.hash == hash && equals(key, e.key)) return e.value; + } + Value val = #expr; + map.add_entry(hash, key, val, index); + return val; +} + +fn Value? LinkedHashMap.get(&map, Key key) @operator([]) => *map.get_ref(key) @inline; + +fn bool LinkedHashMap.has_key(&map, Key key) => @ok(map.get_ref(key)); + +fn bool LinkedHashMap.set(&map, Key key, Value value) @operator([]=) +{ + // If the map isn't initialized, use the defaults to initialize it. + switch (map.allocator.ptr) + { + case &dummy: + map.init(mem); + case null: + map.tinit(); + default: + break; + } + uint hash = rehash(key.hash()); + uint index = index_for(hash, map.table.len); + for (LinkedEntry *e = map.table[index]; e != null; e = e.next) + { + if (e.hash == hash && equals(key, e.key)) + { + e.value = value; + return true; + } + } + map.add_entry(hash, key, value, index); + return false; +} + +fn void? LinkedHashMap.remove(&map, Key key) @maydiscard +{ + if (!map.remove_entry_for_key(key)) return NOT_FOUND?; +} + +fn void LinkedHashMap.clear(&map) +{ + if (!map.count) return; + + LinkedEntry* entry = map.head; + while (entry) + { + LinkedEntry* next = entry.after; + map.free_entry(entry); + entry = next; + } + + foreach (LinkedEntry** &bucket : map.table) + { + *bucket = null; + } + + map.count = 0; + map.head = null; + map.tail = null; +} + +fn void LinkedHashMap.free(&map) +{ + if (!map.is_initialized()) return; + map.clear(); + map.free_internal(map.table.ptr); + map.table = {}; +} + +fn Key[] LinkedHashMap.tkeys(&self) +{ + return self.keys(tmem) @inline; +} + +fn Key[] LinkedHashMap.keys(&self, Allocator allocator) +{ + if (!self.count) return {}; + + Key[] list = allocator::alloc_array(allocator, Key, self.count); + usz index = 0; + + LinkedEntry* entry = self.head; + while (entry) + { + $if COPY_KEYS: + list[index++] = entry.key.copy(allocator); + $else + list[index++] = entry.key; + $endif + entry = entry.after; + } + return list; +} + +macro LinkedHashMap.@each(map; @body(key, value)) +{ + map.@each_entry(; LinkedEntry* entry) + { + @body(entry.key, entry.value); + }; +} + +macro LinkedHashMap.@each_entry(map; @body(entry)) +{ + LinkedEntry* entry = map.head; + while (entry) + { + @body(entry); + entry = entry.after; + } +} + +fn Value[] LinkedHashMap.tvalues(&map) => map.values(tmem) @inline; + +fn Value[] LinkedHashMap.values(&self, Allocator allocator) +{ + if (!self.count) return {}; + Value[] list = allocator::alloc_array(allocator, Value, self.count); + usz index = 0; + LinkedEntry* entry = self.head; + while (entry) + { + list[index++] = entry.value; + entry = entry.after; + } + return list; +} + +fn bool LinkedHashMap.has_value(&map, Value v) @if(VALUE_IS_EQUATABLE) +{ + if (!map.count) return false; + + LinkedEntry* entry = map.head; + while (entry) + { + if (equals(v, entry.value)) return true; + entry = entry.after; + } + return false; +} + +fn LinkedHashMapIterator LinkedHashMap.iter(&self) => { .map = self, .current = self.head, .started = false }; + +fn LinkedHashMapValueIterator LinkedHashMap.value_iter(&self) => { .map = self, .current = self.head, .started = false }; + +fn LinkedHashMapKeyIterator LinkedHashMap.key_iter(&self) => { .map = self, .current = self.head, .started = false }; + +fn bool LinkedHashMapIterator.next(&self) +{ + if (!self.started) + { + self.current = self.map.head; + self.started = true; + } + else if (self.current) + { + self.current = self.current.after; + } + return self.current != null; +} + +fn LinkedEntry*? LinkedHashMapIterator.get(&self) +{ + return self.current ? self.current : NOT_FOUND?; +} + +fn Value*? LinkedHashMapValueIterator.get(&self) +{ + return self.current ? &self.current.value : NOT_FOUND?; +} + +fn Key*? LinkedHashMapKeyIterator.get(&self) +{ + return self.current ? &self.current.key : NOT_FOUND?; +} + +fn bool LinkedHashMapIterator.has_next(&self) +{ + if (!self.started) return self.map.head != null; + return self.current && self.current.after != null; +} + +// --- private methods + +fn void LinkedHashMap.add_entry(&map, uint hash, Key key, Value value, uint bucket_index) @private +{ + $if COPY_KEYS: + key = key.copy(map.allocator); + $endif + + LinkedEntry* entry = allocator::new(map.allocator, LinkedEntry, { + .hash = hash, + .key = key, + .value = value, + .next = map.table[bucket_index], + .before = map.tail, + .after = null + }); + + // Update bucket chain + map.table[bucket_index] = entry; + + // Update linked list + if (map.tail) + { + map.tail.after = entry; + entry.before = map.tail; + } + else + { + map.head = entry; + } + map.tail = entry; + + if (map.count++ >= map.threshold) + { + map.resize(map.table.len * 2); + } +} + +fn void LinkedHashMap.resize(&map, uint new_capacity) @private +{ + LinkedEntry*[] old_table = map.table; + uint old_capacity = old_table.len; + + if (old_capacity == MAXIMUM_CAPACITY) + { + map.threshold = uint.max; + return; + } + + LinkedEntry*[] new_table = allocator::new_array(map.allocator, LinkedEntry*, new_capacity); + map.table = new_table; + map.threshold = (uint)(new_capacity * map.load_factor); + + // Rehash all entries - linked list order remains unchanged + foreach (uint i, LinkedEntry *e : old_table) + { + if (!e) continue; + + // Split the bucket chain into two chains based on new bit + LinkedEntry* lo_head = null; + LinkedEntry* lo_tail = null; + LinkedEntry* hi_head = null; + LinkedEntry* hi_tail = null; + + do + { + LinkedEntry* next = e.next; + if ((e.hash & old_capacity) == 0) + { + if (!lo_tail) + { + lo_head = e; + } + else + { + lo_tail.next = e; + } + lo_tail = e; + } + else + { + if (!hi_tail) + { + hi_head = e; + } + else + { + hi_tail.next = e; + } + hi_tail = e; + } + e.next = null; + e = next; + } + while (e); + + if (lo_tail) + { + lo_tail.next = null; + new_table[i] = lo_head; + } + if (hi_tail) + { + hi_tail.next = null; + new_table[i + old_capacity] = hi_head; + } + } + + map.free_internal(old_table.ptr); +} + +fn usz? LinkedHashMap.to_format(&self, Formatter* f) @dynamic +{ + usz len; + len += f.print("{ ")!; + self.@each_entry(; LinkedEntry* entry) + { + if (len > 2) len += f.print(", ")!; + len += f.printf("%s: %s", entry.key, entry.value)!; + }; + return len + f.print(" }"); +} + +fn void LinkedHashMap.transfer(&map, LinkedEntry*[] new_table) @private +{ + LinkedEntry*[] src = map.table; + uint new_capacity = new_table.len; + foreach (uint j, LinkedEntry *e : src) + { + if (!e) continue; + do + { + LinkedEntry* next = e.next; + uint i = index_for(e.hash, new_capacity); + e.next = new_table[i]; + new_table[i] = e; + e = next; + } + while (e); + } +} + +fn void LinkedHashMap.put_all_for_create(&map, LinkedHashMap* other_map) @private +{ + if (!other_map.count) return; + foreach (LinkedEntry *e : other_map.table) + { + while (e) + { + map.put_for_create(e.key, e.value); + e = e.next; + } + } +} + +fn void LinkedHashMap.put_for_create(&map, Key key, Value value) @private +{ + uint hash = rehash(key.hash()); + uint i = index_for(hash, map.table.len); + for (LinkedEntry *e = map.table[i]; e != null; e = e.next) + { + if (e.hash == hash && equals(key, e.key)) + { + e.value = value; + return; + } + } + map.create_entry(hash, key, value, i); +} + +fn void LinkedHashMap.free_internal(&map, void* ptr) @inline @private +{ + allocator::free(map.allocator, ptr); +} + +fn bool LinkedHashMap.remove_entry_for_key(&map, Key key) @private +{ + if (!map.count) return false; + + uint hash = rehash(key.hash()); + uint i = index_for(hash, map.table.len); + LinkedEntry* prev = null; + LinkedEntry* e = map.table[i]; + + while (e) + { + if (e.hash == hash && equals(key, e.key)) + { + if (prev) + { + prev.next = e.next; + } + else + { + map.table[i] = e.next; + } + + if (e.before) + { + e.before.after = e.after; + } + else + { + map.head = e.after; + } + + if (e.after) + { + e.after.before = e.before; + } + else + { + map.tail = e.before; + } + + map.count--; + map.free_entry(e); + return true; + } + prev = e; + e = e.next; + } + return false; +} + +fn void LinkedHashMap.create_entry(&map, uint hash, Key key, Value value, int bucket_index) @private +{ + LinkedEntry *e = map.table[bucket_index]; + $if COPY_KEYS: + key = key.copy(map.allocator); + $endif + LinkedEntry* entry = allocator::new(map.allocator, LinkedEntry, { .hash = hash, .key = key, .value = value, .next = map.table[bucket_index] }); + map.table[bucket_index] = entry; + map.count++; +} + +fn void LinkedHashMap.free_entry(&self, LinkedEntry *entry) @local +{ + $if COPY_KEYS: + allocator::free(self.allocator, entry.key); + $endif + self.free_internal(entry); +} + + +struct LinkedHashMapIterator +{ + LinkedHashMap* map; + LinkedEntry* current; + bool started; +} + +typedef LinkedHashMapValueIterator = inline LinkedHashMapIterator; +typedef LinkedHashMapKeyIterator = inline LinkedHashMapIterator; + +fn usz LinkedHashMapValueIterator.len(self) @operator(len) => self.map.count; +fn usz LinkedHashMapKeyIterator.len(self) @operator(len) => self.map.count; +fn usz LinkedHashMapIterator.len(self) @operator(len) => self.map.count; + +int dummy @local; diff --git a/lib/std/collections/linked_hashset.c3 b/lib/std/collections/linked_hashset.c3 new file mode 100644 index 000000000..635fec1b8 --- /dev/null +++ b/lib/std/collections/linked_hashset.c3 @@ -0,0 +1,723 @@ +<* + @require $defined((Value){}.hash()) : `No .hash function found on the value` +*> +module std::collections::set {Value}; +import std::math; +import std::io @norecurse; + +const LinkedHashSet LINKEDONHEAP = { .allocator = SET_HEAP_ALLOCATOR }; + +struct LinkedEntry +{ + uint hash; + Value value; + LinkedEntry* next; // For bucket chain + LinkedEntry* before; // Previous in insertion order + LinkedEntry* after; // Next in insertion order +} + +struct LinkedHashSet (Printable) +{ + LinkedEntry*[] table; + Allocator allocator; + usz count; // Number of elements + usz threshold; // Resize limit + float load_factor; + LinkedEntry* head; // First inserted LinkedEntry + LinkedEntry* tail; // Last inserted LinkedEntry +} + +fn int LinkedHashSet.len(&self) @operator(len) => (int) self.count; + +<* + @param [&inout] allocator : "The allocator to use" + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Set was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +fn LinkedHashSet* LinkedHashSet.init(&self, Allocator allocator, usz capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + capacity = math::next_power_of_2(capacity); + self.allocator = allocator; + self.threshold = (usz)(capacity * load_factor); + self.load_factor = load_factor; + self.table = allocator::new_array(allocator, LinkedEntry*, capacity); + + self.head = null; + self.tail = null; + return self; +} + +<* + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Set was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +fn LinkedHashSet* LinkedHashSet.tinit(&self, usz capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + return self.init(tmem, capacity, load_factor) @inline; +} + +<* + @param [&inout] allocator : "The allocator to use" + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Set was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +macro LinkedHashSet* LinkedHashSet.init_with_values(&self, Allocator allocator, ..., uint capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + self.init(allocator, capacity, load_factor); + $for var $i = 0; $i < $vacount; $i++: + self.add($vaarg[$i]); + $endfor + return self; +} + +<* + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Set was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +macro LinkedHashSet* LinkedHashSet.tinit_with_values(&self, ..., uint capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + return self.init_with_values(tmem, $vasplat, capacity: capacity, load_factor: load_factor); +} + +<* + @param [in] values : "The values for the LinkedHashSet" + @param [&inout] allocator : "The allocator to use" + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Set was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +fn LinkedHashSet* LinkedHashSet.init_from_values(&self, Allocator allocator, Value[] values, uint capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + self.init(allocator, capacity, load_factor); + foreach (v : values) self.add(v); + return self; +} + +<* + @param [in] values : "The values for the LinkedHashSet entries" + @require capacity > 0 : "The capacity must be 1 or higher" + @require load_factor > 0.0 : "The load factor must be higher than 0" + @require !self.is_initialized() : "Set was already initialized" + @require capacity < MAXIMUM_CAPACITY : "Capacity cannot exceed maximum" +*> +fn LinkedHashSet* LinkedHashSet.tinit_from_values(&self, Value[] values, uint capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) +{ + return self.init_from_values(tmem, values, capacity, load_factor); +} + +<* + Has this linked hash set been initialized yet? + + @param [&in] set : "The linked hash set we are testing" + @return "Returns true if it has been initialized, false otherwise" +*> +fn bool LinkedHashSet.is_initialized(&set) +{ + return set.allocator && set.allocator.ptr != &dummy; +} + +<* + @param [&inout] allocator : "The allocator to use" + @param [&in] other_set : "The set to copy from." + @require !self.is_initialized() : "Set was already initialized" +*> +fn LinkedHashSet* LinkedHashSet.init_from_set(&self, Allocator allocator, LinkedHashSet* other_set) +{ + self.init(allocator, other_set.table.len, other_set.load_factor); + LinkedEntry* entry = other_set.head; + while (entry) // Save insertion order + { + self.put_for_create(entry.value); + entry = entry.after; + } + return self; +} + +<* + @param [&in] other_set : "The set to copy from." + @require !set.is_initialized() : "Set was already initialized" +*> +fn LinkedHashSet* LinkedHashSet.tinit_from_set(&set, LinkedHashSet* other_set) +{ + return set.init_from_set(tmem, other_set) @inline; +} + +<* + Check if the set is empty + + @return "true if it is empty" + @pure +*> +fn bool LinkedHashSet.is_empty(&set) @inline +{ + return !set.count; +} + +<* + Add all elements in the slice to the set. + + @param [in] list + @return "The number of new elements added" + @ensure total <= list.len +*> +fn usz LinkedHashSet.add_all(&set, Value[] list) +{ + usz total; + foreach (v : list) + { + if (set.add(v)) total++; + } + return total; +} + +<* + @param [&in] other + @return "The number of new elements added" + @ensure return <= other.count +*> +fn usz LinkedHashSet.add_all_from(&set, LinkedHashSet* other) +{ + usz total; + other.@each(;Value value) + { + if (set.add(value)) total++; + }; + return total; +} + +<* + @param value : "The value to add" + @return "true if the value didn't exist in the set" +*> +fn bool LinkedHashSet.add(&set, Value value) +{ + // If the set isn't initialized, use the defaults to initialize it. + switch (set.allocator.ptr) + { + case &dummy: + set.init(mem); + case null: + set.tinit(); + default: + break; + } + uint hash = rehash(value.hash()); + uint index = index_for(hash, set.table.len); + for (LinkedEntry *e = set.table[index]; e != null; e = e.next) + { + if (e.hash == hash && equals(value, e.value)) return false; + } + set.add_entry(hash, value, index); + return true; +} + +<* + Iterate over all the values in the set +*> +macro LinkedHashSet.@each(set; @body(value)) +{ + if (!set.count) return; + LinkedEntry* entry = set.head; + while (entry) + { + @body(entry.value); + entry = entry.after; + } +} + +<* + Check if the set contains the given value. + + @param value : "The value to check" + @return "true if it exists in the set" +*> +fn bool LinkedHashSet.contains(&set, Value value) +{ + if (!set.count) return false; + uint hash = rehash(value.hash()); + for (LinkedEntry *e = set.table[index_for(hash, set.table.len)]; e != null; e = e.next) + { + if (e.hash == hash && equals(value, e.value)) return true; + } + return false; +} + +<* + Remove a single value from the set. + + @param value : "The value to remove" + @return? NOT_FOUND : "If the entry is not found" +*> +fn void? LinkedHashSet.remove(&set, Value value) @maydiscard +{ + if (!set.remove_entry_for_value(value)) return NOT_FOUND?; +} + +fn usz LinkedHashSet.remove_all(&set, Value[] values) +{ + usz total; + foreach (v : values) + { + if (set.remove_entry_for_value(v)) total++; + } + return total; +} + +<* + @param [&in] other : "Other set" +*> +fn usz LinkedHashSet.remove_all_from(&set, LinkedHashSet* other) +{ + usz total; + other.@each(;Value val) + { + if (set.remove_entry_for_value(val)) total++; + }; + return total; +} + +<* + Free all memory allocated by the hash set. +*> +fn void LinkedHashSet.free(&set) +{ + if (!set.is_initialized()) return; + set.clear(); + set.free_internal(set.table.ptr); + set.table = {}; +} + +<* + Clear all elements from the set while keeping the underlying storage + + @ensure set.count == 0 +*> +fn void LinkedHashSet.clear(&set) +{ + if (!set.count) return; + + LinkedEntry* entry = set.head; + while (entry) + { + LinkedEntry* next = entry.after; + set.free_entry(entry); + entry = next; + } + + foreach (LinkedEntry** &bucket : set.table) + { + *bucket = null; + } + + set.count = 0; + set.head = null; + set.tail = null; +} + +fn void LinkedHashSet.reserve(&set, usz capacity) +{ + if (capacity > set.threshold) + { + set.resize(math::next_power_of_2(capacity)); + } +} + +// --- Set Operations --- + +<* + Returns the union of two sets (A | B) + + @param [&in] other : "The other set to union with" + @param [&inout] allocator : "Allocator for the new set" + @return "A new set containing the union of both sets" +*> +fn LinkedHashSet LinkedHashSet.set_union(&self, Allocator allocator, LinkedHashSet* other) +{ + usz new_capacity = math::next_power_of_2(self.count + other.count); + LinkedHashSet result; + result.init(allocator, new_capacity, self.load_factor); + result.add_all_from(self); + result.add_all_from(other); + return result; +} + +fn LinkedHashSet LinkedHashSet.tset_union(&self, LinkedHashSet* other) => self.set_union(tmem, other) @inline; + +<* + Returns the intersection of the two sets (A & B) + + @param [&in] other : "The other set to intersect with" + @param [&inout] allocator : "Allocator for the new set" + @return "A new set containing the intersection of both sets" +*> +fn LinkedHashSet LinkedHashSet.intersection(&self, Allocator allocator, LinkedHashSet* other) +{ + LinkedHashSet result; + result.init(allocator, math::min(self.table.len, other.table.len), self.load_factor); + + // Iterate through the smaller set for efficiency + LinkedHashSet* smaller = self.count <= other.count ? self : other; + LinkedHashSet* larger = self.count > other.count ? self : other; + + smaller.@each(;Value value) + { + if (larger.contains(value)) result.add(value); + }; + + return result; +} + +fn LinkedHashSet LinkedHashSet.tintersection(&self, LinkedHashSet* other) => self.intersection(tmem, other) @inline; + +<* + Return this set - other, so (A & ~B) + + @param [&in] other : "The other set to compare with" + @param [&inout] allocator : "Allocator for the new set" + @return "A new set containing elements in this set but not in the other" +*> +fn LinkedHashSet LinkedHashSet.difference(&self, Allocator allocator, LinkedHashSet* other) +{ + LinkedHashSet result; + result.init(allocator, self.table.len, self.load_factor); + self.@each(;Value value) + { + if (!other.contains(value)) + { + result.add(value); + } + }; + return result; +} + +fn LinkedHashSet LinkedHashSet.tdifference(&self, LinkedHashSet* other) => self.difference(tmem, other) @inline; + +<* + Return (A ^ B) + + @param [&in] other : "The other set to compare with" + @param [&inout] allocator : "Allocator for the new set" + @return "A new set containing elements in this set or the other, but not both" +*> +fn LinkedHashSet LinkedHashSet.symmetric_difference(&self, Allocator allocator, LinkedHashSet* other) +{ + LinkedHashSet result; + result.init(allocator, self.table.len, self.load_factor); + result.add_all_from(self); + other.@each(;Value value) + { + if (!result.add(value)) + { + result.remove(value); + } + }; + return result; +} + +fn LinkedHashSet LinkedHashSet.tsymmetric_difference(&self, LinkedHashSet* other) => self.symmetric_difference(tmem, other) @inline; + +<* + Check if this hash set is a subset of another set. + + @param [&in] other : "The other set to check against" + @return "True if all elements of this set are in the other set" +*> +fn bool LinkedHashSet.is_subset(&self, LinkedHashSet* other) +{ + if (self.count == 0) return true; + if (self.count > other.count) return false; + + self.@each(; Value value) { + if (!other.contains(value)) return false; + }; + return true; +} + +// --- private methods + +fn void LinkedHashSet.add_entry(&set, uint hash, Value value, uint bucket_index) @private +{ + LinkedEntry* entry = allocator::new(set.allocator, LinkedEntry, { + .hash = hash, + .value = value, + .next = set.table[bucket_index], + .before = set.tail, + .after = null + }); + + // Update bucket chain + set.table[bucket_index] = entry; + + // Update linked list + if (set.tail) + { + set.tail.after = entry; + entry.before = set.tail; + } + else + { + set.head = entry; + } + set.tail = entry; + + if (set.count++ >= set.threshold) + { + set.resize(set.table.len * 2); + } +} + +fn void LinkedHashSet.resize(&set, usz new_capacity) @private +{ + LinkedEntry*[] old_table = set.table; + usz old_capacity = old_table.len; + + if (old_capacity == MAXIMUM_CAPACITY) + { + set.threshold = uint.max; + return; + } + + LinkedEntry*[] new_table = allocator::new_array(set.allocator, LinkedEntry*, new_capacity); + set.table = new_table; + set.threshold = (uint)(new_capacity * set.load_factor); + + // Rehash all entries - linked list order remains unchanged + foreach (uint i, LinkedEntry *e : old_table) + { + if (!e) continue; + + // Split the bucket chain into two chains based on new bit + LinkedEntry* lo_head = null; + LinkedEntry* lo_tail = null; + LinkedEntry* hi_head = null; + LinkedEntry* hi_tail = null; + + do + { + LinkedEntry* next = e.next; + if ((e.hash & old_capacity) == 0) + { + if (!lo_tail) + { + lo_head = e; + } + else + { + lo_tail.next = e; + } + lo_tail = e; + } + else + { + if (!hi_tail) + { + hi_head = e; + } + else + { + hi_tail.next = e; + } + hi_tail = e; + } + e.next = null; + e = next; + } + while (e); + + if (lo_tail) + { + lo_tail.next = null; + new_table[i] = lo_head; + } + if (hi_tail) + { + hi_tail.next = null; + new_table[i + old_capacity] = hi_head; + } + } + + set.free_internal(old_table.ptr); +} + +fn usz? LinkedHashSet.to_format(&self, Formatter* f) @dynamic +{ + usz len; + len += f.print("{ ")!; + self.@each(; Value value) + { + if (len > 2) len += f.print(", ")!; + len += f.printf("%s", value)!; + }; + return len + f.print(" }"); +} + +fn void LinkedHashSet.transfer(&set, LinkedEntry*[] new_table) @private +{ + LinkedEntry*[] src = set.table; + uint new_capacity = new_table.len; + foreach (uint j, LinkedEntry *e : src) + { + if (!e) continue; + do + { + LinkedEntry* next = e.next; + uint i = index_for(e.hash, new_capacity); + e.next = new_table[i]; + new_table[i] = e; + e = next; + } + while (e); + } +} + +fn void LinkedHashSet.put_for_create(&set, Value value) @private +{ + uint hash = rehash(value.hash()); + uint i = index_for(hash, set.table.len); + for (LinkedEntry *e = set.table[i]; e != null; e = e.next) + { + if (e.hash == hash && equals(value, e.value)) + { + // Value already exists, no need to do anything + return; + } + } + set.create_entry(hash, value, i); +} + +fn void LinkedHashSet.free_internal(&set, void* ptr) @inline @private +{ + allocator::free(set.allocator, ptr); +} + +fn void LinkedHashSet.create_entry(&set, uint hash, Value value, int bucket_index) @private +{ + LinkedEntry* entry = allocator::new(set.allocator, LinkedEntry, { + .hash = hash, + .value = value, + .next = set.table[bucket_index], + .before = set.tail, + .after = null + }); + + set.table[bucket_index] = entry; + + // Update linked list + if (set.tail) + { + set.tail.after = entry; + entry.before = set.tail; + } + else + { + set.head = entry; + } + set.tail = entry; + + set.count++; +} + +fn bool LinkedHashSet.remove_entry_for_value(&set, Value value) @private +{ + if (!set.count) return false; + + uint hash = rehash(value.hash()); + uint i = index_for(hash, set.table.len); + LinkedEntry* prev = null; + LinkedEntry* e = set.table[i]; + + while (e) + { + if (e.hash == hash && equals(value, e.value)) + { + if (prev) + { + prev.next = e.next; + } + else + { + set.table[i] = e.next; + } + + if (e.before) + { + e.before.after = e.after; + } + else + { + set.head = e.after; + } + + if (e.after) + { + e.after.before = e.before; + } + else + { + set.tail = e.before; + } + + set.count--; + set.free_entry(e); + return true; + } + prev = e; + e = e.next; + } + return false; +} + +fn void LinkedHashSet.free_entry(&set, LinkedEntry *entry) @private +{ + allocator::free(set.allocator, entry); +} + +struct LinkedHashSetIterator +{ + LinkedHashSet* set; + LinkedEntry* current; + bool started; +} + +fn LinkedHashSetIterator LinkedHashSet.iter(&set) => { .set = set, .current = set.head, .started = false }; + +fn bool LinkedHashSetIterator.next(&self) +{ + if (!self.started) + { + self.current = self.set.head; + self.started = true; + } + else if (self.current) + { + self.current = self.current.after; + } + return self.current != null; +} + +fn Value*? LinkedHashSetIterator.get(&self) +{ + return self.current ? &self.current.value : NOT_FOUND?; +} + +fn bool LinkedHashSetIterator.has_next(&self) +{ + if (!self.started) return self.set.head != null; + return self.current && self.current.after != null; +} + +fn usz LinkedHashSetIterator.len(&self) @operator(len) +{ + return self.set.count; +} + +int dummy @local; \ No newline at end of file diff --git a/test/unit/stdlib/collections/linked_map.c3 b/test/unit/stdlib/collections/linked_map.c3 new file mode 100644 index 000000000..388eecc36 --- /dev/null +++ b/test/unit/stdlib/collections/linked_map.c3 @@ -0,0 +1,218 @@ +module linked_map_test @test; +import std::collections::list; +import std::collections::map; +import std::sort; +import std::io; + +alias TestLinkedHashMap = LinkedHashMap{String, usz}; + +struct MapTest +{ + String key; + usz value; +} +alias List = List{MapTest}; + +fn void linked_map_basic() +{ + TestLinkedHashMap m; + assert(!m.is_initialized()); + m.tinit(); + assert(m.is_initialized()); + assert(m.is_empty()); + assert(m.len() == 0); + + m.set("a", 1); + assert(!m.is_empty()); + assert(m.len() == 1); + m.remove("a"); + assert(m.is_empty()); + + MapTest[] tcases = { {"key1", 0}, {"key2", 1}, {"key3", 2} }; + foreach (tc : tcases) + { + m.set(tc.key, tc.value); + } + assert(m.len() == tcases.len); + foreach (tc : tcases) + { + usz v = m.get(tc.key)!!; + assert(tc.value == v); + } +} + +fn void linked_map_insertion_order() +{ + TestLinkedHashMap m; + m.tinit(); + + String[] keys = { "first", "second", "third", "fourth" }; + foreach (i, key : keys) + { + m.set(key, i); + } + + usz index = 0; + m.@each(; String key, usz value) + { + assert(key == keys[index]); + assert(value == index); + index++; + }; + + m.remove("second"); + m.set("second", 1); + + String[] new_order = { "first", "third", "fourth", "second" }; + index = 0; + m.@each(; String key, usz value) + { + assert(key == new_order[index]); + index++; + }; +} + +fn void linked_map_init_with_values() +{ + TestLinkedHashMap m; + m.tinit_with_key_values("a", 1, "b", 2, "c", 3); + + assert(m.len() == 3); + assert(m.get("a")!! == 1); + assert(m.get("b")!! == 2); + assert(m.get("c")!! == 3); + + // Verify order + String[] expected_order = { "a", "b", "c" }; + usz index = 0; + m.@each(; String key, usz value) + { + assert(key == expected_order[index]); + index++; + }; +} + +fn void linked_map_remove() +{ + TestLinkedHashMap m; + assert(!@ok(m.remove("A"))); + m.tinit(); + assert(!@ok(m.remove("A"))); + m.set("A", 0); + assert(@ok(m.remove("A"))); + + m.set("a", 1); + m.set("b", 2); + m.set("c", 3); + m.remove("b"); + + String[] expected_order = { "a", "c" }; + usz index = 0; + m.@each(; String key, usz value) + { + assert(key == expected_order[index]); + index++; + }; +} + +fn void linked_map_copy() +{ + TestLinkedHashMap hash_map; + hash_map.tinit(); + + hash_map.set("aa", 1); + hash_map.set("b", 2); + hash_map.set("bb", 1); + + TestLinkedHashMap hash_map_copy; + hash_map_copy.tinit_from_map(&hash_map); + + assert(hash_map_copy.len() == hash_map.len()); + + String[] expected_order = { "aa", "b", "bb" }; + usz index = 0; + hash_map_copy.@each(; String key, usz value) + { + assert(key == expected_order[index]); + index++; + }; +} + +fn void linked_map_iterators() +{ + TestLinkedHashMap m; + m.tinit_with_key_values("a", 1, "b", 2, "c", 3); + + usz count = 0; + LinkedHashMapIterator{String, ulong} iter = m.iter(); + while (iter.next()) + { + count++; + LinkedEntry {String, ulong} * current = iter.get()!!; + assert(current.key.len > 0); + assert(current.value > 0); + } + assert(count == 3); + + count = 0; + + LinkedHashMapKeyIterator{String, ulong} key_iter = m.key_iter(); + while (key_iter.next()) + { + count++; + assert(key_iter.get()!!.len > 0); + } + assert(count == 3); + + count = 0; + usz sum = 0; + LinkedHashMapValueIterator{String, ulong} value_iter = m.value_iter(); + + while (value_iter.next()) + { + count++; + sum += *(value_iter.get()!!); + } + assert(count == 3); + assert(sum == 6); +} + +alias FooLinkedMap = LinkedHashMap{char, Foobar}; +enum Foobar : inline char +{ + FOO, + BAR, + BAZ +} + +enum Foobar2 : const inline int +{ + ABC = 3, + DEF = 5, +} + +fn void linked_map_inline_enum() +{ + FooLinkedMap x; + x.tinit(); + x[Foobar.BAZ] = FOO; + x[Foobar2.ABC] = BAR; + test::eq(string::tformat("%s", x), "{ 2: FOO, 3: BAR }"); + x.free(); +} + +fn void linked_map_clear() +{ + TestLinkedHashMap m; + m.tinit_with_key_values("a", 1, "b", 2, "c", 3); + + assert(m.len() == 3); + m.clear(); + assert(m.len() == 0); + assert(m.is_empty()); + + m.set("x", 10); + assert(m.len() == 1); + assert(@ok(m.get("x"))); + assert((m.get("x")??0) == 10); +} diff --git a/test/unit/stdlib/collections/linked_set.c3 b/test/unit/stdlib/collections/linked_set.c3 new file mode 100644 index 000000000..9aa81b215 --- /dev/null +++ b/test/unit/stdlib/collections/linked_set.c3 @@ -0,0 +1,300 @@ +module linked_set_test @test; +import std::collections::set; + +alias TestLinkedHashSet = LinkedHashSet{String}; + +fn void linked_set_basic() +{ + TestLinkedHashSet set; + assert(!set.is_initialized()); + set.tinit(); + assert(set.is_initialized()); + assert(set.is_empty()); + assert(set.len() == 0); + + assert(set.add("a")); + assert(!set.is_empty()); + assert(set.len() == 1); + set.remove("a"); + assert(set.is_empty()); + + String[] values = { "key1", "key2", "key3" }; + foreach (value : values) + { + assert(set.add(value)); + } + assert(set.len() == values.len); + foreach (value : values) + { + assert(set.contains(value)); + } +} + +fn void linked_set_insertion_order() +{ + TestLinkedHashSet set; + set.tinit(); + + String[] values = { "first", "second", "third", "fourth" }; + foreach (value : values) + { + set.add(value); + } + + // Test iteration follows insertion order + usz index = 0; + set.@each(; String value) + { + assert(value == values[index]); + index++; + }; + + // Test that removing and re-inserting changes order + set.remove("second"); + set.add("second"); + + String[] new_order = { "first", "third", "fourth", "second" }; + index = 0; + set.@each(; String value) + { + assert(value == new_order[index]); + index++; + }; +} + +fn void linked_set_init_with_values() +{ + TestLinkedHashSet set; + set.tinit_with_values("a", "b", "c"); + + assert(set.len() == 3); + assert(set.contains("a")); + assert(set.contains("b")); + assert(set.contains("c")); + + // Verify order + String[] expected_order = { "a", "b", "c" }; + usz index = 0; + set.@each(; String value) + { + assert(value == expected_order[index]); + index++; + }; +} + +fn void linked_set_remove() +{ + TestLinkedHashSet set; + assert(!@ok(set.remove("A"))); + set.tinit(); + assert(!@ok(set.remove("A"))); + set.add("A"); + assert(@ok(set.remove("A"))); + + set.add("a"); + set.add("b"); + set.add("c"); + set.remove("b"); + + String[] expected_order = { "a", "c" }; + usz index = 0; + set.@each(; String value) + { + assert(value == expected_order[index]); + index++; + }; +} + +fn void linked_set_copy() +{ + TestLinkedHashSet original; + original.tinit(); + + original.add("aa"); + original.add("b"); + original.add("bb"); + + TestLinkedHashSet copy; + copy.tinit_from_set(&original); + + assert(copy.len() == original.len()); + + String[] expected_order = { "aa", "b", "bb" }; + usz index = 0; + copy.@each(; String value) + { + assert(value == expected_order[index]); + index++; + }; +} + +fn void linked_set_iterators() +{ + TestLinkedHashSet set; + set.tinit_with_values("a", "b", "c"); + + // Test entry iterator + usz count = 0; + LinkedHashSetIterator{String} iter = set.iter(); + while (iter.next()) + { + count++; + assert(iter.get()!!.len > 0); + } + assert(count == 3); + + // Test direct each macro + count = 0; + set.@each(; String value) { + count++; + assert(value.len > 0); + }; + assert(count == 3); +} + +fn void linked_set_clear() +{ + TestLinkedHashSet set; + set.tinit_with_values("a", "b", "c"); + + assert(set.len() == 3); + set.clear(); + assert(set.len() == 0); + assert(set.is_empty()); + + // Should be able to reuse after clear + set.add("x"); + assert(set.len() == 1); + assert(set.contains("x")); +} + +fn void linked_set_operations() +{ + TestLinkedHashSet set1; + set1.tinit_with_values("a", "b", "c"); + + TestLinkedHashSet set2; + set2.tinit_with_values("b", "c", "d"); + + // Test union + TestLinkedHashSet union_set = set1.tset_union(&set2); + defer union_set.free(); + assert(union_set.contains("a")); + assert(union_set.contains("b")); + assert(union_set.contains("c")); + assert(union_set.contains("d")); + assert(union_set.len() == 4); + + // Verify union preserves order (elements from first set first) + String[] expected_union_order = { "a", "b", "c", "d" }; + usz index = 0; + union_set.@each(; String value) { + assert(value == expected_union_order[index]); + index++; + }; + + // Test intersection + TestLinkedHashSet intersect_set = set1.tintersection(&set2); + defer intersect_set.free(); + assert(intersect_set.contains("b")); + assert(intersect_set.contains("c")); + assert(!intersect_set.contains("a")); + assert(!intersect_set.contains("d")); + assert(intersect_set.len() == 2); + + // Test difference + TestLinkedHashSet diff_set = set1.tdifference(&set2); + defer diff_set.free(); + assert(diff_set.contains("a")); + assert(!diff_set.contains("b")); + assert(!diff_set.contains("c")); + assert(!diff_set.contains("d")); + assert(diff_set.len() == 1); + + // Test subset + TestLinkedHashSet subset; + subset.tinit_with_values("b", "c"); + assert(subset.is_subset(&set1)); + assert(!set1.is_subset(&subset)); +} + +alias IntLinkedSet = LinkedHashSet{int}; + +fn void linked_set_edge_cases() +{ + // Test empty set + IntLinkedSet empty; + empty.tinit(); + defer empty.free(); + + assert(empty.is_empty()); + assert(!empty.contains(0)); + empty.remove(0); // Shouldn't crash + + // Test large set + IntLinkedSet large; + large.tinit(); + defer large.free(); + + // Insert in reverse order to test ordering + for (int i = 1000; i > 0; i--) { + large.add(i); + } + assert(large.len() == 1000); + + // Verify order is maintained + int expected = 1000; + large.@each(; int value) { + assert(value == expected); + expected--; + }; + + // Test clear + large.clear(); + assert(large.is_empty()); + for (int i = 1; i <= 1000; i++) { + assert(!large.contains(i)); + } +} + +fn void linked_set_string_values() +{ + TestLinkedHashSet set; + set.tinit(); + defer set.free(); + + assert(set.add("hello")); + assert(set.add("world")); + assert(!set.add("hello")); + + assert(set.contains("hello")); + assert(set.contains("world")); + assert(!set.contains("foo")); + + // Test order + String[] expected_order = { "hello", "world" }; + usz index = 0; + set.@each(; String value) { + assert(value == expected_order[index]); + index++; + }; + + set.remove("hello"); + assert(!set.contains("hello")); + assert(set.len() == 1); + + // Test order after removal + assert(set.contains("world")); + set.@each(; String value) { + assert(value == "world"); + }; +} + +fn void linked_set_is_initialized() +{ + TestLinkedHashSet test; + assert(!test.is_initialized()); + test.tinit(); + assert(test.is_initialized()); + test.free(); +}