diff --git a/lib/std/math/bigint.c3 b/lib/std/math/bigint.c3 index 53a7ed925..d67244c26 100644 --- a/lib/std/math/bigint.c3 +++ b/lib/std/math/bigint.c3 @@ -100,10 +100,10 @@ fn BigInt*! BigInt.init_string_radix(&self, String value, int radix) } if (pos_val >= radix) return NumberConversion.MALFORMED_INTEGER?; if (limit == 1) pos_val = -pos_val; - self.set_add(multiplier.mult(from_int(pos_val))); + self.add_this(multiplier.mult(from_int(pos_val))); if (i - 1 >= limit) { - multiplier.set_mult(radix_big); + multiplier.mult_this(radix_big); } } switch @@ -123,11 +123,11 @@ fn bool BigInt.is_negative(&self) fn BigInt BigInt.add(self, BigInt other) { - self.set_add(other); + self.add_this(other); return self; } -fn void BigInt.set_add(&self, BigInt other) +fn void BigInt.add_this(&self, BigInt other) { bool sign = self.is_negative(); bool sign_arg = other.is_negative(); @@ -166,21 +166,28 @@ macro uint find_length(uint* data, uint length) return length; } -fn void BigInt.set_mult(&self, BigInt bi2) +fn BigInt BigInt.mult(self, BigInt bi2) { - *self = self.mult(bi2) @inline; + self.mult_this(bi2); + return self; } -fn BigInt BigInt.mult(&self, BigInt bi2) +fn void BigInt.mult_this(&self, BigInt bi2) { - BigInt res = { .len = 1 }; - BigInt bi1 = *self; + if (bi2.is_zero()) + { + *self = ZERO; + return; + } + if (bi2.is_one()) return; + + BigInt res = ZERO; bool negative_sign = false; - if (bi1.is_negative()) + if (self.is_negative()) { - bi1.negate(); + self.negate(); negative_sign = !negative_sign; } if (bi2.is_negative()) @@ -190,18 +197,16 @@ fn BigInt BigInt.mult(&self, BigInt bi2) } // multiply the absolute values - for (uint i = 0; i < bi1.len; i++) + for (uint i = 0; i < self.len; i++) { - if (bi1.data[i] == 0) continue; + if (self.data[i] == 0) continue; ulong mcarry = 0; for (int j = 0, int k = i; j < bi2.len; j++, k++) { // k = i + j - - ulong bi1_val = (ulong)bi1.data[i]; + ulong bi1_val = (ulong)self.data[i]; ulong bi2_val = (ulong)bi2.data[j]; ulong res_val = (ulong)res.data[k]; - ulong val = (bi1_val * bi2_val) + res_val + mcarry; res.data[k] = (uint)(val & 0xFFFFFFFF); mcarry = val >> 32; @@ -213,7 +218,7 @@ fn BigInt BigInt.mult(&self, BigInt bi2) } } - res.len = min(bi1.len + bi2.len, (uint)MAX_LEN); + res.len = min(self.len + bi2.len, (uint)MAX_LEN); res.reduce_len(); @@ -224,8 +229,7 @@ fn BigInt BigInt.mult(&self, BigInt bi2) { res.negate(); } - - return res; + *self = res; } fn void BigInt.negate(&self) @@ -262,11 +266,11 @@ macro bool BigInt.is_zero(&self) => self.len == 1 && self.data[0] == 0; fn BigInt BigInt.sub(self, BigInt other) { - self.set_sub(other); + self.sub_this(other); return self; } -fn BigInt* BigInt.set_sub(&self, BigInt other) +fn BigInt* BigInt.sub_this(&self, BigInt other) { self.len = max(self.len, other.len); @@ -324,19 +328,19 @@ fn BigInt BigInt.unary_minus(&self) } -macro void BigInt.set_div(&self, BigInt other) +macro void BigInt.div(self, BigInt other) { - *self = self.div(other); + self.div_this(other); + return self; } -fn BigInt BigInt.div(&self, BigInt other) +fn void BigInt.div_this(&self, BigInt other) { - BigInt bi1 = *self; bool negate_answer = self.is_negative(); if (negate_answer) { - bi1.negate(); + self.negate(); } if (other.is_negative()) { @@ -344,32 +348,36 @@ fn BigInt BigInt.div(&self, BigInt other) other.negate(); } - if (bi1.less_than(other)) return ZERO; + if (self.less_than(other)) + { + *self = ZERO; + } BigInt quotient = ZERO; BigInt remainder = ZERO; if (other.len == 1) { - bi1.single_byte_divide(&other, "ient, &remainder); + self.single_byte_divide(&other, "ient, &remainder); } else { - bi1.multi_byte_divide(&other, "ient, &remainder); + self.multi_byte_divide(&other, "ient, &remainder); } if (negate_answer) { quotient.negate(); } - return quotient; + *self = quotient; } -fn void BigInt.set_mod(&self, BigInt bi2) +fn BigInt BigInt.mod(self, BigInt bi2) { - *self = self.mod(bi2); + self.mod_this(bi2); + return self; } -fn BigInt BigInt.mod(&self, BigInt bi2) +fn void BigInt.mod_this(&self, BigInt bi2) { if (bi2.is_negative()) { @@ -377,15 +385,15 @@ fn BigInt BigInt.mod(&self, BigInt bi2) } bool negate_answer = self.is_negative(); - BigInt bi1 = *self; if (negate_answer) { - bi1.negate(); + self.negate(); } - if (bi1.less_than(bi2)) + if (self.less_than(bi2)) { - return *self; + if (negate_answer) self.negate(); + return; } BigInt quotient = ZERO; @@ -393,46 +401,47 @@ fn BigInt BigInt.mod(&self, BigInt bi2) if (bi2.len == 1) { - bi1.single_byte_divide(&bi2, "ient, &remainder); + self.single_byte_divide(&bi2, "ient, &remainder); } else { - bi2.multi_byte_divide(&bi2, "ient, &remainder); + self.multi_byte_divide(&bi2, "ient, &remainder); } if (negate_answer) { remainder.negate(); } - return remainder; + *self = remainder; } -fn BigInt BigInt.bit_negate(&self) +fn void BigInt.bit_negate_this(&self) { - BigInt result; - for (uint i = 0; i < MAX_LEN; i++) - { - result.data[i] = ~self.data[i]; - } + foreach (&r : self.data) *r = ~*r; - result.len = MAX_LEN; - result.reduce_len(); - return result; + self.len = MAX_LEN; + self.reduce_len(); +} + +fn BigInt BigInt.bit_negate(self) +{ + self.bit_negate_this(); + return self; } fn BigInt BigInt.shr(self, int shift) { - self.set_shr(shift); + self.shr_this(shift); return self; } -fn void BigInt.set_shr(self, int shift) +fn void BigInt.shr_this(self, int shift) { self.len = shift_right(&self.data, self.len, shift); } fn BigInt BigInt.shl(self, int shift) { - self.set_shl(shift); + self.shl_this(shift); return self; } @@ -565,7 +574,7 @@ fn BigInt BigInt.mod_pow(&self, BigInt exp, BigInt mod) mod.negate(); } - num.set_mod(mod); + num.mod_this(mod); // calculate constant = b^(2k) / m BigInt constant = ZERO; @@ -574,7 +583,7 @@ fn BigInt BigInt.mod_pow(&self, BigInt exp, BigInt mod) constant.data[i] = 0x00000001; constant.len = i + 1; - constant.set_div(mod); + constant.div_this(mod); int total_bits = exp.bitcount(); int count = 0; @@ -690,18 +699,18 @@ fn BigInt barrett_reduction(BigInt x, BigInt n, BigInt constant) r2.len = k_plus_one; r2.reduce_len(); - r1.set_sub(r2); + r1.sub_this(r2); if (r1.is_negative()) { BigInt val = ZERO; val.data[k_plus_one] = 0x00000001; val.len = k_plus_one + 1; - r1.set_add(val); + r1.add_this(val); } while (r1.greater_or_equal(n)) { - r1.set_sub(n); + r1.sub_this(n); } return r1; @@ -751,11 +760,11 @@ fn BigInt BigInt.sqrt(&self) fn BigInt BigInt.bit_and(self, BigInt bi2) { - self.set_bit_and(bi2); + self.bit_and_this(bi2); return self; } -fn void BigInt.set_bit_and(&self, BigInt bi2) +fn void BigInt.bit_and_this(&self, BigInt bi2) { uint len = max(self.len, bi2.len); foreach (i, &ref : self.data[:len]) @@ -769,11 +778,11 @@ fn void BigInt.set_bit_and(&self, BigInt bi2) fn BigInt BigInt.bit_or(self, BigInt bi2) { - self.set_bit_or(bi2); + self.bit_or_this(bi2); return self; } -fn void BigInt.set_bit_or(&self, BigInt bi2) +fn void BigInt.bit_or_this(&self, BigInt bi2) { uint len = max(self.len, bi2.len); foreach (i, &ref : self.data[:len]) @@ -787,11 +796,11 @@ fn void BigInt.set_bit_or(&self, BigInt bi2) fn BigInt BigInt.bit_xor(self, BigInt bi2) { - self.set_bit_xor(bi2); + self.bit_xor_this(bi2); return self; } -fn void BigInt.set_bit_xor(&self, BigInt bi2) +fn void BigInt.bit_xor_this(&self, BigInt bi2) { uint len = max(self.len, bi2.len); foreach (i, &ref : self.data[:len]) @@ -803,7 +812,7 @@ fn void BigInt.set_bit_xor(&self, BigInt bi2) self.reduce_len(); } -fn void BigInt.set_shl(&self, int shift) +fn void BigInt.shl_this(&self, int shift) { self.len = shift_left(&self.data, self.len, shift); } @@ -992,7 +1001,7 @@ fn void BigInt.multi_byte_divide(&self, BigInt* other, BigInt* quotient, BigInt* { q_hat--; - ss.set_sub(bi2); + ss.sub_this(bi2); } BigInt yy = kk.sub(ss); diff --git a/test/unit/stdlib/math/bigint.c3 b/test/unit/stdlib/math/bigint.c3 index 0c72614d9..a63414e63 100644 --- a/test/unit/stdlib/math/bigint.c3 +++ b/test/unit/stdlib/math/bigint.c3 @@ -11,6 +11,17 @@ fn void test_plus() assert(a.add(b).equals(bigint::from_int(12323400012311213314141414i128 + 23400012311213314141414i128))); } +fn void test_mult() +{ + BigInt a = bigint::from_int(123); + BigInt b = bigint::from_int(234); + assert(a.mult(b).equals(bigint::from_int(234 * 123))); + + a = bigint::from_int(1232311213314141414i128); + b = bigint::from_int(234000123112414i128); + assert(a.mult(b).equals(bigint::from_int(1232311213314141414i128 * 234000123112414i128))); +} + fn void test_minus() { BigInt a = bigint::from_int(123);