diff --git a/src/compiler/compiler_internal.h b/src/compiler/compiler_internal.h index 4e283cdf3..d817fff83 100644 --- a/src/compiler/compiler_internal.h +++ b/src/compiler/compiler_internal.h @@ -540,8 +540,13 @@ typedef struct bool attr_winmain : 1; bool attr_dynamic : 1; bool attr_interface : 1; - DeclId any_prototype; - Decl **generated_lambda; + bool is_lambda : 1; + union + { + DeclId any_prototype; + Decl **generated_lambda; + Decl **lambda_ct_parameters; + }; }; struct { diff --git a/src/compiler/number.c b/src/compiler/number.c index df5a8ae53..fc637b08b 100644 --- a/src/compiler/number.c +++ b/src/compiler/number.c @@ -53,7 +53,6 @@ static inline bool compare_fps(Real left, Real right, BinaryOp op) bool expr_const_compare(const ExprConst *left, const ExprConst *right, BinaryOp op) { bool is_eq; - switch (left->const_kind) { case CONST_BOOL: @@ -73,18 +72,18 @@ bool expr_const_compare(const ExprConst *left, const ExprConst *right, BinaryOp if (left->string.len != right->string.len) { is_eq = false; - break; + goto RETURN; } if (right->string.chars == left->string.chars) { is_eq = true; - break; + goto RETURN; } is_eq = !strncmp(left->string.chars, right->string.chars, left->string.len); - break; + goto RETURN; case CONST_TYPEID: is_eq = left->typeid == right->typeid; - break; + goto RETURN; case CONST_ERR: case CONST_ENUM: { @@ -110,25 +109,32 @@ bool expr_const_compare(const ExprConst *left, const ExprConst *right, BinaryOp case BINARYOP_EQ: return left_decl->enum_constant.ordinal == right_ordinal; default: - UNREACHABLE + goto RETURN; } } case CONST_BYTES: if (left->bytes.len != right->bytes.len) { is_eq = false; - break; + goto RETURN; } if (right->bytes.ptr == left->bytes.ptr) { is_eq = true; - break; + goto RETURN; } is_eq = !memcmp(left->bytes.ptr, right->bytes.ptr, left->bytes.len); - break; - default: - UNREACHABLE + goto RETURN; + case CONST_INITIALIZER: + return false; + case CONST_UNTYPED_LIST: + return false; + case CONST_MEMBER: + is_eq = left->member.decl == right->member.decl; + goto RETURN; } + UNREACHABLE +RETURN: assert((op == BINARYOP_EQ) || (op == BINARYOP_NE)); return op == BINARYOP_EQ ? is_eq : !is_eq; } diff --git a/src/compiler/sema_expr.c b/src/compiler/sema_expr.c index b15194b3e..837b80e70 100644 --- a/src/compiler/sema_expr.c +++ b/src/compiler/sema_expr.c @@ -979,6 +979,7 @@ static inline bool sema_expr_analyse_ct_identifier(SemaContext *context, Expr *e assert(decl->decl_kind == DECL_VAR); assert(decl->resolve_status == RESOLVE_DONE); + decl->var.is_read = true; expr->ct_ident_expr.decl = decl; expr->type = decl->type; return true; @@ -4260,6 +4261,7 @@ static inline bool sema_binary_analyse_ct_identifier_lvalue(SemaContext *context return expr_poison(expr); } + decl->var.is_read = true; expr->ct_ident_expr.decl = decl; expr->resolve_status = RESOLVE_DONE; return true; @@ -6842,7 +6844,38 @@ static inline Type *sema_evaluate_type_copy(SemaContext *context, TypeInfo *type return type_info->type; } -static inline Decl *sema_find_cached_lambda(SemaContext *context, Type *func_type, Decl *original) +INLINE bool lambda_parameter_match(Decl **ct_lambda_params, Decl *candidate) +{ + unsigned param_count = vec_size(ct_lambda_params); + assert(vec_size(candidate->func_decl.lambda_ct_parameters) == param_count); + if (!param_count) return true; + FOREACH_BEGIN_IDX(i, Decl *param, candidate->func_decl.lambda_ct_parameters) + Decl *ct_param = ct_lambda_params[i]; + if (!param->var.is_read) continue; + assert(ct_param->resolve_status == RESOLVE_DONE || param->resolve_status == RESOLVE_DONE); + assert(ct_param->var.kind == param->var.kind); + switch (ct_param->var.kind) + { + case VARDECL_LOCAL_CT_TYPE: + case VARDECL_PARAM_CT_TYPE: + if (ct_param->var.init_expr->type_expr->type->canonical != + param->var.init_expr->type_expr->type->canonical) return false; + break; + case VARDECL_LOCAL_CT: + case VARDECL_PARAM_CT: + assert(expr_is_const(ct_param->var.init_expr)); + assert(expr_is_const(param->var.init_expr)); + if (!expr_const_compare(&ct_param->var.init_expr->const_expr, + ¶m->var.init_expr->const_expr, BINARYOP_EQ)) return false; + break; + default: + UNREACHABLE + } + FOREACH_END(); + return true; +} + +static inline Decl *sema_find_cached_lambda(SemaContext *context, Type *func_type, Decl *original, Decl **ct_lambda_parameters) { unsigned cached = vec_size(original->func_decl.generated_lambda); if (!cached) return NULL; @@ -6851,7 +6884,8 @@ static inline Decl *sema_find_cached_lambda(SemaContext *context, Type *func_typ { Type *raw = func_type->canonical->pointer->function.prototype->raw_type; FOREACH_BEGIN(Decl *candidate, original->func_decl.generated_lambda) - if (raw == candidate->type->function.prototype->raw_type) return candidate; + if (raw == candidate->type->function.prototype->raw_type && + lambda_parameter_match(ct_lambda_parameters, candidate)) return candidate; FOREACH_END(); return NULL; } @@ -6871,7 +6905,7 @@ static inline Decl *sema_find_cached_lambda(SemaContext *context, Type *func_typ FOREACH_END(); FOREACH_BEGIN(Decl *candidate, original->func_decl.generated_lambda) - if (sema_may_reuse_lambda(context, candidate, types)) return candidate; + if (sema_may_reuse_lambda(context, candidate, types) && lambda_parameter_match(ct_lambda_parameters, candidate)) return candidate; FOREACH_END(); return NULL; } @@ -6885,10 +6919,14 @@ static inline bool sema_expr_analyse_lambda(SemaContext *context, Type *func_typ expr->type = type_get_ptr(decl->type); return true; } - bool in_macro = context->current_macro; - if (in_macro && decl->resolve_status != RESOLVE_DONE) + bool multiple = context->current_macro || context->ct_locals; + + // Capture CT variables + Decl **ct_lambda_parameters = copy_decl_list_single(context->ct_locals); + + if (multiple && decl->resolve_status != RESOLVE_DONE) { - Decl *decl_cached = sema_find_cached_lambda(context, func_type, decl); + Decl *decl_cached = sema_find_cached_lambda(context, func_type, decl, ct_lambda_parameters); if (decl_cached) { expr->type = type_get_ptr(decl_cached->type); @@ -6897,7 +6935,7 @@ static inline bool sema_expr_analyse_lambda(SemaContext *context, Type *func_typ } } Decl *original = decl; - if (in_macro) decl = expr->lambda_expr = copy_lambda_deep(decl); + if (multiple) decl = expr->lambda_expr = copy_lambda_deep(decl); Signature *sig = &decl->func_decl.signature; Signature *to_sig = func_type ? func_type->canonical->pointer->function.signature : NULL; if (!sig->rtype) @@ -6958,15 +6996,35 @@ static inline bool sema_expr_analyse_lambda(SemaContext *context, Type *func_typ decl->extname = decl->name = scratch_buffer_copy(); Type *lambda_type = sema_analyse_function_signature(context, decl, sig->abi, sig, true); if (!lambda_type) return false; - if (lambda_type) + decl->func_decl.lambda_ct_parameters = ct_lambda_parameters; decl->type = lambda_type; + decl->func_decl.is_lambda = true; decl->alignment = type_alloca_alignment(decl->type); // We will actually compile this into any module using it (from a macro) by necessity, // so we'll declare it as weak and externally visible. if (context->compilation_unit != decl->unit) decl->is_external_visible = true; - vec_add(unit->module->lambdas_to_evaluate, decl); + + // Before function analysis, lambda evaluation is deferred + if (unit->module->stage < ANALYSIS_FUNCTIONS) + { + vec_add(unit->module->lambdas_to_evaluate, decl); + } + else + { + SemaContext lambda_context; + sema_context_init(&lambda_context, context->unit); + if (sema_analyse_function_body(&lambda_context, decl)) + { + vec_add(unit->lambdas, decl); + } + sema_context_destroy(&lambda_context); + } + expr->type = type_get_ptr(lambda_type); - if (in_macro) vec_add(original->func_decl.generated_lambda, decl); + if (multiple) + { + vec_add(original->func_decl.generated_lambda, decl); + } decl->resolve_status = RESOLVE_DONE; return true; FAIL_NO_INFER: diff --git a/src/compiler/sema_stmts.c b/src/compiler/sema_stmts.c index d3a7d1950..e2c80ff9b 100644 --- a/src/compiler/sema_stmts.c +++ b/src/compiler/sema_stmts.c @@ -2998,6 +2998,7 @@ bool sema_analyse_function_body(SemaContext *context, Decl *func) context->break_target = 0; assert(func->func_decl.body); Ast *body = astptr(func->func_decl.body); + Decl **lambda_params = NULL; SCOPE_START assert(context->active_scope.depth == 1); Decl **params = signature->params; @@ -3005,6 +3006,14 @@ bool sema_analyse_function_body(SemaContext *context, Decl *func) { if (!sema_add_local(context, params[i])) return false; } + if (func->func_decl.is_lambda) + { + lambda_params = copy_decl_list_single(func->func_decl.lambda_ct_parameters); + FOREACH_BEGIN(Decl *ct_param, lambda_params) + ct_param->var.is_read = false; + if (!sema_add_local(context, ct_param)) return false; + FOREACH_END(); + } AstId assert_first = 0; AstId *next = &assert_first; if (!sema_analyse_contracts(context, func->func_decl.docs, &next, INVALID_SPAN)) return false; @@ -3052,6 +3061,12 @@ bool sema_analyse_function_body(SemaContext *context, Decl *func) } } SCOPE_END; + if (lambda_params) + { + FOREACH_BEGIN_IDX(i, Decl *ct_param, lambda_params) + func->func_decl.lambda_ct_parameters[i]->var.is_read = ct_param->var.is_read; + FOREACH_END(); + } return true; } diff --git a/src/compiler/sema_types.c b/src/compiler/sema_types.c index c71163350..84d99e64f 100644 --- a/src/compiler/sema_types.c +++ b/src/compiler/sema_types.c @@ -206,6 +206,7 @@ static bool sema_resolve_type_identifier(SemaContext *context, TypeInfo *type_in case DECL_VAR: if (decl->var.kind == VARDECL_PARAM_CT_TYPE || decl->var.kind == VARDECL_LOCAL_CT_TYPE) { + decl->var.is_read = true; if (!decl->var.init_expr) { SEMA_ERROR(type_info, "You need to assign a type to '%s' before using it.", decl->name); diff --git a/src/version.h b/src/version.h index 4c5bafb78..f1bf3f667 100644 --- a/src/version.h +++ b/src/version.h @@ -1 +1 @@ -#define COMPILER_VERSION "0.4.536" \ No newline at end of file +#define COMPILER_VERSION "0.4.537" \ No newline at end of file diff --git a/test/test_suite/lambda/ct_lambda.c3t b/test/test_suite/lambda/ct_lambda.c3t new file mode 100644 index 000000000..b175096ab --- /dev/null +++ b/test/test_suite/lambda/ct_lambda.c3t @@ -0,0 +1,147 @@ +// #target: macos-x64 +module test; +import std::io; + +def FooFn = fn void(Foo* f, int x); + +struct Foo +{ + FooFn x; +} + +struct FooTest +{ + inline Foo a; +} + +struct FooTest2 +{ + inline Foo a; + int z; +} + +fn void Foo.test(Foo* f, int x) +{ + f.x(f, x); +} + +fn void FooTest.init(FooTest* this) +{ + static FooFn foo = foo_fn(FooTest); + this.x = foo; +} + +fn void FooTest.high(FooTest* t, int x) +{ + io::printfn("High: %d", x); +} + +fn void FooTest.low(FooTest* t, int x) +{ + io::printfn("Low: %d", x); +} + +fn void FooTest2.init(FooTest2* this, int z) +{ + static FooFn foo = foo_fn(FooTest2); + this.x = foo; + this.z = z; +} + +fn void FooTest2.high(FooTest2* t, int x) +{ + io::printfn("High2: %d", x * t.z); +} + +fn void FooTest2.low(FooTest2* t, int x) +{ + io::printfn("Low2: %d", x * t.z); +} + +macro FooFn foo_fn($FooType) +{ + return fn void(Foo* f, int x) { + $FooType* z = ($FooType*)f; + if (x > 0) return z.high(x); + return z.low(x); + }; +} + +fn int main() +{ + FooTest a; + a.init(); + a.test(10); + a.test(0); + a.test(-1); + FooTest2 b; + b.init(100); + b.test(10); + b.test(0); + b.test(-1); + return 0; +} + +/* #expect: test.ll + + +@"init$foo" = internal unnamed_addr global ptr @"test.$global$lambda1", align 8 +@"init$foo.2" = internal unnamed_addr global ptr @"test.$global$lambda2", align 8 + +; Function Attrs: nounwind +define void @test.Foo.test(ptr %0, i32 %1) #0 { +entry: + %2 = getelementptr inbounds %Foo, ptr %0, i32 0, i32 0 + %3 = load ptr, ptr %2, align 8 + call void %3(ptr %0, i32 %1) + ret void +} + +; Function Attrs: nounwind +define void @test.FooTest.init(ptr %0) #0 { +entry: + %1 = getelementptr inbounds %FooTest, ptr %0, i32 0, i32 0 + %2 = getelementptr inbounds %Foo, ptr %1, i32 0, i32 0 + %3 = load ptr, ptr @"init$foo", align 8 + store ptr %3, ptr %2, align 8 + ret void +} + +; Function Attrs: nounwind + +define internal void @"test.$global$lambda1"(ptr %0, i32 %1) #0 { +entry: + %z = alloca ptr, align 8 + store ptr %0, ptr %z, align 8 + %gt = icmp sgt i32 %1, 0 + br i1 %gt, label %if.then, label %if.exit + +if.then: ; preds = %entry + %2 = load ptr, ptr %z, align 8 + call void @test.FooTest.high(ptr %2, i32 %1) + ret void + +if.exit: ; preds = %entry + %3 = load ptr, ptr %z, align 8 + call void @test.FooTest.low(ptr %3, i32 %1) + ret void +} + +; Function Attrs: nounwind +define internal void @"test.$global$lambda2"(ptr %0, i32 %1) #0 { +entry: + %z = alloca ptr, align 8 + store ptr %0, ptr %z, align 8 + %gt = icmp sgt i32 %1, 0 + br i1 %gt, label %if.then, label %if.exit + +if.then: ; preds = %entry + %2 = load ptr, ptr %z, align 8 + call void @test.FooTest2.high(ptr %2, i32 %1) + ret void + +if.exit: ; preds = %entry + %3 = load ptr, ptr %z, align 8 + call void @test.FooTest2.low(ptr %3, i32 %1) + ret void +} \ No newline at end of file diff --git a/test/test_suite/lambda/ct_lambda2.c3t b/test/test_suite/lambda/ct_lambda2.c3t new file mode 100644 index 000000000..6b77b3178 --- /dev/null +++ b/test/test_suite/lambda/ct_lambda2.c3t @@ -0,0 +1,82 @@ +// #target: macos-x64 +module test; +import std::io; + +def Call = fn void(); + +fn int main() +{ + var $x = 0; + $for (var $i = 0; $i < 10; $i++) + { + var $Type = int; + $if $i % 2 == 0: + $Type = double; + $endif + $if $i % 3 == 0: + $x++; + $endif; + Call x = fn () => (void)io::printfn("%d %s", $x, $Type.nameof); + x(); + } + $endfor + return 0; +} + +/* #expect: test.ll + +@.str.1 = private unnamed_addr constant [7 x i8] c"double\00", align 1 +@.str.3 = private unnamed_addr constant [4 x i8] c"int\00", align 1 +@.str.5 = private unnamed_addr constant [4 x i8] c"int\00", align 1 +@.str.7 = private unnamed_addr constant [7 x i8] c"double\00", align 1 +@.str.9 = private unnamed_addr constant [7 x i8] c"double\00", align 1 +@.str.11 = private unnamed_addr constant [4 x i8] c"int\00", align 1 +@.str.13 = private unnamed_addr constant [4 x i8] c"int\00", align 1 + +; Function Attrs: nounwind +define i32 @main() #0 { + store ptr @"main$lambda1", ptr %x, align 8 + %0 = load ptr, ptr %x, align 8 + call void %0() + store ptr @"main$lambda2", ptr %x1, align 8 + %1 = load ptr, ptr %x1, align 8 + call void %1() + store ptr @"main$lambda1", ptr %x2, align 8 + %2 = load ptr, ptr %x2, align 8 + call void %2() + store ptr @"main$lambda3", ptr %x3, align 8 + %3 = load ptr, ptr %x3, align 8 + call void %3() + store ptr @"main$lambda4", ptr %x4, align 8 + %4 = load ptr, ptr %x4, align 8 + call void %4() + store ptr @"main$lambda3", ptr %x5, align 8 + %5 = load ptr, ptr %x5, align 8 + call void %5() + store ptr @"main$lambda5", ptr %x6, align 8 + %6 = load ptr, ptr %x6, align 8 + call void %6() + store ptr @"main$lambda6", ptr %x7, align 8 + %7 = load ptr, ptr %x7, align 8 + call void %7() + store ptr @"main$lambda5", ptr %x8, align 8 + %8 = load ptr, ptr %x8, align 8 + call void %8() + store ptr @"main$lambda7", ptr %x9, align 8 + %9 = load ptr, ptr %x9, align 8 + call void %9() + +define internal void @"main$lambda1"() #0 { + store %"char[]" { ptr @.str.1, i64 6 }, ptr %taddr1, align 8 +define internal void @"main$lambda2"() #0 { + store %"char[]" { ptr @.str.3, i64 3 }, ptr %taddr1, align 8 +define internal void @"main$lambda3"() #0 { + store %"char[]" { ptr @.str.5, i64 3 }, ptr %taddr1, align 8 +define internal void @"main$lambda4"() #0 { + store %"char[]" { ptr @.str.7, i64 6 }, ptr %taddr1, align 8 +define internal void @"main$lambda5"() #0 { + store %"char[]" { ptr @.str.9, i64 6 }, ptr %taddr1, align 8 +define internal void @"main$lambda6"() #0 { + store %"char[]" { ptr @.str.11, i64 3 }, ptr %taddr1, align 8 +define internal void @"main$lambda7"() #0 { + store %"char[]" { ptr @.str.13, i64 3 }, ptr %taddr1, align 8 \ No newline at end of file diff --git a/test/test_suite/lambda/lambda_in_macro.c3t b/test/test_suite/lambda/lambda_in_macro.c3t index 17fbc52c7..22ed123aa 100644 --- a/test/test_suite/lambda/lambda_in_macro.c3t +++ b/test/test_suite/lambda/lambda_in_macro.c3t @@ -29,17 +29,17 @@ fn void main() /* #expect: test.ll - store ptr @"test.test$lambda1", ptr %z, align 8 - %1 = call i32 %0(i32 3) - store ptr @"test.test$lambda2", ptr %z3, align 8 - %7 = call double %6(double 3.300000e+00) - store ptr @"test.test$lambda2", ptr %z7, align 8 - %13 = call double %12(double 3.300000e+00) - %18 = call i32 @"test.test2$lambda3"(i32 3) - %23 = call i32 @"test.test2$lambda3"(i32 3) - %28 = call double @"test.test2$lambda4"(double 3.300000e+00) + store ptr @"test.test$lambda1", ptr %z, align 8 + %1 = call i32 %0(i32 3) + store ptr @"test.test$lambda2", ptr %z3, align 8 + %7 = call double %6(double 3.300000e+00) + store ptr @"test.test$lambda2", ptr %z7, align 8 + %13 = call double %12(double 3.300000e+00) + %18 = call i32 @"test.test2$lambda3"(i32 3) + %23 = call i32 @"test.test2$lambda3"(i32 3) + %28 = call double @"test.test2$lambda4"(double 3.300000e+00) -define internal double @"test.test2$lambda4"(double %0) #0 { -define internal i32 @"test.test2$lambda3"(i32 %0) #0 { -define internal double @"test.test$lambda2"(double %0) #0 { -define internal i32 @"test.test$lambda1"(i32 %0) #0 { + define internal i32 @"test.test$lambda1"(i32 %0) #0 { + define internal double @"test.test$lambda2"(double %0) #0 { + define internal i32 @"test.test2$lambda3"(i32 %0) #0 { + define internal double @"test.test2$lambda4"(double %0) #0 { diff --git a/test/test_suite/lambda/nested_lambda_def.c3t b/test/test_suite/lambda/nested_lambda_def.c3t index 20eccf82e..a83ffaad1 100644 --- a/test/test_suite/lambda/nested_lambda_def.c3t +++ b/test/test_suite/lambda/nested_lambda_def.c3t @@ -54,28 +54,26 @@ declare i32 @"bar.get_callback$lambda1"() // #expect: foo.ll -define i32 @"foo.get_callback$lambda1"() #0 { - %0 = call i32 @"bar.get_callback2$lambda2"() - define i32 @"foo.get_callback2$lambda2"() #0 { %1 = load i32, ptr @foo.xz, align 4 +define i32 @"foo.get_callback$lambda1"() #0 { + %0 = call i32 @"bar.get_callback2$lambda2"() + declare i32 @"bar.get_callback2$lambda2"() // #expect: bar.ll -define i32 @"bar.get_callback$lambda1"() #0 { -entry: - %0 = call i32 @"foo.get_callback$lambda1"() - ret i32 %0 -} - define i32 @"bar.get_callback2$lambda2"() #0 { entry: %0 = call i32 @"foo.get_callback2$lambda2"() ret i32 %0 } +define i32 @"bar.get_callback$lambda1"() #0 { +entry: + %0 = call i32 @"foo.get_callback$lambda1"() + ret i32 %0 -declare i32 @"foo.get_callback$lambda1"() -declare i32 @"foo.get_callback2$lambda2"() +declare i32 @"foo.get_callback2$lambda2"() #0 +declare i32 @"foo.get_callback$lambda1"() #0 diff --git a/test/test_suite/lambda/simple_lambda.c3t b/test/test_suite/lambda/simple_lambda.c3t index c910880a1..4c3d5a308 100644 --- a/test/test_suite/lambda/simple_lambda.c3t +++ b/test/test_suite/lambda/simple_lambda.c3t @@ -25,5 +25,9 @@ entry: %0 = call i32 @simple_lambda.xy(ptr @"simple_lambda.main$lambda1") %5 = call i32 @simple_lambda.xy(ptr @"simple_lambda.main$lambda2") +define internal i32 @"simple_lambda.main$lambda1"() #0 { + define internal i32 @"simple_lambda.main$lambda2"() #0 { -define internal i32 @"simple_lambda.main$lambda1"() #0 { \ No newline at end of file +entry: + ret i32 3 +}