From db75da65db9742c6f42f0bb6cabacd6978336abb Mon Sep 17 00:00:00 2001 From: Alex Anderson Date: Tue, 9 Jul 2024 12:26:45 -0700 Subject: [PATCH] Make countingsort.c3's recursion stage branchless Tracks the three potential cases for each fallback, item counts ranging from [2,32], [33,128], [128, ...] and uses a loop specifically for each fallback. --- lib/std/sort/countingsort.c3 | 58 +++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/lib/std/sort/countingsort.c3 b/lib/std/sort/countingsort.c3 index 15507310d..19c2836fd 100644 --- a/lib/std/sort/countingsort.c3 +++ b/lib/std/sort/countingsort.c3 @@ -51,7 +51,7 @@ fn void csort(Type list, usz low, usz high, KeyFn key_fn, uint byte_idx) CmpCallback compare_fn = fn (lhs, rhs, key_fn) => compare_to(key_fn(lhs), key_fn(rhs)); $endif; - byte_idx = byte_idx >= KeyFnReturnType.sizeof ? KeyFnReturnType.sizeof - 1 : byte_idx; + byte_idx = byte_idx >= KeyFnReturnType.sizeof ? KeyFnReturnType.sizeof - 1 : byte_idx; Counts counts; Ranges ranges; @@ -89,22 +89,35 @@ fn void csort(Type list, usz low, usz high, KeyFn key_fn, uint byte_idx) KeyFnReturnType diff = mx - mn; if (diff == 0) return; - ushort parition_count = 0; + ushort fallback0_count = 0; + ushort fallback1_count = 0; + ushort recursion_count = 0; + usz total = 0; foreach (char i, count : counts) { - indexs[parition_count] = i; - parition_count += (ushort)(count > 0); + indexs[fallback0_count] = i; + indexs[255 - recursion_count] = i; + + fallback0_count += (ushort)(count > 1 && count <= 32); + recursion_count += (ushort)(count > 128); + counts[i] = total; ranges[i] = total; total += count; } ranges[256] = total; + + for(ushort i = 0; i < 256; i++) { + indexs[fallback0_count + fallback1_count] = (char)i; + ushort count = ranges[i + 1] - ranges[i]; + fallback1_count += (ushort)(count > 32 && count <= 128); + } if (!keys_ordered) { usz sorted_count = 0; - //ElementType* first = list.first(); + do { foreach (x, s : counts) @@ -134,23 +147,32 @@ fn void csort(Type list, usz low, usz high, KeyFn key_fn, uint byte_idx) if (byte_idx) { - for (usz p = 0; p < parition_count; p++) - { + for (usz p = 0; p < fallback0_count; p++) { usz i = indexs[p]; + usz start_offset = ranges[i]; - usz end_offset = ranges[i + 1]; + usz end_offset = ranges[i + 1]; - usz items = end_offset - start_offset; + insertionsort_indexed(list, low + start_offset, low + end_offset, compare_fn, key_fn); + } - switch (items) - { - case 0..32: - insertionsort_indexed(list, low + start_offset, low + end_offset, compare_fn, key_fn); - case 33..128: - quicksort_indexed(list, low + start_offset, low + end_offset, compare_fn, key_fn); - default: - csort(list, low + start_offset, low + end_offset, key_fn, byte_idx - 1); - } + for (usz p = 0; p < fallback1_count; p++) { + usz i = indexs[fallback0_count + p]; + + usz start_offset = ranges[i]; + usz end_offset = ranges[i + 1]; + + quicksort_indexed(list, low + start_offset, low + end_offset, compare_fn, key_fn); + } + + for (usz p = 0; p < recursion_count; p++) + { + usz i = indexs[255 - p]; + + usz start_offset = ranges[i]; + usz end_offset = ranges[i + 1]; + + csort(list, low + start_offset, low + end_offset, key_fn, byte_idx - 1); } } }