Macros generating lambdas now actually is a thing.

This commit is contained in:
Christoffer Lerno
2023-06-22 23:42:40 +02:00
parent d90fa5e292
commit 0eee9daf1d
11 changed files with 365 additions and 49 deletions

View File

@@ -540,8 +540,13 @@ typedef struct
bool attr_winmain : 1;
bool attr_dynamic : 1;
bool attr_interface : 1;
DeclId any_prototype;
Decl **generated_lambda;
bool is_lambda : 1;
union
{
DeclId any_prototype;
Decl **generated_lambda;
Decl **lambda_ct_parameters;
};
};
struct
{

View File

@@ -53,7 +53,6 @@ static inline bool compare_fps(Real left, Real right, BinaryOp op)
bool expr_const_compare(const ExprConst *left, const ExprConst *right, BinaryOp op)
{
bool is_eq;
switch (left->const_kind)
{
case CONST_BOOL:
@@ -73,18 +72,18 @@ bool expr_const_compare(const ExprConst *left, const ExprConst *right, BinaryOp
if (left->string.len != right->string.len)
{
is_eq = false;
break;
goto RETURN;
}
if (right->string.chars == left->string.chars)
{
is_eq = true;
break;
goto RETURN;
}
is_eq = !strncmp(left->string.chars, right->string.chars, left->string.len);
break;
goto RETURN;
case CONST_TYPEID:
is_eq = left->typeid == right->typeid;
break;
goto RETURN;
case CONST_ERR:
case CONST_ENUM:
{
@@ -110,25 +109,32 @@ bool expr_const_compare(const ExprConst *left, const ExprConst *right, BinaryOp
case BINARYOP_EQ:
return left_decl->enum_constant.ordinal == right_ordinal;
default:
UNREACHABLE
goto RETURN;
}
}
case CONST_BYTES:
if (left->bytes.len != right->bytes.len)
{
is_eq = false;
break;
goto RETURN;
}
if (right->bytes.ptr == left->bytes.ptr)
{
is_eq = true;
break;
goto RETURN;
}
is_eq = !memcmp(left->bytes.ptr, right->bytes.ptr, left->bytes.len);
break;
default:
UNREACHABLE
goto RETURN;
case CONST_INITIALIZER:
return false;
case CONST_UNTYPED_LIST:
return false;
case CONST_MEMBER:
is_eq = left->member.decl == right->member.decl;
goto RETURN;
}
UNREACHABLE
RETURN:
assert((op == BINARYOP_EQ) || (op == BINARYOP_NE));
return op == BINARYOP_EQ ? is_eq : !is_eq;
}

View File

@@ -979,6 +979,7 @@ static inline bool sema_expr_analyse_ct_identifier(SemaContext *context, Expr *e
assert(decl->decl_kind == DECL_VAR);
assert(decl->resolve_status == RESOLVE_DONE);
decl->var.is_read = true;
expr->ct_ident_expr.decl = decl;
expr->type = decl->type;
return true;
@@ -4260,6 +4261,7 @@ static inline bool sema_binary_analyse_ct_identifier_lvalue(SemaContext *context
return expr_poison(expr);
}
decl->var.is_read = true;
expr->ct_ident_expr.decl = decl;
expr->resolve_status = RESOLVE_DONE;
return true;
@@ -6842,7 +6844,38 @@ static inline Type *sema_evaluate_type_copy(SemaContext *context, TypeInfo *type
return type_info->type;
}
static inline Decl *sema_find_cached_lambda(SemaContext *context, Type *func_type, Decl *original)
INLINE bool lambda_parameter_match(Decl **ct_lambda_params, Decl *candidate)
{
unsigned param_count = vec_size(ct_lambda_params);
assert(vec_size(candidate->func_decl.lambda_ct_parameters) == param_count);
if (!param_count) return true;
FOREACH_BEGIN_IDX(i, Decl *param, candidate->func_decl.lambda_ct_parameters)
Decl *ct_param = ct_lambda_params[i];
if (!param->var.is_read) continue;
assert(ct_param->resolve_status == RESOLVE_DONE || param->resolve_status == RESOLVE_DONE);
assert(ct_param->var.kind == param->var.kind);
switch (ct_param->var.kind)
{
case VARDECL_LOCAL_CT_TYPE:
case VARDECL_PARAM_CT_TYPE:
if (ct_param->var.init_expr->type_expr->type->canonical !=
param->var.init_expr->type_expr->type->canonical) return false;
break;
case VARDECL_LOCAL_CT:
case VARDECL_PARAM_CT:
assert(expr_is_const(ct_param->var.init_expr));
assert(expr_is_const(param->var.init_expr));
if (!expr_const_compare(&ct_param->var.init_expr->const_expr,
&param->var.init_expr->const_expr, BINARYOP_EQ)) return false;
break;
default:
UNREACHABLE
}
FOREACH_END();
return true;
}
static inline Decl *sema_find_cached_lambda(SemaContext *context, Type *func_type, Decl *original, Decl **ct_lambda_parameters)
{
unsigned cached = vec_size(original->func_decl.generated_lambda);
if (!cached) return NULL;
@@ -6851,7 +6884,8 @@ static inline Decl *sema_find_cached_lambda(SemaContext *context, Type *func_typ
{
Type *raw = func_type->canonical->pointer->function.prototype->raw_type;
FOREACH_BEGIN(Decl *candidate, original->func_decl.generated_lambda)
if (raw == candidate->type->function.prototype->raw_type) return candidate;
if (raw == candidate->type->function.prototype->raw_type &&
lambda_parameter_match(ct_lambda_parameters, candidate)) return candidate;
FOREACH_END();
return NULL;
}
@@ -6871,7 +6905,7 @@ static inline Decl *sema_find_cached_lambda(SemaContext *context, Type *func_typ
FOREACH_END();
FOREACH_BEGIN(Decl *candidate, original->func_decl.generated_lambda)
if (sema_may_reuse_lambda(context, candidate, types)) return candidate;
if (sema_may_reuse_lambda(context, candidate, types) && lambda_parameter_match(ct_lambda_parameters, candidate)) return candidate;
FOREACH_END();
return NULL;
}
@@ -6885,10 +6919,14 @@ static inline bool sema_expr_analyse_lambda(SemaContext *context, Type *func_typ
expr->type = type_get_ptr(decl->type);
return true;
}
bool in_macro = context->current_macro;
if (in_macro && decl->resolve_status != RESOLVE_DONE)
bool multiple = context->current_macro || context->ct_locals;
// Capture CT variables
Decl **ct_lambda_parameters = copy_decl_list_single(context->ct_locals);
if (multiple && decl->resolve_status != RESOLVE_DONE)
{
Decl *decl_cached = sema_find_cached_lambda(context, func_type, decl);
Decl *decl_cached = sema_find_cached_lambda(context, func_type, decl, ct_lambda_parameters);
if (decl_cached)
{
expr->type = type_get_ptr(decl_cached->type);
@@ -6897,7 +6935,7 @@ static inline bool sema_expr_analyse_lambda(SemaContext *context, Type *func_typ
}
}
Decl *original = decl;
if (in_macro) decl = expr->lambda_expr = copy_lambda_deep(decl);
if (multiple) decl = expr->lambda_expr = copy_lambda_deep(decl);
Signature *sig = &decl->func_decl.signature;
Signature *to_sig = func_type ? func_type->canonical->pointer->function.signature : NULL;
if (!sig->rtype)
@@ -6958,15 +6996,35 @@ static inline bool sema_expr_analyse_lambda(SemaContext *context, Type *func_typ
decl->extname = decl->name = scratch_buffer_copy();
Type *lambda_type = sema_analyse_function_signature(context, decl, sig->abi, sig, true);
if (!lambda_type) return false;
if (lambda_type)
decl->func_decl.lambda_ct_parameters = ct_lambda_parameters;
decl->type = lambda_type;
decl->func_decl.is_lambda = true;
decl->alignment = type_alloca_alignment(decl->type);
// We will actually compile this into any module using it (from a macro) by necessity,
// so we'll declare it as weak and externally visible.
if (context->compilation_unit != decl->unit) decl->is_external_visible = true;
vec_add(unit->module->lambdas_to_evaluate, decl);
// Before function analysis, lambda evaluation is deferred
if (unit->module->stage < ANALYSIS_FUNCTIONS)
{
vec_add(unit->module->lambdas_to_evaluate, decl);
}
else
{
SemaContext lambda_context;
sema_context_init(&lambda_context, context->unit);
if (sema_analyse_function_body(&lambda_context, decl))
{
vec_add(unit->lambdas, decl);
}
sema_context_destroy(&lambda_context);
}
expr->type = type_get_ptr(lambda_type);
if (in_macro) vec_add(original->func_decl.generated_lambda, decl);
if (multiple)
{
vec_add(original->func_decl.generated_lambda, decl);
}
decl->resolve_status = RESOLVE_DONE;
return true;
FAIL_NO_INFER:

View File

@@ -2998,6 +2998,7 @@ bool sema_analyse_function_body(SemaContext *context, Decl *func)
context->break_target = 0;
assert(func->func_decl.body);
Ast *body = astptr(func->func_decl.body);
Decl **lambda_params = NULL;
SCOPE_START
assert(context->active_scope.depth == 1);
Decl **params = signature->params;
@@ -3005,6 +3006,14 @@ bool sema_analyse_function_body(SemaContext *context, Decl *func)
{
if (!sema_add_local(context, params[i])) return false;
}
if (func->func_decl.is_lambda)
{
lambda_params = copy_decl_list_single(func->func_decl.lambda_ct_parameters);
FOREACH_BEGIN(Decl *ct_param, lambda_params)
ct_param->var.is_read = false;
if (!sema_add_local(context, ct_param)) return false;
FOREACH_END();
}
AstId assert_first = 0;
AstId *next = &assert_first;
if (!sema_analyse_contracts(context, func->func_decl.docs, &next, INVALID_SPAN)) return false;
@@ -3052,6 +3061,12 @@ bool sema_analyse_function_body(SemaContext *context, Decl *func)
}
}
SCOPE_END;
if (lambda_params)
{
FOREACH_BEGIN_IDX(i, Decl *ct_param, lambda_params)
func->func_decl.lambda_ct_parameters[i]->var.is_read = ct_param->var.is_read;
FOREACH_END();
}
return true;
}

View File

@@ -206,6 +206,7 @@ static bool sema_resolve_type_identifier(SemaContext *context, TypeInfo *type_in
case DECL_VAR:
if (decl->var.kind == VARDECL_PARAM_CT_TYPE || decl->var.kind == VARDECL_LOCAL_CT_TYPE)
{
decl->var.is_read = true;
if (!decl->var.init_expr)
{
SEMA_ERROR(type_info, "You need to assign a type to '%s' before using it.", decl->name);

View File

@@ -1 +1 @@
#define COMPILER_VERSION "0.4.536"
#define COMPILER_VERSION "0.4.537"

View File

@@ -0,0 +1,147 @@
// #target: macos-x64
module test;
import std::io;
def FooFn = fn void(Foo* f, int x);
struct Foo
{
FooFn x;
}
struct FooTest
{
inline Foo a;
}
struct FooTest2
{
inline Foo a;
int z;
}
fn void Foo.test(Foo* f, int x)
{
f.x(f, x);
}
fn void FooTest.init(FooTest* this)
{
static FooFn foo = foo_fn(FooTest);
this.x = foo;
}
fn void FooTest.high(FooTest* t, int x)
{
io::printfn("High: %d", x);
}
fn void FooTest.low(FooTest* t, int x)
{
io::printfn("Low: %d", x);
}
fn void FooTest2.init(FooTest2* this, int z)
{
static FooFn foo = foo_fn(FooTest2);
this.x = foo;
this.z = z;
}
fn void FooTest2.high(FooTest2* t, int x)
{
io::printfn("High2: %d", x * t.z);
}
fn void FooTest2.low(FooTest2* t, int x)
{
io::printfn("Low2: %d", x * t.z);
}
macro FooFn foo_fn($FooType)
{
return fn void(Foo* f, int x) {
$FooType* z = ($FooType*)f;
if (x > 0) return z.high(x);
return z.low(x);
};
}
fn int main()
{
FooTest a;
a.init();
a.test(10);
a.test(0);
a.test(-1);
FooTest2 b;
b.init(100);
b.test(10);
b.test(0);
b.test(-1);
return 0;
}
/* #expect: test.ll
@"init$foo" = internal unnamed_addr global ptr @"test.$global$lambda1", align 8
@"init$foo.2" = internal unnamed_addr global ptr @"test.$global$lambda2", align 8
; Function Attrs: nounwind
define void @test.Foo.test(ptr %0, i32 %1) #0 {
entry:
%2 = getelementptr inbounds %Foo, ptr %0, i32 0, i32 0
%3 = load ptr, ptr %2, align 8
call void %3(ptr %0, i32 %1)
ret void
}
; Function Attrs: nounwind
define void @test.FooTest.init(ptr %0) #0 {
entry:
%1 = getelementptr inbounds %FooTest, ptr %0, i32 0, i32 0
%2 = getelementptr inbounds %Foo, ptr %1, i32 0, i32 0
%3 = load ptr, ptr @"init$foo", align 8
store ptr %3, ptr %2, align 8
ret void
}
; Function Attrs: nounwind
define internal void @"test.$global$lambda1"(ptr %0, i32 %1) #0 {
entry:
%z = alloca ptr, align 8
store ptr %0, ptr %z, align 8
%gt = icmp sgt i32 %1, 0
br i1 %gt, label %if.then, label %if.exit
if.then: ; preds = %entry
%2 = load ptr, ptr %z, align 8
call void @test.FooTest.high(ptr %2, i32 %1)
ret void
if.exit: ; preds = %entry
%3 = load ptr, ptr %z, align 8
call void @test.FooTest.low(ptr %3, i32 %1)
ret void
}
; Function Attrs: nounwind
define internal void @"test.$global$lambda2"(ptr %0, i32 %1) #0 {
entry:
%z = alloca ptr, align 8
store ptr %0, ptr %z, align 8
%gt = icmp sgt i32 %1, 0
br i1 %gt, label %if.then, label %if.exit
if.then: ; preds = %entry
%2 = load ptr, ptr %z, align 8
call void @test.FooTest2.high(ptr %2, i32 %1)
ret void
if.exit: ; preds = %entry
%3 = load ptr, ptr %z, align 8
call void @test.FooTest2.low(ptr %3, i32 %1)
ret void
}

View File

@@ -0,0 +1,82 @@
// #target: macos-x64
module test;
import std::io;
def Call = fn void();
fn int main()
{
var $x = 0;
$for (var $i = 0; $i < 10; $i++)
{
var $Type = int;
$if $i % 2 == 0:
$Type = double;
$endif
$if $i % 3 == 0:
$x++;
$endif;
Call x = fn () => (void)io::printfn("%d %s", $x, $Type.nameof);
x();
}
$endfor
return 0;
}
/* #expect: test.ll
@.str.1 = private unnamed_addr constant [7 x i8] c"double\00", align 1
@.str.3 = private unnamed_addr constant [4 x i8] c"int\00", align 1
@.str.5 = private unnamed_addr constant [4 x i8] c"int\00", align 1
@.str.7 = private unnamed_addr constant [7 x i8] c"double\00", align 1
@.str.9 = private unnamed_addr constant [7 x i8] c"double\00", align 1
@.str.11 = private unnamed_addr constant [4 x i8] c"int\00", align 1
@.str.13 = private unnamed_addr constant [4 x i8] c"int\00", align 1
; Function Attrs: nounwind
define i32 @main() #0 {
store ptr @"main$lambda1", ptr %x, align 8
%0 = load ptr, ptr %x, align 8
call void %0()
store ptr @"main$lambda2", ptr %x1, align 8
%1 = load ptr, ptr %x1, align 8
call void %1()
store ptr @"main$lambda1", ptr %x2, align 8
%2 = load ptr, ptr %x2, align 8
call void %2()
store ptr @"main$lambda3", ptr %x3, align 8
%3 = load ptr, ptr %x3, align 8
call void %3()
store ptr @"main$lambda4", ptr %x4, align 8
%4 = load ptr, ptr %x4, align 8
call void %4()
store ptr @"main$lambda3", ptr %x5, align 8
%5 = load ptr, ptr %x5, align 8
call void %5()
store ptr @"main$lambda5", ptr %x6, align 8
%6 = load ptr, ptr %x6, align 8
call void %6()
store ptr @"main$lambda6", ptr %x7, align 8
%7 = load ptr, ptr %x7, align 8
call void %7()
store ptr @"main$lambda5", ptr %x8, align 8
%8 = load ptr, ptr %x8, align 8
call void %8()
store ptr @"main$lambda7", ptr %x9, align 8
%9 = load ptr, ptr %x9, align 8
call void %9()
define internal void @"main$lambda1"() #0 {
store %"char[]" { ptr @.str.1, i64 6 }, ptr %taddr1, align 8
define internal void @"main$lambda2"() #0 {
store %"char[]" { ptr @.str.3, i64 3 }, ptr %taddr1, align 8
define internal void @"main$lambda3"() #0 {
store %"char[]" { ptr @.str.5, i64 3 }, ptr %taddr1, align 8
define internal void @"main$lambda4"() #0 {
store %"char[]" { ptr @.str.7, i64 6 }, ptr %taddr1, align 8
define internal void @"main$lambda5"() #0 {
store %"char[]" { ptr @.str.9, i64 6 }, ptr %taddr1, align 8
define internal void @"main$lambda6"() #0 {
store %"char[]" { ptr @.str.11, i64 3 }, ptr %taddr1, align 8
define internal void @"main$lambda7"() #0 {
store %"char[]" { ptr @.str.13, i64 3 }, ptr %taddr1, align 8

View File

@@ -29,17 +29,17 @@ fn void main()
/* #expect: test.ll
store ptr @"test.test$lambda1", ptr %z, align 8
%1 = call i32 %0(i32 3)
store ptr @"test.test$lambda2", ptr %z3, align 8
%7 = call double %6(double 3.300000e+00)
store ptr @"test.test$lambda2", ptr %z7, align 8
%13 = call double %12(double 3.300000e+00)
%18 = call i32 @"test.test2$lambda3"(i32 3)
%23 = call i32 @"test.test2$lambda3"(i32 3)
%28 = call double @"test.test2$lambda4"(double 3.300000e+00)
store ptr @"test.test$lambda1", ptr %z, align 8
%1 = call i32 %0(i32 3)
store ptr @"test.test$lambda2", ptr %z3, align 8
%7 = call double %6(double 3.300000e+00)
store ptr @"test.test$lambda2", ptr %z7, align 8
%13 = call double %12(double 3.300000e+00)
%18 = call i32 @"test.test2$lambda3"(i32 3)
%23 = call i32 @"test.test2$lambda3"(i32 3)
%28 = call double @"test.test2$lambda4"(double 3.300000e+00)
define internal double @"test.test2$lambda4"(double %0) #0 {
define internal i32 @"test.test2$lambda3"(i32 %0) #0 {
define internal double @"test.test$lambda2"(double %0) #0 {
define internal i32 @"test.test$lambda1"(i32 %0) #0 {
define internal i32 @"test.test$lambda1"(i32 %0) #0 {
define internal double @"test.test$lambda2"(double %0) #0 {
define internal i32 @"test.test2$lambda3"(i32 %0) #0 {
define internal double @"test.test2$lambda4"(double %0) #0 {

View File

@@ -54,28 +54,26 @@ declare i32 @"bar.get_callback$lambda1"()
// #expect: foo.ll
define i32 @"foo.get_callback$lambda1"() #0 {
%0 = call i32 @"bar.get_callback2$lambda2"()
define i32 @"foo.get_callback2$lambda2"() #0 {
%1 = load i32, ptr @foo.xz, align 4
define i32 @"foo.get_callback$lambda1"() #0 {
%0 = call i32 @"bar.get_callback2$lambda2"()
declare i32 @"bar.get_callback2$lambda2"()
// #expect: bar.ll
define i32 @"bar.get_callback$lambda1"() #0 {
entry:
%0 = call i32 @"foo.get_callback$lambda1"()
ret i32 %0
}
define i32 @"bar.get_callback2$lambda2"() #0 {
entry:
%0 = call i32 @"foo.get_callback2$lambda2"()
ret i32 %0
}
define i32 @"bar.get_callback$lambda1"() #0 {
entry:
%0 = call i32 @"foo.get_callback$lambda1"()
ret i32 %0
declare i32 @"foo.get_callback$lambda1"()
declare i32 @"foo.get_callback2$lambda2"()
declare i32 @"foo.get_callback2$lambda2"() #0
declare i32 @"foo.get_callback$lambda1"() #0

View File

@@ -25,5 +25,9 @@ entry:
%0 = call i32 @simple_lambda.xy(ptr @"simple_lambda.main$lambda1")
%5 = call i32 @simple_lambda.xy(ptr @"simple_lambda.main$lambda2")
define internal i32 @"simple_lambda.main$lambda1"() #0 {
define internal i32 @"simple_lambda.main$lambda2"() #0 {
define internal i32 @"simple_lambda.main$lambda1"() #0 {
entry:
ret i32 3
}