diff --git a/lib/std/math/math.c3 b/lib/std/math/math.c3 index 8c3b98984..734a536e0 100644 --- a/lib/std/math/math.c3 +++ b/lib/std/math/math.c3 @@ -56,22 +56,6 @@ const DOUBLE_MAX_EXP = 1024; const DOUBLE_MIN_EXP = -1021; const DOUBLE_EPSILON = 2.22044604925031308085e-16; -const QUAD_MANT_DIG = 113; - -/* -Currently unsupported float128 constants -const QUAD_MAX = 1.18973149535723176508575932662800702e+4932; -const QUAD_MIN = 3.36210314311209350626267781732175260e-4932; -const QUAD_DENORM_MIN = 6.47517511943802511092443895822764655e-4966; -const QUAD_DIG = 33; -const QUAD_DEC_DIGITS = 36; -const QUAD_MAX_10_EXP = 4932; -const QUAD_MIN_10_EXP = -4931; -const QUAD_MAX_EXP = 16384; -const QUAD_MIN_EXP = -16481; -const QUAD_EPSILON = 1.92592994438723585305597794258492732e-34; -*/ - enum RoundingMode : int { TOWARD_ZERO, @@ -82,35 +66,6 @@ enum RoundingMode : int faultdef OVERFLOW, MATRIX_INVERSE_DOESNT_EXIST; -alias Complexf = Complex {float}; -alias Complex = Complex {double}; -alias COMPLEX_IDENTITY @builtin = complex::IDENTITY {double}; -alias COMPLEXF_IDENTITY @builtin = complex::IDENTITY {float}; - -alias Quaternionf = Quaternion {float}; -alias Quaternion = Quaternion {double}; -alias QUATERNION_IDENTITY @builtin = quaternion::IDENTITY {double}; -alias QUATERNIONF_IDENTITY @builtin = quaternion::IDENTITY {float}; - -alias Matrix2f = Matrix2x2 {float}; -alias Matrix2 = Matrix2x2 {double}; -alias Matrix3f = Matrix3x3 {float}; -alias Matrix3 = Matrix3x3 {double}; -alias Matrix4f = Matrix4x4 {float}; -alias Matrix4 = Matrix4x4 {double}; -alias matrix4_ortho @builtin = matrix::ortho {double}; -alias matrix4f_ortho @builtin = matrix::ortho {float}; -alias matrix4_perspective @builtin = matrix::perspective {double}; -alias matrix4f_perspective @builtin = matrix::perspective {float}; - -alias MATRIX2_IDENTITY @builtin = matrix::IDENTITY2 {double}; -alias MATRIX2F_IDENTITY @builtin = matrix::IDENTITY2 {float}; -alias MATRIX3_IDENTITY @builtin = matrix::IDENTITY3 {double}; -alias MATRIX3F_IDENTITY @builtin = matrix::IDENTITY3 {float}; -alias MATRIX4_IDENTITY @builtin = matrix::IDENTITY4 {double}; -alias MATRIX4F_IDENTITY @builtin = matrix::IDENTITY4 {float}; - - <* @require types::is_numerical($typeof(x)) : `The input must be a numerical value or numerical vector` *> diff --git a/lib/std/math/math_complex.c3 b/lib/std/math/math_complex.c3 index 2f9ee0aaf..aad3dfd2d 100644 --- a/lib/std/math/math_complex.c3 +++ b/lib/std/math/math_complex.c3 @@ -1,4 +1,22 @@ -module std::math::complex{Real}; +module std::math; + +// Complex number aliases. + +alias Complexf = Complex {float}; +alias Complex = Complex {double}; +alias COMPLEX_IDENTITY @builtin = complex::IDENTITY {double}; +alias COMPLEXF_IDENTITY @builtin = complex::IDENTITY {float}; +alias IMAGINARY @builtin @deprecated("Use I") = complex::IMAGINARY { double }; +alias IMAGINARYF @builtin @deprecated("Use I_F") = complex::IMAGINARY { float }; +alias I @builtin = complex::IMAGINARY { double }; +alias I_F @builtin = complex::IMAGINARY { float }; + +<* + The generic complex number module, for float or double based complex number definitions. + + @require Real.kindof == FLOAT : "A complex number must use a floating type" +*> +module std::math::complex {Real}; import std::io; union Complex (Printable) @@ -13,7 +31,6 @@ union Complex (Printable) const Complex IDENTITY = { 1, 0 }; const Complex IMAGINARY = { 0, 1 }; - macro Complex Complex.add(self, Complex b) @operator(+) => { .v = self.v + b.v }; macro Complex Complex.add_real(self, Real r) @operator_s(+) => { .v = self.v + (Real[<2>]) { r, 0 } }; macro Complex Complex.add_each(self, Real b) => { .v = self.v + b }; @@ -38,6 +55,7 @@ macro Complex Complex.inverse(self) macro Complex Complex.conjugate(self) => { .r = self.r, .c = -self.c }; macro Complex Complex.negate(self) @operator(-) => { .v = -self.v }; macro bool Complex.equals(self, Complex b) @operator(==) => self.v == b.v; +macro bool Complex.equals_real(self, Real r) @operator_s(==) => self.v == { r, 0 }; macro bool Complex.not_equals(self, Complex b) @operator(!=) => self.v != b.v; fn usz? Complex.to_format(&self, Formatter* f) @dynamic diff --git a/lib/std/math/math_libc.c3 b/lib/std/math/math_libc.c3 deleted file mode 100644 index 0f467c540..000000000 --- a/lib/std/math/math_libc.c3 +++ /dev/null @@ -1,47 +0,0 @@ -/* origin: FreeBSD /usr/src/lib/msun/src/s_atan.c - * ==================================================== - * Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved. - * - * Developed at SunPro, a Sun Microsystems, Inc. business. - * Permission to use, copy, modify, and distribute this - * software is freely granted, provided that this notice - * is preserved. - * ==================================================== - */ -/* atan(x) - * Method - * 1. Reduce x to positive by atan(x) = -atan(-x). - * 2. According to the integer k=4t+0.25 chopped, t=x, the argument - * is further reduced to one of the following intervals and the - * arctangent of t is evaluated by the corresponding formula: - * - * [0,7/16] atan(x) = t-t^3*(a1+t^2*(a2+...(a10+t^2*a11)...) - * [7/16,11/16] atan(x) = atan(1/2) + atan( (t-0.5)/(1+t/2) ) - * [11/16.19/16] atan(x) = atan( 1 ) + atan( (t-1)/(1+t) ) - * [19/16,39/16] atan(x) = atan(3/2) + atan( (t-1.5)/(1+1.5t) ) - * [39/16,INF] atan(x) = atan(INF) + atan( -1/t ) - * - * Constants: - * The hexadecimal values are the intended ones for the following - * constants. The decimal values may be used, provided that the - * compiler will convert from decimal to binary accurately enough - * to produce the hexadecimal values shown. - */ -/* origin: FreeBSD /usr/src/lib/msun/src/s_atanf.c */ -/* - * Conversion to float by Ian Lance Taylor, Cygnus Support, ian@cygnus.com. - */ -/* - * ==================================================== - * Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved. - * - * Developed at SunPro, a Sun Microsystems, Inc. business. - * Permission to use, copy, modify, and distribute this - * software is freely granted, provided that this notice - * is preserved. - * ==================================================== - */ -module std::math; - - - diff --git a/lib/std/math/math_matrix.c3 b/lib/std/math/math_matrix.c3 index 45ac9eaf8..f73088642 100644 --- a/lib/std/math/math_matrix.c3 +++ b/lib/std/math/math_matrix.c3 @@ -1,4 +1,32 @@ -module std::math::matrix{Real}; +module std::math; + +// Predefined matrix types +alias Matrix2f = Matrix2x2 {float}; +alias Matrix2 = Matrix2x2 {double}; +alias Matrix3f = Matrix3x3 {float}; +alias Matrix3 = Matrix3x3 {double}; +alias Matrix4f = Matrix4x4 {float}; +alias Matrix4 = Matrix4x4 {double}; + +// Predefined matrix functions +alias matrix4_ortho @builtin = matrix::ortho {double}; +alias matrix4f_ortho @builtin = matrix::ortho {float}; +alias matrix4_perspective @builtin = matrix::perspective {double}; +alias matrix4f_perspective @builtin = matrix::perspective {float}; + +alias MATRIX2_IDENTITY @builtin = matrix::IDENTITY2 {double}; +alias MATRIX2F_IDENTITY @builtin = matrix::IDENTITY2 {float}; +alias MATRIX3_IDENTITY @builtin = matrix::IDENTITY3 {double}; +alias MATRIX3F_IDENTITY @builtin = matrix::IDENTITY3 {float}; +alias MATRIX4_IDENTITY @builtin = matrix::IDENTITY4 {double}; +alias MATRIX4F_IDENTITY @builtin = matrix::IDENTITY4 {float}; + +<* + The generic matrix module, for float or double based matrix definitions. + + @require Real.kindof == FLOAT : "A matrix must use a floating type" +*> +module std::math::matrix {Real}; import std::math::vector; struct Matrix2x2 diff --git a/lib/std/math/math_quaternion.c3 b/lib/std/math/math_quaternion.c3 index 86d7d4c11..589babd25 100644 --- a/lib/std/math/math_quaternion.c3 +++ b/lib/std/math/math_quaternion.c3 @@ -1,4 +1,19 @@ -module std::math::quaternion{Real}; +module std::math; + +// Predefined quaternion aliases. + +alias Quaternionf = Quaternion {float}; +alias Quaternion = Quaternion {double}; +alias QUATERNION_IDENTITY @builtin = quaternion::IDENTITY {double}; +alias QUATERNIONF_IDENTITY @builtin = quaternion::IDENTITY {float}; + +<* + The generic quaternion module, for float or double based quaternion definitions. + + @require Real.kindof == FLOAT : "A quaternion must use a floating type" +*> + +module std::math::quaternion {Real}; import std::math::vector; union Quaternion { diff --git a/lib/std/math/math_random.c3 b/lib/std/math/random.c3 similarity index 100% rename from lib/std/math/math_random.c3 rename to lib/std/math/random.c3 diff --git a/lib/std/math/math_i128.c3 b/lib/std/math/runtime/math_i128.c3 similarity index 100% rename from lib/std/math/math_i128.c3 rename to lib/std/math/runtime/math_i128.c3 diff --git a/lib/std/math/math_builtin.c3 b/lib/std/math/runtime/math_supplemental.c3 similarity index 100% rename from lib/std/math/math_builtin.c3 rename to lib/std/math/runtime/math_supplemental.c3 diff --git a/lib/std/math/math_vector.c3 b/lib/std/math/vector.c3 similarity index 99% rename from lib/std/math/math_vector.c3 rename to lib/std/math/vector.c3 index 70a2c13ac..52c964da9 100644 --- a/lib/std/math/math_vector.c3 +++ b/lib/std/math/vector.c3 @@ -1,3 +1,5 @@ +// Vector supplemental methods + module std::math::vector; import std::math; @@ -51,6 +53,8 @@ fn double[<3>] double[<3>].unproject(self, Matrix4 projection, Matrix4 view) => fn void ortho_normalize(float[<3>]* v1, float[<3>]* v2) => ortho_normalize3(v1, v2); fn void ortho_normalized(double[<3>]* v1, double[<3>]* v2) => ortho_normalize3(v1, v2); +// -- private helpers + macro towards(v, target, max_distance) @private { var delta = target - v; diff --git a/releasenotes.md b/releasenotes.md index 0e4939f26..2889259e9 100644 --- a/releasenotes.md +++ b/releasenotes.md @@ -10,7 +10,8 @@ - Operator overloading for `+ - * / % & | ^ << >> ~ == != += -= *= /= %= &= |= ^= <<= >>=` - Add `@operator_r` and `@operator_s` attributes. - More stdlib tests: `sincos`, `ArenaAllocator`, `Slice2d`. - +- Make aliases able to use `@deprecated`. + ### Fixes - Trying to cast an enum to int and back caused the compiler to crash. - Incorrect rounding at compile time going from double to int. @@ -19,6 +20,7 @@ ### Stdlib changes - Hash functions for integer vectors and arrays. +- Prefer `math::I` and `math::I_F` for `math::IMAGINARY` and `math::IMAGINARYF` the latter is deprecated. ## 0.7.0 Change list diff --git a/src/compiler/c_codegen.c b/src/compiler/c_codegen.c index 14acfb5a0..f3bee9e42 100644 --- a/src/compiler/c_codegen.c +++ b/src/compiler/c_codegen.c @@ -412,6 +412,7 @@ static void c_emit_expr(GenContext *c, CValue *value, Expr *expr) case EXPR_MAKE_SLICE: case EXPR_INT_TO_BOOL: case EXPR_VECTOR_FROM_ARRAY: + case EXPR_TWO: break; case NON_RUNTIME_EXPR: case UNRESOLVED_EXPRS: diff --git a/src/compiler/compiler_internal.h b/src/compiler/compiler_internal.h index 57312aa9a..5cc418656 100644 --- a/src/compiler/compiler_internal.h +++ b/src/compiler/compiler_internal.h @@ -978,6 +978,12 @@ typedef struct TypeInfoId type_info; } ExprCast; +typedef struct +{ + Expr *first; + Expr *last; +} ExprTwo; + typedef struct { Expr **values; @@ -1177,6 +1183,7 @@ struct Expr_ ExprSlice slice_expr; ExprSwizzle swizzle_expr; ExprTernary ternary_expr; // 16 + ExprTwo two_expr; BuiltinDefine benchmark_hook_expr; ExprTypeCall type_call_expr; BuiltinDefine test_hook_expr; @@ -2210,6 +2217,8 @@ bool expr_is_simple(Expr *expr, bool to_float); bool expr_is_pure(Expr *expr); bool expr_is_runtime_const(Expr *expr); Expr *expr_generate_decl(Decl *decl, Expr *assign); +Expr *expr_new_two(Expr *first, Expr *second); +void expr_rewrite_two(Expr *original, Expr *first, Expr *second); void expr_insert_addr(Expr *original); void expr_rewrite_insert_deref(Expr *original); Expr *expr_generate_decl(Decl *decl, Expr *assign); @@ -3340,6 +3349,10 @@ static inline void expr_set_span(Expr *expr, SourceSpan loc) expr->span = loc; switch (expr->expr_kind) { + case EXPR_TWO: + expr_set_span(expr->two_expr.first, loc); + expr_set_span(expr->two_expr.last, loc); + return; case EXPR_INT_TO_BOOL: expr_set_span(expr->int_to_bool_expr.inner, loc); return; diff --git a/src/compiler/copying.c b/src/compiler/copying.c index 2dbfd7ed4..e9d2992ac 100644 --- a/src/compiler/copying.c +++ b/src/compiler/copying.c @@ -296,6 +296,10 @@ Expr *copy_expr(CopyStruct *c, Expr *source_expr) Expr *expr = expr_copy(source_expr); switch (source_expr->expr_kind) { + case EXPR_TWO: + MACRO_COPY_EXPR(source_expr->two_expr.first); + MACRO_COPY_EXPR(source_expr->two_expr.last); + return expr; case EXPR_TYPECALL: case EXPR_CT_SUBSCRIPT: UNREACHABLE diff --git a/src/compiler/enums.h b/src/compiler/enums.h index 30997a581..87aedeb9b 100644 --- a/src/compiler/enums.h +++ b/src/compiler/enums.h @@ -799,6 +799,7 @@ typedef enum EXPR_SWIZZLE, EXPR_TERNARY, EXPR_TEST_HOOK, + EXPR_TWO, EXPR_TRY, EXPR_TRY_UNRESOLVED, EXPR_TRY_UNWRAP_CHAIN, diff --git a/src/compiler/expr.c b/src/compiler/expr.c index 2ce062511..6a6212df4 100644 --- a/src/compiler/expr.c +++ b/src/compiler/expr.c @@ -241,6 +241,7 @@ bool expr_may_addr(Expr *expr) case EXPR_RECAST: case EXPR_DISCARD: case EXPR_ADDR_CONVERSION: + case EXPR_TWO: return false; case NON_RUNTIME_EXPR: case EXPR_ASM: @@ -333,6 +334,7 @@ bool expr_is_runtime_const(Expr *expr) case EXPR_INT_TO_FLOAT: case EXPR_FLOAT_TO_INT: case EXPR_SLICE_LEN: + case EXPR_TWO: return false; case UNRESOLVED_EXPRS: UNREACHABLE @@ -550,6 +552,22 @@ static inline bool expr_unary_addr_is_constant_eval(Expr *expr) } } +void expr_rewrite_two(Expr *original, Expr *first, Expr *second) +{ + original->expr_kind = EXPR_TWO; + original->two_expr.first = first; + original->two_expr.last = second; + original->resolve_status = RESOLVE_NOT_DONE; +} + +Expr *expr_new_two(Expr *first, Expr *second) +{ + Expr *expr = expr_new_expr(EXPR_TWO, first); + expr->two_expr.first = first; + expr->two_expr.last = second; + return expr; +} + void expr_insert_addr(Expr *original) { ASSERT(original->resolve_status == RESOLVE_DONE); @@ -807,6 +825,8 @@ bool expr_is_pure(Expr *expr) case EXPR_LAST_FAULT: case EXPR_MEMBER_GET: return true; + case EXPR_TWO: + return expr_is_pure(expr->two_expr.first) && expr_is_pure(expr->two_expr.last); case EXPR_BITASSIGN: return false; case EXPR_BINARY: diff --git a/src/compiler/llvm_codegen_expr.c b/src/compiler/llvm_codegen_expr.c index db5c74980..9971017f7 100644 --- a/src/compiler/llvm_codegen_expr.c +++ b/src/compiler/llvm_codegen_expr.c @@ -6964,6 +6964,10 @@ void llvm_emit_expr(GenContext *c, BEValue *value, Expr *expr) case EXPR_BUILTIN: case EXPR_OPERATOR_CHARS: UNREACHABLE + case EXPR_TWO: + llvm_emit_expr(c, value, expr->two_expr.first); + llvm_emit_expr(c, value, expr->two_expr.last); + return; case EXPR_VECTOR_TO_ARRAY: llvm_emit_vector_to_array(c, value, expr); return; diff --git a/src/compiler/sema_casts.c b/src/compiler/sema_casts.c index 47f352ff9..c49cdcbfe 100644 --- a/src/compiler/sema_casts.c +++ b/src/compiler/sema_casts.c @@ -440,6 +440,9 @@ RETRY: // It's unclear if this can happen. expr = VECLAST(expr->expression_list); goto RETRY; + case EXPR_TWO: + expr = expr->two_expr.last; + goto RETRY; case EXPR_TERNARY: { // In the case a ?: b -> check a and b diff --git a/src/compiler/sema_decls.c b/src/compiler/sema_decls.c index 1ddcb8ae6..60701e9f5 100755 --- a/src/compiler/sema_decls.c +++ b/src/compiler/sema_decls.c @@ -2850,7 +2850,7 @@ static bool sema_analyse_attribute(SemaContext *context, ResolvedAttrData *attr_ [ATTRIBUTE_CALLCONV] = ATTR_FUNC | ATTR_INTERFACE_METHOD | ATTR_FNTYPE, [ATTRIBUTE_COMPACT] = ATTR_STRUCT | ATTR_UNION, [ATTRIBUTE_CONST] = ATTR_MACRO, - [ATTRIBUTE_DEPRECATED] = USER_DEFINED_TYPES | CALLABLE_TYPE | ATTR_CONST | ATTR_GLOBAL | ATTR_MEMBER | ATTR_BITSTRUCT_MEMBER | ATTR_INTERFACE, + [ATTRIBUTE_DEPRECATED] = USER_DEFINED_TYPES | CALLABLE_TYPE | ATTR_CONST | ATTR_GLOBAL | ATTR_MEMBER | ATTR_BITSTRUCT_MEMBER | ATTR_INTERFACE | ATTR_ALIAS, [ATTRIBUTE_DYNAMIC] = ATTR_FUNC, [ATTRIBUTE_EXPORT] = ATTR_FUNC | ATTR_GLOBAL | ATTR_CONST | USER_DEFINED_TYPES | ATTR_ALIAS, [ATTRIBUTE_EXTERN] = ATTR_FUNC | ATTR_GLOBAL | ATTR_CONST | USER_DEFINED_TYPES, diff --git a/src/compiler/sema_expr.c b/src/compiler/sema_expr.c index de5765eb1..6f5a5a2d7 100644 --- a/src/compiler/sema_expr.c +++ b/src/compiler/sema_expr.c @@ -623,6 +623,7 @@ static bool sema_binary_is_expr_lvalue(SemaContext *context, Expr *top_expr, Exp case EXPR_SCALAR_TO_VECTOR: case EXPR_SUBSCRIPT_ADDR: case EXPR_EXPRESSION_LIST: + case EXPR_TWO: goto ERR; } UNREACHABLE @@ -713,6 +714,8 @@ static bool expr_may_ref(Expr *expr) return true; case EXPR_HASH_IDENT: return false; + case EXPR_TWO: + return expr_may_ref(expr->two_expr.last); case EXPR_EXPRESSION_LIST: if (!vec_size(expr->expression_list)) return false; return expr_may_ref(VECLAST(expr->expression_list)); @@ -1480,14 +1483,14 @@ INLINE Expr **sema_splat_arraylike_insert(SemaContext *context, Expr **args, Exp } Decl *temp = decl_new_generated_var(arg->type, VARDECL_LOCAL, arg->span); Expr *decl_expr = expr_generate_decl(temp, arg); - Expr *list = expr_new_expr(EXPR_EXPRESSION_LIST, arg); - vec_add(list->expression_list, decl_expr); + Expr *two = expr_new_expr(EXPR_TWO, arg); + two->two_expr.first = decl_expr; Expr *subscript = expr_new_expr(EXPR_SUBSCRIPT, arg); subscript->subscript_expr.index.expr = exprid(expr_new_const_int(arg->span, type_usz, 0)); subscript->subscript_expr.expr = exprid(expr_variable(temp)); - vec_add(list->expression_list, subscript); - if (!sema_analyse_expr(context, list)) return NULL; - args[index] = list; + two->two_expr.last = subscript; + if (!sema_analyse_expr(context, two)) return NULL; + args[index] = two; for (ArrayIndex i = 1; i < len; i++) { subscript = expr_new_expr(EXPR_SUBSCRIPT, arg); @@ -3350,10 +3353,7 @@ static inline bool sema_expr_analyse_subscript_lvalue(SemaContext *context, Expr if (!sema_analyse_expr(context, current_expr)) return false; Decl *temp = decl_new_generated_var(current_expr->type, VARDECL_PARAM, current_expr->span); Expr *decl = expr_generate_decl(temp, expr_copy(current_expr)); - current_expr->expr_kind = EXPR_EXPRESSION_LIST; - current_expr->expression_list = NULL; - vec_add(current_expr->expression_list, decl); - vec_add(current_expr->expression_list, expr_variable(temp)); + expr_rewrite_two(current_expr, decl, expr_variable(temp)); if (!sema_analyse_expr(context, current_expr)) return false; Expr *var_for_len = expr_variable(temp); Expr *len_expr = expr_new(EXPR_CALL, expr->span); @@ -3469,10 +3469,7 @@ static inline bool sema_expr_analyse_subscript(SemaContext *context, Expr *expr, if (!sema_analyse_expr(context, current_expr)) return false; Decl *temp = decl_new_generated_var(current_expr->type, VARDECL_PARAM, current_expr->span); Expr *decl = expr_generate_decl(temp, expr_copy(current_expr)); - current_expr->expr_kind = EXPR_EXPRESSION_LIST; - current_expr->expression_list = NULL; - vec_add(current_expr->expression_list, decl); - vec_add(current_expr->expression_list, expr_variable(temp)); + expr_rewrite_two(current_expr, decl, expr_variable(temp)); if (!sema_analyse_expr(context, current_expr)) return false; Expr *var_for_len = expr_variable(temp); Expr *len_expr = expr_new(EXPR_CALL, expr->span); @@ -6215,32 +6212,134 @@ static bool sema_binary_analyse_ct_subscript_op_assign(SemaContext *context, Exp static BoolErr sema_insert_overload_in_op_assign_or_error(SemaContext *context, Expr *expr, Expr *left, Expr *right, BinaryOp operator, Type *lhs_type) { - if (type_is_user_defined(lhs_type)) + assert(type_is_user_defined(lhs_type)); + if (!sema_analyse_inferred_expr(context, lhs_type, right)) return BOOL_ERR; + static OperatorOverload MAP[BINARYOP_LAST + 1] = { + [BINARYOP_ADD_ASSIGN] = OVERLOAD_PLUS_ASSIGN, + [BINARYOP_SUB_ASSIGN] = OVERLOAD_MINUS_ASSIGN, + [BINARYOP_MULT_ASSIGN] = OVERLOAD_MULTIPLY_ASSIGN, + [BINARYOP_DIV_ASSIGN] = OVERLOAD_DIVIDE_ASSIGN, + [BINARYOP_MOD_ASSIGN] = OVERLOAD_REMINDER_ASSIGN, + [BINARYOP_BIT_XOR_ASSIGN] = OVERLOAD_XOR_ASSIGN, + [BINARYOP_BIT_OR_ASSIGN] = OVERLOAD_OR_ASSIGN, + [BINARYOP_BIT_AND_ASSIGN] = OVERLOAD_AND_ASSIGN, + [BINARYOP_SHL_ASSIGN] = OVERLOAD_SHL_ASSIGN, + [BINARYOP_SHR_ASSIGN] = OVERLOAD_SHR_ASSIGN, + }; + OperatorOverload overload = MAP[operator]; + assert(overload && "Overload not mapped"); + if (!sema_replace_with_overload(context, expr, left, right, lhs_type, &overload)) return BOOL_ERR; + if (!overload) { - if (lhs_type->type_kind == TYPE_BITSTRUCT) - { - if (operator == BINARYOP_BIT_OR_ASSIGN || operator == BINARYOP_BIT_AND_ASSIGN || operator == BINARYOP_BIT_XOR_ASSIGN) return BOOL_FALSE; - } - if (!sema_analyse_inferred_expr(context, lhs_type, right)) return BOOL_ERR; - static OperatorOverload MAP[BINARYOP_LAST + 1] = { - [BINARYOP_ADD_ASSIGN] = OVERLOAD_PLUS_ASSIGN, - [BINARYOP_SUB_ASSIGN] = OVERLOAD_MINUS_ASSIGN, - [BINARYOP_MULT_ASSIGN] = OVERLOAD_MULTIPLY_ASSIGN, - [BINARYOP_DIV_ASSIGN] = OVERLOAD_DIVIDE_ASSIGN, - [BINARYOP_MOD_ASSIGN] = OVERLOAD_REMINDER_ASSIGN, - [BINARYOP_BIT_XOR_ASSIGN] = OVERLOAD_XOR_ASSIGN, - [BINARYOP_BIT_OR_ASSIGN] = OVERLOAD_OR_ASSIGN, - [BINARYOP_BIT_AND_ASSIGN] = OVERLOAD_AND_ASSIGN, - [BINARYOP_SHL_ASSIGN] = OVERLOAD_SHL_ASSIGN, - [BINARYOP_SHR_ASSIGN] = OVERLOAD_SHR_ASSIGN, - }; - OperatorOverload overload = MAP[operator]; - assert(overload && "Overload not mapped"); - if (!sema_replace_with_overload(context, expr, left, right, lhs_type, &overload)) return BOOL_ERR; - if (!overload) return BOOL_TRUE; + return BOOL_TRUE; } return BOOL_FALSE; } + +INLINE bool sema_rewrite_op_assign(SemaContext *context, Expr *expr, Expr *left, Expr *right, BinaryOp new_op) +{ + // Simple case: f += a => f = f + a + if (left->expr_kind == EXPR_IDENTIFIER) + { + expr->expr_kind = EXPR_BINARY; + left->expr_kind = EXPR_BINARY; + Expr *lvalue = expr_copy(left); + Expr *rvalue = expr_copy(right); + left->binary_expr = (ExprBinary) { .left = exprid(rvalue), .right = exprid(right), new_op }; + expr->binary_expr = (ExprBinary) { .left = exprid(lvalue), .right = exprid(left), .operator = BINARYOP_ASSIGN, .grouped = false }; + expr->resolve_status = RESOLVE_NOT_DONE; + left->resolve_status = RESOLVE_NOT_DONE; + return sema_analyse_expr(context, expr); + } + + Type *lhs_type = type_no_optional(left->type)->canonical; + Decl *variable = decl_new_generated_var(type_get_ptr(left->type), VARDECL_LOCAL, left->span); + Expr *left_copy = expr_copy(left); + + // If we have a &[] overload, then replace left copy with that one. + if (left->expr_kind == EXPR_SUBSCRIPT_ASSIGN) + { + Expr *parent = exprptr(left->subscript_assign_expr.expr); + Type *parent_type = type_no_optional(parent->type)->canonical; + Decl *operator = sema_find_untyped_operator(context, parent_type, OVERLOAD_ELEMENT_REF, NULL); + Expr *index = exprptr(left->subscript_assign_expr.index); + if (operator) + { + Expr **args = NULL; + vec_add(args, index); + if (!sema_insert_method_call(context, left_copy, operator, parent, args, false)) return false; + goto AFTER_ADDR; + } + // If we only have []=, then we need [] + operator = sema_find_untyped_operator(context, parent_type, OVERLOAD_ELEMENT_AT, NULL); + if (!operator) + { + RETURN_SEMA_ERROR(left, "There is no overload for [] for %s.", type_quoted_error_string(type_no_optional(left->type))); + } + Type *return_type = typeget(operator->func_decl.signature.rtype); + if (type_no_optional(return_type->canonical) != lhs_type->canonical) + { + RETURN_SEMA_ERROR(expr, "There is a type mismatch between the overload for [] and []= for %s.", type_quoted_error_string(type_no_optional(left->type))); + } + // First we want to create the indexed value and the index: + Decl *index_val = decl_new_generated_var(index->type, VARDECL_LOCAL, index->span); + // We need to take the address of the parent here, otherwise this might fail, + expr_insert_addr(parent); + Decl *parent_val = decl_new_generated_var(parent->type, VARDECL_LOCAL, parent->span); + Expr **list = NULL; + // temp = parent, temp_2 = index + vec_add(list, expr_generate_decl(parent_val, parent)); + vec_add(list, expr_generate_decl(index_val, index)); + // Now, create a lhs of the binary add: + Expr *lhs = expr_new_expr(EXPR_SUBSCRIPT, left); + Expr *parent_by_variable = expr_variable(parent_val); + expr_rewrite_insert_deref(parent_by_variable); + lhs->subscript_expr = (ExprSubscript) { .expr = exprid(parent_by_variable), .index.expr = exprid(expr_variable(index_val)) }; + // Now create the binary expression + Expr *binary = expr_new_expr(EXPR_BINARY, expr); + binary->binary_expr = (ExprBinary) { .left = exprid(lhs), .right = exprid(right), .operator = new_op }; + // Finally, we need the assign, and here we just need to replace + Expr *assign = expr_new_expr(EXPR_BINARY, expr); + assign->binary_expr = (ExprBinary) { .left = exprid(left), .right = exprid(binary), .operator = BINARYOP_ASSIGN }; + // Now we need to patch the values in `left`: + parent_by_variable = expr_variable(parent_val); + expr_rewrite_insert_deref(parent_by_variable); + left->subscript_assign_expr.expr = exprid(parent_by_variable); + left->subscript_assign_expr.index = exprid(expr_variable(index_val)); + // We add the assign + vec_add(list, assign); + // And rewrite the expression to an expression list: + expr->expr_kind = EXPR_EXPRESSION_LIST; + expr->expression_list = list; + return sema_expr_analyse_expr_list(context, expr); + } + + // f => &f + expr_insert_addr(left_copy); + +AFTER_ADDR:; + + // temp = &f + Expr *init = expr_generate_decl(variable, left_copy); + // lvalue = temp, rvalue = temp + Expr *left_rvalue = expr_variable(variable); + Expr *left_lvalue = expr_variable(variable); + + // lvalue = *temp, rvalue = *temp + expr_rewrite_insert_deref(left_lvalue); + expr_rewrite_insert_deref(left_rvalue); + + // init, expr -> lvalue = rvalue + a + expr->expr_kind = EXPR_BINARY; + left->expr_kind = EXPR_BINARY; + left->binary_expr = (ExprBinary) { .left = exprid(left_lvalue), .right = exprid(right), new_op }; + expr->binary_expr = (ExprBinary) { .left = exprid(left_rvalue), .right = exprid(left), .operator = BINARYOP_ASSIGN, .grouped = false }; + expr->resolve_status = RESOLVE_NOT_DONE; + left->resolve_status = RESOLVE_NOT_DONE; + Expr *binary = expr_copy(expr); + expr_rewrite_two(expr, init, binary); + return sema_analyse_expr(context, expr); +} /** * Analyse *= /= %= ^= |= &= += -= <<= >>= * @@ -6303,10 +6402,42 @@ static bool sema_expr_analyse_op_assign(SemaContext *context, Expr *expr, Expr * Type *no_fail = type_no_optional(left->type); Type *flat = type_flatten(no_fail); - BoolErr b = sema_insert_overload_in_op_assign_or_error(context, expr, left, right, operator, no_fail->canonical); - if (b == BOOL_ERR) return false; - if (b == BOOL_TRUE) return true; - + Type *canonical = no_fail->canonical; + if (type_is_user_defined(canonical)) + { + if (canonical->type_kind == TYPE_BITSTRUCT) + { + if (operator == BINARYOP_BIT_OR_ASSIGN + || operator == BINARYOP_BIT_AND_ASSIGN + || operator == BINARYOP_BIT_XOR_ASSIGN) goto SKIP_OVERLOAD_CHECK; + } + BoolErr b = sema_insert_overload_in_op_assign_or_error(context, expr, left, right, operator, no_fail->canonical); + if (b == BOOL_ERR) return false; + if (b == BOOL_TRUE) return true; + // Maybe we have the corresponding implemented % * etc + // The right hand side is now already checked. + BinaryOp underlying_op = binaryop_assign_base_op(operator); + static OperatorOverload MAP[BINARYOP_LAST + 1] = { + [BINARYOP_ADD] = OVERLOAD_PLUS, + [BINARYOP_SUB] = OVERLOAD_MINUS, + [BINARYOP_DIV] = OVERLOAD_DIVIDE, + [BINARYOP_MOD] = OVERLOAD_REMINDER, + [BINARYOP_BIT_XOR] = OVERLOAD_XOR, + [BINARYOP_BIT_OR] = OVERLOAD_OR, + [BINARYOP_BIT_AND] = OVERLOAD_AND, + [BINARYOP_SHL] = OVERLOAD_SHL, + [BINARYOP_SHR] = OVERLOAD_SHR, + }; + OperatorOverload mapped_overload = MAP[underlying_op]; + Decl *ambiguous = NULL; + bool reverse = false; + Decl *candidate = expr_may_ref(left) ? sema_find_typed_operator(context, mapped_overload, left, right, &ambiguous, &reverse) : NULL; + if (candidate && typeget(candidate->func_decl.signature.rtype)->canonical == canonical) + { + return sema_rewrite_op_assign(context, expr, left, right, underlying_op); + } + } +SKIP_OVERLOAD_CHECK: // 3. If this is only defined for ints (^= |= &= %=) verify that this is an int. if (int_only && !type_flat_is_intlike(flat)) { @@ -7930,9 +8061,9 @@ static bool sema_analyse_assign_mutate_overloaded_subscript(SemaContext *context { Expr *increased = exprptr(subscript_expr->subscript_assign_expr.expr); Type *type_check = increased->type->canonical; - Expr *index = exprptr(subscript_expr->subscript_assign_expr.index); Decl *operator = sema_find_untyped_operator(context, type_check, OVERLOAD_ELEMENT_REF, NULL); Expr **args = NULL; + // The simple case: we have &[] so just replace it by that. if (operator) { vec_add(args, exprptr(subscript_expr->subscript_assign_expr.index)); @@ -7941,6 +8072,7 @@ static bool sema_analyse_assign_mutate_overloaded_subscript(SemaContext *context main->type = subscript_expr->type; return true; } + // We need []= and [] now. operator = sema_find_untyped_operator(context, type_check, OVERLOAD_ELEMENT_AT, NULL); if (!operator) { @@ -7954,6 +8086,7 @@ static bool sema_analyse_assign_mutate_overloaded_subscript(SemaContext *context bool is_optional_result = type_is_optional(increased->type) || type_is_optional(return_type); Type *result_type = type_add_optional(subscript_expr->type, is_optional_result); expr_insert_addr(increased); + Expr *index = exprptr(subscript_expr->subscript_assign_expr.index); Decl *temp_val = decl_new_generated_var(increased->type, VARDECL_LOCAL, increased->span); Decl *index_val = decl_new_generated_var(index->type, VARDECL_LOCAL, index->span); Decl *value_val = decl_new_generated_var(return_type, VARDECL_LOCAL, main->span); @@ -9600,6 +9733,7 @@ static inline bool sema_expr_analyse_ct_defined(SemaContext *context, Expr *expr case EXPR_INT_TO_PTR: case EXPR_PTR_TO_INT: case EXPR_MAKE_SLICE: + case EXPR_TWO: if (!sema_analyse_expr(active_context, main_expr)) goto FAIL; break; } @@ -9923,6 +10057,11 @@ static inline bool sema_analyse_expr_dispatch(SemaContext *context, Expr *expr, case EXPR_MAKE_SLICE: case EXPR_CT_SUBSCRIPT: UNREACHABLE + case EXPR_TWO: + if (!sema_analyse_expr(context, expr->two_expr.first)) return false; + if (!sema_analyse_expr_check(context, expr->two_expr.last, check)) return false; + expr->type = expr->two_expr.last->type; + return true; case EXPR_MAKE_ANY: if (!sema_analyse_expr(context, expr->make_any_expr.typeid)) return false; return sema_analyse_expr(context, expr->make_any_expr.inner); @@ -10481,6 +10620,7 @@ IDENT_CHECK:; case EXPR_TYPEID: case EXPR_VASPLAT: case EXPR_TRY_UNRESOLVED: + case EXPR_TWO: break; case EXPR_BITACCESS: case EXPR_SUBSCRIPT_ASSIGN: @@ -10764,14 +10904,9 @@ bool sema_insert_method_call(SemaContext *context, Expr *method_call, Decl *meth parent->ident_expr = temp; parent->resolve_status = RESOLVE_DONE; parent->type = temp->type; - Expr **list = NULL; - vec_add(list, generate); if (!sema_analyse_expr(context, generate)) return false; Expr *copied_method = expr_copy(method_call); - vec_add(list, copied_method); - method_call->expr_kind = EXPR_EXPRESSION_LIST; - method_call->resolve_status = RESOLVE_NOT_DONE; - method_call->expression_list = list; + expr_rewrite_two(method_call, generate, copied_method); Expr *arg0 = arguments[0]; arguments[0] = parent; parent = arg0; @@ -10813,10 +10948,6 @@ bool sema_insert_method_call(SemaContext *context, Expr *method_call, Decl *meth if (!sema_analyse_expr(context, parent)) return false; expr_rewrite_insert_deref(parent); } - if (!(parent && parent->type && first == parent->type->canonical)) - { - puts("TODO"); - } ASSERT_SPAN(method_call, parent && parent->type && first == parent->type->canonical); if (!sema_expr_analyse_general_call(context, method_call, method_decl, parent, false, NULL)) return expr_poison(method_call); diff --git a/src/compiler/sema_initializers.c b/src/compiler/sema_initializers.c index 21ea89ae9..6c9ea40b5 100644 --- a/src/compiler/sema_initializers.c +++ b/src/compiler/sema_initializers.c @@ -374,16 +374,16 @@ static inline bool sema_expr_analyse_array_plain_initializer(SemaContext *contex SEMA_ERROR(element, "Too many elements in initializer when expanding, expected only %d.", expected_members); return false; } - Expr *expr_list = expr_new_expr(EXPR_EXPRESSION_LIST, element); + Expr *expr_two = expr_new_expr(EXPR_TWO, element); Decl *decl = decl_new_generated_var(element_type, VARDECL_LOCAL, element->span); Expr *decl_expr = expr_generate_decl(decl, element); - vec_add(expr_list->expression_list, decl_expr); + expr_two->two_expr.first = decl_expr; Expr *sub = expr_new_expr(EXPR_SUBSCRIPT, element); sub->subscript_expr.expr = exprid(expr_variable(decl)); sub->subscript_expr.index.expr = exprid(expr_new_const_int(element->span, type_usz, 0)); - vec_add(expr_list->expression_list, sub); - if (!sema_analyse_expr_rhs(context, inner_type, expr_list, true, NULL, false)) return false; - elements[i] = expr_list; + expr_two->two_expr.last = sub; + if (!sema_analyse_expr_rhs(context, inner_type, expr_two, true, NULL, false)) return false; + elements[i] = expr_two; for (unsigned j = 1; j < len; j++) { sub = expr_new_expr(EXPR_SUBSCRIPT, element); diff --git a/src/compiler/sema_internal.h b/src/compiler/sema_internal.h index e4a21d5ad..b9ad235da 100644 --- a/src/compiler/sema_internal.h +++ b/src/compiler/sema_internal.h @@ -94,7 +94,7 @@ bool sema_analyse_expr_lvalue(SemaContext *context, Expr *expr, bool *failed_ref bool sema_analyse_expr_value(SemaContext *context, Expr *expr); Expr *expr_access_inline_member(Expr *parent, Decl *parent_decl); bool sema_analyse_ct_expr(SemaContext *context, Expr *expr); -Decl *sema_find_typed_operator(SemaContext *context, OperatorOverload operator_overload, Expr *rhs, Expr *lhs, Decl **ambiguous_ref, bool *reverse); +Decl *sema_find_typed_operator(SemaContext *context, OperatorOverload operator_overload, Expr *lhs, Expr *rhs, Decl **ambiguous_ref, bool *reverse); Decl *sema_find_untyped_operator(SemaContext *context, Type *type, OperatorOverload operator_overload, Decl *skipped); bool sema_insert_method_call(SemaContext *context, Expr *method_call, Decl *method_decl, Expr *parent, Expr **arguments, bool reverse_overload); bool sema_expr_analyse_builtin_call(SemaContext *context, Expr *expr); diff --git a/src/compiler/sema_liveness.c b/src/compiler/sema_liveness.c index 9c35151f0..0139b0298 100644 --- a/src/compiler/sema_liveness.c +++ b/src/compiler/sema_liveness.c @@ -264,6 +264,10 @@ RETRY: case EXPR_NAMED_ARGUMENT: case UNRESOLVED_EXPRS: UNREACHABLE + case EXPR_TWO: + sema_trace_expr_liveness(expr->two_expr.first); + sema_trace_expr_liveness(expr->two_expr.last); + return; case EXPR_DESIGNATOR: sema_trace_expr_liveness(expr->designator_expr.value); return; diff --git a/src/compiler/sema_stmts.c b/src/compiler/sema_stmts.c index 010e2895c..6869636d4 100644 --- a/src/compiler/sema_stmts.c +++ b/src/compiler/sema_stmts.c @@ -792,6 +792,9 @@ static inline bool sema_expr_valid_try_expression(Expr *expr) case EXPR_ADDR_CONVERSION: case EXPR_ENUM_FROM_ORD: return true; + case EXPR_TWO: + return sema_expr_valid_try_expression(expr->two_expr.last); + } UNREACHABLE } diff --git a/test/test_suite/expressions/overload_through_overload.c3t b/test/test_suite/expressions/overload_through_overload.c3t new file mode 100644 index 000000000..475742914 --- /dev/null +++ b/test/test_suite/expressions/overload_through_overload.c3t @@ -0,0 +1,137 @@ +// #target: macos-x64 +module test; +import std; + +struct Abc +{ + int x; +} + +fn Abc Abc.add(self, Abc other) @operator(+) => { self.x + other.x }; +fn Abc Abc.sub_self(&self, Abc other) @operator(-) { self.x -= other.x; return *self; } + +fn Abc* get_ref(Abc* abc) +{ + return abc; +} + +struct Container +{ + Abc y; +} + +fn Abc Container.get(self, int i) @operator([]) => self.y; +fn Abc Container.set(&self, int i, Abc c) @operator([]=) => self.y = c; + +fn int main() +{ + Abc a = { 3 }; + Abc b = { 5 }; + + Abc[2] y; + y[0] += b; + assert(y[0].x == 5); + *get_ref(&a) += *get_ref(&b); + assert(a.x == 8); + + List {Abc} l; + l.push({3}); + assert(l[0].x == 3); + l[0] += b; + assert(l[0].x == 8); + l[0] -= b; + assert(l[0].x == 3); + Container c = { { 5 }}; + c[0] += b; + assert(c.y.x == 10); + c[0] -= b; + assert(c.y.x == 5); + return 0; +} + +/* #expect: test.ll + +define i32 @main() #0 { +entry: + %a = alloca %Abc, align 4 + %b = alloca %Abc, align 4 + %y = alloca [2 x %Abc], align 4 + %result = alloca %Abc, align 4 + %result1 = alloca %Abc, align 4 + %l = alloca %List, align 8 + %literal = alloca %Abc, align 4 + %result3 = alloca %Abc, align 4 + %result4 = alloca %Abc, align 4 + %c = alloca %Container, align 4 + %result5 = alloca %Abc, align 4 + %result6 = alloca %Abc, align 4 + %result7 = alloca %Abc, align 4 + %result9 = alloca %Abc, align 4 + %result10 = alloca %Abc, align 4 + %result11 = alloca %Abc, align 4 + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %a, ptr align 4 @.__const, i32 4, i1 false) + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %b, ptr align 4 @.__const.1, i32 4, i1 false) + store i32 0, ptr %y, align 4 + %ptradd = getelementptr inbounds i8, ptr %y, i64 4 + store i32 0, ptr %ptradd, align 4 + %0 = load i32, ptr %y, align 4 + %1 = load i32, ptr %b, align 4 + %2 = call i32 @test.Abc.add(i32 %0, i32 %1) + store i32 %2, ptr %result, align 4 + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %y, ptr align 4 %result, i32 4, i1 false) + %3 = load i32, ptr %y, align 4 + %eq = icmp eq i32 %3, 5 + call void @llvm.assume(i1 %eq) + %4 = call ptr @test.get_ref(ptr %a) + %5 = call ptr @test.get_ref(ptr %b) + %6 = load i32, ptr %4, align 4 + %7 = load i32, ptr %5, align 4 + %8 = call i32 @test.Abc.add(i32 %6, i32 %7) + store i32 %8, ptr %result1, align 4 + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %4, ptr align 4 %result1, i32 4, i1 false) + %9 = load i32, ptr %a, align 4 + %eq2 = icmp eq i32 %9, 8 + call void @llvm.assume(i1 %eq2) + call void @llvm.memset.p0.i64(ptr align 8 %l, i8 0, i64 40, i1 false) + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %literal, ptr align 4 @.__const.2, i32 4, i1 false) + %10 = load i32, ptr %literal, align 4 + call void @"std_collections_list$test.Abc$.List.push"(ptr %l, i32 %10) #4 + %11 = call ptr @"std_collections_list$test.Abc$.List.get_ref"(ptr %l, i64 0) #4 + %12 = load i32, ptr %11, align 4 + %13 = load i32, ptr %b, align 4 + %14 = call i32 @test.Abc.add(i32 %12, i32 %13) + store i32 %14, ptr %result3, align 4 + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %11, ptr align 4 %result3, i32 4, i1 false) + %15 = call ptr @"std_collections_list$test.Abc$.List.get_ref"(ptr %l, i64 0) #4 + %16 = load i32, ptr %b, align 4 + %17 = call i32 @test.Abc.sub_self(ptr %15, i32 %16) + store i32 %17, ptr %result4, align 4 + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %15, ptr align 4 %result4, i32 4, i1 false) + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %c, ptr align 4 @.__const.3, i32 4, i1 false) + %18 = load i32, ptr %c, align 4 + %19 = call i32 @test.Container.get(i32 %18, i32 0) + store i32 %19, ptr %result5, align 4 + %20 = load i32, ptr %result5, align 4 + %21 = load i32, ptr %b, align 4 + %22 = call i32 @test.Abc.add(i32 %20, i32 %21) + store i32 %22, ptr %result6, align 4 + %23 = load i32, ptr %result6, align 4 + %24 = call i32 @test.Container.set(ptr %c, i32 0, i32 %23) + store i32 %24, ptr %result7, align 4 + %25 = load i32, ptr %c, align 4 + %eq8 = icmp eq i32 %25, 10 + call void @llvm.assume(i1 %eq8) + %26 = load i32, ptr %c, align 4 + %27 = call i32 @test.Container.get(i32 %26, i32 0) + store i32 %27, ptr %result9, align 4 + %28 = load i32, ptr %b, align 4 + %29 = call i32 @test.Abc.sub_self(ptr %result9, i32 %28) + store i32 %29, ptr %result10, align 4 + %30 = load i32, ptr %result10, align 4 + %31 = call i32 @test.Container.set(ptr %c, i32 0, i32 %30) + store i32 %31, ptr %result11, align 4 + %32 = load i32, ptr %c, align 4 + %eq12 = icmp eq i32 %32, 5 + call void @llvm.assume(i1 %eq12) + ret i32 0 +} diff --git a/test/unit/stdlib/math/math_complex.c3 b/test/unit/stdlib/math/math_complex.c3 index d2f370e71..5d5d7b41f 100644 --- a/test/unit/stdlib/math/math_complex.c3 +++ b/test/unit/stdlib/math/math_complex.c3 @@ -1,59 +1,59 @@ module math_tests @test; -import math_tests::complex; - -alias ComplexDouble @local = ComplexType {double}; -alias ComplexInt @local = ComplexType {int}; - -module math_tests::complex {ElementType} @test; import std::math; +import std::math::complex; -alias ComplexType = Complex {ElementType}; fn void complex_mul_imaginary() { - ComplexType i = complex::IMAGINARY {ElementType}; - assert(i.mul(i).equals((ComplexType){-1, 0})); - assert(i.mul(i).mul(i).equals((ComplexType){0, -1})); + Complex c = math::I; + assert(c * c == -1); + assert(c * c * c == -math::I); } fn void complex_add() { - ComplexType a = {3, 4}; - ComplexType b = {1, 2}; - assert(a.add(b).equals((ComplexType){4, 6})); - assert(a.add_each(1).equals((ComplexType){4, 5})); + Complex a = { 3, 4 }; + Complex b = { 1, 2 }; + assert(a + b == (Complex){ 4, 6 }); + assert(a.add_each(1).equals({4, 5})); + //a += b; + //assert(a == (Complex){ 4, 6 }); } fn void complex_sub() { - ComplexType a = {3, 4}; - ComplexType b = {1, 2}; - assert(a.sub(b).equals((ComplexType){2, 2})); - assert(a.sub_each(1).equals((ComplexType){2, 3})); + Complex a = { 3, 4 }; + Complex b = { 1, 2 }; + assert(a - b == (Complex){ 2, 2 }); + assert(a.sub_each(1).equals({2, 3})); + //a -= b; + //assert(a == (Complex){ 2, 2 }); } fn void complex_scale() { - ComplexType a = {2, 1}; - assert(a.scale(2).equals((ComplexType){4, 2})); + Complex a = {2, 1}; + assert(a * 2 == (Complex) {4, 2}); } fn void complex_conjugate() { - ComplexType a = {3, 4}; - assert(a.conjugate().equals((ComplexType){3, -4})); + Complex a = {3, 4}; + assert(a.conjugate() == (Complex) {3, -4}); } -fn void complex_inverse() @if(types::is_float(ElementType)) +fn void complex_inverse() { - ComplexType a = {3, 4}; - assert(a.inverse().mul(a).equals(complex::IDENTITY{ElementType})); + Complex a = {3, 4}; + assert(a.inverse() * a == 1); } -fn void complex_div() @if(types::is_float(ElementType)) +fn void complex_div() { - ComplexType a = {2, 5}; - ComplexType b = {4, -1}; - assert(a.div(b).equals((ComplexType){3.0/17.0, 22.0/17.0})); + Complex a = {2, 5}; + Complex b = {4, -1}; + assert(a / b == (Complex) {3.0/17.0, 22.0/17.0}); + //a /= b; + //assert(a == (Complex) {3.0/17.0, 22.0/17.0}); }