From 223501eeca1203ef7d5212c14e80d10e859e5ae2 Mon Sep 17 00:00:00 2001 From: Christoffer Lerno Date: Sat, 7 Sep 2024 05:26:43 +0200 Subject: [PATCH] Support splat for varargs #1352. --- releasenotes.md | 1 + src/compiler/sema_expr.c | 155 +++++++++++++++++++---- test/test_suite/functions/multisplat.c3t | 116 +++++++++++++++++ test/test_suite/functions/raw_splat.c3t | 150 ++++++++++++++++++++++ 4 files changed, 399 insertions(+), 23 deletions(-) create mode 100644 test/test_suite/functions/multisplat.c3t create mode 100644 test/test_suite/functions/raw_splat.c3t diff --git a/releasenotes.md b/releasenotes.md index ece882bd5..a0903fc10 100644 --- a/releasenotes.md +++ b/releasenotes.md @@ -4,6 +4,7 @@ ### Changes / improvements - Introduce `arg: x` named arguments instead of `.arg = x`, deprecate old style. +- Support splat for varargs #1352. ### Fixes - Issue where a lambda wasn't correctly registered as external. #1408 diff --git a/src/compiler/sema_expr.c b/src/compiler/sema_expr.c index a9d16032d..31a245b18 100644 --- a/src/compiler/sema_expr.c +++ b/src/compiler/sema_expr.c @@ -1456,6 +1456,89 @@ INLINE bool sema_call_splat_vasplat(SemaContext *context, Expr *arg, Expr ***arg *args_ref = args; return true; } + + +INLINE Expr **sema_splat_arraylike_append(SemaContext *context, Expr **args, Expr *arg, ArrayIndex len) +{ + Decl *temp = decl_new_generated_var(arg->type, VARDECL_LOCAL, arg->span); + Expr *decl_expr = expr_generate_decl(temp, arg); + Expr *list = expr_new_expr(EXPR_EXPRESSION_LIST, arg); + vec_add(list->expression_list, decl_expr); + Expr *subscript = expr_new_expr(EXPR_SUBSCRIPT, arg); + subscript->subscript_expr.range.start = exprid(expr_new_const_int(arg->span, type_usz, 0)); + subscript->subscript_expr.expr = exprid(expr_variable(temp)); + vec_add(list->expression_list, subscript); + if (!sema_analyse_expr(context, list)) return NULL; + vec_add(args, list); + for (ArrayIndex i = 1; i < len; i++) + { + subscript = expr_new_expr(EXPR_SUBSCRIPT, arg); + subscript->subscript_expr.range.start = exprid(expr_new_const_int(arg->span, type_usz, i)); + subscript->subscript_expr.expr = exprid(expr_variable(temp)); + vec_add(args, subscript); + } + return args; +} +INLINE bool sema_call_splat_arraylike(SemaContext *context, Expr *arg, Expr ***args_ref, int index, ArrayIndex len) +{ + ASSERT_SPAN(arg, len > 0); + + // If it was the last element then just append. + Expr **args = *args_ref; + unsigned num_args = vec_size(args); + if (index == num_args - 1) + { + vec_pop(args); + args = sema_splat_arraylike_append(context, args, arg, len); + if (!args) return false; + *args_ref = args; + return true; + } + // Otherwise append to the end. + args = sema_splat_arraylike_append(context, args, arg, len); + if (!args) return false; + unsigned new_size = vec_size(args); + // Same after size => then just remove the $vasplat + ASSERT_SPAN(arg, new_size != num_args); + unsigned added_elements = new_size - num_args; + // Copy those elements + for (unsigned j = 0; j < added_elements; j++) + { + unsigned dest = index + j; + unsigned source = num_args + j; + // Copy the next element to the index position. + args[dest] = args[source]; + // Copy the following into the place of the index. + args[source] = args[dest + 1]; + } + vec_pop(args); + *args_ref = args; + return true; +} + +static inline ArrayIndex sema_len_from_expr(Expr *expr) +{ + Type *type = type_flatten(expr->type); + switch (type->type_kind) + { + case TYPE_VECTOR: + case TYPE_ARRAY: + return type->array.len; + case TYPE_UNTYPED_LIST: + return sema_len_from_const(expr); + case TYPE_SLICE: + break; + default: + return -1; + } + if (sema_cast_const(expr)) + { + return sema_len_from_const(expr); + } + if (expr->expr_kind != EXPR_SLICE) return -1; + return range_const_len(&expr->subscript_expr.range); +} + INLINE bool sema_call_evaluate_arguments(SemaContext *context, CalledDecl *callee, Expr *call, bool *optional, bool *no_match_ref) { @@ -1481,12 +1564,13 @@ INLINE bool sema_call_evaluate_arguments(SemaContext *context, CalledDecl *calle // We might have a typed variadic call e.g. foo(int, double...) // get that type. Type *variadic_type = NULL; + Type *variadic_slot_type = NULL; if (variadic == VARIADIC_TYPED || variadic == VARIADIC_ANY) { // 7a. The parameter type is [], so we get the - Type *vararg_slot_type = decl_params[vaarg_index]->type; - ASSERT_SPAN(call, vararg_slot_type->type_kind == TYPE_SLICE); - variadic_type = vararg_slot_type->array.base; + variadic_slot_type = decl_params[vaarg_index]->type; + ASSERT_SPAN(call, variadic_slot_type->type_kind == TYPE_SLICE); + variadic_type = variadic_slot_type->array.base; } Expr **args = call->call_expr.arguments; @@ -1525,27 +1609,56 @@ INLINE bool sema_call_evaluate_arguments(SemaContext *context, CalledDecl *calle { RETURN_SEMA_ERROR(arg, "Splat is only possible with variadic functions."); } - if (!variadic_type) - { - RETURN_SEMA_ERROR(arg, "Splat may not be used with raw varargs."); - } - if (i != vaarg_index) - { - RETURN_SEMA_ERROR(arg, "Expected a splat only in the vaarg slot."); - } - call->call_expr.va_is_splat = true; + Expr *inner = arg->inner_expr; - // Potentially should be inferred + if (!sema_analyse_expr(context, inner)) return false; - if (!expr_may_splat_as_vararg(inner, variadic_type)) + // Let's try fit up a slice to the in the vaslot + if (variadic_type && i == vaarg_index) { - RETURN_SEMA_ERROR(inner, "It's not possible to splat %s as vararg of type %s", - type_quoted_error_string(inner->type), - type_quoted_error_string(variadic_type)); + // Is it not the last and not a named argument, then we do a normal splat. + if (i + 1 < num_args && args[i + 1]->expr_kind != EXPR_NAMED_ARGUMENT) goto SPLAT_NORMAL; + + // Convert an array/vector to an address of an array. + Expr *inner_new = inner; + if (type_is_arraylike(inner->type)) + { + inner_new = expr_copy(inner); + expr_insert_addr(inner_new); + } + if (!cast_implicit_silent(context, inner_new, variadic_slot_type, false)) goto SPLAT_NORMAL; + if (inner != inner_new) expr_replace(inner, inner_new); + // We splat it in the right spot! + call->call_expr.va_is_splat = true; + *optional |= IS_OPTIONAL(inner); + call->call_expr.vasplat = inner; + continue; } - *optional |= IS_OPTIONAL(inner); - call->call_expr.vasplat = inner; +SPLAT_NORMAL:; + Type *flat = type_flatten(inner->type); + switch (flat->type_kind) + { + case TYPE_VECTOR: + case TYPE_ARRAY: + case TYPE_SLICE: + case TYPE_UNTYPED_LIST: + // These may be splatted + break; + default: + RETURN_SEMA_ERROR(arg, "An argument of type %s cannot be splatted.", + type_quoted_error_string(inner->type)); + } + // This is the fallback: just splat like vasplat: + ArrayIndex len = sema_len_from_expr(inner); + if (len == -1) RETURN_SEMA_ERROR(arg, "Splat may not be used with raw varargs if the length is not known."); + if (len == 0 && !expr_is_const(arg)) + { + RETURN_SEMA_ERROR(arg, "A non-constant zero size splat cannot be used with raw varargs."); + } + if (!sema_call_splat_arraylike(context, inner, &args, i, len)) return false; + i--; + num_args = vec_size(args); continue; } if (arg->expr_kind == EXPR_NAMED_ARGUMENT) @@ -4951,10 +5064,6 @@ Expr **sema_expand_vasplat_exprs(SemaContext *c, Expr **exprs) return exprs; } - - - - static inline bool sema_expr_analyse_expr_list(SemaContext *context, Expr *expr) { bool success = true; diff --git a/test/test_suite/functions/multisplat.c3t b/test/test_suite/functions/multisplat.c3t new file mode 100644 index 000000000..70943449a --- /dev/null +++ b/test/test_suite/functions/multisplat.c3t @@ -0,0 +1,116 @@ +// #target: macos-x64 +module test; +extern fn void foo(args...); + +fn void main() +{ + int[3] y = { 33, 44, 55 }; + foo(...y); + foo(1, ...y); + int[] z = &y; + foo(...z[1..2]); + foo(1, ...z[1..2], 5); +} + +/* #expect: test.ll + +define void @test.main() #0 { +entry: + %y = alloca [3 x i32], align 4 + %varargslots = alloca [3 x %any], align 16 + %.anon = alloca [3 x i32], align 4 + %varargslots4 = alloca [4 x %any], align 16 + %taddr = alloca i32, align 4 + %.anon5 = alloca [3 x i32], align 4 + %z = alloca %"int[]", align 8 + %varargslots11 = alloca [2 x %any], align 16 + %.anon12 = alloca %"int[]", align 8 + %varargslots16 = alloca [4 x %any], align 16 + %taddr17 = alloca i32, align 4 + %.anon18 = alloca %"int[]", align 8 + %taddr23 = alloca i32, align 4 + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %y, ptr align 4 @.__const, i32 12, i1 false) + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %.anon, ptr align 4 %y, i32 12, i1 false) + %0 = insertvalue %any undef, ptr %.anon, 0 + %1 = insertvalue %any %0, i64 ptrtoint (ptr @"$ct.int" to i64), 1 + store %any %1, ptr %varargslots, align 16 + %ptradd = getelementptr inbounds i8, ptr %.anon, i64 4 + %2 = insertvalue %any undef, ptr %ptradd, 0 + %3 = insertvalue %any %2, i64 ptrtoint (ptr @"$ct.int" to i64), 1 + %ptradd1 = getelementptr inbounds i8, ptr %varargslots, i64 16 + store %any %3, ptr %ptradd1, align 16 + %ptradd2 = getelementptr inbounds i8, ptr %.anon, i64 8 + %4 = insertvalue %any undef, ptr %ptradd2, 0 + %5 = insertvalue %any %4, i64 ptrtoint (ptr @"$ct.int" to i64), 1 + %ptradd3 = getelementptr inbounds i8, ptr %varargslots, i64 32 + store %any %5, ptr %ptradd3, align 16 + call void @foo(ptr %varargslots, i64 3) + store i32 1, ptr %taddr, align 4 + %6 = insertvalue %any undef, ptr %taddr, 0 + %7 = insertvalue %any %6, i64 ptrtoint (ptr @"$ct.int" to i64), 1 + store %any %7, ptr %varargslots4, align 16 + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %.anon5, ptr align 4 %y, i32 12, i1 false) + %8 = insertvalue %any undef, ptr %.anon5, 0 + %9 = insertvalue %any %8, i64 ptrtoint (ptr @"$ct.int" to i64), 1 + %ptradd6 = getelementptr inbounds i8, ptr %varargslots4, i64 16 + store %any %9, ptr %ptradd6, align 16 + %ptradd7 = getelementptr inbounds i8, ptr %.anon5, i64 4 + %10 = insertvalue %any undef, ptr %ptradd7, 0 + %11 = insertvalue %any %10, i64 ptrtoint (ptr @"$ct.int" to i64), 1 + %ptradd8 = getelementptr inbounds i8, ptr %varargslots4, i64 32 + store %any %11, ptr %ptradd8, align 16 + %ptradd9 = getelementptr inbounds i8, ptr %.anon5, i64 8 + %12 = insertvalue %any undef, ptr %ptradd9, 0 + %13 = insertvalue %any %12, i64 ptrtoint (ptr @"$ct.int" to i64), 1 + %ptradd10 = getelementptr inbounds i8, ptr %varargslots4, i64 48 + store %any %13, ptr %ptradd10, align 16 + call void @foo(ptr %varargslots4, i64 4) + %14 = insertvalue %"int[]" undef, ptr %y, 0 + %15 = insertvalue %"int[]" %14, i64 3, 1 + store %"int[]" %15, ptr %z, align 8 + %16 = load %"int[]", ptr %z, align 8 + %17 = extractvalue %"int[]" %16, 0 + %ptradd13 = getelementptr inbounds i8, ptr %17, i64 4 + %18 = insertvalue %"int[]" undef, ptr %ptradd13, 0 + %19 = insertvalue %"int[]" %18, i64 2, 1 + store %"int[]" %19, ptr %.anon12, align 8 + %20 = load ptr, ptr %.anon12, align 8 + %21 = insertvalue %any undef, ptr %20, 0 + %22 = insertvalue %any %21, i64 ptrtoint (ptr @"$ct.int" to i64), 1 + store %any %22, ptr %varargslots11, align 16 + %23 = load ptr, ptr %.anon12, align 8 + %ptradd14 = getelementptr inbounds i8, ptr %23, i64 4 + %24 = insertvalue %any undef, ptr %ptradd14, 0 + %25 = insertvalue %any %24, i64 ptrtoint (ptr @"$ct.int" to i64), 1 + %ptradd15 = getelementptr inbounds i8, ptr %varargslots11, i64 16 + store %any %25, ptr %ptradd15, align 16 + call void @foo(ptr %varargslots11, i64 2) + store i32 1, ptr %taddr17, align 4 + %26 = insertvalue %any undef, ptr %taddr17, 0 + %27 = insertvalue %any %26, i64 ptrtoint (ptr @"$ct.int" to i64), 1 + store %any %27, ptr %varargslots16, align 16 + %28 = load %"int[]", ptr %z, align 8 + %29 = extractvalue %"int[]" %28, 0 + %ptradd19 = getelementptr inbounds i8, ptr %29, i64 4 + %30 = insertvalue %"int[]" undef, ptr %ptradd19, 0 + %31 = insertvalue %"int[]" %30, i64 2, 1 + store %"int[]" %31, ptr %.anon18, align 8 + %32 = load ptr, ptr %.anon18, align 8 + %33 = insertvalue %any undef, ptr %32, 0 + %34 = insertvalue %any %33, i64 ptrtoint (ptr @"$ct.int" to i64), 1 + %ptradd20 = getelementptr inbounds i8, ptr %varargslots16, i64 16 + store %any %34, ptr %ptradd20, align 16 + %35 = load ptr, ptr %.anon18, align 8 + %ptradd21 = getelementptr inbounds i8, ptr %35, i64 4 + %36 = insertvalue %any undef, ptr %ptradd21, 0 + %37 = insertvalue %any %36, i64 ptrtoint (ptr @"$ct.int" to i64), 1 + %ptradd22 = getelementptr inbounds i8, ptr %varargslots16, i64 32 + store %any %37, ptr %ptradd22, align 16 + store i32 5, ptr %taddr23, align 4 + %38 = insertvalue %any undef, ptr %taddr23, 0 + %39 = insertvalue %any %38, i64 ptrtoint (ptr @"$ct.int" to i64), 1 + %ptradd24 = getelementptr inbounds i8, ptr %varargslots16, i64 48 + store %any %39, ptr %ptradd24, align 16 + call void @foo(ptr %varargslots16, i64 4) + ret void +} diff --git a/test/test_suite/functions/raw_splat.c3t b/test/test_suite/functions/raw_splat.c3t new file mode 100644 index 000000000..83c3be48b --- /dev/null +++ b/test/test_suite/functions/raw_splat.c3t @@ -0,0 +1,150 @@ +// #target: macos-x64 +module rawsplat; + +int x; +macro void foo(...) +{ + $for (var $i = 0; $i < $vacount; $i++) + x += $vaarg[$i] * $i; + $endfor +} + +fn void main() +{ + int[3] y = { 33, 44, 55 }; + foo(...y); + foo(1, ...y); + int[] z = &y; + foo(...z[1..2]); + foo(1, ...z[1..2], 5); +} + +/* #expect: rawsplat.ll + +define void @rawsplat.main() #0 { +entry: + %y = alloca [3 x i32], align 4 + %.anon = alloca [3 x i32], align 4 + %.anon1 = alloca i32, align 4 + %.anon2 = alloca i32, align 4 + %.anon4 = alloca i32, align 4 + %.anon9 = alloca [3 x i32], align 4 + %.anon10 = alloca i32, align 4 + %.anon12 = alloca i32, align 4 + %.anon14 = alloca i32, align 4 + %z = alloca %"int[]", align 8 + %.anon22 = alloca %"int[]", align 8 + %.anon24 = alloca i32, align 4 + %.anon26 = alloca i32, align 4 + %.anon31 = alloca %"int[]", align 8 + %.anon33 = alloca i32, align 4 + %.anon35 = alloca i32, align 4 + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %y, ptr align 4 @.__const, i32 12, i1 false) + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %.anon, ptr align 4 %y, i32 12, i1 false) + %0 = load i32, ptr %.anon, align 4 + store i32 %0, ptr %.anon1, align 4 + %ptradd = getelementptr inbounds i8, ptr %.anon, i64 4 + %1 = load i32, ptr %ptradd, align 4 + store i32 %1, ptr %.anon2, align 4 + %ptradd3 = getelementptr inbounds i8, ptr %.anon, i64 8 + %2 = load i32, ptr %ptradd3, align 4 + store i32 %2, ptr %.anon4, align 4 + %3 = load i32, ptr @rawsplat.x, align 4 + %4 = load i32, ptr %.anon1, align 4 + %mul = mul i32 %4, 0 + %add = add i32 %3, %mul + store i32 %add, ptr @rawsplat.x, align 4 + %5 = load i32, ptr @rawsplat.x, align 4 + %6 = load i32, ptr %.anon2, align 4 + %mul5 = mul i32 %6, 1 + %add6 = add i32 %5, %mul5 + store i32 %add6, ptr @rawsplat.x, align 4 + %7 = load i32, ptr @rawsplat.x, align 4 + %8 = load i32, ptr %.anon4, align 4 + %mul7 = mul i32 %8, 2 + %add8 = add i32 %7, %mul7 + store i32 %add8, ptr @rawsplat.x, align 4 + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %.anon9, ptr align 4 %y, i32 12, i1 false) + %9 = load i32, ptr %.anon9, align 4 + store i32 %9, ptr %.anon10, align 4 + %ptradd11 = getelementptr inbounds i8, ptr %.anon9, i64 4 + %10 = load i32, ptr %ptradd11, align 4 + store i32 %10, ptr %.anon12, align 4 + %ptradd13 = getelementptr inbounds i8, ptr %.anon9, i64 8 + %11 = load i32, ptr %ptradd13, align 4 + store i32 %11, ptr %.anon14, align 4 + %12 = load i32, ptr @rawsplat.x, align 4 + %add15 = add i32 %12, 0 + store i32 %add15, ptr @rawsplat.x, align 4 + %13 = load i32, ptr @rawsplat.x, align 4 + %14 = load i32, ptr %.anon10, align 4 + %mul16 = mul i32 %14, 1 + %add17 = add i32 %13, %mul16 + store i32 %add17, ptr @rawsplat.x, align 4 + %15 = load i32, ptr @rawsplat.x, align 4 + %16 = load i32, ptr %.anon12, align 4 + %mul18 = mul i32 %16, 2 + %add19 = add i32 %15, %mul18 + store i32 %add19, ptr @rawsplat.x, align 4 + %17 = load i32, ptr @rawsplat.x, align 4 + %18 = load i32, ptr %.anon14, align 4 + %mul20 = mul i32 %18, 3 + %add21 = add i32 %17, %mul20 + store i32 %add21, ptr @rawsplat.x, align 4 + %19 = insertvalue %"int[]" undef, ptr %y, 0 + %20 = insertvalue %"int[]" %19, i64 3, 1 + store %"int[]" %20, ptr %z, align 8 + %21 = load %"int[]", ptr %z, align 8 + %22 = extractvalue %"int[]" %21, 0 + %ptradd23 = getelementptr inbounds i8, ptr %22, i64 4 + %23 = insertvalue %"int[]" undef, ptr %ptradd23, 0 + %24 = insertvalue %"int[]" %23, i64 2, 1 + store %"int[]" %24, ptr %.anon22, align 8 + %25 = load ptr, ptr %.anon22, align 8 + %26 = load i32, ptr %25, align 4 + store i32 %26, ptr %.anon24, align 4 + %27 = load ptr, ptr %.anon22, align 8 + %ptradd25 = getelementptr inbounds i8, ptr %27, i64 4 + %28 = load i32, ptr %ptradd25, align 4 + store i32 %28, ptr %.anon26, align 4 + %29 = load i32, ptr @rawsplat.x, align 4 + %30 = load i32, ptr %.anon24, align 4 + %mul27 = mul i32 %30, 0 + %add28 = add i32 %29, %mul27 + store i32 %add28, ptr @rawsplat.x, align 4 + %31 = load i32, ptr @rawsplat.x, align 4 + %32 = load i32, ptr %.anon26, align 4 + %mul29 = mul i32 %32, 1 + %add30 = add i32 %31, %mul29 + store i32 %add30, ptr @rawsplat.x, align 4 + %33 = load %"int[]", ptr %z, align 8 + %34 = extractvalue %"int[]" %33, 0 + %ptradd32 = getelementptr inbounds i8, ptr %34, i64 4 + %35 = insertvalue %"int[]" undef, ptr %ptradd32, 0 + %36 = insertvalue %"int[]" %35, i64 2, 1 + store %"int[]" %36, ptr %.anon31, align 8 + %37 = load ptr, ptr %.anon31, align 8 + %38 = load i32, ptr %37, align 4 + store i32 %38, ptr %.anon33, align 4 + %39 = load ptr, ptr %.anon31, align 8 + %ptradd34 = getelementptr inbounds i8, ptr %39, i64 4 + %40 = load i32, ptr %ptradd34, align 4 + store i32 %40, ptr %.anon35, align 4 + %41 = load i32, ptr @rawsplat.x, align 4 + %add36 = add i32 %41, 0 + store i32 %add36, ptr @rawsplat.x, align 4 + %42 = load i32, ptr @rawsplat.x, align 4 + %43 = load i32, ptr %.anon33, align 4 + %mul37 = mul i32 %43, 1 + %add38 = add i32 %42, %mul37 + store i32 %add38, ptr @rawsplat.x, align 4 + %44 = load i32, ptr @rawsplat.x, align 4 + %45 = load i32, ptr %.anon35, align 4 + %mul39 = mul i32 %45, 2 + %add40 = add i32 %44, %mul39 + store i32 %add40, ptr @rawsplat.x, align 4 + %46 = load i32, ptr @rawsplat.x, align 4 + %add41 = add i32 %46, 15 + store i32 %add41, ptr @rawsplat.x, align 4 + ret void +}