// Copyright (c) 2023 Christoffer Lerno. All rights reserved. // Use of this source code is governed by the MIT license // a copy of which can be found in the LICENSE_STDLIB file. module std::collections::map; import std::math; const uint DEFAULT_INITIAL_CAPACITY = 16; const uint MAXIMUM_CAPACITY = 1u << 31; const float DEFAULT_LOAD_FACTOR = 0.75; struct HashMap { Entry*[] table; Allocator* allocator; uint count; // Number of elements uint threshold; // Resize limit float load_factor; } /** * @require capacity > 0 "The capacity must be 1 or higher" * @require load_factor > 0.0 "The load factor must be higher than 0" * @require !map.allocator "Map was already initialized" * @require capacity < MAXIMUM_CAPACITY "Capacity cannot exceed maximum" * @require using != null "The allocator must be non-null" **/ fn void HashMap.init(HashMap* map, uint capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR, Allocator* using = mem::heap()) { capacity = math::next_power_of_2(capacity); map.allocator = using; map.load_factor = load_factor; map.threshold = (uint)(capacity * load_factor); map.table = calloc(Entry*, capacity, .using = using); } /** * @require capacity > 0 "The capacity must be 1 or higher" * @require load_factor > 0.0 "The load factor must be higher than 0" * @require !map.allocator "Map was already initialized" * @require capacity < MAXIMUM_CAPACITY "Capacity cannot exceed maximum" **/ fn void HashMap.tinit(HashMap* map, uint capacity = DEFAULT_INITIAL_CAPACITY, float load_factor = DEFAULT_LOAD_FACTOR) { map.init(capacity, load_factor, mem::temp()); } /** * Has this hash map been initialized yet? * * @param [&in] map "The hash map we are testing" * @return "Returns true if it has been initialized, false otherwise" **/ fn bool HashMap.is_initialized(HashMap* map) { return map.allocator != null; } fn void HashMap.init_from_map(HashMap* map, HashMap* other_map, Allocator* using = mem::heap()) { map.init(other_map.table.len, other_map.load_factor, using); map.put_all_for_create(other_map); } fn void HashMap.tinit_from_map(HashMap* map, HashMap* other_map) { map.init_from_map(other_map, mem::temp()) @inline; } fn bool HashMap.is_empty(HashMap* map) @inline { return !map.count; } fn Value*! HashMap.get_ref(HashMap* map, Key key) { if (!map.count) return SearchResult.MISSING?; uint hash = rehash(key.hash()); for (Entry *e = map.table[index_for(hash, map.table.len)]; e != null; e = e.next) { if (e.hash == hash && equals(key, e.key)) return &e.value; } return SearchResult.MISSING?; } fn Entry*! HashMap.get_entry(HashMap* map, Key key) { if (!map.count) return SearchResult.MISSING?; uint hash = rehash(key.hash()); for (Entry *e = map.table[index_for(hash, map.table.len)]; e != null; e = e.next) { if (e.hash == hash && equals(key, e.key)) return e; } return SearchResult.MISSING?; } /** * Get the value or update and **/ macro Value HashMap.@get_or_set(HashMap* map, Key key, Value #expr) { if (!map.count) { Value val = #expr; map.set(key, val); return val; } uint hash = rehash(key.hash()); uint index = index_for(hash, map.table.len); for (Entry *e = map.table[index]; e != null; e = e.next) { if (e.hash == hash && equals(key, e.key)) return e.value; } Value val = #expr; map.add_entry(hash, key, val, index); return val; } fn Value! HashMap.get(HashMap* map, Key key) @operator([]) { return *map.get_ref(key) @inline; } fn bool HashMap.has_key(HashMap* map, Key key) { return @ok(map.get_ref(key)); } fn bool HashMap.set(HashMap* map, Key key, Value value) @operator([]=) { // If the map isn't initialized, use the defaults to initialize it. if (!map.allocator) { map.init(); } uint hash = rehash(key.hash()); uint index = index_for(hash, map.table.len); for (Entry *e = map.table[index]; e != null; e = e.next) { if (e.hash == hash && equals(key, e.key)) { e.value = value; return true; } } map.add_entry(hash, key, value, index); return false; } fn void! HashMap.remove(HashMap* map, Key key) @maydiscard { if (!map.remove_entry_for_key(key)) return SearchResult.MISSING?; } fn void HashMap.clear(HashMap* map) { if (!map.count) return; foreach (Entry** &entry_ref : map.table) { Entry* entry = *entry_ref; if (!entry) continue; map.free_internal(entry); *entry_ref = null; } map.count = 0; } fn void HashMap.free(HashMap* map) { if (!map.allocator) return; map.clear(); map.free_internal(map.table.ptr); map.table = Entry*[] {}; } fn Key[] HashMap.key_tlist(HashMap* map) { return map.key_list(mem::temp()) @inline; } fn Key[] HashMap.key_list(HashMap* map, Allocator* using = mem::heap()) { if (!map.count) return Key[] {}; Key[] list = calloc(Key, map.count, .using = using); usz index = 0; foreach (Entry* entry : map.table) { while (entry) { list[index++] = entry.key; entry = entry.next; } } return list; } fn Value[] HashMap.value_tlist(HashMap* map) { return map.value_list(mem::temp()) @inline; } fn Value[] HashMap.value_list(HashMap* map, Allocator* using = mem::heap()) { if (!map.count) return Value[] {}; Value[] list = calloc(Value, map.count, .using = using); usz index = 0; foreach (Entry* entry : map.table) { while (entry) { list[index++] = entry.value; entry = entry.next; } } return list; } $if types::is_equatable(Value): fn bool HashMap.has_value(HashMap* map, Value v) { if (!map.count) return false; foreach (Entry* entry : map.table) { while (entry) { if (equals(v, entry.value)) return true; entry = entry.next; } } return false; } $endif // --- private methods fn void HashMap.add_entry(HashMap* map, uint hash, Key key, Value value, uint bucket_index) @private { Entry* entry = malloc(Entry, .using = map.allocator); *entry = { .hash = hash, .key = key, .value = value, .next = map.table[bucket_index] }; map.table[bucket_index] = entry; if (map.count++ >= map.threshold) { map.resize(map.table.len * 2); } } fn void HashMap.resize(HashMap* map, uint new_capacity) @private { Entry*[] old_table = map.table; uint old_capacity = old_table.len; if (old_capacity == MAXIMUM_CAPACITY) { map.threshold = uint.max; return; } Entry*[] new_table = calloc(Entry*, new_capacity, .using = map.allocator); map.transfer(new_table); map.table = new_table; map.free_internal(old_table.ptr); map.threshold = (uint)(new_capacity * map.load_factor); } fn uint rehash(uint hash) @inline @private { hash ^= (hash >> 20) ^ (hash >> 12); return hash ^ ((hash >> 7) ^ (hash >> 4)); } macro uint index_for(uint hash, uint capacity) @private { return hash & (capacity - 1); } fn void HashMap.transfer(HashMap* map, Entry*[] new_table) @private { Entry*[] src = map.table; uint new_capacity = new_table.len; foreach (uint j, Entry *e : src) { if (!e) continue; do { Entry* next = e.next; uint i = index_for(e.hash, new_capacity); e.next = new_table[i]; new_table[i] = e; e = next; } while (e); } } fn void HashMap.put_all_for_create(HashMap* map, HashMap* other_map) @private { if (!other_map.count) return; foreach (Entry *e : other_map.table) { if (!e) continue; map.put_for_create(e.key, e.value); } } fn void HashMap.put_for_create(HashMap* map, Key key, Value value) @private { uint hash = rehash(key.hash()); uint i = index_for(hash, map.table.len); for (Entry *e = map.table[i]; e != null; e = e.next) { if (e.hash == hash && equals(key, e.key)) { e.value = value; return; } } map.create_entry(hash, key, value, i); } fn void HashMap.free_internal(HashMap* map, void* ptr) @inline @private { map.allocator.free(ptr)!!; } fn bool HashMap.remove_entry_for_key(HashMap* map, Key key) @private { uint hash = rehash(key.hash()); uint i = index_for(hash, map.table.len); Entry* prev = map.table[i]; Entry* e = prev; while (e) { Entry *next = e.next; if (e.hash == hash && equals(key, e.key)) { map.count--; if (prev == e) { map.table[i] = next; } else { prev.next = next; } map.free_internal(e); return true; } prev = e; e = next; } return false; } fn void HashMap.create_entry(HashMap* map, uint hash, Key key, Value value, int bucket_index) @private { Entry *e = map.table[bucket_index]; Entry* entry = malloc(Entry, .using = map.allocator); *entry = { .hash = hash, .key = key, .value = value, .next = map.table[bucket_index] }; map.table[bucket_index] = entry; map.count++; } struct Entry { uint hash; Key key; Value value; Entry* next; }