diff --git a/lib/std/math/bigint.c3 b/lib/std/math/bigint.c3 index 65f9a26e2..79a8629db 100644 --- a/lib/std/math/bigint.c3 +++ b/lib/std/math/bigint.c3 @@ -172,7 +172,7 @@ macro uint find_length(uint* data, uint length) return length; } -fn BigInt BigInt.mult(self, BigInt bi2) +fn BigInt BigInt.mult(self, BigInt bi2) @operator(*) { self.mult_this(bi2); return self; @@ -270,7 +270,7 @@ fn void BigInt.negate(&self) macro bool BigInt.is_zero(&self) => self.len == 1 && self.data[0] == 0; -fn BigInt BigInt.sub(self, BigInt other) +fn BigInt BigInt.sub(self, BigInt other) @operator(-) { self.sub_this(other); return self; @@ -325,7 +325,7 @@ fn int BigInt.bitcount(&self) return bits; } -fn BigInt BigInt.unary_minus(&self) +fn BigInt BigInt.unary_minus(&self) @operator(-) { if (self.is_zero()) return *self; BigInt result = *self; @@ -334,7 +334,7 @@ fn BigInt BigInt.unary_minus(&self) } -macro BigInt BigInt.div(self, BigInt other) +macro BigInt BigInt.div(self, BigInt other) @operator(/) { self.div_this(other); return self; @@ -377,7 +377,7 @@ fn void BigInt.div_this(&self, BigInt other) *self = quotient; } -fn BigInt BigInt.mod(self, BigInt bi2) +fn BigInt BigInt.mod(self, BigInt bi2) @operator(%) { self.mod_this(bi2); return self; @@ -428,13 +428,13 @@ fn void BigInt.bit_negate_this(&self) self.reduce_len(); } -fn BigInt BigInt.bit_negate(self) +fn BigInt BigInt.bit_negate(self) @operator(~) { self.bit_negate_this(); return self; } -fn BigInt BigInt.shr(self, int shift) +fn BigInt BigInt.shr(self, int shift) @operator(>>) { self.shr_this(shift); return self; @@ -445,13 +445,13 @@ fn void BigInt.shr_this(self, int shift) self.len = shift_right(&self.data, self.len, shift); } -fn BigInt BigInt.shl(self, int shift) +fn BigInt BigInt.shl(self, int shift) @operator(<<) { self.shl_this(shift); return self; } -macro bool BigInt.equals(&self, BigInt other) +macro bool BigInt.equals(&self, BigInt other) @operator(==) { if (self.len != other.len) return false; return self.data[:self.len] == other.data[:self.len]; @@ -764,7 +764,7 @@ fn BigInt BigInt.sqrt(&self) return result; } -fn BigInt BigInt.bit_and(self, BigInt bi2) +fn BigInt BigInt.bit_and(self, BigInt bi2) @operator(&) { self.bit_and_this(bi2); return self; @@ -782,7 +782,7 @@ fn void BigInt.bit_and_this(&self, BigInt bi2) self.reduce_len(); } -fn BigInt BigInt.bit_or(self, BigInt bi2) +fn BigInt BigInt.bit_or(self, BigInt bi2) @operator(|) { self.bit_or_this(bi2); return self; @@ -800,7 +800,7 @@ fn void BigInt.bit_or_this(&self, BigInt bi2) self.reduce_len(); } -fn BigInt BigInt.bit_xor(self, BigInt bi2) +fn BigInt BigInt.bit_xor(self, BigInt bi2) @operator(^) { self.bit_xor_this(bi2); return self; diff --git a/lib/std/math/math_complex.c3 b/lib/std/math/math_complex.c3 index df3a39543..d636c3d35 100644 --- a/lib/std/math/math_complex.c3 +++ b/lib/std/math/math_complex.c3 @@ -1,6 +1,7 @@ module std::math::complex{Real}; +import std::io; -union Complex +union Complex (Printable) { struct { @@ -11,13 +12,20 @@ union Complex const Complex IDENTITY = { 1, 0 }; const Complex IMAGINARY = { 0, 1 }; -macro Complex Complex.add(self, Complex b) => { .v = self.v + b.v }; +macro Complex Real.add_complex(self, Complex r) @operator(+) => { .v = (Real[<2>]) { self, 0 } + c.v }; +macro Complex Real.sub_complex(self, Complex r) @operator(-) => { .v = (Real[<2>]) { self, 0 } - c.v }; +macro Complex Real.scale_complex(self, Complex c) @operator(*) => { .v = self * c.v }; +macro Complex Real.div_complex(self, Complex c) @operator(/) => ((Complex) { .r = self }).div(c); +macro Complex Complex.add(self, Complex b) @operator(+) => { .v = self.v + b.v }; +macro Complex Complex.add_real(self, Real r) @operator(+) => { .v = self.v + (Real[<2>]) { r, 0 } }; macro Complex Complex.add_each(self, Real b) => { .v = self.v + b }; -macro Complex Complex.sub(self, Complex b) => { .v = self.v - b.v }; +macro Complex Complex.sub(self, Complex b) @operator(-) => { .v = self.v - b.v }; +macro Complex Complex.sub_real(self, Real r) @operator(-) => { .v = self.v - (Real[<2>]) { r, 0 } }; macro Complex Complex.sub_each(self, Real b) => { .v = self.v - b }; -macro Complex Complex.scale(self, Real s) => { .v = self.v * s }; -macro Complex Complex.mul(self, Complex b) => { self.r * b.r - self.c * b.c, self.r * b.c + b.r * self.c }; -macro Complex Complex.div(self, Complex b) +macro Complex Complex.scale(self, Real r) @operator(*) => { .v = self.v * r }; +macro Complex Complex.mul(self, Complex b)@operator(*) => { self.r * b.r - self.c * b.c, self.r * b.c + b.r * self.c }; +macro Complex Complex.div_real(self, Real r) @operator(/) => { .v = self.v / r }; +macro Complex Complex.div(self, Complex b) @operator(/) { Real div = b.v.dot(b.v); return { (self.r * b.r + self.c * b.c) / div, (self.c * b.r - self.r * b.c) / div }; @@ -28,4 +36,11 @@ macro Complex Complex.inverse(self) return { self.r / sqr, -self.c / sqr }; } macro Complex Complex.conjugate(self) => { .r = self.r, .c = -self.c }; -macro bool Complex.equals(self, Complex b) => self.v == b.v; +macro Complex Complex.negate(self) @operator(-) => { .v = -self.v }; +macro bool Complex.equals(self, Complex b) @operator(==) => self.v == b.v; +macro bool Complex.not_equals(self, Complex b) @operator(!=) => self.v != b.v; + +fn usz? Complex.to_format(&self, Formatter* f) @dynamic +{ + return f.printf("%g%+gi", self.r, self.c); +} \ No newline at end of file diff --git a/lib/std/math/math_matrix.c3 b/lib/std/math/math_matrix.c3 index d30dab835..45ac9eaf8 100644 --- a/lib/std/math/math_matrix.c3 +++ b/lib/std/math/math_matrix.c3 @@ -43,7 +43,7 @@ struct Matrix4x4 } } -fn Real[<2>] Matrix2x2.apply(&self, Real[<2>] vec) +fn Real[<2>] Matrix2x2.apply(&self, Real[<2>] vec) @operator(*) { return { self.m00 * vec[0] + self.m01 * vec[1], @@ -51,7 +51,7 @@ fn Real[<2>] Matrix2x2.apply(&self, Real[<2>] vec) }; } -fn Real[<3>] Matrix3x3.apply(&self, Real[<3>] vec) +fn Real[<3>] Matrix3x3.apply(&self, Real[<3>] vec) @operator(*) { return { self.m00 * vec[0] + self.m01 * vec[1] + self.m02 * vec[2], @@ -60,7 +60,7 @@ fn Real[<3>] Matrix3x3.apply(&self, Real[<3>] vec) }; } -fn Real[<4>] Matrix4x4.apply(&self, Real[<4>] vec) +fn Real[<4>] Matrix4x4.apply(&self, Real[<4>] vec) @operator(*) { return { self.m00 * vec[0] + self.m01 * vec[1] + self.m02 * vec[2] + self.m03 * vec[3], @@ -71,7 +71,7 @@ fn Real[<4>] Matrix4x4.apply(&self, Real[<4>] vec) } -fn Matrix2x2 Matrix2x2.mul(&self, Matrix2x2 b) +fn Matrix2x2 Matrix2x2.mul(&self, Matrix2x2 b) @operator(*) { return { self.m00 * b.m00 + self.m01 * b.m10, self.m00 * b.m01 + self.m01 * b.m11, @@ -79,7 +79,7 @@ fn Matrix2x2 Matrix2x2.mul(&self, Matrix2x2 b) }; } -fn Matrix3x3 Matrix3x3.mul(&self, Matrix3x3 b) +fn Matrix3x3 Matrix3x3.mul(&self, Matrix3x3 b) @operator(*) { return { self.m00 * b.m00 + self.m01 * b.m10 + self.m02 * b.m20, @@ -96,7 +96,7 @@ fn Matrix3x3 Matrix3x3.mul(&self, Matrix3x3 b) }; } -fn Matrix4x4 Matrix4x4.mul(Matrix4x4* self, Matrix4x4 b) +fn Matrix4x4 Matrix4x4.mul(Matrix4x4* self, Matrix4x4 b) @operator(*) { return { self.m00 * b.m00 + self.m01 * b.m10 + self.m02 * b.m20 + self.m03 * b.m30, @@ -125,13 +125,25 @@ fn Matrix2x2 Matrix2x2.component_mul(&self, Real s) => matrix_component_mul(self fn Matrix3x3 Matrix3x3.component_mul(&self, Real s) => matrix_component_mul(self, s); fn Matrix4x4 Matrix4x4.component_mul(&self, Real s) => matrix_component_mul(self, s); -fn Matrix2x2 Matrix2x2.add(&self, Matrix2x2 mat2) => matrix_add(self, mat2); -fn Matrix3x3 Matrix3x3.add(&self, Matrix3x3 mat2) => matrix_add(self, mat2); -fn Matrix4x4 Matrix4x4.add(&self, Matrix4x4 mat2) => matrix_add(self, mat2); +fn Matrix2x2 Matrix2x2.add(&self, Matrix2x2 mat2) @operator(+) => matrix_add(self, mat2); +fn Matrix3x3 Matrix3x3.add(&self, Matrix3x3 mat2) @operator(+) => matrix_add(self, mat2); +fn Matrix4x4 Matrix4x4.add(&self, Matrix4x4 mat2) @operator(+) => matrix_add(self, mat2); -fn Matrix2x2 Matrix2x2.sub(&self, Matrix2x2 mat2) => matrix_sub(self, mat2); -fn Matrix3x3 Matrix3x3.sub(&self, Matrix3x3 mat2) => matrix_sub(self, mat2); -fn Matrix4x4 Matrix4x4.sub(&self, Matrix4x4 mat2) => matrix_sub(self, mat2); +fn Matrix2x2 Matrix2x2.sub(&self, Matrix2x2 mat2) @operator(-) => matrix_sub(self, mat2); +fn Matrix3x3 Matrix3x3.sub(&self, Matrix3x3 mat2) @operator(-) => matrix_sub(self, mat2); +fn Matrix4x4 Matrix4x4.sub(&self, Matrix4x4 mat2) @operator(-) => matrix_sub(self, mat2); + +fn Matrix2x2 Matrix2x2.negate(&self) @operator(-) => { .m = (Real[<4>])self.m }; +fn Matrix3x3 Matrix3x3.negate(&self) @operator(-) => { .m = (Real[<9>])self.m }; +fn Matrix4x4 Matrix4x4.negate(&self) @operator(-) => { .m = (Real[<16>])self.m }; + +fn bool Matrix2x2.eq(&self, Matrix2x2 mat2) @operator(==) => (Real[<4>])self.m == (Real[<4>])mat2.m; +fn bool Matrix3x3.eq(&self, Matrix3x3 mat2) @operator(==) => (Real[<9>])self.m == (Real[<9>])mat2.m; +fn bool Matrix4x4.eq(&self, Matrix4x4 mat2) @operator(==) => (Real[<16>])self.m == (Real[<16>])mat2.m; + +fn bool Matrix2x2.neq(&self, Matrix2x2 mat2) @operator(!=) => (Real[<4>])self.m != (Real[<4>])mat2.m; +fn bool Matrix3x3.neq(&self, Matrix3x3 mat2) @operator(!=) => (Real[<9>])self.m != (Real[<9>])mat2.m; +fn bool Matrix4x4.neq(&self, Matrix4x4 mat2) @operator(!=) => (Real[<16>])self.m != (Real[<16>])mat2.m; fn Matrix4x4 look_at(Real[<3>] eye, Real[<3>] target, Real[<3>] up) => matrix_look_at(Matrix4x4, eye, target, up); diff --git a/lib/std/math/math_quaternion.c3 b/lib/std/math/math_quaternion.c3 index ea6892236..d07f260d3 100644 --- a/lib/std/math/math_quaternion.c3 +++ b/lib/std/math/math_quaternion.c3 @@ -11,11 +11,12 @@ union Quaternion const Quaternion IDENTITY = { 0, 0, 0, 1 }; -macro Quaternion Quaternion.add(self, Quaternion b) => { .v = self.v + b.v }; +macro Quaternion Quaternion.add(self, Quaternion b) @operator(+) => { .v = self.v + b.v }; macro Quaternion Quaternion.add_each(self, Real b) => { .v = self.v + b }; -macro Quaternion Quaternion.sub(self, Quaternion b) => { .v = self.v - b.v }; +macro Quaternion Quaternion.sub(self, Quaternion b) @operator(-) => { .v = self.v - b.v }; +macro Quaternion Quaternion.negate(self) @operator(-) => { .v = -self.v }; macro Quaternion Quaternion.sub_each(self, Real b) => { .v = self.v - b }; -macro Quaternion Quaternion.scale(self, Real s) => { .v = self.v * s }; +macro Quaternion Quaternion.scale(self, Real s) @operator(*) => { .v = self.v * s }; macro Quaternion Quaternion.normalize(self) => { .v = self.v.normalize() }; macro Real Quaternion.length(self) => self.v.length(); macro Quaternion Quaternion.lerp(self, Quaternion q2, Real amount) => { .v = self.v.lerp(q2.v, amount) }; @@ -60,7 +61,7 @@ fn Quaternion Quaternion.slerp(self, Quaternion q2, Real amount) return { .v = q1v * ratio_a + q2v * ratio_b }; } -fn Quaternion Quaternion.mul(self, Quaternion b) +fn Quaternion Quaternion.mul(self, Quaternion b) @operator(+) { return { self.i * b.l + self.l * b.i + self.j * b.k - self.k * b.j, self.j * b.l + self.l * b.j + self.k * b.i - self.i * b.k, diff --git a/lib/std/time/time.c3 b/lib/std/time/time.c3 index b1aa50416..b50ebefe0 100644 --- a/lib/std/time/time.c3 +++ b/lib/std/time/time.c3 @@ -87,7 +87,7 @@ fn Time Time.add_minutes(time, long minutes) => time + (Time)(minutes * (long)MI fn Time Time.add_hours(time, long hours) => time + (Time)(hours * (long)HOUR); fn Time Time.add_days(time, long days) => time + (Time)(days * (long)DAY); fn Time Time.add_weeks(time, long weeks) => time + (Time)(weeks * (long)WEEK); -fn Time Time.add_duration(time, Duration duration) => time + (Time)duration; +fn Time Time.add_duration(time, Duration duration) @operator(+) => time + (Time)duration; fn int Time.compare_to(time, Time other) { @@ -96,7 +96,7 @@ fn int Time.compare_to(time, Time other) } fn double Time.to_seconds(time) => (long)time / (double)SEC; -fn Duration Time.diff_us(time, Time other) => (Duration)(time - other); +fn Duration Time.diff_us(time, Time other) @operator(-) => (Duration)(time - other); fn double Time.diff_sec(time, Time other) => (long)time.diff_us(other) / (double)SEC; fn double Time.diff_min(time, Time other) => (long)time.diff_us(other) / (double)MIN; fn double Time.diff_hour(time, Time other) => (long)time.diff_us(other) / (double)HOUR; diff --git a/releasenotes.md b/releasenotes.md index 80b3aa744..4a9fd4ab0 100644 --- a/releasenotes.md +++ b/releasenotes.md @@ -7,6 +7,7 @@ - Better errors trying to convert an enum to an int and vice versa. - Function `@require` checks are added to the caller in safe mode. #186 - Improved error message when narrowing isn't allowed. +- Operator overloading for `+ - * / % & | ^ << >> ~ == !=` ### Fixes - Trying to cast an enum to int and back caused the compiler to crash. diff --git a/src/compiler/compiler_internal.h b/src/compiler/compiler_internal.h index 56ec6a13d..7173f932e 100644 --- a/src/compiler/compiler_internal.h +++ b/src/compiler/compiler_internal.h @@ -613,7 +613,7 @@ typedef struct Decl_ bool attr_nopadding : 1; bool attr_compact : 1; bool resolved_attributes : 1; - OperatorOverload operator : 4; + OperatorOverload operator : 5; union { void *backend_ref; diff --git a/src/compiler/enums.h b/src/compiler/enums.h index 85fb351fd..d32db6bdb 100644 --- a/src/compiler/enums.h +++ b/src/compiler/enums.h @@ -902,6 +902,22 @@ typedef enum OVERLOAD_ELEMENT_REF, OVERLOAD_ELEMENT_SET, OVERLOAD_LEN, + OVERLOAD_NEGATE, + OVERLOAD_UNARY_MINUS, + OVERLOAD_PLUS, + OVERLOAD_TYPED_START = OVERLOAD_PLUS, + OVERLOAD_MINUS, + OVERLOAD_MULTIPLY, + OVERLOAD_DIVIDE, + OVERLOAD_REMINDER, + OVERLOAD_AND, + OVERLOAD_OR, + OVERLOAD_XOR, + OVERLOAD_SHL, + OVERLOAD_SHR, + OVERLOAD_EQUAL, + OVERLOAD_NOT_EQUAL, + OVERLOADS_COUNT = OVERLOAD_NOT_EQUAL } OperatorOverload; typedef enum diff --git a/src/compiler/parse_global.c b/src/compiler/parse_global.c index 833b021e3..d6e7cb532 100644 --- a/src/compiler/parse_global.c +++ b/src/compiler/parse_global.c @@ -952,10 +952,64 @@ Decl *parse_var_decl(ParseContext *c) // --- Parse parameters & throws & attributes +static Expr *parse_overload_from_token(ParseContext *c, TokenType token) +{ + OperatorOverload overload; + switch (token) + { + case TOKEN_PLUS: + overload = OVERLOAD_PLUS; + break; + case TOKEN_MINUS: + overload = OVERLOAD_MINUS; + break; + case TOKEN_STAR: + overload = OVERLOAD_MULTIPLY; + break; + case TOKEN_DIV: + overload = OVERLOAD_DIVIDE; + break; + case TOKEN_MOD: + overload = OVERLOAD_REMINDER; + break; + case TOKEN_AMP: + overload = OVERLOAD_AND; + break; + case TOKEN_BIT_OR: + overload = OVERLOAD_OR; + break; + case TOKEN_BIT_XOR: + overload = OVERLOAD_XOR; + break; + case TOKEN_SHL: + overload = OVERLOAD_SHL; + break; + case TOKEN_SHR: + overload = OVERLOAD_SHR; + break; + case TOKEN_BIT_NOT: + overload = OVERLOAD_NEGATE; + break; + case TOKEN_EQEQ: + overload = OVERLOAD_EQUAL; + break; + case TOKEN_NOT_EQUAL: + overload = OVERLOAD_NOT_EQUAL; + break; + default: + UNREACHABLE; + } + Expr *expr = EXPR_NEW_TOKEN(EXPR_OPERATOR_CHARS); + expr->resolve_status = RESOLVE_DONE; + expr->overload_expr = overload; + advance(c); + RANGE_EXTEND_PREV(expr); + return expr; +} /** * attribute ::= (AT_IDENT | path_prefix? AT_TYPE_IDENT) attr_params? * attr_params ::= '(' attr_param (',' attr_param)* ')' - * attr_param ::= const_expr | '&' '[' ']' || '[' ']' '='? + * attr_param ::= const_expr | '&' '[' ']' | '[' ']' '='? | '-' | '+' | '/' | '%' | '==' | '<=>' | '<<' | '>>' | '|' | '&' | '^' | '~' */ bool parse_attribute(ParseContext *c, Attr **attribute_ref, bool expect_eos) { @@ -1012,9 +1066,30 @@ bool parse_attribute(ParseContext *c, Attr **attribute_ref, bool expect_eos) while (1) { Expr *expr; + bool next_is_rparen = c->lexer.token_type == TOKEN_RPAREN; switch (c->tok) { + case TOKEN_PLUS: + case TOKEN_MINUS: + case TOKEN_STAR: + case TOKEN_DIV: + case TOKEN_MOD: + case TOKEN_BIT_NOT: + case TOKEN_BIT_OR: + case TOKEN_BIT_XOR: + case TOKEN_SHL: + case TOKEN_SHR: + case TOKEN_EQEQ: + case TOKEN_NOT_EQUAL: + if (!next_is_rparen) goto PARSE_EXPR; + expr = parse_overload_from_token(c, c->tok); + break; case TOKEN_AMP: + if (next_is_rparen) + { + expr = parse_overload_from_token(c, c->tok); + break; + } // &[] expr = EXPR_NEW_TOKEN(EXPR_OPERATOR_CHARS); expr->resolve_status = RESOLVE_DONE; @@ -1034,6 +1109,7 @@ bool parse_attribute(ParseContext *c, Attr **attribute_ref, bool expect_eos) RANGE_EXTEND_PREV(expr); break; default: +PARSE_EXPR: expr = parse_constant_expr(c); if (!expr_ok(expr)) return false; break; diff --git a/src/compiler/sema_decls.c b/src/compiler/sema_decls.c index 63092cad0..05972fd33 100755 --- a/src/compiler/sema_decls.c +++ b/src/compiler/sema_decls.c @@ -19,10 +19,11 @@ static inline bool unit_add_base_extension_method(SemaContext *context, Compilat static inline bool unit_add_method(SemaContext *context, Type *parent_type, Decl *method); static bool sema_analyse_operator_common(SemaContext *context, Decl *method, TypeInfo **rtype_ptr, Decl ***params_ptr, uint32_t parameters); -static inline Decl *operator_in_module(SemaContext *c, Module *module, OperatorOverload operator_overload); +static inline Decl *operator_in_module_typed(SemaContext *c, Module *module, OperatorOverload operator_overload, + Type *method_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref); static inline bool sema_analyse_operator_element_at(SemaContext *context, Decl *method); static inline bool sema_analyse_operator_element_set(SemaContext *context, Decl *method); -static inline bool sema_analyse_operator_len(Decl *method, SemaContext *context); +static inline bool sema_analyse_operator_len(SemaContext *context, Decl *method); static bool sema_check_operator_method_validity(SemaContext *context, Decl *method); static inline const char *method_name_by_decl(Decl *method_like); @@ -46,6 +47,7 @@ static bool sema_analyse_attributes(SemaContext *context, Decl *decl, Attr **att static bool sema_analyse_attributes_for_var(SemaContext *context, Decl *decl, bool *erase_decl); static bool sema_check_section(SemaContext *context, Attr *attr); static inline bool sema_analyse_attribute_decl(SemaContext *context, SemaContext *c, Decl *decl, bool *erase_decl); +static Decl *sema_find_typed_operator_in_list(SemaContext *context, Decl **methods, OperatorOverload operator_overload, Type *parent_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref); static inline bool sema_analyse_typedef(SemaContext *context, Decl *decl, bool *erase_decl); static bool sema_analyse_variable_type(SemaContext *context, Type *type, SourceSpan span); @@ -1669,7 +1671,6 @@ static inline const char *method_name_by_decl(Decl *method_like) } - static bool sema_analyse_operator_common(SemaContext *context, Decl *method, TypeInfo **rtype_ptr, Decl ***params_ptr, uint32_t parameters) { @@ -1698,13 +1699,30 @@ static bool sema_analyse_operator_common(SemaContext *context, Decl *method, Typ return true; } -static inline Decl *operator_in_module(SemaContext *c, Module *module, OperatorOverload operator_overload) +INLINE bool decl_matches_overload(Decl *method, Type *type, OperatorOverload overload) +{ + return method->operator == overload && typeget(method->func_decl.type_parent)->canonical == type; +} + +static inline Decl *operator_in_module_typed(SemaContext *c, Module *module, OperatorOverload operator_overload, Type *method_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref) { if (module->is_generic) return NULL; - Decl **extensions = module->private_method_extensions; - FOREACH(Decl *, extension, extensions) + Decl *found = sema_find_typed_operator_in_list(c, module->private_method_extensions, operator_overload, + method_type, binary_arg, binary_type, candidate_ref, ambiguous_ref); + if (found) return found; + FOREACH(Module *, sub_module, module->sub_modules) { - if (extension->operator == operator_overload) + return operator_in_module_typed(c, sub_module, operator_overload, method_type, binary_arg, binary_type, candidate_ref, ambiguous_ref); + } + return NULL; +} + +static inline Decl *operator_in_module_untyped(SemaContext *c, Module *module, Type *type, OperatorOverload operator_overload) +{ + if (module->is_generic) return NULL; + FOREACH(Decl *, extension, module->private_method_extensions) + { + if (decl_matches_overload(extension, type, operator_overload)) { unit_register_external_symbol(c, extension); return extension; @@ -1712,14 +1730,15 @@ static inline Decl *operator_in_module(SemaContext *c, Module *module, OperatorO } FOREACH(Module *, sub_module, module->sub_modules) { - return operator_in_module(c, sub_module, operator_overload); + return operator_in_module_untyped(c, sub_module, type, operator_overload); } return NULL; } -Decl *sema_find_operator(SemaContext *context, Type *type, OperatorOverload operator_overload) +Decl *sema_find_untyped_operator(SemaContext *context, Type *type, OperatorOverload operator_overload) { type = type->canonical; + assert(operator_overload < OVERLOAD_TYPED_START); if (!type_may_have_sub_elements(type)) return NULL; Decl *def = type->decl; FOREACH(Decl *, func, def->methods) @@ -1730,17 +1749,150 @@ Decl *sema_find_operator(SemaContext *context, Type *type, OperatorOverload oper return func; } } - Decl *extension = operator_in_module(context, context->compilation_unit->module, operator_overload); + FOREACH(Decl *, extension, context->unit->local_method_extensions) + { + if (decl_matches_overload(extension, type, operator_overload)) return extension; + } + Decl *extension = operator_in_module_untyped(context, context->compilation_unit->module, type, operator_overload); + if (extension) return extension; + FOREACH(Decl *, import, context->unit->imports) + { + extension = operator_in_module_untyped(context, import->import.module, type, operator_overload); + if (extension) return extension; + } + return NULL; +} + +static Decl *sema_find_typed_operator_in_list(SemaContext *context, Decl **methods, OperatorOverload operator_overload, Type *parent_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref) +{ + FOREACH(Decl *, func, methods) + { + if (func->operator != operator_overload) continue; + if (parent_type && parent_type != typeget(func->func_decl.type_parent)) continue; + Type *first_arg = func->func_decl.signature.params[1]->type->canonical; + if (first_arg != binary_type) + { + if (!binary_arg) continue; + if (may_cast(context, binary_arg, first_arg, false, true)) + { + if (*candidate_ref) + { + *ambiguous_ref = func; + continue; + } + *candidate_ref = func; + } + continue; + } + unit_register_external_symbol(context, func); + return func; + } + return NULL; +} +Decl *sema_find_typed_operator(SemaContext *context, Type *type, OperatorOverload operator_overload, Expr *binary_arg, Type *binary_type, Decl **ambiguous_ref) +{ + assert(operator_overload >= OVERLOAD_TYPED_START); + assert(!binary_arg || ambiguous_ref); + assert(!binary_type || !binary_arg); + type = type->canonical; + if (binary_arg) binary_type = binary_arg->type->canonical; + Decl *candidate = NULL; + Decl *ambiguous = NULL; + + if (type_is_user_defined(type)) + { + Decl *func = sema_find_typed_operator_in_list(context, type->decl->methods, operator_overload, type, binary_arg, + binary_type, &candidate, &ambiguous); + if (func) return func; + } + else + { + Decl *func = sema_find_typed_operator_in_list(context, compiler.context.method_extensions, + operator_overload, type, binary_arg, binary_type, &candidate, &ambiguous); + if (func) return func; + } + + Decl *extension = sema_find_typed_operator_in_list(context, context->unit->local_method_extensions, + operator_overload, type, binary_arg, binary_type, &candidate, &ambiguous); + if (extension) return extension; + + extension = operator_in_module_typed(context, context->compilation_unit->module, operator_overload, type, + binary_arg, binary_type, &candidate, &ambiguous); if (extension) return extension; FOREACH(Decl *, import, context->unit->imports) { - extension = operator_in_module(context, import->import.module, operator_overload); + extension = operator_in_module_typed(context, import->import.module, operator_overload, type, binary_arg, + binary_type, &candidate, &ambiguous); if (extension) return extension; } + if (ambiguous) + { + *ambiguous_ref = ambiguous; + return NULL; + } + if (candidate) + { + unit_register_external_symbol(context, candidate); + return candidate; + } return NULL; } +static inline bool sema_analyse_operator_unary(SemaContext *context, Decl *method, OperatorOverload operator_overload) +{ + assert(operator_overload == OVERLOAD_NEGATE || operator_overload == OVERLOAD_UNARY_MINUS); + TypeInfo *rtype; + Decl **params; + if (!sema_analyse_operator_common(context, method, &rtype, ¶ms, 1)) return false; + if (!rtype) RETURN_SEMA_ERROR(method, "The return value must be explicitly typed for '%s'.", method->name); + if (rtype->type->canonical != typeget(method->func_decl.type_parent)->canonical) + { + RETURN_SEMA_ERROR(rtype, "The return value must be %s but was %s.", type_quoted_error_string(typeget(method->func_decl.type_parent)), + type_quoted_error_string(rtype->type)); + } + return true; +} + +static inline bool sema_analyse_operator_arithmetics(SemaContext *context, Decl *method, OperatorOverload operator_overload) +{ + if (operator_overload == OVERLOAD_MINUS && vec_size(method->func_decl.signature.params) < 2) + { + return sema_analyse_operator_unary(context, method, method->operator = OVERLOAD_UNARY_MINUS); + } + Signature *signature = &method->func_decl.signature; + Decl **params = signature->params; + uint32_t param_count = vec_size(params); + if (param_count > 2) + { + RETURN_SEMA_ERROR(params[2], "Too many parameters, '%s' expects only 2 parameters.", method->name); + } + if (param_count < 2) + { + RETURN_SEMA_ERROR(method, "Not enough parameters, '%s' requires 2 parameters.", method->name); + } + if (!signature->rtype) RETURN_SEMA_ERROR(method, "The return value must be explicitly typed for '%s'.", method->name); + TypeInfo *rtype = type_infoptr(signature->rtype); + if (IS_OPTIONAL(rtype)) + { + RETURN_SEMA_ERROR(rtype, "The return type may not be an optional."); + } + if (operator_overload == OVERLOAD_EQUAL || operator_overload == OVERLOAD_NOT_EQUAL) + { + if (rtype->type->canonical != type_bool) + { + RETURN_SEMA_ERROR(rtype, "The return type was %s, but it must be bool for comparisons.", type_quoted_error_string(rtype->type)); + } + } + if (type_is_void(rtype->type)) + { + RETURN_SEMA_ERROR(rtype, "The return type may not be %s.", type_quoted_error_string(rtype->type)); + } + // Set it as pure + method->func_decl.signature.attrs.is_pure = true; + return true; +} + static inline bool sema_analyse_operator_element_at(SemaContext *context, Decl *method) { TypeInfo *rtype; @@ -1777,7 +1929,7 @@ static inline bool sema_analyse_operator_element_set(SemaContext *context, Decl return sema_analyse_operator_common(context, method, &rtype, ¶ms, 3); } -static inline bool sema_analyse_operator_len(Decl *method, SemaContext *context) +static inline bool sema_analyse_operator_len(SemaContext *context, Decl *method) { TypeInfo *rtype; Decl **params; @@ -1791,7 +1943,8 @@ static inline bool sema_analyse_operator_len(Decl *method, SemaContext *context) static bool sema_check_operator_method_validity(SemaContext *context, Decl *method) { - switch (method->operator) + OperatorOverload operator = method->operator; + switch (operator) { case OVERLOAD_ELEMENT_SET: return sema_analyse_operator_element_set(context, method); @@ -1799,7 +1952,25 @@ static bool sema_check_operator_method_validity(SemaContext *context, Decl *meth case OVERLOAD_ELEMENT_REF: return sema_analyse_operator_element_at(context, method); case OVERLOAD_LEN: - return sema_analyse_operator_len(method, context); + return sema_analyse_operator_len(context, method); + case OVERLOAD_PLUS: + case OVERLOAD_MULTIPLY: + case OVERLOAD_MINUS: + case OVERLOAD_DIVIDE: + case OVERLOAD_REMINDER: + case OVERLOAD_EQUAL: + case OVERLOAD_NOT_EQUAL: + case OVERLOAD_AND: + case OVERLOAD_OR: + case OVERLOAD_XOR: + case OVERLOAD_SHL: + case OVERLOAD_SHR: + return sema_analyse_operator_arithmetics(context, method, operator); + case OVERLOAD_NEGATE: + return sema_analyse_operator_unary(context, method, operator); + case OVERLOAD_UNARY_MINUS: + // Changed in OVERLOAD_MINUS analysis + UNREACHABLE } UNREACHABLE } @@ -1888,20 +2059,31 @@ INLINE bool sema_analyse_operator_method(SemaContext *context, Type *parent_type // Check it's valid for the operator type. if (!sema_check_operator_method_validity(context, method)) return false; - // We don't support operator overloading on base types, because - // there seems little use for it frankly. - if (!type_is_user_defined(parent_type)) - { - sema_error_at(context, method_find_overload_span(method), - "Only user-defined types support operator overloading."); - return false; - } - // See if the operator has already been defined. OperatorOverload operator = method->operator; - Decl *other = sema_find_operator(context, parent_type, operator); - if (other != method) + Type *second_param = vec_size(method->func_decl.signature.params) > 1 ? method->func_decl.signature.params[1]->type : NULL; + + // We don't support operator overloading on base types, because + // there seems little use for it frankly. + if (!type_is_user_defined(parent_type) && (operator < OVERLOAD_TYPED_START || !type_is_user_defined(second_param->canonical))) + { + sema_error_at(context, method_find_overload_span(method), + "Only overloads involving user-defined types support overloading."); + return false; + } + + + Decl *other = NULL; + if (operator >= OVERLOAD_TYPED_START) + { + other = sema_find_typed_operator(context, parent_type, operator, NULL, second_param, NULL); + } + else + { + other = sema_find_untyped_operator(context, parent_type, operator); + } + if (other && other != method) { SourceSpan span = method_find_overload_span(method); sema_error_at(context, span, "This operator is already defined for '%s'.", parent_type->name); @@ -1919,14 +2101,14 @@ INLINE bool sema_analyse_operator_method(SemaContext *context, Type *parent_type return true; case OVERLOAD_ELEMENT_AT: // [] compares &[] - other = sema_find_operator(context, parent_type, OVERLOAD_ELEMENT_REF); + other = sema_find_untyped_operator(context, parent_type, OVERLOAD_ELEMENT_REF); if (other && decl_ok(other)) { sema_get_overload_arguments(other, &value, &index_type); break; } // And []= - other = sema_find_operator(context, parent_type, OVERLOAD_ELEMENT_SET); + other = sema_find_untyped_operator(context, parent_type, OVERLOAD_ELEMENT_SET); if (other && decl_ok(other)) { sema_get_overload_arguments(other, &value, &index_type); @@ -1935,14 +2117,14 @@ INLINE bool sema_analyse_operator_method(SemaContext *context, Type *parent_type return true; case OVERLOAD_ELEMENT_REF: // &[] compares [] - other = sema_find_operator(context, parent_type, OVERLOAD_ELEMENT_AT); + other = sema_find_untyped_operator(context, parent_type, OVERLOAD_ELEMENT_AT); if (other && decl_ok(other)) { sema_get_overload_arguments(other, &value, &index_type); break; } // And []= - other = sema_find_operator(context, parent_type, OVERLOAD_ELEMENT_SET); + other = sema_find_untyped_operator(context, parent_type, OVERLOAD_ELEMENT_SET); if (other && decl_ok(other)) { sema_get_overload_arguments(other, &value, &index_type); @@ -1951,20 +2133,35 @@ INLINE bool sema_analyse_operator_method(SemaContext *context, Type *parent_type return true; case OVERLOAD_ELEMENT_SET: // []= compares &[] - other = sema_find_operator(context, parent_type, OVERLOAD_ELEMENT_REF); + other = sema_find_untyped_operator(context, parent_type, OVERLOAD_ELEMENT_REF); if (other && decl_ok(other)) { sema_get_overload_arguments(other, &value, &index_type); break; } // And [] - other = sema_find_operator(context, parent_type, OVERLOAD_ELEMENT_AT); + other = sema_find_untyped_operator(context, parent_type, OVERLOAD_ELEMENT_AT); if (other && decl_ok(other)) { sema_get_overload_arguments(other, &value, &index_type); break; } return true; + case OVERLOAD_SHL: + case OVERLOAD_SHR: + case OVERLOAD_PLUS: + case OVERLOAD_DIVIDE: + case OVERLOAD_REMINDER: + case OVERLOAD_UNARY_MINUS: + case OVERLOAD_AND: + case OVERLOAD_OR: + case OVERLOAD_XOR: + case OVERLOAD_NEGATE: + case OVERLOAD_MULTIPLY: + case OVERLOAD_MINUS: + case OVERLOAD_EQUAL: + case OVERLOAD_NOT_EQUAL: + return true; default: UNREACHABLE } diff --git a/src/compiler/sema_expr.c b/src/compiler/sema_expr.c index 245ef0ba5..ff6868744 100644 --- a/src/compiler/sema_expr.c +++ b/src/compiler/sema_expr.c @@ -87,7 +87,7 @@ static bool sema_expr_analyse_add(SemaContext *context, Expr *expr, Expr *left, static bool sema_expr_analyse_mult(SemaContext *context, Expr *expr, Expr *left, Expr *right); static bool sema_expr_analyse_div(SemaContext *context, Expr *expr, Expr *left, Expr *right); static bool sema_expr_analyse_mod(SemaContext *context, Expr *expr, Expr *left, Expr *right); -static bool sema_expr_analyse_bit(SemaContext *context, Expr *expr, Expr *left, Expr *right); +static bool sema_expr_analyse_bit(SemaContext *context, Expr *expr, Expr *left, Expr *right, OperatorOverload overload); static bool sema_expr_analyse_enum_add_sub(SemaContext *context, Expr *expr, Expr *left, Expr *right); static bool sema_expr_analyse_shift(SemaContext *context, Expr *expr, Expr *left, Expr *right); static bool sema_expr_check_shift_rhs(SemaContext *context, Expr *expr, Type *left_type, Type *left_type_flat, Expr *right, Type *right_type_flat); @@ -134,11 +134,12 @@ static bool sema_analyse_assign_mutate_overloaded_subscript(SemaContext *context static void expr_binary_unify_failability(Expr *expr, Expr *left, Expr *right); static inline bool sema_binary_analyse_subexpr(SemaContext *context, Expr *left, Expr *right); -static inline bool sema_binary_analyse_arithmetic_subexpr(SemaContext *context, Expr *expr, const char *error, bool bool_and_bitstruct_is_allowed); +static inline bool sema_binary_analyse_arithmetic_subexpr(SemaContext *context, Expr *expr, const char *error, + bool bool_and_bitstruct_is_allowed, OperatorOverload *operator_overload_red); static bool sema_binary_check_unclear_op_precedence(Expr *left_side, Expr * main_expr, Expr *right_side); static bool sema_binary_analyse_ct_op_assign(SemaContext *context, Expr *expr, Expr *left); static bool sema_binary_arithmetic_promotion(SemaContext *context, Expr *left, Expr *right, Type *left_type, Type *right_type, - Expr *parent, const char *error_message, bool allow_bool_vec); + Expr *parent, const char *error_message, bool allow_bool_vec, OperatorOverload *operator_overload_ref); static bool sema_binary_is_unsigned_always_same_comparison(SemaContext *context, Expr *expr, Expr *left, Expr *right, Type *lhs_type, Type *rhs_type); static bool sema_binary_is_expr_lvalue(SemaContext *context, Expr *top_expr, Expr *expr, bool *failed_ref); @@ -206,7 +207,7 @@ static inline bool sema_analyse_expr_check(SemaContext *context, Expr *expr, Che static inline Expr **sema_prepare_splat_insert(Expr **exprs, unsigned added, unsigned insert_point); static inline bool sema_analyse_maybe_dead_expr(SemaContext *, Expr *expr, bool is_dead, Type *infer_type); - +static inline bool sema_insert_binary_overload(SemaContext *context, Expr *expr, Decl *overload, Expr *lhs, Expr *rhs); // -- implementations @@ -1217,7 +1218,8 @@ static inline bool sema_binary_analyse_subexpr(SemaContext *context, Expr *left, return sema_analyse_expr(context, left) && sema_analyse_expr(context, right); } -static inline bool sema_binary_analyse_arithmetic_subexpr(SemaContext *context, Expr *expr, const char *error, bool bool_and_bitstruct_is_allowed) +static inline bool sema_binary_analyse_arithmetic_subexpr(SemaContext *context, Expr *expr, const char *error, + bool bool_and_bitstruct_is_allowed, OperatorOverload *operator_overload_ref) { Expr *left = exprptr(expr->binary_expr.left); Expr *right = exprptr(expr->binary_expr.right); @@ -1225,8 +1227,6 @@ static inline bool sema_binary_analyse_arithmetic_subexpr(SemaContext *context, // 1. Analyse both sides. if (!sema_binary_analyse_subexpr(context, left, right)) return false; - //if (!sema_binary_promote_top_down(context, expr, left, right)) return false; - Type *left_type = type_no_optional(left->type)->canonical; Type *right_type = type_no_optional(right->type)->canonical; @@ -1236,7 +1236,7 @@ static inline bool sema_binary_analyse_arithmetic_subexpr(SemaContext *context, if (left_type == type_bool && right_type == type_bool) return true; } // 2. Perform promotion to a common type. - return sema_binary_arithmetic_promotion(context, left, right, left_type, right_type, expr, error, bool_and_bitstruct_is_allowed); + return sema_binary_arithmetic_promotion(context, left, right, left_type, right_type, expr, error, bool_and_bitstruct_is_allowed, operator_overload_ref); } static inline int sema_call_find_index_of_named_parameter(SemaContext *context, Decl **func_params, Expr *expr) @@ -3153,7 +3153,7 @@ static Expr *sema_expr_find_subscript_type_or_overload_for_subscript(SemaContext Decl **overload_ptr) { Decl *overload = NULL; - overload = sema_find_operator(context, current_expr->type, overload_type); + overload = sema_find_untyped_operator(context, current_expr->type, overload_type); if (overload) { // Overload for []= @@ -3340,7 +3340,7 @@ static inline bool sema_expr_analyse_subscript_lvalue(SemaContext *context, Expr { if (start_from_end) { - Decl *len = sema_find_operator(context, current_expr->type, OVERLOAD_LEN); + Decl *len = sema_find_untyped_operator(context, current_expr->type, OVERLOAD_LEN); if (!len) { if (check_valid) goto VALID_FAIL_POISON; @@ -3459,7 +3459,7 @@ static inline bool sema_expr_analyse_subscript(SemaContext *context, Expr *expr, { if (start_from_end) { - Decl *len = sema_find_operator(context, current_expr->type, OVERLOAD_LEN); + Decl *len = sema_find_untyped_operator(context, current_expr->type, OVERLOAD_LEN); if (!len) { if (check_valid) goto VALID_FAIL_POISON; @@ -6423,11 +6423,33 @@ END: return true; } - +static bool sema_replace_with_overload(SemaContext *context, Expr *expr, Expr *left, Expr *right, Type *left_type, OperatorOverload* operator_overload_ref) +{ + assert(!type_is_optional(left_type) && left_type->canonical == left_type); + Decl *ambiguous = NULL; + Decl *overload = sema_find_typed_operator(context, left_type, *operator_overload_ref, right, NULL, &ambiguous); + if (overload) + { + *operator_overload_ref = 0; + return sema_insert_binary_overload(context, expr, overload, left, right); + } + if (ambiguous) + { + RETURN_SEMA_ERROR(expr, "Overload was ambiguous for types %s and %s.", + type_quoted_error_string(left->type), type_quoted_error_string(right->type)); + } + return true; +} static bool sema_binary_arithmetic_promotion(SemaContext *context, Expr *left, Expr *right, Type *left_type, Type *right_type, - Expr *parent, const char *error_message, bool allow_bool_vec) + Expr *parent, const char *error_message, bool allow_bool_vec, OperatorOverload *operator_overload_ref) { + if (type_is_user_defined(left_type) || type_is_user_defined(right_type)) + { + if (!sema_replace_with_overload(context, parent, left, right, left_type, operator_overload_ref)) return false; + if (!*operator_overload_ref) return true; + } + Type *max = cast_numeric_arithmetic_promotion(type_find_max_type(left_type, right_type)); if (!max || (!type_underlying_is_numeric(max) && !(allow_bool_vec && type_flat_is_bool_vector(max)))) { @@ -6642,16 +6664,13 @@ static bool sema_expr_analyse_sub(SemaContext *context, Expr *expr, Expr *left, right_type = type_no_optional(right->type)->canonical; // 7. Attempt arithmetic promotion, to promote both to a common type. - if (!sema_binary_arithmetic_promotion(context, - left, - right, - left_type, - right_type, - expr, - "The subtraction %s - %s is not possible.", false)) + OperatorOverload overload = OVERLOAD_MINUS; + if (!sema_binary_arithmetic_promotion(context, left, right, left_type, right_type, expr, + "The subtraction %s - %s is not possible.", false, &overload)) { return false; } + if (!overload) return true; left_type = left->type->canonical; @@ -6786,16 +6805,13 @@ static bool sema_expr_analyse_add(SemaContext *context, Expr *expr, Expr *left, ASSERT_SPAN(expr, !cast_to_iptr); // 4. Do a binary arithmetic promotion - if (!sema_binary_arithmetic_promotion(context, - left, - right, - left_type, - right_type, - expr, - "Cannot do the addition %s + %s.", false)) + OperatorOverload overload = OVERLOAD_PLUS; + if (!sema_binary_arithmetic_promotion(context, left, right, left_type, right_type, expr, + "Cannot do the addition %s + %s.", false, &overload)) { return false; } + if (!overload) return true; // 5. Handle the "both const" case. We should only see ints and floats at this point. if (expr_both_const(left, right) && sema_constant_fold_ops(left)) @@ -6833,8 +6849,10 @@ static bool sema_expr_analyse_mult(SemaContext *context, Expr *expr, Expr *left, { // 1. Analyse the sub expressions and promote to a common type - if (!sema_binary_analyse_arithmetic_subexpr(context, expr, "It is not possible to multiply %s by %s.", false)) return false; + OperatorOverload overload = OVERLOAD_MULTIPLY; + if (!sema_binary_analyse_arithmetic_subexpr(context, expr, "It is not possible to multiply %s by %s.", false, &overload)) return false; + if (!overload) return true; // 2. Handle constant folding. if (expr_both_const(left, right) && sema_constant_fold_ops(left)) @@ -6866,7 +6884,9 @@ static bool sema_expr_analyse_mult(SemaContext *context, Expr *expr, Expr *left, static bool sema_expr_analyse_div(SemaContext *context, Expr *expr, Expr *left, Expr *right) { // 1. Analyse sub expressions and promote to a common type - if (!sema_binary_analyse_arithmetic_subexpr(context, expr, "Cannot divide %s by %s.", false)) return false; + OperatorOverload overload = OVERLOAD_DIVIDE; + if (!sema_binary_analyse_arithmetic_subexpr(context, expr, "Cannot divide %s by %s.", false, &overload)) return false; + if (!overload) return true; // 2. Check for a constant 0 on the rhs. if (sema_cast_const(right)) @@ -6921,7 +6941,9 @@ static bool sema_expr_analyse_div(SemaContext *context, Expr *expr, Expr *left, static bool sema_expr_analyse_mod(SemaContext *context, Expr *expr, Expr *left, Expr *right) { // 1. Analyse both sides and promote to a common type - if (!sema_binary_analyse_arithmetic_subexpr(context, expr, NULL, false)) return false; + OperatorOverload overload = OVERLOAD_REMINDER; + if (!sema_binary_analyse_arithmetic_subexpr(context, expr, "Cannot calculate the reminder %s %% %s", false, &overload)) return false; + if (!overload) return true; Type *flat = type_flatten(left->type); if (type_is_float(flat)) @@ -6962,11 +6984,11 @@ static bool sema_expr_analyse_mod(SemaContext *context, Expr *expr, Expr *left, * Analyse a ^ b, a | b, a & b * @return true if the analysis succeeded. */ -static bool sema_expr_analyse_bit(SemaContext *context, Expr *expr, Expr *left, Expr *right) +static bool sema_expr_analyse_bit(SemaContext *context, Expr *expr, Expr *left, Expr *right, OperatorOverload overload) { - // 1. Convert to common type if possible. - if (!sema_binary_analyse_arithmetic_subexpr(context, expr, NULL, true)) return false; + if (!sema_binary_analyse_arithmetic_subexpr(context, expr, NULL, true, &overload)) return false; + if (!overload) return true; // 2. Check that both are integers or bools. bool is_bool = left->type->canonical == type_bool; @@ -7067,8 +7089,18 @@ static bool sema_expr_analyse_shift(SemaContext *context, Expr *expr, Expr *left // 1. Analyze both sides. if (!sema_binary_analyse_subexpr(context, left, right)) return false; + Type *lhs_type = type_no_optional(left->type)->canonical; + bool shr = expr->binary_expr.operator == BINARYOP_SHR; + + if (type_is_user_defined(lhs_type)) + { + OperatorOverload overload = shr ? OVERLOAD_SHR : OVERLOAD_SHL; + if (!sema_replace_with_overload(context, expr, left, right, lhs_type, &overload)) return false; + if (!overload) return true; + } + // 3. Promote lhs using the usual numeric promotion. - if (!cast_implicit_binary(context, left, cast_numeric_arithmetic_promotion(type_no_optional(left->type)), false)) return false; + if (!cast_implicit_binary(context, left, cast_numeric_arithmetic_promotion(lhs_type), false)) return false; Type *flat_left = type_flatten(left->type); Type *flat_right = type_flatten(right->type); @@ -7085,8 +7117,6 @@ static bool sema_expr_analyse_shift(SemaContext *context, Expr *expr, Expr *left // Fold constant expressions. if (expr_is_const_int(right) && sema_cast_const(left)) { - - bool shr = expr->binary_expr.operator == BINARYOP_SHR; expr_replace(expr, left); if (shr) { @@ -7200,9 +7230,56 @@ static bool sema_expr_analyse_comp(SemaContext *context, Expr *expr, Expr *left, bool is_equality_type_op = expr->binary_expr.operator == BINARYOP_NE || expr->binary_expr.operator == BINARYOP_EQ; + Type *left_type = type_no_optional(left->type)->canonical; + Type *right_type = type_no_optional(right->type)->canonical; + + if (is_equality_type_op && (!type_is_comparable(left_type) || !type_is_comparable(right_type))) + { + Decl *overload = NULL; + bool negated_overload = false; + Decl *ambiguous = NULL; + switch (expr->binary_expr.operator) + { + case BINARYOP_NE: + overload = sema_find_typed_operator(context, left_type, OVERLOAD_NOT_EQUAL, right, NULL, &ambiguous); + if (!overload && !ambiguous) + { + negated_overload = true; + overload = sema_find_typed_operator(context, left_type, OVERLOAD_EQUAL, right, NULL, &ambiguous); + } + if (!overload) goto NEXT; + break; + case BINARYOP_EQ: + overload = sema_find_typed_operator(context, left_type, OVERLOAD_EQUAL, right, NULL, &ambiguous); + if (!overload && !ambiguous) + { + negated_overload = true; + overload = sema_find_typed_operator(context, left_type, OVERLOAD_NOT_EQUAL, right, NULL, &ambiguous); + } + if (!overload) goto NEXT; + break; + default: + UNREACHABLE + } + Expr **args = NULL; + if (overload->func_decl.signature.params[1]->type->canonical->type_kind == TYPE_POINTER) + { + expr_insert_addr(right); + } + vec_add(args, right); + if (!sema_insert_method_call(context, expr, overload, left, args)) return false; + if (!negated_overload) return true; + assert(expr->resolve_status == RESOLVE_DONE); + Expr *inner = expr_copy(expr); + expr->expr_kind = EXPR_UNARY; + expr->unary_expr = (ExprUnary) { .expr = inner, .operator = UNARYOP_NOT }; + expr->resolve_status = RESOLVE_NOT_DONE; + return sema_analyse_expr(context, expr); + } +NEXT: // Flatten distinct/optional - Type *left_type = type_flat_distinct_inline(type_no_optional(left->type)->canonical)->canonical; - Type *right_type = type_flat_distinct_inline(type_no_optional(right->type)->canonical)->canonical; + left_type = type_flat_distinct_inline(left_type)->canonical; + right_type = type_flat_distinct_inline(right_type)->canonical; // 2. Handle the case of signed comparisons. // This happens when either side has a definite integer type @@ -7216,6 +7293,7 @@ static bool sema_expr_analyse_comp(SemaContext *context, Expr *expr, Expr *left, goto DONE; } + // 3. In the normal case, treat this as a binary op, finding the max type. Type *max = type_find_max_type(left_type, right_type); @@ -7235,12 +7313,6 @@ static bool sema_expr_analyse_comp(SemaContext *context, Expr *expr, Expr *left, if (!type_is_comparable(max)) { - if (type_is_user_defined(max)) - { - RETURN_SEMA_ERROR(expr, - "%s does not support comparisons, you need to manually implement a comparison if you need it.", - type_quoted_error_string(left->type)); - } RETURN_SEMA_ERROR(expr, "%s does not support comparisons.", type_quoted_error_string(left->type)); } @@ -7269,6 +7341,8 @@ static bool sema_expr_analyse_comp(SemaContext *context, Expr *expr, Expr *left, if (!cast_implicit(context, left, max, false) || !cast_implicit(context, right, max, false)) return false; bool success = cast_explicit(context, left, max) && cast_explicit(context, right, max); ASSERT_SPAN(expr, success); + + DONE: // 7. Do constant folding. @@ -7564,6 +7638,23 @@ static inline bool sema_expr_analyse_neg_plus(SemaContext *context, Expr *expr) // 2. Check if it's possible to negate this (i.e. is it an int, float or vector) Type *no_fail = type_no_optional(inner->type); + Type *canonical = no_fail->canonical; + + // Check for overload + if (type_is_user_defined(canonical)) + { + Decl *overload = sema_find_untyped_operator(context, canonical, OVERLOAD_UNARY_MINUS); + if (overload) + { + // Plus just returns inner + if (is_plus) + { + expr_replace(expr, inner); + return true; + } + return sema_insert_method_call(context, expr, overload, inner, NULL); + } + } if (!type_may_negate(no_fail)) { if (is_plus) @@ -7625,6 +7716,13 @@ static inline bool sema_expr_analyse_bit_not(SemaContext *context, Expr *expr) // 2. Check that it's a vector, bool Type *canonical = type_no_optional(inner->type)->canonical; + + if (type_is_user_defined(canonical) && canonical->type_kind != TYPE_BITSTRUCT) + { + Decl *overload = sema_find_untyped_operator(context, canonical, OVERLOAD_NEGATE); + if (overload) return sema_insert_method_call(context, expr, overload, inner, NULL); + } + Type *flat = type_flatten(canonical); bool is_bitstruct = flat->type_kind == TYPE_BITSTRUCT; if (!type_is_integer_or_bool_kind(flat) && !is_bitstruct) @@ -7797,7 +7895,7 @@ static bool sema_analyse_assign_mutate_overloaded_subscript(SemaContext *context Expr *increased = exprptr(subscript_expr->subscript_assign_expr.expr); Type *type_check = increased->type->canonical; Expr *index = exprptr(subscript_expr->subscript_assign_expr.index); - Decl *operator = sema_find_operator(context, type_check, OVERLOAD_ELEMENT_REF); + Decl *operator = sema_find_untyped_operator(context, type_check, OVERLOAD_ELEMENT_REF); Expr **args = NULL; if (operator) { @@ -7807,7 +7905,7 @@ static bool sema_analyse_assign_mutate_overloaded_subscript(SemaContext *context main->type = subscript_expr->type; return true; } - operator = sema_find_operator(context, type_check, OVERLOAD_ELEMENT_AT); + operator = sema_find_untyped_operator(context, type_check, OVERLOAD_ELEMENT_AT); if (!operator) { RETURN_SEMA_ERROR(main, "There is no overload for [] for %s.", type_quoted_error_string(increased->type)); @@ -8117,9 +8215,11 @@ static inline bool sema_expr_analyse_binary(SemaContext *context, Type *infer_ty case BINARYOP_OR: return sema_expr_analyse_and_or(context, expr, left, right); case BINARYOP_BIT_OR: + return sema_expr_analyse_bit(context, expr, left, right, OVERLOAD_OR); case BINARYOP_BIT_XOR: + return sema_expr_analyse_bit(context, expr, left, right, OVERLOAD_XOR); case BINARYOP_BIT_AND: - return sema_expr_analyse_bit(context, expr, left, right); + return sema_expr_analyse_bit(context, expr, left, right, OVERLOAD_AND); case BINARYOP_VEC_NE: case BINARYOP_VEC_EQ: case BINARYOP_VEC_GT: @@ -10661,6 +10761,17 @@ bool sema_insert_method_call(SemaContext *context, Expr *method_call, Decl *meth return true; } +static inline bool sema_insert_binary_overload(SemaContext *context, Expr *expr, Decl *overload, Expr *lhs, Expr *rhs) +{ + Expr **args = NULL; + if (overload->func_decl.signature.params[1]->type->canonical->type_kind == TYPE_POINTER) + { + expr_insert_addr(rhs); + } + vec_add(args, rhs); + return sema_insert_method_call(context, expr, overload, lhs, args); +} + // Check if the assignment fits bool sema_bit_assignment_check(SemaContext *context, Expr *right, Decl *member) { diff --git a/src/compiler/sema_internal.h b/src/compiler/sema_internal.h index 76b9d3b20..c61680d10 100644 --- a/src/compiler/sema_internal.h +++ b/src/compiler/sema_internal.h @@ -94,7 +94,8 @@ bool sema_analyse_expr_lvalue(SemaContext *context, Expr *expr, bool *failed_ref bool sema_analyse_expr_value(SemaContext *context, Expr *expr); Expr *expr_access_inline_member(Expr *parent, Decl *parent_decl); bool sema_analyse_ct_expr(SemaContext *context, Expr *expr); -Decl *sema_find_operator(SemaContext *context, Type *type, OperatorOverload operator_overload); +Decl *sema_find_typed_operator(SemaContext *context, Type *type, OperatorOverload operator_overload, Expr *binary_arg, Type *binary_type, Decl **ambiguous_ref); +Decl *sema_find_untyped_operator(SemaContext *context, Type *type, OperatorOverload operator_overload); bool sema_insert_method_call(SemaContext *context, Expr *method_call, Decl *method_decl, Expr *parent, Expr **arguments); bool sema_expr_analyse_builtin_call(SemaContext *context, Expr *expr); diff --git a/src/compiler/sema_stmts.c b/src/compiler/sema_stmts.c index 29eb45938..7fe89b01a 100644 --- a/src/compiler/sema_stmts.c +++ b/src/compiler/sema_stmts.c @@ -1477,9 +1477,9 @@ static inline bool sema_analyse_foreach_stmt(SemaContext *context, Ast *statemen if (!value_type || canonical->type_kind == TYPE_DISTINCT) { - len = sema_find_operator(context, enumerator->type, OVERLOAD_LEN); - Decl *by_val = sema_find_operator(context, enumerator->type, OVERLOAD_ELEMENT_AT); - Decl *by_ref = sema_find_operator(context, enumerator->type, OVERLOAD_ELEMENT_REF); + len = sema_find_untyped_operator(context, enumerator->type, OVERLOAD_LEN); + Decl *by_val = sema_find_untyped_operator(context, enumerator->type, OVERLOAD_ELEMENT_AT); + Decl *by_ref = sema_find_untyped_operator(context, enumerator->type, OVERLOAD_ELEMENT_REF); if (!len || (!by_val && !by_ref)) { if (value_type) goto SKIP_OVERLOAD; diff --git a/test/test_suite/methods/unsupported_operator.c3 b/test/test_suite/methods/unsupported_operator.c3 index 8610d8f9d..076c65b97 100755 --- a/test/test_suite/methods/unsupported_operator.c3 +++ b/test/test_suite/methods/unsupported_operator.c3 @@ -1,3 +1,3 @@ import std::io; -fn int int.fadd(&self, int x) @operator([]) { return 1; } // #error: Only user-defined types support operator overloading \ No newline at end of file +fn int int.fadd(&self, int x) @operator([]) { return 1; } // #error: Only overloads involving user-defined types support overloading \ No newline at end of file diff --git a/test/unit/regression/operator_overload.c3 b/test/unit/regression/operator_overload.c3 new file mode 100644 index 000000000..8fd16425e --- /dev/null +++ b/test/unit/regression/operator_overload.c3 @@ -0,0 +1,47 @@ +module operator_overload; +import std; + +struct Abc +{ int a; } + +int b; +fn bool Abc.neq(self, Abc abc) @operator(!=) => self.a != abc.a; +fn bool Abc.eq(self, Abc abc) @operator(==) => self.a == abc.a; +fn Abc Abc.plus(self, Abc abc) @operator(+) => { self.a + abc.a }; +fn Abc Abc.plus2(self, int i) @operator(+) => { self.a + i }; +fn Abc Abc.minus2(self, int i) @operator(-) => { self.a - i }; +fn Abc int.plus_abc(self, Abc abc) @operator(+) => { self + abc.a }; +fn int int.mul_abc(self, Abc abc) @operator(*) => self * abc.a; +fn Abc Abc.negate(self) @operator(~) => { ~self.a }; +fn Abc Abc.negate2(self) @operator(-) => { -self.a }; +fn Abc Abc.shr2(self, int x) @operator(>>) => { self.a >> x }; +fn Abc Abc.shl2(self, int x) @operator(<<) @local { b++; return { self.a << x }; } +fn Abc Abc.shl(self, Abc x) @operator(<<) => { self.a << x.a }; +fn Abc Abc.div(self, Abc abc) @operator(/) => { self.a / abc.a }; +fn Abc Abc.rem(self, Abc abc) @operator(%) => { self.a % abc.a }; +fn Abc Abc.xor(self, Abc abc) @operator(^) => { self.a ^ abc.a }; +fn Abc Abc.or(self, Abc abc) @operator(|) => { self.a | abc.a }; +fn Abc Abc.and(self, Abc abc) @operator(&) => { self.a & abc.a }; + +fn void test_struct_overload() @test +{ + Abc x = { 2 }; + Abc y = { 3 }; + assert(x != y); + assert(x == x); + assert(x + 2 == { 4 }); + assert(x - 2 == { 0 }); + assert(2 + x == { 4 }); + assert(3 * x == 6); + assert(~x == { -3 }); + assert(-x == { -2 }); + assert(+x == { 2 }); + assert(y >> 1 == { 1 }); + assert(y << x == { 12 }); + assert(x << 4 == { 32 }); + assert((Abc) { 123 } / x == { 61 }); + assert((Abc) { 123 } % x == { 1 }); + assert(x ^ y == { 1 }); + assert(x | y == { 3 }); + assert(x & y == { 2 }); +} \ No newline at end of file diff --git a/test/unit/stdlib/math/matrix.c3 b/test/unit/stdlib/math/matrix.c3 index ff2831905..6b817e294 100644 --- a/test/unit/stdlib/math/matrix.c3 +++ b/test/unit/stdlib/math/matrix.c3 @@ -8,6 +8,7 @@ fn void test_mat4() Matrix4 mat2 = { 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1 }; Matrix4 calc = mat.mul(mat2); assert(calc.m == mat.m); + assert(mat * mat2 == mat); Matrix4 translated = mat.translate({0.0, 0.0, 0.0}); assert(translated.m == mat.m); @@ -19,6 +20,7 @@ fn void test_mat4() Matrix4 calc = mat.mul(mat2); Matrix4 value = { 56, 46, 36, 26, 152, 126, 100, 74, 56, 46, 36, 26, 152, 126, 100, 74 }; assert(calc.m == value.m); + assert(mat * mat2 == value); }; {