diff --git a/lib/std/core/string.c3 b/lib/std/core/string.c3 index 5345fddd6..190cea02c 100644 --- a/lib/std/core/string.c3 +++ b/lib/std/core/string.c3 @@ -577,6 +577,19 @@ fn usz? String.rindex_of(self, String substr) return NOT_FOUND?; } +fn bool ZString.eq(self, ZString other) @operator(==) +{ + char* a = self; + char* b = other; + if (a == b) return true; + for (;; a++, b++) + { + char c = *a; + if (c != *b) return false; + if (!c) return true; + } +} + fn String ZString.str_view(self) { return (String)(self[:self.len()]); diff --git a/releasenotes.md b/releasenotes.md index 805752344..7487f3a5d 100644 --- a/releasenotes.md +++ b/releasenotes.md @@ -19,6 +19,7 @@ - Make accepting arguments for `main` a bit more liberal, accepting `main(int argc, ZString* argv)` - Make `$echo` and `@sprintf` correctly stringify compile time initializers and slices. - Add `--sources` build option to add additional files to compile. #2097 +- Support untyped second argument for operator overloading. ### Fixes - `-2147483648`, MIN literals work correctly. @@ -29,9 +30,12 @@ - Improve Android termux detection. - Update Android ABI. - Fixes to `@format` checking #2199. +- Distinct versions of builtin types ignore @operator overloads #2204. +- @operator macro using untyped parameter causes compiler segfault #2200. ### Stdlib changes - Deprecate `String.is_zstr` and `String.quick_zstr` #2188. +- Add comparison with `==` for ZString types. ## 0.7.2 Change list diff --git a/src/compiler/compiler_internal.h b/src/compiler/compiler_internal.h index 5b61ac865..bee7b5f76 100644 --- a/src/compiler/compiler_internal.h +++ b/src/compiler/compiler_internal.h @@ -500,6 +500,7 @@ typedef struct TypeInfoId type_parent; OperatorOverload operator : 6; unsigned overload_type : 2; + bool is_wildcard_overload : 1; Signature signature; AstId body; AstId docs; diff --git a/src/compiler/sema_decls.c b/src/compiler/sema_decls.c index ed1948082..8ada800a2 100755 --- a/src/compiler/sema_decls.c +++ b/src/compiler/sema_decls.c @@ -19,7 +19,7 @@ static inline bool unit_add_base_extension_method(SemaContext *context, Compilat static inline bool unit_add_method(SemaContext *context, Type *parent_type, Decl *method); static bool sema_analyse_operator_common(SemaContext *context, Decl *method, TypeInfo **rtype_ptr, Decl ***params_ptr, uint32_t parameters); -static inline Decl *operator_in_module_typed(SemaContext *c, Module *module, OperatorOverload operator_overload, +static inline bool operator_in_module_typed(SemaContext *c, Module *module, OperatorOverload operator_overload, OverloadType overload_type, Type *method_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref); static inline Decl *operator_in_module_exact_typed(Module *module, OperatorOverload operator_overload, OverloadType overload_type, Type *method_type, Type *param_type, Decl *skipped); static inline bool sema_analyse_operator_element_at(SemaContext *context, Decl *method); @@ -48,7 +48,7 @@ static bool sema_analyse_attributes(SemaContext *context, Decl *decl, Attr **att static bool sema_analyse_attributes_for_var(SemaContext *context, Decl *decl, bool *erase_decl); static bool sema_check_section(SemaContext *context, Attr *attr); static inline bool sema_analyse_attribute_decl(SemaContext *context, SemaContext *c, Decl *decl, bool *erase_decl); -static Decl *sema_find_typed_operator_in_list(SemaContext *context, Decl **methods, OperatorOverload operator_overload, OverloadType overload_type, Type *parent_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref); +static bool sema_find_typed_operator_in_list(SemaContext *context, Decl **methods, OperatorOverload operator_overload, OverloadType overload_type, Type *parent_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref); static Decl *sema_find_exact_typed_operator_in_list(Decl **methods, OperatorOverload operator_overload, OverloadType overload_type, Type *parent_type, Type *binary_type, Decl *skipped); static inline bool sema_analyse_typedef(SemaContext *context, Decl *decl, bool *erase_decl); @@ -1709,16 +1709,15 @@ INLINE bool decl_matches_overload(Decl *method, Type *type, OperatorOverload ove return method->func_decl.operator == overload && typeget(method->func_decl.type_parent)->canonical == type; } -static inline Decl *operator_in_module_typed(SemaContext *c, Module *module, OperatorOverload operator_overload, OverloadType overload_type, Type *method_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref) +static inline bool operator_in_module_typed(SemaContext *c, Module *module, OperatorOverload operator_overload, OverloadType overload_type, Type *method_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref) { if (module->is_generic) return NULL; - Decl *found = sema_find_typed_operator_in_list(c, module->private_method_extensions, operator_overload, OVERLOAD_TYPE_SYMMETRIC, method_type, binary_arg, binary_type, candidate_ref, ambiguous_ref); - if (found) return found; + if (!sema_find_typed_operator_in_list(c, module->private_method_extensions, operator_overload, OVERLOAD_TYPE_SYMMETRIC, method_type, binary_arg, binary_type, candidate_ref, ambiguous_ref)) return false; FOREACH(Module *, sub_module, module->sub_modules) { - return operator_in_module_typed(c, sub_module, operator_overload, overload_type, method_type, binary_arg, binary_type, candidate_ref, ambiguous_ref); + if (!operator_in_module_typed(c, sub_module, operator_overload, overload_type, method_type, binary_arg, binary_type, candidate_ref, ambiguous_ref)) return false; } - return NULL; + return true; } static inline Decl *operator_in_module_exact_typed(Module *module, OperatorOverload operator_overload, OverloadType overload_type, Type *method_type, Type *param_type, Decl *skipped) @@ -1779,6 +1778,7 @@ Decl *sema_find_untyped_operator(SemaContext *context, Type *type, OperatorOverl static Decl *sema_find_exact_typed_operator_in_list(Decl **methods, OperatorOverload operator_overload, OverloadType overload_type, Type *parent_type, Type *binary_type, Decl *skipped) { + Decl *wildcard = NULL; FOREACH(Decl *, func, methods) { if (func == skipped) continue; @@ -1786,39 +1786,46 @@ static Decl *sema_find_exact_typed_operator_in_list(Decl **methods, OperatorOver if (func->func_decl.operator != operator_overload) continue; if (parent_type && parent_type != typeget(func->func_decl.type_parent)) continue; if ((overload_type & func->func_decl.overload_type) == 0) continue; - Type *first_arg = func->func_decl.signature.params[1]->type->canonical; - if (first_arg != binary_type) continue; + if (func->func_decl.is_wildcard_overload) + { + wildcard = func; + continue; + } + Type *first_arg = func->func_decl.signature.params[1]->type; + if (first_arg->canonical != binary_type) continue; return func; } - return NULL; + return wildcard; } -static Decl *sema_find_typed_operator_in_list(SemaContext *context, Decl **methods, OperatorOverload operator_overload, OverloadType overload_type, Type *parent_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref) +static bool sema_find_typed_operator_in_list(SemaContext *context, Decl **methods, OperatorOverload operator_overload, OverloadType overload_type, Type *parent_type, Expr *binary_arg, Type *binary_type, Decl **candidate_ref, Decl **ambiguous_ref) { + + Decl *candidate = *candidate_ref; FOREACH(Decl *, func, methods) { if (func->func_decl.operator != operator_overload) continue; if (parent_type && parent_type != typeget(func->func_decl.type_parent)) continue; if ((overload_type & func->func_decl.overload_type) == 0) continue; - Type *first_arg = func->func_decl.signature.params[1]->type->canonical; - if (first_arg != binary_type) + if (!func->func_decl.is_wildcard_overload) { - if (!binary_arg) continue; - if (may_cast(context, binary_arg, first_arg, false, true)) + Type *first_arg = func->func_decl.signature.params[1]->type->canonical; + if (first_arg != binary_type) { - if (*candidate_ref) - { - *ambiguous_ref = func; - continue; - } - *candidate_ref = func; + if (!binary_arg) continue; + if (!may_cast(context, binary_arg, first_arg, false, true)) continue; } - continue; } - unit_register_external_symbol(context, func); - return func; + if (candidate && !candidate->func_decl.is_wildcard_overload) + { + *ambiguous_ref = func; + *candidate_ref = candidate; + return false; + } + candidate = func; } - return NULL; + *candidate_ref = candidate; + return true; } static Decl *sema_find_exact_typed_operator(SemaContext *context, Type *type, OperatorOverload operator_overload, OverloadType overload_type, Type *param_type, Decl *skipped) @@ -1848,67 +1855,69 @@ static Decl *sema_find_exact_typed_operator(SemaContext *context, Type *type, Op return NULL; } -static Decl *sema_find_typed_operator_type(SemaContext *context, OperatorOverload operator_overload, OverloadType overloat_type, Type *lhs_type, Type *rhs_type, Expr *rhs, Decl **candidate_ref, Decl **ambiguous_ref) +static bool sema_find_typed_operator_type(SemaContext *context, OperatorOverload operator_overload, OverloadType overloat_type, Type *lhs_type, Type *rhs_type, Expr *rhs, Decl **candidate_ref, Decl **ambiguous_ref) { // Can we find the overload directly on the method? - Decl *func = sema_find_typed_operator_in_list(context, lhs_type->decl->methods, operator_overload, overloat_type, lhs_type, - rhs, rhs_type, candidate_ref, ambiguous_ref); - if (func) return func; + if (!sema_find_typed_operator_in_list(context, lhs_type->decl->methods, + operator_overload, overloat_type, lhs_type, + rhs, rhs_type, candidate_ref, ambiguous_ref)) return false; // Can we find it as a local extension? - Decl *extension = sema_find_typed_operator_in_list(context, context->unit->local_method_extensions, - operator_overload, overloat_type, lhs_type, rhs, rhs_type, candidate_ref, ambiguous_ref); + if (!sema_find_typed_operator_in_list(context, context->unit->local_method_extensions, + operator_overload, overloat_type, lhs_type, rhs, rhs_type, candidate_ref, ambiguous_ref)) return false; + // Can we find it in the current module? - if (extension) return extension; - extension = operator_in_module_typed(context, context->compilation_unit->module, operator_overload, overloat_type, - lhs_type, rhs, rhs_type, candidate_ref, ambiguous_ref); - if (extension) return extension; - // TODO incorrect, doesn't recurse - // Look through our public imports + if (!operator_in_module_typed(context, context->compilation_unit->module, operator_overload, overloat_type, + lhs_type, rhs, rhs_type, candidate_ref, ambiguous_ref)) return false; FOREACH(Decl *, import, context->unit->public_imports) { - extension = operator_in_module_typed(context, import->import.module, operator_overload, overloat_type, - lhs_type, rhs, rhs_type, candidate_ref, ambiguous_ref); - if (extension) return extension; + if (!operator_in_module_typed(context, import->import.module, operator_overload, overloat_type, + lhs_type, rhs, rhs_type, candidate_ref, ambiguous_ref)) return false; } - return NULL; + return true; + } + Decl *sema_find_typed_operator(SemaContext *context, OperatorOverload operator_overload, Expr *lhs, Expr *rhs, Decl **ambiguous_ref, bool *reverse) { assert(operator_overload >= OVERLOAD_TYPED_START); assert(lhs && rhs && ambiguous_ref); Type *left_type = type_no_optional(lhs->type)->canonical; Type *right_type = type_no_optional(rhs->type)->canonical; - Decl *candidate = NULL; Decl *ambiguous = NULL; Decl *left_candidate = NULL; *reverse = false; if (type_is_user_defined(left_type)) { - Decl *found = sema_find_typed_operator_type(context, operator_overload, OVERLOAD_TYPE_LEFT, left_type, right_type, rhs, &candidate, &ambiguous); - if (found) return found; + if (!sema_find_typed_operator_type(context, operator_overload, OVERLOAD_TYPE_LEFT, left_type, right_type, rhs, &candidate, &ambiguous)) return NULL; left_candidate = candidate; } + candidate = NULL; if (type_is_user_defined(right_type)) { - Decl *found = sema_find_typed_operator_type(context, operator_overload, OVERLOAD_TYPE_RIGHT, right_type, left_type, lhs, &candidate, &ambiguous); - if (found) - { - *reverse = true; - return found; - } + if (!sema_find_typed_operator_type(context, operator_overload, OVERLOAD_TYPE_RIGHT, right_type, left_type, lhs, &candidate, &ambiguous)) return NULL; } - if (ambiguous) + + // If one or the other is missing, pick the right one. + if (!left_candidate) { - *ambiguous_ref = ambiguous; - return NULL; - } - if (candidate) - { - *reverse = candidate != left_candidate; + *reverse = true; return candidate; } + if (!candidate) return left_candidate; + + // Both exist, prefer non-wildcard + if (left_candidate->func_decl.is_wildcard_overload && !candidate->func_decl.is_wildcard_overload) + { + *reverse = true; + return candidate; + } + if (candidate->func_decl.is_wildcard_overload && !left_candidate->func_decl.is_wildcard_overload) + { + return left_candidate; + } + *ambiguous_ref = candidate; return NULL; } @@ -2193,7 +2202,22 @@ INLINE bool sema_analyse_operator_method(SemaContext *context, Type *parent_type // See if the operator has already been defined. OperatorOverload operator = method->func_decl.operator; - Type *second_param = vec_size(method->func_decl.signature.params) > 1 ? method->func_decl.signature.params[1]->type->canonical : NULL; + + Type *second_param = NULL; + if (vec_size(method->func_decl.signature.params) > 1) + { + second_param = method->func_decl.signature.params[1]->type; + if (!second_param) + { + if (method->func_decl.overload_type & OVERLOAD_TYPE_RIGHT) + { + RETURN_SEMA_ERROR(method, "Only regular overloads can have untyped right hand parameters"); + } + method->func_decl.is_wildcard_overload = true; + second_param = type_void; + } + second_param = second_param->canonical; + } if (!type_is_user_defined(parent_type)) { diff --git a/src/compiler/sema_expr.c b/src/compiler/sema_expr.c index 962ff564f..77f782aca 100644 --- a/src/compiler/sema_expr.c +++ b/src/compiler/sema_expr.c @@ -7674,7 +7674,7 @@ static bool sema_expr_analyse_comp(SemaContext *context, Expr *expr, Expr *left, Type *left_type = type_no_optional(left->type)->canonical; Type *right_type = type_no_optional(right->type)->canonical; - if (is_equality_type_op && (!type_is_comparable(left_type) || !type_is_comparable(right_type))) + if (is_equality_type_op && (type_is_user_defined(left_type) || type_is_user_defined(right_type))) { Decl *overload = NULL; bool negated_overload = false; diff --git a/test/test_suite/methods/distinct_overload.c3t b/test/test_suite/methods/distinct_overload.c3t new file mode 100644 index 000000000..ac44c8e11 --- /dev/null +++ b/test/test_suite/methods/distinct_overload.c3t @@ -0,0 +1,37 @@ +// #target: macos-x64 +module test; +import std; +typedef ZString2 = inline char*; + +fn bool ZString2.equals(self, ZString2 other) @operator(==) +{ + io::printn("aa"); + return true; +} + +fn int main(String[] args) +{ + ZString2 a = "123"; + ZString2 b = "456"; + bool c = a == b; + return 0; +} + +/* #expect: test.ll + +entry: + %args = alloca %"char[][]", align 8 + %a = alloca ptr, align 8 + %b = alloca ptr, align 8 + %c = alloca i8, align 1 + store ptr %0, ptr %args, align 8 + %ptradd = getelementptr inbounds i8, ptr %args, i64 8 + store i64 %1, ptr %ptradd, align 8 + store ptr @.str, ptr %a, align 8 + store ptr @.str.1, ptr %b, align 8 + %2 = load ptr, ptr %a, align 8 + %3 = load ptr, ptr %b, align 8 + %4 = call i8 @test.ZString2.equals(ptr %2, ptr %3) + store i8 %4, ptr %c, align 1 + ret i32 0 +} \ No newline at end of file diff --git a/test/test_suite/methods/overload_any.c3t b/test/test_suite/methods/overload_any.c3t new file mode 100644 index 000000000..3d0c29435 --- /dev/null +++ b/test/test_suite/methods/overload_any.c3t @@ -0,0 +1,55 @@ +// #target: macos-x64 +module test; +import std::io; +fn int main(String[] args) +{ + ((*io::stdout() << "Hello, World") << 3) << "\n"; + return 0; +} + +macro File File.print(self, other) @operator(<<) +{ + (void)io::fprint(&self, other); + return self; +} + +/* #expect: test.ll + +define i32 @test.main(ptr %0, i64 %1) #0 { +entry: + %args = alloca %"char[][]", align 8 + %self = alloca %File, align 8 + %retparam = alloca i64, align 8 + %self2 = alloca %File, align 8 + %varargslots = alloca [1 x %any], align 16 + %taddr = alloca i32, align 4 + %retparam4 = alloca i64, align 8 + %taddr5 = alloca %any, align 8 + %indirectarg = alloca %"any[]", align 8 + %self7 = alloca %File, align 8 + %retparam9 = alloca i64, align 8 + store ptr %0, ptr %args, align 8 + %ptradd = getelementptr inbounds i8, ptr %args, i64 8 + store i64 %1, ptr %ptradd, align 8 + %2 = call ptr @std.io.stdout() + call void @llvm.memcpy.p0.p0.i32(ptr align 8 %self, ptr align 8 %2, i32 8, i1 false) + %3 = call i64 @std.io.File.write(ptr %retparam, ptr %self, ptr @.str, i64 12) + call void @llvm.memcpy.p0.p0.i32(ptr align 8 %self2, ptr align 8 %self, i32 8, i1 false) + %4 = insertvalue %any undef, ptr %self2, 0 + %5 = insertvalue %any %4, i64 ptrtoint (ptr @"$ct.std.io.File" to i64), 1 + store i32 3, ptr %taddr, align 4 + %6 = insertvalue %any undef, ptr %taddr, 0 + %7 = insertvalue %any %6, i64 ptrtoint (ptr @"$ct.int" to i64), 1 + store %any %7, ptr %varargslots, align 16 + %8 = insertvalue %"any[]" undef, ptr %varargslots, 0 + %"$$temp" = insertvalue %"any[]" %8, i64 1, 1 + store %any %5, ptr %taddr5, align 8 + %lo = load i64, ptr %taddr5, align 8 + %ptradd6 = getelementptr inbounds i8, ptr %taddr5, i64 8 + %hi = load ptr, ptr %ptradd6, align 8 + store %"any[]" %"$$temp", ptr %indirectarg, align 8 + %9 = call i64 @std.io.fprintf(ptr %retparam4, i64 %lo, ptr %hi, ptr @.str.1, i64 2, ptr byval(%"any[]") align 8 %indirectarg) + call void @llvm.memcpy.p0.p0.i32(ptr align 8 %self7, ptr align 8 %self2, i32 8, i1 false) + %10 = call i64 @std.io.File.write(ptr %retparam9, ptr %self7, ptr @.str.2, i64 1) + ret i32 0 +} diff --git a/test/unit/stdlib/core/string.c3 b/test/unit/stdlib/core/string.c3 index cb043ebae..2235ad2d8 100644 --- a/test/unit/stdlib/core/string.c3 +++ b/test/unit/stdlib/core/string.c3 @@ -190,6 +190,14 @@ fn void test_treplace() assert(test.treplace("Never", "Always") == "Befriend some dragons?"); } +fn void test_zstring() +{ + String test = "hello"; + ZString test2 = "hello"; + ZString test3 = "bye"; + assert(test.zstr_tcopy() == test2); + assert(test.zstr_tcopy() != test3); +} fn void test_replace() { String test = "Befriend some dragons?";