From 3888fcb182097f81d7ecbdc63fe660ef927cfcff Mon Sep 17 00:00:00 2001 From: Christoffer Lerno Date: Sun, 13 Apr 2025 13:43:03 +0200 Subject: [PATCH] - Add `@operator_r` and `@operator_s` attributes. --- lib/std/math/math_complex.c3 | 12 +- releasenotes.md | 1 + src/compiler/compiler_internal.h | 5 +- src/compiler/context.c | 1 + src/compiler/enums.h | 9 + src/compiler/sema_decls.c | 183 ++++++++++++++---- src/compiler/sema_expr.c | 85 +++++--- src/compiler/sema_internal.h | 4 +- src/compiler/sema_stmts.c | 2 +- src/compiler/symtab.c | 2 + .../methods/unsupported_operator.c3 | 2 +- test/unit/regression/operator_overload.c3 | 4 +- 12 files changed, 230 insertions(+), 80 deletions(-) diff --git a/lib/std/math/math_complex.c3 b/lib/std/math/math_complex.c3 index d636c3d35..849ae6b64 100644 --- a/lib/std/math/math_complex.c3 +++ b/lib/std/math/math_complex.c3 @@ -12,19 +12,19 @@ union Complex (Printable) const Complex IDENTITY = { 1, 0 }; const Complex IMAGINARY = { 0, 1 }; -macro Complex Real.add_complex(self, Complex r) @operator(+) => { .v = (Real[<2>]) { self, 0 } + c.v }; -macro Complex Real.sub_complex(self, Complex r) @operator(-) => { .v = (Real[<2>]) { self, 0 } - c.v }; -macro Complex Real.scale_complex(self, Complex c) @operator(*) => { .v = self * c.v }; -macro Complex Real.div_complex(self, Complex c) @operator(/) => ((Complex) { .r = self }).div(c); + + macro Complex Complex.add(self, Complex b) @operator(+) => { .v = self.v + b.v }; -macro Complex Complex.add_real(self, Real r) @operator(+) => { .v = self.v + (Real[<2>]) { r, 0 } }; +macro Complex Complex.add_real(self, Real r) @operator_s(+) => { .v = self.v + (Real[<2>]) { r, 0 } }; macro Complex Complex.add_each(self, Real b) => { .v = self.v + b }; macro Complex Complex.sub(self, Complex b) @operator(-) => { .v = self.v - b.v }; macro Complex Complex.sub_real(self, Real r) @operator(-) => { .v = self.v - (Real[<2>]) { r, 0 } }; +macro Complex Complex.sub_from_real(self, Real r) @operator_r(-) => { .v = (Real[<2>]) { r, 0 } - self.v }; macro Complex Complex.sub_each(self, Real b) => { .v = self.v - b }; -macro Complex Complex.scale(self, Real r) @operator(*) => { .v = self.v * r }; +macro Complex Complex.scale(self, Real r) @operator_s(*) => { .v = self.v * r }; macro Complex Complex.mul(self, Complex b)@operator(*) => { self.r * b.r - self.c * b.c, self.r * b.c + b.r * self.c }; macro Complex Complex.div_real(self, Real r) @operator(/) => { .v = self.v / r }; +macro Complex Complex.real_div(Complex c, Real r) @operator_r(/) => ((Complex) { .r = self }).div(c); macro Complex Complex.div(self, Complex b) @operator(/) { Real div = b.v.dot(b.v); diff --git a/releasenotes.md b/releasenotes.md index 4a9fd4ab0..db7dbef3f 100644 --- a/releasenotes.md +++ b/releasenotes.md @@ -8,6 +8,7 @@ - Function `@require` checks are added to the caller in safe mode. #186 - Improved error message when narrowing isn't allowed. - Operator overloading for `+ - * / % & | ^ << >> ~ == !=` +- Add `@operator_r` and `@operator_s` attributes. ### Fixes - Trying to cast an enum to int and back caused the compiler to crash. diff --git a/src/compiler/compiler_internal.h b/src/compiler/compiler_internal.h index 8d0789445..57312aa9a 100644 --- a/src/compiler/compiler_internal.h +++ b/src/compiler/compiler_internal.h @@ -498,7 +498,8 @@ struct Signature_ typedef struct { TypeInfoId type_parent; - OperatorOverload operator : 8; + OperatorOverload operator : 6; + unsigned overload_type : 2; Signature signature; AstId body; AstId docs; @@ -815,6 +816,7 @@ typedef struct ExprId offset; } ExprPointerOffset; + typedef struct { BuiltinAccessKind kind : 8; @@ -1602,6 +1604,7 @@ struct CompilationUnit_ Module *module; File *file; Decl **imports; + Decl **public_imports; Decl **types; Decl **functions; Decl **lambdas; diff --git a/src/compiler/context.c b/src/compiler/context.c index 97f059e98..4de2311ab 100644 --- a/src/compiler/context.c +++ b/src/compiler/context.c @@ -311,6 +311,7 @@ bool unit_add_import(CompilationUnit *unit, Path *path, bool private_import, boo import->import.import_private_as_public = private_import; import->import.is_non_recurse = is_non_recursive; vec_add(unit->imports, import); + if (private_import) vec_add(unit->public_imports, import); DEBUG_LOG("Added import %s", path->module); return true; } diff --git a/src/compiler/enums.h b/src/compiler/enums.h index d32db6bdb..ca70fd2f9 100644 --- a/src/compiler/enums.h +++ b/src/compiler/enums.h @@ -295,6 +295,8 @@ typedef enum ATTRIBUTE_NOSTRIP, ATTRIBUTE_OBFUSCATE, ATTRIBUTE_OPERATOR, + ATTRIBUTE_OPERATOR_R, + ATTRIBUTE_OPERATOR_S, ATTRIBUTE_OPTIONAL, ATTRIBUTE_OVERLAP, ATTRIBUTE_PACKED, @@ -896,6 +898,13 @@ typedef enum OBJ_FORMAT_AOUT, } ObjectFormatType; +typedef enum +{ + OVERLOAD_TYPE_LEFT = 1, + OVERLOAD_TYPE_RIGHT = 2, + OVERLOAD_TYPE_SYMMETRIC = OVERLOAD_TYPE_LEFT | OVERLOAD_TYPE_RIGHT, +} OverloadType; + typedef enum { OVERLOAD_ELEMENT_AT = 1, diff --git a/src/compiler/sema_decls.c b/src/compiler/sema_decls.c index 14c746b3e..ebe772cbb 100755 --- a/src/compiler/sema_decls.c +++ b/src/compiler/sema_decls.c @@ -20,7 +20,8 @@ static inline bool unit_add_method(SemaContext *context, Type *parent_type, Decl static bool sema_analyse_operator_common(SemaContext *context, Decl *method, TypeInfo **rtype_ptr, Decl ***params_ptr, uint32_t parameters); static inline Decl *operator_in_module_typed(SemaContext *c, Module *module, OperatorOverload operator_overload, - Type *method_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref); + OverloadType overload_type, Type *method_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref); +static inline Decl *operator_in_module_exact_typed(SemaContext *c, Module *module, OperatorOverload operator_overload, OverloadType overload_type, Type *method_type, Type *param_type); static inline bool sema_analyse_operator_element_at(SemaContext *context, Decl *method); static inline bool sema_analyse_operator_element_set(SemaContext *context, Decl *method); static inline bool sema_analyse_operator_len(SemaContext *context, Decl *method); @@ -47,7 +48,8 @@ static bool sema_analyse_attributes(SemaContext *context, Decl *decl, Attr **att static bool sema_analyse_attributes_for_var(SemaContext *context, Decl *decl, bool *erase_decl); static bool sema_check_section(SemaContext *context, Attr *attr); static inline bool sema_analyse_attribute_decl(SemaContext *context, SemaContext *c, Decl *decl, bool *erase_decl); -static Decl *sema_find_typed_operator_in_list(SemaContext *context, Decl **methods, OperatorOverload operator_overload, Type *parent_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref); +static Decl *sema_find_typed_operator_in_list(SemaContext *context, Decl **methods, OperatorOverload operator_overload, OverloadType overload_type, Type *parent_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref); +static Decl *sema_find_exact_typed_operator_in_list(SemaContext *context, Decl **methods, OperatorOverload operator_overload, OverloadType overload_type, Type *parent_type, Type *binary_type); static inline bool sema_analyse_typedef(SemaContext *context, Decl *decl, bool *erase_decl); static bool sema_analyse_variable_type(SemaContext *context, Type *type, SourceSpan span); @@ -1704,15 +1706,26 @@ INLINE bool decl_matches_overload(Decl *method, Type *type, OperatorOverload ove return method->func_decl.operator == overload && typeget(method->func_decl.type_parent)->canonical == type; } -static inline Decl *operator_in_module_typed(SemaContext *c, Module *module, OperatorOverload operator_overload, Type *method_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref) +static inline Decl *operator_in_module_typed(SemaContext *c, Module *module, OperatorOverload operator_overload, OverloadType overload_type, Type *method_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref) { if (module->is_generic) return NULL; - Decl *found = sema_find_typed_operator_in_list(c, module->private_method_extensions, operator_overload, - method_type, binary_arg, binary_type, candidate_ref, ambiguous_ref); + Decl *found = sema_find_typed_operator_in_list(c, module->private_method_extensions, operator_overload, OVERLOAD_TYPE_SYMMETRIC, method_type, binary_arg, binary_type, candidate_ref, ambiguous_ref); if (found) return found; FOREACH(Module *, sub_module, module->sub_modules) { - return operator_in_module_typed(c, sub_module, operator_overload, method_type, binary_arg, binary_type, candidate_ref, ambiguous_ref); + return operator_in_module_typed(c, sub_module, operator_overload, overload_type, method_type, binary_arg, binary_type, candidate_ref, ambiguous_ref); + } + return NULL; +} + +static inline Decl *operator_in_module_exact_typed(SemaContext *c, Module *module, OperatorOverload operator_overload, OverloadType overload_type, Type *method_type, Type *param_type) +{ + if (module->is_generic) return NULL; + Decl *found = sema_find_exact_typed_operator_in_list(c, module->private_method_extensions, operator_overload, overload_type, method_type, param_type); + if (found) return found; + FOREACH(Module *, sub_module, module->sub_modules) + { + return operator_in_module_exact_typed(c, sub_module, operator_overload, overload_type, method_type, param_type); } return NULL; } @@ -1757,18 +1770,34 @@ Decl *sema_find_untyped_operator(SemaContext *context, Type *type, OperatorOverl if (extension) return extension; FOREACH(Decl *, import, context->unit->imports) { + if (!import->import.import_private_as_public) continue; extension = operator_in_module_untyped(context, import->import.module, type, operator_overload); if (extension) return extension; } return NULL; } -static Decl *sema_find_typed_operator_in_list(SemaContext *context, Decl **methods, OperatorOverload operator_overload, Type *parent_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref) +static Decl *sema_find_exact_typed_operator_in_list(SemaContext *context, Decl **methods, OperatorOverload operator_overload, OverloadType overload_type, Type *parent_type, Type *binary_type) { FOREACH(Decl *, func, methods) { if (func->func_decl.operator != operator_overload) continue; if (parent_type && parent_type != typeget(func->func_decl.type_parent)) continue; + if ((overload_type & func->func_decl.overload_type) == 0) continue; + Type *first_arg = func->func_decl.signature.params[1]->type->canonical; + if (first_arg != binary_type) continue; + return func; + } + return NULL; +} + +static Decl *sema_find_typed_operator_in_list(SemaContext *context, Decl **methods, OperatorOverload operator_overload, OverloadType overload_type, Type *parent_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref) +{ + FOREACH(Decl *, func, methods) + { + if (func->func_decl.operator != operator_overload) continue; + if (parent_type && parent_type != typeget(func->func_decl.type_parent)) continue; + if ((overload_type & func->func_decl.overload_type) == 0) continue; Type *first_arg = func->func_decl.signature.params[1]->type->canonical; if (first_arg != binary_type) { @@ -1789,43 +1818,83 @@ static Decl *sema_find_typed_operator_in_list(SemaContext *context, Decl **metho } return NULL; } -Decl *sema_find_typed_operator(SemaContext *context, Type *type, OperatorOverload operator_overload, Expr *binary_arg, Type *binary_type, Decl **ambiguous_ref) + +static Decl *sema_find_exact_typed_operator(SemaContext *context, Type *type, OperatorOverload operator_overload, OverloadType overload_type, Type *param_type) { assert(operator_overload >= OVERLOAD_TYPED_START); - assert(!binary_arg || ambiguous_ref); - assert(!binary_type || !binary_arg); type = type->canonical; - if (binary_arg) binary_type = binary_arg->type->canonical; - Decl *candidate = NULL; - Decl *ambiguous = NULL; - if (type_is_user_defined(type)) - { - Decl *func = sema_find_typed_operator_in_list(context, type->decl->methods, operator_overload, type, binary_arg, - binary_type, &candidate, &ambiguous); - if (func) return func; - } - else - { - Decl *func = sema_find_typed_operator_in_list(context, compiler.context.method_extensions, - operator_overload, type, binary_arg, binary_type, &candidate, &ambiguous); - if (func) return func; - } + Decl *func = sema_find_exact_typed_operator_in_list(context, type->decl->methods, operator_overload, overload_type, type, param_type); + if (func) return func; - Decl *extension = sema_find_typed_operator_in_list(context, context->unit->local_method_extensions, - operator_overload, type, binary_arg, binary_type, &candidate, &ambiguous); + Decl *extension = sema_find_exact_typed_operator_in_list(context, context->unit->local_method_extensions, operator_overload, overload_type, type, param_type); if (extension) return extension; - extension = operator_in_module_typed(context, context->compilation_unit->module, operator_overload, type, - binary_arg, binary_type, &candidate, &ambiguous); + extension = operator_in_module_exact_typed(context, context->compilation_unit->module, operator_overload, overload_type, + type, param_type); if (extension) return extension; FOREACH(Decl *, import, context->unit->imports) { - extension = operator_in_module_typed(context, import->import.module, operator_overload, type, binary_arg, - binary_type, &candidate, &ambiguous); + if (!import->import.import_private_as_public) continue; + extension = operator_in_module_exact_typed(context, import->import.module, operator_overload, overload_type, + type, param_type); if (extension) return extension; } + return NULL; +} + +static Decl *sema_find_typed_operator_type(SemaContext *context, OperatorOverload operator_overload, OverloadType overloat_type, Type *lhs_type, Type *rhs_type, Expr *rhs, Decl **candidate_ref, Decl **ambiguous_ref) +{ + // Can we find the overload directly on the method? + Decl *func = sema_find_typed_operator_in_list(context, lhs_type->decl->methods, operator_overload, overloat_type, lhs_type, + rhs, rhs_type, candidate_ref, ambiguous_ref); + if (func) return func; + // Can we find it as a local extension? + Decl *extension = sema_find_typed_operator_in_list(context, context->unit->local_method_extensions, + operator_overload, overloat_type, lhs_type, rhs, rhs_type, candidate_ref, ambiguous_ref); + // Can we find it in the current module? + if (extension) return extension; + extension = operator_in_module_typed(context, context->compilation_unit->module, operator_overload, overloat_type, + lhs_type, rhs, rhs_type, candidate_ref, ambiguous_ref); + if (extension) return extension; + // TODO incorrect, doesn't recurse + // Look through our public imports + FOREACH(Decl *, import, context->unit->public_imports) + { + extension = operator_in_module_typed(context, import->import.module, operator_overload, overloat_type, + lhs_type, rhs, rhs_type, candidate_ref, ambiguous_ref); + if (extension) return extension; + } + return NULL; +} + +Decl *sema_find_typed_operator(SemaContext *context, OperatorOverload operator_overload, Expr *lhs, Expr *rhs, Decl **ambiguous_ref, bool *reverse) +{ + assert(operator_overload >= OVERLOAD_TYPED_START); + assert(lhs && rhs && ambiguous_ref); + Type *left_type = type_no_optional(lhs->type)->canonical; + Type *right_type = type_no_optional(rhs->type)->canonical; + + Decl *candidate = NULL; + Decl *ambiguous = NULL; + Decl *left_candidate = NULL; + *reverse = false; + if (type_is_user_defined(left_type)) + { + Decl *found = sema_find_typed_operator_type(context, operator_overload, OVERLOAD_TYPE_LEFT, left_type, right_type, rhs, &candidate, &ambiguous); + if (found) return found; + left_candidate = candidate; + } + if (type_is_user_defined(right_type)) + { + Decl *found = sema_find_typed_operator_type(context, operator_overload, OVERLOAD_TYPE_RIGHT, right_type, left_type, lhs, &candidate, &ambiguous); + if (found) + { + *reverse = true; + return found; + } + } if (ambiguous) { *ambiguous_ref = ambiguous; @@ -1833,7 +1902,7 @@ Decl *sema_find_typed_operator(SemaContext *context, Type *type, OperatorOverloa } if (candidate) { - unit_register_external_symbol(context, candidate); + *reverse = candidate != left_candidate; return candidate; } return NULL; @@ -2063,22 +2132,30 @@ INLINE bool sema_analyse_operator_method(SemaContext *context, Type *parent_type // See if the operator has already been defined. OperatorOverload operator = method->func_decl.operator; - Type *second_param = vec_size(method->func_decl.signature.params) > 1 ? method->func_decl.signature.params[1]->type : NULL; + Type *second_param = vec_size(method->func_decl.signature.params) > 1 ? method->func_decl.signature.params[1]->type->canonical : NULL; - // We don't support operator overloading on base types, because - // there seems little use for it frankly. - if (!type_is_user_defined(parent_type) && (operator < OVERLOAD_TYPED_START || !type_is_user_defined(second_param->canonical))) + if (!type_is_user_defined(parent_type)) { - sema_error_at(context, method_find_overload_span(method), - "Only overloads involving user-defined types support overloading."); - return false; + RETURN_SEMA_ERROR(method, "Only user-defined types may have overloads."); } + bool is_symmetric = method->func_decl.overload_type == OVERLOAD_TYPE_SYMMETRIC; + bool is_reverse = method->func_decl.overload_type == OVERLOAD_TYPE_RIGHT; + if (!second_param) + { + if (is_symmetric) RETURN_SEMA_ERROR(method, "Methods with single arguments cannot be have symmetric operators."); + if (is_reverse) RETURN_SEMA_ERROR(method, "Methods with single arguments cannot have reverse operators."); + } + else if (second_param->type_kind == parent_type->type_kind) + { + if (is_symmetric) RETURN_SEMA_ERROR(method, "Methods with same argument types cannot be have symmetric operators."); + if (is_reverse) RETURN_SEMA_ERROR(method, "Methods with the same argument types cannot have reverse operators."); + } Decl *other = NULL; if (operator >= OVERLOAD_TYPED_START) { - other = sema_find_typed_operator(context, parent_type, operator, NULL, second_param, NULL); + other = sema_find_exact_typed_operator(context, parent_type, operator, method->func_decl.overload_type, second_param); } else { @@ -2743,6 +2820,8 @@ static bool sema_analyse_attribute(SemaContext *context, ResolvedAttrData *attr_ [ATTRIBUTE_NOSTRIP] = ATTR_FUNC | ATTR_GLOBAL | ATTR_CONST | EXPORTED_USER_DEFINED_TYPES, [ATTRIBUTE_OBFUSCATE] = ATTR_ENUM | ATTR_FAULT, [ATTRIBUTE_OPERATOR] = ATTR_MACRO | ATTR_FUNC, + [ATTRIBUTE_OPERATOR_R] = ATTR_MACRO | ATTR_FUNC, + [ATTRIBUTE_OPERATOR_S] = ATTR_MACRO | ATTR_FUNC, [ATTRIBUTE_OPTIONAL] = ATTR_INTERFACE_METHOD, [ATTRIBUTE_OVERLAP] = ATTR_BITSTRUCT, [ATTRIBUTE_PACKED] = ATTR_STRUCT | ATTR_UNION, @@ -2847,6 +2926,8 @@ static bool sema_analyse_attribute(SemaContext *context, ResolvedAttrData *attr_ case ATTRIBUTE_TEST: decl->func_decl.attr_test = true; break; + case ATTRIBUTE_OPERATOR_R: + case ATTRIBUTE_OPERATOR_S: case ATTRIBUTE_OPERATOR: { ASSERT(decl->decl_kind == DECL_FUNC || decl->decl_kind == DECL_MACRO); @@ -2864,6 +2945,20 @@ static bool sema_analyse_attribute(SemaContext *context, ResolvedAttrData *attr_ break; case EXPR_OPERATOR_CHARS: decl->func_decl.operator = expr->overload_expr; + switch (type) + { + case ATTRIBUTE_OPERATOR: + decl->func_decl.overload_type = OVERLOAD_TYPE_LEFT; + break; + case ATTRIBUTE_OPERATOR_R: + decl->func_decl.overload_type = OVERLOAD_TYPE_RIGHT; + break; + case ATTRIBUTE_OPERATOR_S: + decl->func_decl.overload_type = OVERLOAD_TYPE_SYMMETRIC; + break; + default: + UNREACHABLE + } break; default: goto FAILED_OP_TYPE; @@ -4454,6 +4549,14 @@ static CompilationUnit *unit_copy(Module *module, CompilationUnit *unit) { CompilationUnit *copy = unit_create(unit->file); copy->imports = copy_decl_list_single(unit->imports); + copy->public_imports = NULL; + if (unit->public_imports) + { + FOREACH(Decl *, import, copy->imports) + { + if (import->import.import_private_as_public) vec_add(copy->public_imports, import); + } + } copy->global_decls = copy_decl_list_single_for_unit(unit->global_decls); copy->global_cond_decls = copy_decl_list_single_for_unit(unit->global_cond_decls); copy->module = module; diff --git a/src/compiler/sema_expr.c b/src/compiler/sema_expr.c index ff6868744..884884521 100644 --- a/src/compiler/sema_expr.c +++ b/src/compiler/sema_expr.c @@ -207,7 +207,7 @@ static inline bool sema_analyse_expr_check(SemaContext *context, Expr *expr, Che static inline Expr **sema_prepare_splat_insert(Expr **exprs, unsigned added, unsigned insert_point); static inline bool sema_analyse_maybe_dead_expr(SemaContext *, Expr *expr, bool is_dead, Type *infer_type); -static inline bool sema_insert_binary_overload(SemaContext *context, Expr *expr, Decl *overload, Expr *lhs, Expr *rhs); +static inline bool sema_insert_binary_overload(SemaContext *context, Expr *expr, Decl *overload, Expr *lhs, Expr *rhs, bool reverse); // -- implementations @@ -3356,7 +3356,7 @@ static inline bool sema_expr_analyse_subscript_lvalue(SemaContext *context, Expr if (!sema_analyse_expr(context, current_expr)) return false; Expr *var_for_len = expr_variable(temp); Expr *len_expr = expr_new(EXPR_CALL, expr->span); - if (!sema_insert_method_call(context, len_expr, len, var_for_len, NULL)) return false; + if (!sema_insert_method_call(context, len_expr, len, var_for_len, NULL, false)) return false; if (!sema_analyse_expr(context, len_expr)) return false; Expr *index_copy = expr_copy(index); if (!sema_analyse_expr(context, index_copy)) return false; @@ -3475,7 +3475,7 @@ static inline bool sema_expr_analyse_subscript(SemaContext *context, Expr *expr, if (!sema_analyse_expr(context, current_expr)) return false; Expr *var_for_len = expr_variable(temp); Expr *len_expr = expr_new(EXPR_CALL, expr->span); - if (!sema_insert_method_call(context, len_expr, len, var_for_len, NULL)) return false; + if (!sema_insert_method_call(context, len_expr, len, var_for_len, NULL, false)) return false; if (!sema_analyse_expr(context, len_expr)) return false; Expr *index_copy = expr_copy(index); if (!sema_analyse_expr(context, index_copy)) return false; @@ -3486,7 +3486,7 @@ static inline bool sema_expr_analyse_subscript(SemaContext *context, Expr *expr, } Expr **args = NULL; vec_add(args, index); - return sema_insert_method_call(context, expr, overload, current_expr, args); + return sema_insert_method_call(context, expr, overload, current_expr, args, false); } // Cast to an appropriate type for index. @@ -6143,7 +6143,7 @@ static bool sema_expr_analyse_assign(SemaContext *context, Expr *expr, Expr *lef Expr **args = NULL; vec_add(args, exprptr(left->subscript_assign_expr.index)); vec_add(args, right); - return sema_insert_method_call(context, expr, declptr(left->subscript_assign_expr.method), exprptr(left->subscript_assign_expr.expr), args); + return sema_insert_method_call(context, expr, declptr(left->subscript_assign_expr.method), exprptr(left->subscript_assign_expr.expr), args, false); } if (left->expr_kind == EXPR_BITACCESS) { @@ -6427,11 +6427,12 @@ static bool sema_replace_with_overload(SemaContext *context, Expr *expr, Expr *l { assert(!type_is_optional(left_type) && left_type->canonical == left_type); Decl *ambiguous = NULL; - Decl *overload = sema_find_typed_operator(context, left_type, *operator_overload_ref, right, NULL, &ambiguous); + bool reverse; + Decl *overload = sema_find_typed_operator(context, *operator_overload_ref, left, right, &ambiguous, &reverse); if (overload) { - *operator_overload_ref = 0; - return sema_insert_binary_overload(context, expr, overload, left, right); + *operator_overload_ref = (OperatorOverload)0; // NOLINT + return sema_insert_binary_overload(context, expr, overload, left, right, reverse); } if (ambiguous) { @@ -6804,6 +6805,7 @@ static bool sema_expr_analyse_add(SemaContext *context, Expr *expr, Expr *left, right_type = type_no_optional(right->type)->canonical; ASSERT_SPAN(expr, !cast_to_iptr); + // 4. Do a binary arithmetic promotion OperatorOverload overload = OVERLOAD_PLUS; if (!sema_binary_arithmetic_promotion(context, left, right, left_type, right_type, expr, @@ -7238,23 +7240,24 @@ static bool sema_expr_analyse_comp(SemaContext *context, Expr *expr, Expr *left, Decl *overload = NULL; bool negated_overload = false; Decl *ambiguous = NULL; + bool reverse = false; switch (expr->binary_expr.operator) { case BINARYOP_NE: - overload = sema_find_typed_operator(context, left_type, OVERLOAD_NOT_EQUAL, right, NULL, &ambiguous); + overload = sema_find_typed_operator(context, OVERLOAD_NOT_EQUAL, left, right, &ambiguous, &reverse); if (!overload && !ambiguous) { negated_overload = true; - overload = sema_find_typed_operator(context, left_type, OVERLOAD_EQUAL, right, NULL, &ambiguous); + overload = sema_find_typed_operator(context, OVERLOAD_EQUAL, left, right, &ambiguous, &reverse); } if (!overload) goto NEXT; break; case BINARYOP_EQ: - overload = sema_find_typed_operator(context, left_type, OVERLOAD_EQUAL, right, NULL, &ambiguous); + overload = sema_find_typed_operator(context, OVERLOAD_EQUAL, left, right, &ambiguous, &reverse); if (!overload && !ambiguous) { negated_overload = true; - overload = sema_find_typed_operator(context, left_type, OVERLOAD_NOT_EQUAL, right, NULL, &ambiguous); + overload = sema_find_typed_operator(context, OVERLOAD_NOT_EQUAL, left, right, &ambiguous, &reverse); } if (!overload) goto NEXT; break; @@ -7267,7 +7270,7 @@ static bool sema_expr_analyse_comp(SemaContext *context, Expr *expr, Expr *left, expr_insert_addr(right); } vec_add(args, right); - if (!sema_insert_method_call(context, expr, overload, left, args)) return false; + if (!sema_insert_method_call(context, expr, overload, left, args, reverse)) return false; if (!negated_overload) return true; assert(expr->resolve_status == RESOLVE_DONE); Expr *inner = expr_copy(expr); @@ -7652,7 +7655,7 @@ static inline bool sema_expr_analyse_neg_plus(SemaContext *context, Expr *expr) expr_replace(expr, inner); return true; } - return sema_insert_method_call(context, expr, overload, inner, NULL); + return sema_insert_method_call(context, expr, overload, inner, NULL, false); } } if (!type_may_negate(no_fail)) @@ -7720,7 +7723,7 @@ static inline bool sema_expr_analyse_bit_not(SemaContext *context, Expr *expr) if (type_is_user_defined(canonical) && canonical->type_kind != TYPE_BITSTRUCT) { Decl *overload = sema_find_untyped_operator(context, canonical, OVERLOAD_NEGATE); - if (overload) return sema_insert_method_call(context, expr, overload, inner, NULL); + if (overload) return sema_insert_method_call(context, expr, overload, inner, NULL, false); } Type *flat = type_flatten(canonical); @@ -7900,7 +7903,7 @@ static bool sema_analyse_assign_mutate_overloaded_subscript(SemaContext *context if (operator) { vec_add(args, exprptr(subscript_expr->subscript_assign_expr.index)); - if (!sema_insert_method_call(context, subscript_expr, operator, exprptr(subscript_expr->subscript_assign_expr.expr), args)) return false; + if (!sema_insert_method_call(context, subscript_expr, operator, exprptr(subscript_expr->subscript_assign_expr.expr), args, false)) return false; expr_rewrite_insert_deref(subscript_expr); main->type = subscript_expr->type; return true; @@ -7949,7 +7952,7 @@ static bool sema_analyse_assign_mutate_overloaded_subscript(SemaContext *context vec_add(args, expr_variable(index_val)); Expr *temp_val_1 = expr_variable(temp_val); expr_rewrite_insert_deref(temp_val_1); - if (!sema_insert_method_call(context, get_expr, operator, temp_val_1, args)) return false; + if (!sema_insert_method_call(context, get_expr, operator, temp_val_1, args, false)) return false; Expr *value_val_expr = expr_generate_decl(value_val, get_expr); // temp_value = func(temp, temp_index) vec_add(main->expression_list, value_val_expr); @@ -7961,7 +7964,7 @@ static bool sema_analyse_assign_mutate_overloaded_subscript(SemaContext *context vec_add(args, expr_variable(value_val)); Expr *temp_val_2 = expr_variable(temp_val); expr_rewrite_insert_deref(temp_val_2); - if (!sema_insert_method_call(context, subscript_expr, declptr(subscript_expr->subscript_assign_expr.method), temp_val_2, args)) return false; + if (!sema_insert_method_call(context, subscript_expr, declptr(subscript_expr->subscript_assign_expr.method), temp_val_2, args, false)) return false; ASSERT(subscript_expr->expr_kind == EXPR_CALL); subscript_expr->call_expr.has_optional_arg = false; vec_add(main->expression_list, subscript_expr); @@ -10716,19 +10719,42 @@ TokenType sema_splitpathref(const char *string, ArraySize len, Path **path_ref, } } -bool sema_insert_method_call(SemaContext *context, Expr *method_call, Decl *method_decl, Expr *parent, Expr **arguments) +bool sema_insert_method_call(SemaContext *context, Expr *method_call, Decl *method_decl, Expr *parent, Expr **arguments, bool reverse_overload) { SourceSpan original_span = method_call->span; + Expr *resolve = method_call; + if (reverse_overload) + { + Decl *temp = decl_new_generated_var(method_decl->func_decl.signature.params[1]->type, VARDECL_LOCAL, parent->span); + Expr *generate = expr_generate_decl(temp, expr_copy(parent)); + parent->expr_kind = EXPR_IDENTIFIER; + parent->ident_expr = temp; + parent->resolve_status = RESOLVE_DONE; + parent->type = temp->type; + Expr **list = NULL; + vec_add(list, generate); + if (!sema_analyse_expr(context, generate)) return false; + Expr *copied_method = expr_copy(method_call); + vec_add(list, copied_method); + method_call->expr_kind = EXPR_EXPRESSION_LIST; + method_call->resolve_status = RESOLVE_NOT_DONE; + method_call->expression_list = list; + Expr *arg0 = arguments[0]; + arguments[0] = parent; + parent = arg0; + method_call = copied_method; + } *method_call = (Expr) { .expr_kind = EXPR_CALL, .span = original_span, .resolve_status = RESOLVE_RUNNING, .call_expr.func_ref = declid(method_decl), .call_expr.arguments = arguments, .call_expr.is_func_ref = true, - .call_expr.is_type_method = true }; + .call_expr.is_type_method = true, + }; Type *type = parent->type->canonical; Decl *first_param = method_decl->func_decl.signature.params[0]; - Type *first = first_param->type; + Type *first = first_param->type->canonical; // Deref / addr as needed. if (type != first) { @@ -10754,22 +10780,27 @@ bool sema_insert_method_call(SemaContext *context, Expr *method_call, Decl *meth if (!sema_analyse_expr(context, parent)) return false; expr_rewrite_insert_deref(parent); } + if (!(parent && parent->type && first == parent->type->canonical)) + { + puts("TODO"); + } ASSERT_SPAN(method_call, parent && parent->type && first == parent->type->canonical); if (!sema_expr_analyse_general_call(context, method_call, method_decl, parent, false, NULL)) return expr_poison(method_call); method_call->resolve_status = RESOLVE_DONE; + if (resolve != method_call) + { + resolve->resolve_status = RESOLVE_DONE; + resolve->type = method_call->type; + } return true; } -static inline bool sema_insert_binary_overload(SemaContext *context, Expr *expr, Decl *overload, Expr *lhs, Expr *rhs) +static inline bool sema_insert_binary_overload(SemaContext *context, Expr *expr, Decl *overload, Expr *lhs, Expr *rhs, bool reverse) { Expr **args = NULL; - if (overload->func_decl.signature.params[1]->type->canonical->type_kind == TYPE_POINTER) - { - expr_insert_addr(rhs); - } vec_add(args, rhs); - return sema_insert_method_call(context, expr, overload, lhs, args); + return sema_insert_method_call(context, expr, overload, lhs, args, reverse); } // Check if the assignment fits diff --git a/src/compiler/sema_internal.h b/src/compiler/sema_internal.h index c61680d10..6276bdbf3 100644 --- a/src/compiler/sema_internal.h +++ b/src/compiler/sema_internal.h @@ -94,9 +94,9 @@ bool sema_analyse_expr_lvalue(SemaContext *context, Expr *expr, bool *failed_ref bool sema_analyse_expr_value(SemaContext *context, Expr *expr); Expr *expr_access_inline_member(Expr *parent, Decl *parent_decl); bool sema_analyse_ct_expr(SemaContext *context, Expr *expr); -Decl *sema_find_typed_operator(SemaContext *context, Type *type, OperatorOverload operator_overload, Expr *binary_arg, Type *binary_type, Decl **ambiguous_ref); +Decl *sema_find_typed_operator(SemaContext *context, OperatorOverload operator_overload, Expr *rhs, Expr *lhs, Decl **ambiguous_ref, bool *reverse); Decl *sema_find_untyped_operator(SemaContext *context, Type *type, OperatorOverload operator_overload); -bool sema_insert_method_call(SemaContext *context, Expr *method_call, Decl *method_decl, Expr *parent, Expr **arguments); +bool sema_insert_method_call(SemaContext *context, Expr *method_call, Decl *method_decl, Expr *parent, Expr **arguments, bool reverse_overload); bool sema_expr_analyse_builtin_call(SemaContext *context, Expr *expr); bool sema_expr_analyse_macro_call(SemaContext *context, Expr *call_expr, Expr *struct_var, Decl *decl, bool call_var_optional, bool *no_match_ref); diff --git a/src/compiler/sema_stmts.c b/src/compiler/sema_stmts.c index 7fe89b01a..2a3830f99 100644 --- a/src/compiler/sema_stmts.c +++ b/src/compiler/sema_stmts.c @@ -1576,7 +1576,7 @@ SKIP_OVERLOAD:; if (len) { len_call = expr_new(EXPR_CALL, enumerator->span); - if (!sema_insert_method_call(context, len_call, len, enum_val, NULL)) return false; + if (!sema_insert_method_call(context, len_call, len, enum_val, NULL, false)) return false; } else { diff --git a/src/compiler/symtab.c b/src/compiler/symtab.c index 8f41d7f57..b912b96d7 100644 --- a/src/compiler/symtab.c +++ b/src/compiler/symtab.c @@ -347,6 +347,8 @@ void symtab_init(uint32_t capacity) attribute_list[ATTRIBUTE_NOSTRIP] = KW_DEF("@nostrip"); attribute_list[ATTRIBUTE_OBFUSCATE] = KW_DEF("@obfuscate"); attribute_list[ATTRIBUTE_OPERATOR] = KW_DEF("@operator"); + attribute_list[ATTRIBUTE_OPERATOR_R] = KW_DEF("@operator_r"); + attribute_list[ATTRIBUTE_OPERATOR_S] = KW_DEF("@operator_s"); attribute_list[ATTRIBUTE_OPTIONAL] = KW_DEF("@optional"); attribute_list[ATTRIBUTE_OVERLAP] = KW_DEF("@overlap"); attribute_list[ATTRIBUTE_PACKED] = KW_DEF("@packed"); diff --git a/test/test_suite/methods/unsupported_operator.c3 b/test/test_suite/methods/unsupported_operator.c3 index 076c65b97..6c874d053 100755 --- a/test/test_suite/methods/unsupported_operator.c3 +++ b/test/test_suite/methods/unsupported_operator.c3 @@ -1,3 +1,3 @@ import std::io; -fn int int.fadd(&self, int x) @operator([]) { return 1; } // #error: Only overloads involving user-defined types support overloading \ No newline at end of file +fn int int.fadd(&self, int x) @operator([]) { return 1; } // #error: Only user-defined types may have overloads. \ No newline at end of file diff --git a/test/unit/regression/operator_overload.c3 b/test/unit/regression/operator_overload.c3 index 8fd16425e..6936a852c 100644 --- a/test/unit/regression/operator_overload.c3 +++ b/test/unit/regression/operator_overload.c3 @@ -10,8 +10,8 @@ fn bool Abc.eq(self, Abc abc) @operator(==) => self.a == abc.a; fn Abc Abc.plus(self, Abc abc) @operator(+) => { self.a + abc.a }; fn Abc Abc.plus2(self, int i) @operator(+) => { self.a + i }; fn Abc Abc.minus2(self, int i) @operator(-) => { self.a - i }; -fn Abc int.plus_abc(self, Abc abc) @operator(+) => { self + abc.a }; -fn int int.mul_abc(self, Abc abc) @operator(*) => self * abc.a; +fn Abc Abc.plus_abc(Abc abc, int self) @operator_r(+) => { self + abc.a }; +fn int Abc.mul_abc(Abc abc, int self) @operator_s(*) => self * abc.a; fn Abc Abc.negate(self) @operator(~) => { ~self.a }; fn Abc Abc.negate2(self) @operator(-) => { -self.a }; fn Abc Abc.shr2(self, int x) @operator(>>) => { self.a >> x };