Make countingsort.c3's recursion stage branchless

Tracks the three potential cases for each fallback, item counts ranging from [2,32], [33,128], [128, ...] and uses a loop specifically for each fallback.
This commit is contained in:
Alex Anderson
2024-07-09 12:26:45 -07:00
committed by Christoffer Lerno
parent cf95257c81
commit db75da65db

View File

@@ -51,7 +51,7 @@ fn void csort(Type list, usz low, usz high, KeyFn key_fn, uint byte_idx)
CmpCallback compare_fn = fn (lhs, rhs, key_fn) => compare_to(key_fn(lhs), key_fn(rhs)); CmpCallback compare_fn = fn (lhs, rhs, key_fn) => compare_to(key_fn(lhs), key_fn(rhs));
$endif; $endif;
byte_idx = byte_idx >= KeyFnReturnType.sizeof ? KeyFnReturnType.sizeof - 1 : byte_idx; byte_idx = byte_idx >= KeyFnReturnType.sizeof ? KeyFnReturnType.sizeof - 1 : byte_idx;
Counts counts; Counts counts;
Ranges ranges; Ranges ranges;
@@ -89,22 +89,35 @@ fn void csort(Type list, usz low, usz high, KeyFn key_fn, uint byte_idx)
KeyFnReturnType diff = mx - mn; KeyFnReturnType diff = mx - mn;
if (diff == 0) return; if (diff == 0) return;
ushort parition_count = 0; ushort fallback0_count = 0;
ushort fallback1_count = 0;
ushort recursion_count = 0;
usz total = 0; usz total = 0;
foreach (char i, count : counts) foreach (char i, count : counts)
{ {
indexs[parition_count] = i; indexs[fallback0_count] = i;
parition_count += (ushort)(count > 0); indexs[255 - recursion_count] = i;
fallback0_count += (ushort)(count > 1 && count <= 32);
recursion_count += (ushort)(count > 128);
counts[i] = total; counts[i] = total;
ranges[i] = total; ranges[i] = total;
total += count; total += count;
} }
ranges[256] = total; ranges[256] = total;
for(ushort i = 0; i < 256; i++) {
indexs[fallback0_count + fallback1_count] = (char)i;
ushort count = ranges[i + 1] - ranges[i];
fallback1_count += (ushort)(count > 32 && count <= 128);
}
if (!keys_ordered) if (!keys_ordered)
{ {
usz sorted_count = 0; usz sorted_count = 0;
//ElementType* first = list.first();
do do
{ {
foreach (x, s : counts) foreach (x, s : counts)
@@ -134,23 +147,32 @@ fn void csort(Type list, usz low, usz high, KeyFn key_fn, uint byte_idx)
if (byte_idx) if (byte_idx)
{ {
for (usz p = 0; p < parition_count; p++) for (usz p = 0; p < fallback0_count; p++) {
{
usz i = indexs[p]; usz i = indexs[p];
usz start_offset = ranges[i]; usz start_offset = ranges[i];
usz end_offset = ranges[i + 1]; usz end_offset = ranges[i + 1];
usz items = end_offset - start_offset; insertionsort_indexed(list, low + start_offset, low + end_offset, compare_fn, key_fn);
}
switch (items) for (usz p = 0; p < fallback1_count; p++) {
{ usz i = indexs[fallback0_count + p];
case 0..32:
insertionsort_indexed(list, low + start_offset, low + end_offset, compare_fn, key_fn); usz start_offset = ranges[i];
case 33..128: usz end_offset = ranges[i + 1];
quicksort_indexed(list, low + start_offset, low + end_offset, compare_fn, key_fn);
default: quicksort_indexed(list, low + start_offset, low + end_offset, compare_fn, key_fn);
csort(list, low + start_offset, low + end_offset, key_fn, byte_idx - 1); }
}
for (usz p = 0; p < recursion_count; p++)
{
usz i = indexs[255 - p];
usz start_offset = ranges[i];
usz end_offset = ranges[i + 1];
csort(list, low + start_offset, low + end_offset, key_fn, byte_idx - 1);
} }
} }
} }