diff --git a/lib/std/math/math_complex.c3 b/lib/std/math/math_complex.c3 index c323fc60c..4dc6ea7f6 100644 --- a/lib/std/math/math_complex.c3 +++ b/lib/std/math/math_complex.c3 @@ -9,8 +9,8 @@ union Complex Real[<2>] v; } - const Complex IDENTITY = { 1, 0 }; +const Complex IMAGINARY = { 0, 1 }; macro Complex Complex.add(self, Complex b) => Complex { .v = self.v + b.v }; macro Complex Complex.add_each(self, Real b) => Complex { .v = self.v + b }; macro Complex Complex.sub(self, Complex b) => Complex { .v = self.v - b.v }; @@ -22,3 +22,10 @@ macro Complex Complex.div(self, Complex b) Real div = b.v.dot(b.v); return Complex{ (self.r * b.r + self.c * b.c) / div, (self.c * b.r - self.r * b.c) / div }; } +macro Complex Complex.inverse(self) +{ + Real sqr = self.v.dot(self.v); + return Complex{ self.r / sqr, -self.c / sqr }; +} +macro Complex Complex.conjugate(self) => Complex { .r = self.r, .c = -self.c }; +macro bool Complex.equals(self, Complex b) => self.v == b.v; diff --git a/test/unit/stdlib/math/math_complex.c3 b/test/unit/stdlib/math/math_complex.c3 new file mode 100644 index 000000000..0255f6544 --- /dev/null +++ b/test/unit/stdlib/math/math_complex.c3 @@ -0,0 +1,59 @@ +module math_tests @test; +import math_tests::complex; + +def ComplexDouble = ComplexType() @local; +def ComplexInt = ComplexType() @local; + +module math_tests::complex() @test; +import std::math; + +def ComplexType = Complex(); + +fn void! complex_mul_imaginary() +{ + ComplexType i = complex::IMAGINARY(); + assert(i.mul(i).equals(ComplexType{-1, 0})); + assert(i.mul(i).mul(i).equals(ComplexType{0, -1})); +} + +fn void! complex_add() +{ + ComplexType a = {3, 4}; + ComplexType b = {1, 2}; + assert(a.add(b).equals(ComplexType{4, 6})); + assert(a.add_each(1).equals(ComplexType{4, 5})); +} + +fn void! complex_sub() +{ + ComplexType a = {3, 4}; + ComplexType b = {1, 2}; + assert(a.sub(b).equals(ComplexType{2, 2})); + assert(a.sub_each(1).equals(ComplexType{2, 3})); +} + +fn void! complex_scale() +{ + ComplexType a = {2, 1}; + assert(a.scale(2).equals(ComplexType{4, 2})); +} + +fn void! complex_conjugate() +{ + ComplexType a = {3, 4}; + assert(a.conjugate().equals(ComplexType{3, -4})); +} + +fn void! complex_inverse() @if(types::is_float(ElementType)) +{ + ComplexType a = {3, 4}; + assert(a.inverse().mul(a).equals(complex::IDENTITY())); +} + +fn void! complex_div() @if(types::is_float(ElementType)) +{ + ComplexType a = {2, 5}; + ComplexType b = {4, -1}; + assert(a.div(b).equals(ComplexType{3.0/17.0, 22.0/17.0})); +} +