From 26acce246da7185ce269114cd6762ee74c55ba49 Mon Sep 17 00:00:00 2001 From: Christoffer Lerno Date: Tue, 27 Aug 2024 04:31:14 +0200 Subject: [PATCH] Fixed int128 div/mod. Fix WASM memory init priority. --- lib/std/core/mem.c3 | 2 +- lib/std/math/math_i128.c3 | 271 ++++++++++++++++++++++++--------- releasenotes.md | 2 + test/unit/regression/int128.c3 | 21 +++ 4 files changed, 224 insertions(+), 72 deletions(-) create mode 100644 test/unit/regression/int128.c3 diff --git a/lib/std/core/mem.c3 b/lib/std/core/mem.c3 index 191d1379a..24264e4af 100644 --- a/lib/std/core/mem.c3 +++ b/lib/std/core/mem.c3 @@ -536,7 +536,7 @@ import std::core::mem::allocator @public; SimpleHeapAllocator wasm_allocator @private; extern int __heap_base; -fn void initialize_wasm_mem() @init(1) @private +fn void initialize_wasm_mem() @init(1024) @private { allocator::wasm_memory.allocate_block(mem::DEFAULT_MEM_ALIGNMENT)!!; // Give us a valid null. // Check if we need to move the heap. diff --git a/lib/std/math/math_i128.c3 b/lib/std/math/math_i128.c3 index 7dc539617..8620ce0f4 100644 --- a/lib/std/math/math_i128.c3 +++ b/lib/std/math/math_i128.c3 @@ -10,54 +10,186 @@ fn int128 __divti3(int128 a, int128 b) @extern("__divti3") @weak @nostrip return __udivti3(unsigned_a, unsigned_b) @inline ^ sign_a + (-sign_a); } +macro uint128 @__udivmodti4(uint128 a, uint128 b, bool $return_rem) +{ + Int128bits n = { .all = a }; + Int128bits d = { .all = b }; + Int128bits q @noinit; + Int128bits r @noinit; + uint sr; + if (n.high == 0) + { + if (d.high == 0) + { + $if $return_rem: + return n.low % d.low; + $else + return n.low / d.low; + $endif + } + $if $return_rem: + return n.low; + $else + return 0; + $endif + } + if (d.low == 0) + { + if (d.high == 0) + { + $if $return_rem: + return n.high % d.low; + $else + return n.high / d.low; + $endif + } + if (n.low == 0) + { + $if $return_rem: + r.high = n.high % d.high; + r.low = 0; + return r.all; + $else + return n.high / d.high; + $endif + } + if (d.high & (d.high - 1) == 0) // d is pot + { + $if $return_rem: + r.low = n.low; + r.high = n.high & (d.high - 1); + return r.all; + $else + return (uint128)(n.high >> $$ctz(d.high)); + $endif + } + sr = (uint)$$clz(d.high) - (uint)$$clz(n.high); + // 0 <= sr <= n_udword_bits - 2 or sr large + if (sr > 64 - 2) + { + $if $return_rem: + return n.all; + $else + return 0; + $endif + } + sr++; + // 1 <= sr <= n_udword_bits - 1 + // q.all = n.all << (n_utword_bits - sr); + q.low = 0; + q.high = n.low << (64 - sr); + r.high = n.high >> sr; + r.low = (n.high << (64 - sr)) | (n.low >> sr); + } + else // d.s.low != 0 + { + if (d.high == 0) + { + if (d.low & (d.low - 1) == 0) // if d is a power of 2 + { + $if $return_rem: + return (uint128)(n.low & (d.low - 1)); + $else + if (d.low == 1) return n.all; + sr = (uint)$$ctz(d.low); + q.high = n.high >> sr; + q.low = (n.high << (64 - sr)) | (n.low >> sr); + return q.all; + $endif + } + sr = 1 + 64 + (uint)$$clz(d.low) - (uint)$$clz(n.high); + // 2 <= sr <= n_utword_bits - 1 + // q.all = n.all << (n_utword_bits - sr); + // r.all = n.all >> sr; + switch + { + case sr == 64: + q.low = 0; + q.high = n.low; + r.high = 0; + r.low = n.high; + case sr < 64: + q.low = 0; + q.high = n.low << (64 - sr); + r.high = n.high >> sr; + r.low = (n.high << (64 - sr)) | (n.low >> sr); + default: // n_udword_bits + 1 <= sr <= n_utword_bits - 1 + q.low = n.low << (128 - sr); + q.high = (n.high << (128 - sr)) | (n.low >> (sr - 64)); + r.high = 0; + r.low = n.high >> (sr - 64); + } + } + else + { + sr = (uint)$$clz(d.high) - (uint)$$clz(n.high); + // 0 <= sr <= n_udword_bits - 1 or sr large + if (sr > 64 - 1) + { + $if $return_rem: + return n.all; + $else + return 0; + $endif + } + + sr++; + // 1 <= sr <= n_udword_bits + // q.all = n.all << (n_utword_bits - sr); + // r.all = n.all >> sr; + q.low = 0; + if (sr == 64) + { + q.high = n.low; + r.high = 0; + r.low = n.high; + } + else + { + r.high = n.high >> sr; + r.low = (n.high << (64 - sr)) | (n.low >> sr); + q.high = n.low << (64 - sr); + } + } + } + // Not a special case + // q and r are initialized with: + // q.all = n.all << (128 - sr); + // r.all = n.all >> sr; + // 1 <= sr <= n_utword_bits - 1 + uint carry = 0; + for (; sr > 0; sr--) + { + // r:q = ((r:q) << 1) | carry + r.high = (r.high << 1) | (r.low >> (64 - 1)); + r.low = (r.low << 1) | (q.high >> (64 - 1)); + q.high = (q.high << 1) | (q.low >> (64 - 1)); + q.low = (q.low << 1) | carry; + // carry = 0; + // if (r.all >= d.all) + // { + // r.all -= d.all; + // carry = 1; + // } + int128 s = (int128)(d.all - r.all - 1) >> (128 - 1); + carry = (uint)(s & 1); + r.all -= d.all & s; + } + $if $return_rem: + return r.all; + $else + return (q.all << 1) | carry; + $endif +} + fn uint128 __umodti3(uint128 n, uint128 d) @extern("__umodti3") @weak @nostrip { - // Ignore d = 0 - uint128 sr = (d ? $$clz(d) : 128) - (n ? $$clz(n) : 128); - // If n < d then sr is wrapping. - // which means we can just return n. - if (sr > 127) return n; - // If d == 1 and n = MAX - if (sr == 127) return 0; - sr++; - uint128 r = n >> sr; - // Follow known algorithm: - n <<= 128 - sr; - for (uint128 carry = 0; sr > 0; sr--) - { - r = (r << 1) | (n >> 127); - n = (n << 1) | carry; - int128 sign = (int128)(d - r - 1) >> 127; - carry = sign & 1; - r -= d & sign; - } - return r; + return @__udivmodti4(n, d, true); } fn uint128 __udivti3(uint128 n, uint128 d) @extern("__udivti3") @weak @nostrip { - // Ignore d = 0 - uint128 sr = (d ? $$clz(d) : 128) - (n ? $$clz(n) : 128); - // If n < d then sr is wrapping. - // which means we can just return 0. - if (sr > 127) return 0; - // If d == 1 and n = MAX - if (sr == 127) return n; - sr++; - uint128 r = n >> sr; - // Follow known algorithm: - n <<= 128 - sr; - uint128 carry = 0; - for (; sr > 0; sr--) - { - r = (r << 1) | (n >> 127); - n = (n << 1) | carry; - int128 sign = (int128)(d - r - 1) >> 127; - carry = sign & 1; - r -= d & sign; - } - n = (n << 1) | carry; - return n; + return @__udivmodti4(n, d, false); } fn int128 __modti3(int128 a, int128 b) @extern("__modti3") @weak @nostrip @@ -74,11 +206,8 @@ union Int128bits @private { struct { - ulong ulow, uhigh; - } - struct - { - long ilow, ihigh; + ulong low; + ulong high; } uint128 all; } @@ -89,14 +218,14 @@ fn uint128 __lshrti3(uint128 a, uint b) @extern("__lshrti3") @weak @nostrip result.all = a; if (b >= 64) { - result.ulow = result.uhigh >> (b - 64); - result.uhigh = 0; + result.low = result.high >> (b - 64); + result.high = 0; } else { if (b == 0) return a; - result.ulow = (result.uhigh << (64 - b)) | (result.ulow >> b); - result.uhigh = result.uhigh >> b; + result.low = (result.high << (64 - b)) | (result.low >> b); + result.high = result.high >> b; } return result.all; } @@ -107,14 +236,14 @@ fn int128 __ashrti3(int128 a, uint b) @extern("__ashrti3") @weak @nostrip result.all = a; if (b >= 64) { - result.ilow = result.ihigh >> (b - 64); - result.ihigh = result.ihigh >> 63; + result.low = result.high >> (b - 64); + result.high = result.high >> 63; } else { if (b == 0) return a; - result.ilow = result.ihigh << (64 - b) | (result.ilow >> b); - result.ihigh = result.ihigh >> b; + result.low = result.high << (64 - b) | (result.low >> b); + result.high = result.high >> b; } return result.all; } @@ -125,14 +254,14 @@ fn int128 __ashlti3(int128 a, uint b) @extern("__ashlti3") @weak @nostrip result.all = a; if (b >= 64) { - result.ulow = 0; - result.uhigh = result.ulow << (b - 64); + result.low = 0; + result.high = result.low << (b - 64); } else { if (b == 0) return a; - result.uhigh = (result.uhigh << b) | (result.ulow >> (64 - b)); - result.ulow = result.ulow << b; + result.high = (result.high << b) | (result.low >> (64 - b)); + result.low = result.low << b; } return result.all; } @@ -143,18 +272,18 @@ fn int128 __mulddi3(ulong a, ulong b) @private { Int128bits r; const ulong LOWER_MASK = 0xffff_ffff; - r.ulow = (a & LOWER_MASK) * (b & LOWER_MASK); - ulong t = r.ulow >> 32; - r.ulow &= LOWER_MASK; + r.low = (a & LOWER_MASK) * (b & LOWER_MASK); + ulong t = r.low >> 32; + r.low &= LOWER_MASK; t += (a >> 32) * (b & LOWER_MASK); - r.ulow += (t & LOWER_MASK) << 32; - r.uhigh = t >> 32; - t = r.ulow >> 32; - r.ulow &= LOWER_MASK; + r.low += (t & LOWER_MASK) << 32; + r.high = t >> 32; + t = r.low >> 32; + r.low &= LOWER_MASK; t += (b >> 32) * (a & LOWER_MASK); - r.ulow += (t & LOWER_MASK) << 32; - r.uhigh += t >> 32; - r.uhigh += (a >> 32) * (b >> 32); + r.low += (t & LOWER_MASK) << 32; + r.high += t >> 32; + r.high += (a >> 32) * (b >> 32); return r.all; } @@ -162,8 +291,8 @@ fn int128 __multi3(int128 a, int128 b) @extern("__multi3") @weak @nostrip { Int128bits x = { .all = a }; Int128bits y = { .all = b }; - Int128bits r = { .all = __mulddi3(x.ulow, y.ulow) }; - r.uhigh += x.uhigh * y.ulow + x.ulow * y.uhigh; + Int128bits r = { .all = __mulddi3(x.low, y.low) }; + r.high += x.high * y.low + x.low * y.high; return r.all; } diff --git a/releasenotes.md b/releasenotes.md index 89e9b40d9..388b97bb4 100644 --- a/releasenotes.md +++ b/releasenotes.md @@ -89,6 +89,8 @@ - Fix of bug in `defer (catch err)` with a direct return error. - Too restrictive compile time checks for @const. - Fixes to wasm nolibc in the standard library. +- Fixed int128 div/mod. +- Fix WASM memory init priority. ### Stdlib changes diff --git a/test/unit/regression/int128.c3 b/test/unit/regression/int128.c3 new file mode 100644 index 000000000..12b063f9d --- /dev/null +++ b/test/unit/regression/int128.c3 @@ -0,0 +1,21 @@ +module int128_test; + +fn void check(uint128 a, uint128 b) +{ + uint128 div = a / b; + uint128 mod = a % b; + assert(div * b + mod == a); +} + +fn void test_big() @test +{ + uint128 a = 12345678901234567890012u128; + uint128 b = 1234567890123456789001u128; + for (int i = 0; i < 10; i++) + { + for (int j = 0; j < 10; j++) + { + check(a + i, b + j); + } + } +}