- Allow splatting of structs. #2555

This commit is contained in:
Christoffer Lerno
2025-11-17 16:33:22 +01:00
parent 20dfdf5c5d
commit 49b8cfe267
3 changed files with 223 additions and 4 deletions

View File

@@ -13,6 +13,7 @@
- Refactored `@simd` implementation.
- Improve error message for `Foo{}` when `Foo` is not a generic type #2574.
- Support `@param` directives for `...` parameters. #2578
- Allow splatting of structs. #2555
### Fixes
- `Foo.is_eq` would return false if the type was a `typedef` and had an overload, but the underlying type was not comparable.

View File

@@ -1723,6 +1723,77 @@ INLINE bool sema_set_default_argument(SemaContext *context, CalledDecl *callee,
callee->macro, false);
}
INLINE Expr **sema_splat_struct_insert(SemaContext *context, Expr **args, Expr *arg, Decl *strukt, ArrayIndex index)
{
unsigned len = vec_size(strukt->strukt.members);
args = sema_prepare_splat_insert(args, len, index);
if (sema_cast_const(arg))
{
ASSERT(expr_is_const_initializer(arg));
ConstInitializer *initializer = arg->const_expr.initializer;
if (initializer->kind == CONST_INIT_ZERO)
{
for (ArrayIndex i = 0; i < len; i++)
{
Expr *expr = expr_calloc();
expr->span = arg->span;
expr_rewrite_to_const_zero(expr, strukt->strukt.members[i]->type);
args[i + index] = expr;
}
return args;
}
ASSERT(initializer->kind == CONST_INIT_STRUCT);
for (ArrayIndex i = 0; i < len; i++)
{
ConstInitializer *c = initializer->init_struct[i];
Expr *expr;
switch (c->kind)
{
case CONST_INIT_ZERO:
expr = expr_calloc();
expr->span = arg->span;
expr_rewrite_to_const_zero(expr, strukt->strukt.members[i]->type);
args[i + index] = expr;
break;
case CONST_INIT_VALUE:
expr = expr_copy(c->init_value);
break;
default:
expr = expr_calloc();
expr->span = arg->span;
expr_rewrite_const_initializer(expr, strukt->strukt.members[i]->type, c);
break;
}
args[i + index] = expr;
}
return args;
}
if (context->call_env.kind != CALL_ENV_FUNCTION)
{
SEMA_ERROR(arg, "Cannot splat a non-constant value in a global context.");
return NULL;
}
Decl *temp = decl_new_generated_var(arg->type, VARDECL_LOCAL, arg->span);
Expr *decl_expr = expr_generate_decl(temp, arg);
Expr *two = expr_new_expr(EXPR_TWO, arg);
two->two_expr.first = decl_expr;
Expr *access = expr_new_expr(EXPR_ACCESS_RESOLVED, arg);
access->access_resolved_expr = (ExprResolvedAccess) { .parent = expr_variable(temp), .ref = strukt->strukt.members[0] };
access->resolve_status = RESOLVE_DONE;
access->type = strukt->strukt.members[0]->type;
two->two_expr.last = access;
if (!sema_analyse_expr_rvalue(context, two)) return NULL;
args[index] = two;
for (ArrayIndex i = 1; i < len; i++)
{
access = expr_new_expr(EXPR_ACCESS_RESOLVED, arg);
access->access_resolved_expr = (ExprResolvedAccess) { .parent = expr_variable(temp), .ref = strukt->strukt.members[i] };
access->resolve_status = RESOLVE_DONE;
access->type = strukt->strukt.members[i]->type;
args[index + i] = access;
}
return args;
}
INLINE Expr **sema_splat_arraylike_insert(SemaContext *context, Expr **args, Expr *arg, ArraySize len, ArrayIndex index)
{
@@ -1965,6 +2036,7 @@ INLINE bool sema_call_evaluate_arguments(SemaContext *context, CalledDecl *calle
continue;
}
SPLAT_NORMAL:;
Expr **new_args;
Type *flat = type_flatten(inner->type);
switch (flat->type_kind)
{
@@ -1972,7 +2044,10 @@ SPLAT_NORMAL:;
case TYPE_ARRAY:
case TYPE_SLICE:
case TYPE_UNTYPED_LIST:
// These may be splatted
break;
case TYPE_STRUCT:
new_args = sema_splat_struct_insert(context, args, inner, flat->decl, i);
goto AFTER_SPLAT;
break;
default:
RETURN_SEMA_ERROR(arg, "An argument of type %s cannot be splatted.",
@@ -1985,7 +2060,8 @@ SPLAT_NORMAL:;
{
RETURN_SEMA_ERROR(arg, "A non-constant zero size splat cannot be used with raw varargs.");
}
Expr **new_args = sema_splat_arraylike_insert(context, args, inner, len, i);
new_args = sema_splat_arraylike_insert(context, args, inner, len, i);
AFTER_SPLAT:;
if (!new_args) return false;
args = new_args;
i--;
@@ -6597,14 +6673,18 @@ Expr **sema_expand_vasplat_exprs(SemaContext *context, Expr **exprs)
Expr *inner = arg->inner_expr;
if (!sema_analyse_expr_rvalue(context, inner)) return false;
Type *flat = type_flatten(inner->type);
Expr **new_args;
switch (flat->type_kind)
{
case VECTORS:
case TYPE_ARRAY:
case TYPE_SLICE:
case TYPE_UNTYPED_LIST:
// These may be splatted
// These may be splatted like arrays
break;
case TYPE_STRUCT:
new_args = sema_splat_struct_insert(context, exprs, inner, flat->decl, i);
goto SPLAT_DONE;
default:
SEMA_ERROR(arg, "An argument of type %s cannot be splatted.",
type_quoted_error_string(inner->type));
@@ -6622,7 +6702,8 @@ Expr **sema_expand_vasplat_exprs(SemaContext *context, Expr **exprs)
SEMA_ERROR(arg, "A non-constant zero size splat is not allowed.");
return NULL;
}
Expr **new_args = sema_splat_arraylike_insert(context, exprs, inner, len, i);
new_args = sema_splat_arraylike_insert(context, exprs, inner, len, i);
SPLAT_DONE:
if (!new_args) return false;
exprs = new_args;
count = vec_size(exprs);

View File

@@ -0,0 +1,137 @@
// #target: macos-x64
module test;
import std;
struct Foo
{
int a;
double b;
}
struct Bar
{
int x;
Foo f;
}
fn void test(int x, double y)
{
io::printfn("%s %s", x, y);
}
fn void test2(int x, Foo f)
{
io::printfn("%s {%s %s}", x, f.a, f.b);
}
fn int main()
{
Foo f = { 42, 3.14};
test(...f);
test(...(Foo){ 43, 3.15 });
test(...(Foo){ });
test2(...(Bar){ });
test2(...(Bar){ 42, {} });
test2(...(Bar){ 42, { .b = 4.5 } });
test2(...(Bar){ .f = { .b = 4.5 } });
Bar z = { 100, f };
test2(...z);
Bar z2 = { 100, { ...f }};
f = { ...f };
return 0;
}
/* #expect: test.ll
%.introspect = type { i8, i64, ptr, i64, i64, i64, [0 x i64] }
%Foo = type { i32, double }
%any = type { ptr, i64 }
%Bar = type { i32, %Foo }
@"$ct.test.Foo" = linkonce global %.introspect { i8 10, i64 0, ptr null, i64 16, i64 0, i64 2, [0 x i64] zeroinitializer }, align 8
@"$ct.test.Bar" = linkonce global %.introspect { i8 10, i64 0, ptr null, i64 24, i64 0, i64 2, [0 x i64] zeroinitializer }, align 8
@.str = private unnamed_addr constant [6 x i8] c"%s %s\00", align 1
@"$ct.int" = linkonce global %.introspect { i8 2, i64 0, ptr null, i64 4, i64 0, i64 0, [0 x i64] zeroinitializer }, align 8
@"$ct.double" = linkonce global %.introspect { i8 4, i64 0, ptr null, i64 8, i64 0, i64 0, [0 x i64] zeroinitializer }, align 8
@.str.1 = private unnamed_addr constant [11 x i8] c"%s {%s %s}\00", align 1
@.__const = private unnamed_addr constant %Foo { i32 42, double 3.140000e+00 }, align 8
define i32 @main() #0 {
entry:
%f = alloca %Foo, align 8
%.anon = alloca %Foo, align 8
%literal = alloca %Foo, align 8
%literal3 = alloca %Foo, align 8
%literal8 = alloca %Foo, align 8
%literal13 = alloca %Foo, align 8
%z = alloca %Bar, align 8
%.anon19 = alloca %Bar, align 8
%z2 = alloca %Bar, align 8
%.anon25 = alloca %Foo, align 8
%.assign_list = alloca %Foo, align 8
%.anon28 = alloca %Foo, align 8
call void @llvm.memcpy.p0.p0.i32(ptr align 8 %f, ptr align 8 @.__const, i32 16, i1 false)
call void @llvm.memcpy.p0.p0.i32(ptr align 8 %.anon, ptr align 8 %f, i32 16, i1 false)
%ptradd = getelementptr inbounds i8, ptr %.anon, i64 8
%0 = load i32, ptr %.anon, align 8
%1 = load double, ptr %ptradd, align 8
call void @test.test(i32 %0, double %1)
call void @test.test(i32 43, double 3.150000e+00)
call void @test.test(i32 0, double 0.000000e+00)
store i32 0, ptr %literal, align 8
%ptradd1 = getelementptr inbounds i8, ptr %literal, i64 8
store double 0.000000e+00, ptr %ptradd1, align 8
%lo = load i32, ptr %literal, align 8
%ptradd2 = getelementptr inbounds i8, ptr %literal, i64 8
%hi = load double, ptr %ptradd2, align 8
call void @test.test2(i32 0, i32 %lo, double %hi)
store i32 0, ptr %literal3, align 8
%ptradd4 = getelementptr inbounds i8, ptr %literal3, i64 8
store double 0.000000e+00, ptr %ptradd4, align 8
%lo5 = load i32, ptr %literal3, align 8
%ptradd6 = getelementptr inbounds i8, ptr %literal3, i64 8
%hi7 = load double, ptr %ptradd6, align 8
call void @test.test2(i32 42, i32 %lo5, double %hi7)
store i32 0, ptr %literal8, align 8
%ptradd9 = getelementptr inbounds i8, ptr %literal8, i64 8
store double 4.500000e+00, ptr %ptradd9, align 8
%lo10 = load i32, ptr %literal8, align 8
%ptradd11 = getelementptr inbounds i8, ptr %literal8, i64 8
%hi12 = load double, ptr %ptradd11, align 8
call void @test.test2(i32 42, i32 %lo10, double %hi12)
store i32 0, ptr %literal13, align 8
%ptradd14 = getelementptr inbounds i8, ptr %literal13, i64 8
store double 4.500000e+00, ptr %ptradd14, align 8
%lo15 = load i32, ptr %literal13, align 8
%ptradd16 = getelementptr inbounds i8, ptr %literal13, i64 8
%hi17 = load double, ptr %ptradd16, align 8
call void @test.test2(i32 0, i32 %lo15, double %hi17)
store i32 100, ptr %z, align 8
%ptradd18 = getelementptr inbounds i8, ptr %z, i64 8
call void @llvm.memcpy.p0.p0.i32(ptr align 8 %ptradd18, ptr align 8 %f, i32 16, i1 false)
call void @llvm.memcpy.p0.p0.i32(ptr align 8 %.anon19, ptr align 8 %z, i32 24, i1 false)
%ptradd20 = getelementptr inbounds i8, ptr %.anon19, i64 8
%2 = load i32, ptr %.anon19, align 8
%lo21 = load i32, ptr %ptradd20, align 8
%ptradd22 = getelementptr inbounds i8, ptr %ptradd20, i64 8
%hi23 = load double, ptr %ptradd22, align 8
call void @test.test2(i32 %2, i32 %lo21, double %hi23)
store i32 100, ptr %z2, align 8
%ptradd24 = getelementptr inbounds i8, ptr %z2, i64 8
call void @llvm.memcpy.p0.p0.i32(ptr align 8 %.anon25, ptr align 8 %f, i32 16, i1 false)
%3 = load i32, ptr %.anon25, align 8
store i32 %3, ptr %ptradd24, align 8
%ptradd26 = getelementptr inbounds i8, ptr %ptradd24, i64 8
%ptradd27 = getelementptr inbounds i8, ptr %.anon25, i64 8
%4 = load double, ptr %ptradd27, align 8
store double %4, ptr %ptradd26, align 8
call void @llvm.memcpy.p0.p0.i32(ptr align 8 %.anon28, ptr align 8 %f, i32 16, i1 false)
%5 = load i32, ptr %.anon28, align 8
store i32 %5, ptr %.assign_list, align 8
%ptradd29 = getelementptr inbounds i8, ptr %.assign_list, i64 8
%ptradd30 = getelementptr inbounds i8, ptr %.anon28, i64 8
%6 = load double, ptr %ptradd30, align 8
store double %6, ptr %ptradd29, align 8
call void @llvm.memcpy.p0.p0.i32(ptr align 8 %f, ptr align 8 %.assign_list, i32 16, i1 false)
ret i32 0
}