diff --git a/releasenotes.md b/releasenotes.md index 63b4c0629..c8fe1018b 100644 --- a/releasenotes.md +++ b/releasenotes.md @@ -19,6 +19,7 @@ - Add `inline` to enums #1819. - Cleaner error message when missing comma in struct initializer #1941. - Distinct inline void causes unexpected error if used in slice #1946. +- Allow `fn int test() => @pool() { return 1; }` short function syntax usage #1906. ### Fixes - Fix issue requiring prefix on a generic interface declaration. diff --git a/src/compiler/compiler_internal.h b/src/compiler/compiler_internal.h index 7ca0c2ecc..7f07edda6 100644 --- a/src/compiler/compiler_internal.h +++ b/src/compiler/compiler_internal.h @@ -750,6 +750,7 @@ typedef struct bool must_use : 1; bool is_optional_return : 1; bool va_is_splat : 1; + bool is_outer_call : 1; Expr **arguments; union { Expr **varargs; diff --git a/src/compiler/parse_stmt.c b/src/compiler/parse_stmt.c index 37d076f22..81af7f441 100644 --- a/src/compiler/parse_stmt.c +++ b/src/compiler/parse_stmt.c @@ -1523,27 +1523,35 @@ Ast* parse_compound_stmt(ParseContext *c) return ast; } -Ast *parse_short_body(ParseContext *c, TypeInfoId return_type, bool require_eos) +Ast *parse_short_body(ParseContext *c, TypeInfoId return_type, bool is_regular_fn) { advance(c); Ast *ast = ast_new_curr(c, AST_COMPOUND_STMT); AstId *next = &ast->compound_stmt.first_stmt; + Ast *ret = ast_new_curr(c, AST_RETURN_STMT); + ast_append(&next, ret); TypeInfo *rtype = return_type ? type_infoptr(return_type) : NULL; - if (!rtype || (rtype->resolve_status != RESOLVE_DONE || rtype->type->type_kind != TYPE_VOID)) + bool is_void_return = rtype && rtype->resolve_status == RESOLVE_DONE && rtype->type->type_kind == TYPE_VOID; + ASSIGN_EXPR_OR_RET(Expr *expr, parse_expr(c), poisoned_ast); + if (expr->expr_kind == EXPR_CALL && expr->call_expr.macro_body) { - Ast *ret = ast_new_curr(c, AST_RETURN_STMT); - ast_append(&next, ret); - ASSIGN_EXPR_OR_RET(ret->return_stmt.expr, parse_expr(c), poisoned_ast); + ret->ast_kind = AST_EXPR_STMT; + ret->expr_stmt = expr; + is_regular_fn = false; + expr->call_expr.is_outer_call = true; + goto END; } - else + if (is_void_return) { - Ast *stmt = new_ast(AST_EXPR_STMT, c->span); - ASSIGN_EXPR_OR_RET(stmt->expr_stmt, parse_expr(c), poisoned_ast); - ast_append(&next, stmt); + ret->ast_kind = AST_EXPR_STMT; + ret->expr_stmt = expr; + goto END; } + ret->return_stmt.expr = expr; +END:; RANGE_EXTEND_PREV(ast); - if (require_eos) + if (is_regular_fn) { CONSUME_EOS_OR_RET(poisoned_ast); } diff --git a/src/compiler/parser_internal.h b/src/compiler/parser_internal.h index 5206b028c..fef77c45e 100644 --- a/src/compiler/parser_internal.h +++ b/src/compiler/parser_internal.h @@ -42,7 +42,7 @@ Expr *parse_decl_or_expr(ParseContext *c, Decl **decl_ref); void recover_top_level(ParseContext *c); Expr *parse_cond(ParseContext *c); Ast* parse_compound_stmt(ParseContext *c); -Ast *parse_short_body(ParseContext *c, TypeInfoId return_type, bool require_eos); +Ast *parse_short_body(ParseContext *c, TypeInfoId return_type, bool is_regular_fn); bool parse_attribute(ParseContext *c, Attr **attribute_ref, bool expect_eos); diff --git a/src/compiler/sema_decls.c b/src/compiler/sema_decls.c index 6652dffd6..0301cfdfe 100755 --- a/src/compiler/sema_decls.c +++ b/src/compiler/sema_decls.c @@ -3865,11 +3865,11 @@ static inline bool sema_analyse_macro(SemaContext *context, Decl *decl, bool *er type_infoptrzero(decl->func_decl.type_parent), false, deprecated, decl->span)) return false; - if (!decl->func_decl.signature.is_at_macro && decl->func_decl.body_param && !decl->func_decl.signature.is_safemacro) + DeclId body_param = decl->func_decl.body_param; + if (!decl->func_decl.signature.is_at_macro && body_param && !decl->func_decl.signature.is_safemacro) { RETURN_SEMA_ERROR(decl, "Names of macros with a trailing body must start with '@'."); } - DeclId body_param = decl->func_decl.body_param; Decl **body_parameters = body_param ? declptr(body_param)->body_params : NULL; if (!sema_analyse_macro_body(context, body_parameters)) return false; bool pure = false; diff --git a/src/compiler/sema_expr.c b/src/compiler/sema_expr.c index 61ab49d0f..52eecfb31 100644 --- a/src/compiler/sema_expr.c +++ b/src/compiler/sema_expr.c @@ -2081,6 +2081,7 @@ bool sema_expr_analyse_macro_call(SemaContext *context, Expr *call_expr, Expr *s bool call_var_optional, bool *no_match_ref) { bool is_always_const = decl->func_decl.signature.attrs.always_const; + bool is_outer = call_expr->call_expr.is_outer_call; ASSERT_SPAN(call_expr, decl->decl_kind == DECL_MACRO); if (context->macro_call_depth > 256) @@ -2442,6 +2443,12 @@ NOT_CT: call_expr->macro_block.block_exit = block_exit_ref; call_expr->macro_block.is_noreturn = is_no_return; EXIT: + if (is_outer && !type_is_void(call_expr->type)) + { + RETURN_SEMA_ERROR(call_expr, "The macro itself returns %s here, but only 'void' is permitted " + "when a macro with trailing body is used directly after '=>'.", + type_quoted_error_string(rtype)); + } ASSERT_SPAN(call_expr, context->active_scope.defer_last == context->active_scope.defer_start); context->active_scope = old_scope; if (is_no_return) context->active_scope.jump_end = true; diff --git a/test/test_suite/macros/short_trailing_body.c3t b/test/test_suite/macros/short_trailing_body.c3t new file mode 100644 index 000000000..88e69c7b8 --- /dev/null +++ b/test/test_suite/macros/short_trailing_body.c3t @@ -0,0 +1,122 @@ +// #target: macos-x64 +module test; +import std; + +fn int foo(int x) => @pool() +{ + String s = string::tformat("%d", x); + io::printn(s); + return 2; +} + +fn void main() +{ + foo(3); +} + +/* #expect: test.ll + +entry: + %current = alloca ptr, align 8 + %mark = alloca i64, align 8 + %s = alloca %"char[]", align 8 + %varargslots = alloca [1 x %any], align 16 + %taddr = alloca i32, align 4 + %result = alloca %"char[]", align 8 + %x = alloca %"char[]", align 8 + %x1 = alloca %"char[]", align 8 + %len = alloca i64, align 8 + %error_var = alloca i64, align 8 + %x2 = alloca %"char[]", align 8 + %retparam = alloca i64, align 8 + %error_var5 = alloca i64, align 8 + %error_var11 = alloca i64, align 8 + %1 = load ptr, ptr @std.core.mem.allocator.thread_temp_allocator, align 8 + %i2nb = icmp eq ptr %1, null + br i1 %i2nb, label %if.then, label %if.exit + +if.then: ; preds = %entry + call void @std.core.mem.allocator.init_default_temp_allocators() + br label %if.exit + +if.exit: ; preds = %if.then, %entry + %2 = load ptr, ptr @std.core.mem.allocator.thread_temp_allocator, align 8 + store ptr %2, ptr %current, align 8 + %3 = load ptr, ptr %current, align 8 + %ptradd = getelementptr inbounds i8, ptr %3, i64 24 + %4 = load i64, ptr %ptradd, align 8 + store i64 %4, ptr %mark, align 8 + store i32 %0, ptr %taddr, align 4 + %5 = insertvalue %any undef, ptr %taddr, 0 + %6 = insertvalue %any %5, i64 ptrtoint (ptr @"$ct.int" to i64), 1 + store %any %6, ptr %varargslots, align 16 + %7 = call { ptr, i64 } @std.core.string.tformat(ptr @.str, i64 2, ptr %varargslots, i64 1) + store { ptr, i64 } %7, ptr %result, align 8 + call void @llvm.memcpy.p0.p0.i32(ptr align 8 %s, ptr align 8 %result, i32 16, i1 false) + call void @llvm.memcpy.p0.p0.i32(ptr align 8 %x, ptr align 8 %s, i32 16, i1 false) + %8 = call ptr @std.io.stdout() + call void @llvm.memcpy.p0.p0.i32(ptr align 8 %x1, ptr align 8 %x, i32 16, i1 false) + call void @llvm.memcpy.p0.p0.i32(ptr align 8 %x2, ptr align 8 %x1, i32 16, i1 false) + %lo = load ptr, ptr %x2, align 8 + %ptradd4 = getelementptr inbounds i8, ptr %x2, i64 8 + %hi = load i64, ptr %ptradd4, align 8 + %9 = call i64 @std.io.File.write(ptr %retparam, ptr %8, ptr %lo, i64 %hi) + %not_err = icmp eq i64 %9, 0 + %10 = call i1 @llvm.expect.i1(i1 %not_err, i1 true) + br i1 %10, label %after_check, label %assign_optional + +assign_optional: ; preds = %if.exit + store i64 %9, ptr %error_var, align 8 + br label %guard_block + +after_check: ; preds = %if.exit + br label %noerr_block + +guard_block: ; preds = %assign_optional + br label %voiderr + +noerr_block: ; preds = %after_check + %11 = load i64, ptr %retparam, align 8 + store i64 %11, ptr %len, align 8 + %12 = call i64 @std.io.File.write_byte(ptr %8, i8 zeroext 10) + %not_err6 = icmp eq i64 %12, 0 + %13 = call i1 @llvm.expect.i1(i1 %not_err6, i1 true) + br i1 %13, label %after_check8, label %assign_optional7 + +assign_optional7: ; preds = %noerr_block + store i64 %12, ptr %error_var5, align 8 + br label %guard_block9 + +after_check8: ; preds = %noerr_block + br label %noerr_block10 + +guard_block9: ; preds = %assign_optional7 + br label %voiderr + +noerr_block10: ; preds = %after_check8 + %14 = call i64 @std.io.File.flush(ptr %8) + %not_err12 = icmp eq i64 %14, 0 + %15 = call i1 @llvm.expect.i1(i1 %not_err12, i1 true) + br i1 %15, label %after_check14, label %assign_optional13 + +assign_optional13: ; preds = %noerr_block10 + store i64 %14, ptr %error_var11, align 8 + br label %guard_block15 + +after_check14: ; preds = %noerr_block10 + br label %noerr_block16 + +guard_block15: ; preds = %assign_optional13 + br label %voiderr + +noerr_block16: ; preds = %after_check14 + %16 = load i64, ptr %len, align 8 + %add = add i64 %16, 1 + br label %voiderr + +voiderr: ; preds = %noerr_block16, %guard_block15, %guard_block9, %guard_block + %17 = load ptr, ptr %current, align 8 + %18 = load i64, ptr %mark, align 8 + call void @std.core.mem.allocator.TempAllocator.reset(ptr %17, i64 %18) + ret i32 2 +} diff --git a/test/test_suite/macros/trailing_body_type.c3 b/test/test_suite/macros/trailing_body_type.c3 new file mode 100644 index 000000000..b73746620 --- /dev/null +++ b/test/test_suite/macros/trailing_body_type.c3 @@ -0,0 +1,10 @@ +macro @test(int x = 1; @body()) +{ + if (x > 0) @body(); + return x; +} + +fn int foo(int x) => @test(x) // #error: The macro itself returns 'int' here +{ + return 2; +}