From 8f5676b48878822624b2df232ce48161e30e2315 Mon Sep 17 00:00:00 2001 From: Christoffer Lerno Date: Tue, 21 Feb 2023 20:10:03 +0100 Subject: [PATCH] Add defer catch/try. Fix missing defer invoked on `return a > 0 ? Foo.ABC! : 1` --- src/compiler/compiler_internal.h | 3 + src/compiler/copying.c | 11 +- src/compiler/llvm_codegen_expr.c | 2 +- src/compiler/llvm_codegen_stmt.c | 9 +- src/compiler/parse_stmt.c | 18 +++ src/compiler/sema_expr.c | 7 +- src/compiler/sema_internal.h | 2 +- src/compiler/sema_liveness.c | 4 + src/compiler/sema_stmts.c | 58 ++++++--- src/compiler/semantic_analyser.c | 16 ++- src/version.h | 2 +- test/test_suite/defer/defer_catch_try.c3t | 148 ++++++++++++++++++++++ 12 files changed, 247 insertions(+), 33 deletions(-) create mode 100644 test/test_suite/defer/defer_catch_try.c3t diff --git a/src/compiler/compiler_internal.h b/src/compiler/compiler_internal.h index b8cd331b7..c92c99ddb 100644 --- a/src/compiler/compiler_internal.h +++ b/src/compiler/compiler_internal.h @@ -1195,6 +1195,7 @@ typedef struct { Expr *expr; // May be NULL AstId cleanup; + AstId cleanup_fail; BlockExit** block_exit_ref; // For block exits } AstReturnStmt; @@ -1300,6 +1301,8 @@ typedef struct { AstId prev_defer; AstId body; // Compound statement + bool is_try : 1; + bool is_catch : 1; } AstDeferStmt; diff --git a/src/compiler/copying.c b/src/compiler/copying.c index 857f8cea2..4e749af70 100644 --- a/src/compiler/copying.c +++ b/src/compiler/copying.c @@ -661,7 +661,16 @@ RETRY: case AST_BLOCK_EXIT_STMT: case AST_RETURN_STMT: MACRO_COPY_EXPR(ast->return_stmt.expr); - MACRO_COPY_ASTID(ast->return_stmt.cleanup); + if (ast->return_stmt.cleanup == ast->return_stmt.cleanup_fail) + { + MACRO_COPY_ASTID(ast->return_stmt.cleanup); + ast->return_stmt.cleanup_fail = ast->return_stmt.cleanup; + } + else + { + MACRO_COPY_ASTID(ast->return_stmt.cleanup); + MACRO_COPY_ASTID(ast->return_stmt.cleanup_fail); + } break; case AST_SWITCH_STMT: case AST_IF_CATCH_SWITCH_STMT: diff --git a/src/compiler/llvm_codegen_expr.c b/src/compiler/llvm_codegen_expr.c index 410241d6c..274b70285 100644 --- a/src/compiler/llvm_codegen_expr.c +++ b/src/compiler/llvm_codegen_expr.c @@ -5532,7 +5532,7 @@ static inline void llvm_emit_return_block(GenContext *c, BEValue *be_value, Type // Defers? In that case we also use the default behaviour. // We might optimize this later. - if (value->return_stmt.cleanup) break; + if (value->return_stmt.cleanup || value->return_stmt.cleanup_fail) break; Expr *ret_expr = value->return_stmt.expr; diff --git a/src/compiler/llvm_codegen_stmt.c b/src/compiler/llvm_codegen_stmt.c index 28daab15f..caf3f4fe8 100644 --- a/src/compiler/llvm_codegen_stmt.c +++ b/src/compiler/llvm_codegen_stmt.c @@ -174,7 +174,7 @@ static inline void llvm_emit_return(GenContext *c, Ast *ast) { BEValue be_value; llvm_emit_expr(c, &be_value, expr->inner_expr); - llvm_emit_statement_chain(c, ast->return_stmt.cleanup); + llvm_emit_statement_chain(c, ast->return_stmt.cleanup_fail); llvm_emit_return_abi(c, NULL, &be_value); return; } @@ -200,7 +200,6 @@ static inline void llvm_emit_return(GenContext *c, Ast *ast) POP_OPT(); - llvm_emit_statement_chain(c, ast->return_stmt.cleanup); // Are we in an expression block? @@ -216,6 +215,7 @@ static inline void llvm_emit_return(GenContext *c, Ast *ast) if (error_return_block && LLVMGetFirstUse(LLVMBasicBlockAsValue(error_return_block))) { llvm_emit_block(c, error_return_block); + llvm_emit_statement_chain(c, ast->return_stmt.cleanup_fail); BEValue value; llvm_value_set_address_abi_aligned(&value, error_out, type_anyerr); llvm_emit_return_abi(c, NULL, &value); @@ -242,7 +242,7 @@ static inline void llvm_emit_block_exit_return(GenContext *c, Ast *ast) BEValue return_value = { 0 }; if (ret_expr) { - if (ast->return_stmt.cleanup && IS_OPTIONAL(ret_expr)) + if (ast->return_stmt.cleanup_fail && IS_OPTIONAL(ret_expr)) { assert(c->catch_block); err_cleanup_block = llvm_basic_block_new(c, "opt_block_cleanup"); @@ -255,7 +255,8 @@ static inline void llvm_emit_block_exit_return(GenContext *c, Ast *ast) POP_OPT(); AstId cleanup = ast->return_stmt.cleanup; - AstId err_cleanup = err_cleanup_block && cleanup ? astid(copy_ast_defer(astptr(cleanup))) : 0; + AstId cleanup_fail = ast->return_stmt.cleanup_fail; + AstId err_cleanup = err_cleanup_block && cleanup_fail ? astid(copy_ast_defer(astptr(cleanup_fail))) : 0; llvm_emit_statement_chain(c, cleanup); if (exit->block_return_out && return_value.value) { diff --git a/src/compiler/parse_stmt.c b/src/compiler/parse_stmt.c index edf23e2bd..df728a8db 100644 --- a/src/compiler/parse_stmt.c +++ b/src/compiler/parse_stmt.c @@ -467,6 +467,24 @@ static inline Ast* parse_defer_stmt(ParseContext *c) { advance_and_verify(c, TOKEN_DEFER); Ast *defer_stmt = new_ast(AST_DEFER_STMT, c->span); + if (try_consume(c, TOKEN_TRY)) + { + defer_stmt->defer_stmt.is_try = true; + if (tok_is(c, TOKEN_LPAREN)) + { + SEMA_ERROR_HERE("Expected a '{' or a non-'try' statement after 'defer try'."); + return poisoned_ast; + } + } + else if (try_consume(c, TOKEN_CATCH)) + { + defer_stmt->defer_stmt.is_catch = true; + if (tok_is(c, TOKEN_LPAREN)) + { + SEMA_ERROR_HERE("Expected a '{' or a non-'catch' statement after 'defer catch'."); + return poisoned_ast; + } + } ASSIGN_ASTID_OR_RET(defer_stmt->defer_stmt.body, parse_stmt(c), poisoned_ast); return defer_stmt; } diff --git a/src/compiler/sema_expr.c b/src/compiler/sema_expr.c index 4fea54230..5f93d1542 100644 --- a/src/compiler/sema_expr.c +++ b/src/compiler/sema_expr.c @@ -560,9 +560,8 @@ static inline bool sema_cast_ident_rvalue(SemaContext *context, Expr *expr) SEMA_ERROR(expr, "Did you forget a '!' after '%s'?", decl->name); return expr_poison(expr); case DECL_ENUM_CONSTANT: - TODO - //expr_replace(expr, decl->enum_constant.expr); - return true; + // This can't happen, inferred identifiers are folded to consts they are never identifiers. + UNREACHABLE; case DECL_VAR: break; case DECL_DISTINCT: @@ -5887,7 +5886,7 @@ static inline bool sema_expr_analyse_rethrow(SemaContext *context, Expr *expr) SEMA_ERROR(expr, "Returns are not allowed inside of defers."); return false; } - expr->rethrow_expr.cleanup = context_get_defers(context, context->active_scope.defer_last, 0); + expr->rethrow_expr.cleanup = context_get_defers(context, context->active_scope.defer_last, 0, false); if (inner->type == type_anyfail) { SEMA_ERROR(expr, "This expression will always throw, which isn't allowed."); diff --git a/src/compiler/sema_internal.h b/src/compiler/sema_internal.h index 196132e1f..5a8dbfc42 100644 --- a/src/compiler/sema_internal.h +++ b/src/compiler/sema_internal.h @@ -50,7 +50,7 @@ Decl **global_context_acquire_locals_list(void); void generic_context_release_locals_list(Decl **); Type *global_context_string_type(void); -AstId context_get_defers(SemaContext *context, AstId defer_top, AstId defer_bottom); +AstId context_get_defers(SemaContext *context, AstId defer_top, AstId defer_bottom, bool is_success); void context_pop_defers(SemaContext *context, AstId *next); void context_pop_defers_and_replace_ast(SemaContext *context, Ast *ast); void context_change_scope_for_label(SemaContext *context, Decl *label); diff --git a/src/compiler/sema_liveness.c b/src/compiler/sema_liveness.c index ca78f7a2b..731350c66 100644 --- a/src/compiler/sema_liveness.c +++ b/src/compiler/sema_liveness.c @@ -104,6 +104,10 @@ static void sema_trace_stmt_liveness(Ast *ast) case AST_BLOCK_EXIT_STMT: sema_trace_expr_liveness(ast->return_stmt.expr); sema_trace_astid_liveness(ast->return_stmt.cleanup); + if (ast->return_stmt.cleanup != ast->return_stmt.cleanup_fail) + { + sema_trace_astid_liveness(ast->return_stmt.cleanup_fail); + } return; case AST_ASM_BLOCK_STMT: if (ast->asm_block_stmt.is_string) diff --git a/src/compiler/sema_stmts.c b/src/compiler/sema_stmts.c index 9d2f6d14e..680f06b42 100644 --- a/src/compiler/sema_stmts.c +++ b/src/compiler/sema_stmts.c @@ -23,6 +23,7 @@ static inline bool sema_analyse_nextcase_stmt(SemaContext *context, Ast *stateme static inline bool sema_analyse_return_stmt(SemaContext *context, Ast *statement); static inline bool sema_analyse_switch_stmt(SemaContext *context, Ast *statement); +static inline bool sema_defer_by_result(AstId defer_top, AstId defer_bottom); static inline bool sema_analyse_block_exit_stmt(SemaContext *context, Ast *statement); static inline bool sema_analyse_defer_stmt_body(SemaContext *context, Ast *statement, Ast *body); static inline bool sema_analyse_for_cond(SemaContext *context, ExprId *cond_ref, bool *infinite); @@ -126,7 +127,6 @@ static inline bool sema_analyse_assert_stmt(SemaContext *context, Ast *statement return true; } - /** * break and break LABEL; */ @@ -174,7 +174,7 @@ static inline bool sema_analyse_break_stmt(SemaContext *context, Ast *statement) statement->contbreak_stmt.ast = astid(parent); // Append the defers. - statement->contbreak_stmt.defers = context_get_defers(context, context->active_scope.defer_last, defer_begin); + statement->contbreak_stmt.defers = context_get_defers(context, context->active_scope.defer_last, defer_begin, true); return true; } @@ -233,7 +233,7 @@ static inline bool sema_analyse_continue_stmt(SemaContext *context, Ast *stateme // Link the parent and add the defers. statement->contbreak_stmt.ast = astid(parent); - statement->contbreak_stmt.defers = context_get_defers(context, context->active_scope.defer_last, defer_id); + statement->contbreak_stmt.defers = context_get_defers(context, context->active_scope.defer_last, defer_id, true); return true; } @@ -316,6 +316,28 @@ static inline bool assert_create_from_contract(SemaContext *context, Ast *direct return true; } +static inline bool sema_defer_by_result(AstId defer_top, AstId defer_bottom) +{ + AstId first = 0; + AstId *next = &first; + while (defer_bottom != defer_top) + { + Ast *defer = astptr(defer_top); + if (defer->defer_stmt.is_catch || defer->defer_stmt.is_try) return true; + defer_top = defer->defer_stmt.prev_defer; + } + return false; +} + +static inline void sema_inline_return_defers(SemaContext *context, Ast *stmt, AstId defer_top, AstId defer_bottom) +{ + stmt->return_stmt.cleanup_fail = stmt->return_stmt.cleanup = context_get_defers(context, defer_top, defer_bottom, true); + if (stmt->return_stmt.expr && IS_OPTIONAL(stmt->return_stmt.expr) && sema_defer_by_result(context->active_scope.defer_last, context->block_return_defer)) + { + stmt->return_stmt.cleanup_fail = context_get_defers(context, context->active_scope.defer_last, context->block_return_defer, false); + } +} + /** * Handle exit in a macro or in an expression block. * @param context @@ -348,7 +370,7 @@ static inline bool sema_analyse_block_exit_stmt(SemaContext *context, Ast *state } } statement->return_stmt.block_exit_ref = context->block_exit_ref; - statement->return_stmt.cleanup = context_get_defers(context, context->active_scope.defer_last, context->block_return_defer); + sema_inline_return_defers(context, statement, context->active_scope.defer_last, context->block_return_defer); vec_add(context->returns, statement); return true; } @@ -438,12 +460,12 @@ static inline bool sema_analyse_return_stmt(SemaContext *context, Ast *statement SEMA_ERROR(statement, "Expected to return a result of type %s.", type_to_error_string(expected_rtype)); return false; } - statement->return_stmt.cleanup = context_get_defers(context, context->active_scope.defer_last, 0); + statement->return_stmt.cleanup = context_get_defers(context, context->active_scope.defer_last, 0, true); return true; } // Process any ensures. - AstId cleanup = context_get_defers(context, context->active_scope.defer_last, 0); + sema_inline_return_defers(context, statement, context->active_scope.defer_last, 0); if (context->call_env.ensures) { AstId first = 0; @@ -460,21 +482,25 @@ static inline bool sema_analyse_return_stmt(SemaContext *context, Ast *statement } doc_directive = directive->next; } - if (cleanup) + if (!first) goto SKIP_ENSURE; + if (statement->return_stmt.cleanup) { - Ast *last = ast_last(astptr(cleanup)); + // If we have the same ast on cleanup / cleanup-fail we need to separate them. + if (type_is_optional(expected_rtype) && statement->return_stmt.cleanup == statement->return_stmt.cleanup_fail) + { + statement->return_stmt.cleanup_fail = astid(copy_ast_defer(astptr(statement->return_stmt.cleanup))); + } + Ast *last = ast_last(astptr(statement->return_stmt.cleanup)); last->next = first; } else { - cleanup = first; + statement->return_stmt.cleanup = statement->return_stmt.cleanup_fail = first; } } - - statement->return_stmt.cleanup = cleanup; +SKIP_ENSURE:; assert(type_no_optional(statement->return_stmt.expr->type)->canonical == type_no_optional(expected_rtype)->canonical); - return true; } @@ -1005,7 +1031,6 @@ bool sema_analyse_defer_stmt_body(SemaContext *context, Ast *statement, Ast *bod SEMA_ERROR(body, "A defer may not have a body consisting of a raw 'defer', this looks like a mistake."); return false; } - // TODO special parsing of "catch" bool success; SCOPE_START @@ -1033,7 +1058,6 @@ bool sema_analyse_defer_stmt_body(SemaContext *context, Ast *statement, Ast *bod } static inline bool sema_analyse_defer_stmt(SemaContext *context, Ast *statement) { - // TODO special parsing of "catch" if (!sema_analyse_defer_stmt_body(context, statement, astptr(statement->defer_stmt.body))) return false; statement->defer_stmt.prev_defer = context->active_scope.defer_last; @@ -1766,7 +1790,7 @@ static bool sema_analyse_nextcase_stmt(SemaContext *context, Ast *statement) if (!statement->nextcase_stmt.expr) { assert(context->next_target); - statement->nextcase_stmt.defer_id = context_get_defers(context, context->active_scope.defer_last, parent->switch_stmt.defer); + statement->nextcase_stmt.defer_id = context_get_defers(context, context->active_scope.defer_last, parent->switch_stmt.defer, true); statement->nextcase_stmt.case_switch_stmt = astid(context->next_target); return true; } @@ -1777,7 +1801,7 @@ static bool sema_analyse_nextcase_stmt(SemaContext *context, Ast *statement) TypeInfo *type_info = statement->nextcase_stmt.expr->type_expr; if (!sema_resolve_type_info(context, type_info)) return false; Ast **cases; - statement->nextcase_stmt.defer_id = context_get_defers(context, context->active_scope.defer_last, parent->switch_stmt.defer); + statement->nextcase_stmt.defer_id = context_get_defers(context, context->active_scope.defer_last, parent->switch_stmt.defer, true); if (cond->type->canonical != type_typeid) { SEMA_ERROR(statement, "Unexpected 'type' in as an 'nextcase' destination."); @@ -1818,7 +1842,7 @@ static bool sema_analyse_nextcase_stmt(SemaContext *context, Ast *statement) if (!sema_analyse_expr_rhs(context, expected_type, target, false)) return false; - statement->nextcase_stmt.defer_id = context_get_defers(context, context->active_scope.defer_last, parent->switch_stmt.defer); + statement->nextcase_stmt.defer_id = context_get_defers(context, context->active_scope.defer_last, parent->switch_stmt.defer, true); if (target->expr_kind == EXPR_CONST) { diff --git a/src/compiler/semantic_analyser.c b/src/compiler/semantic_analyser.c index 243bbbf66..55266d7e5 100644 --- a/src/compiler/semantic_analyser.c +++ b/src/compiler/semantic_analyser.c @@ -60,13 +60,18 @@ void context_change_scope_for_label(SemaContext *context, Decl *label) } } -AstId context_get_defers(SemaContext *context, AstId defer_top, AstId defer_bottom) +AstId context_get_defers(SemaContext *context, AstId defer_top, AstId defer_bottom, bool is_success) { AstId first = 0; AstId *next = &first; while (defer_bottom != defer_top) { Ast *defer = astptr(defer_top); + if ((is_success && defer->defer_stmt.is_catch) || (!is_success && defer->defer_stmt.is_try)) + { + defer_top = defer->defer_stmt.prev_defer; + continue; + } Ast *defer_body = copy_ast_defer(astptr(defer->defer_stmt.body)); *next = astid(defer_body); next = &defer_body->next; @@ -84,9 +89,12 @@ void context_pop_defers(SemaContext *context, AstId *next) while (defer_current != defer_start) { Ast *defer = astptr(defer_current); - Ast *defer_body = copy_ast_defer(astptr(defer->defer_stmt.body)); - *next = astid(defer_body); - next = &defer_body->next; + if (!defer->defer_stmt.is_catch) + { + Ast *defer_body = copy_ast_defer(astptr(defer->defer_stmt.body)); + *next = astid(defer_body); + next = &defer_body->next; + } defer_current = defer->defer_stmt.prev_defer; } } diff --git a/src/version.h b/src/version.h index 6ebc08eba..1975de5c7 100644 --- a/src/version.h +++ b/src/version.h @@ -1 +1 @@ -#define COMPILER_VERSION "0.4.79" \ No newline at end of file +#define COMPILER_VERSION "0.4.80" \ No newline at end of file diff --git a/test/test_suite/defer/defer_catch_try.c3t b/test/test_suite/defer/defer_catch_try.c3t new file mode 100644 index 000000000..da7f94bdf --- /dev/null +++ b/test/test_suite/defer/defer_catch_try.c3t @@ -0,0 +1,148 @@ +// #target: macos-x64 +module test; +extern fn void printf(char*, ...); + +fault Abc +{ + FOO +} +fn int! abc(int x) +{ + printf("Enter abc\n"); + defer catch printf("Abc catch %d\n", x); + defer try printf("Abc try %d\n", x); + defer printf("Abc normal %d\n", x); + return x > 0 ? Abc.FOO! : 0; +} +fn int! bcd(int x) +{ + printf("Enter bcd\n"); + for (int i = 0; i < 10; i++) + { + printf("bcd loop\n"); + defer catch printf("Bcd %d catch %d\n", i, x); + defer try printf("Bcd %d try %d\n", i, x); + defer printf("Bcd %d normal %d\n", i, x); + if (i == 1) continue; + printf("bcd check\n"); + if (i == 2) return x > 0 ? Abc.FOO! : 0; + } + return 0; +} + +fn int main() +{ + (void)abc(3); + (void)abc(-1); + (void)bcd(3); + (void)bcd(-1); + return 1; +} + + +/* #expect: test.ll + +define i64 @test.abc(ptr %0, i32 %1) #0 { +entry: + %reterr = alloca i64, align 8 + call void (ptr, ...) @printf(ptr @.str) + %gt = icmp sgt i32 %1, 0 + br i1 %gt, label %cond.lhs, label %cond.rhs + +cond.lhs: ; preds = %entry + store i64 ptrtoint (ptr @"test.Abc$FOO" to i64), ptr %reterr, align 8 + br label %err_retblock + +cond.rhs: ; preds = %entry + br label %cond.phi + +cond.phi: ; preds = %cond.rhs + call void (ptr, ...) @printf(ptr @.str.1, i32 %1) + call void (ptr, ...) @printf(ptr @.str.2, i32 %1) + store i32 0, ptr %0, align 4 + ret i64 0 + +err_retblock: ; preds = %cond.lhs + call void (ptr, ...) @printf(ptr @.str.3, i32 %1) + call void (ptr, ...) @printf(ptr @.str.4, i32 %1) + %2 = load i64, ptr %reterr, align 8 + ret i64 %2 +} + +define i64 @test.bcd(ptr %0, i32 %1) #0 { +entry: + %i = alloca i32, align 4 + %reterr = alloca i64, align 8 + %reterr4 = alloca i64, align 8 + call void (ptr, ...) @printf(ptr @.str.5) + store i32 0, ptr %i, align 4 + br label %loop.cond + +loop.cond: ; preds = %loop.inc, %entry + %2 = load i32, ptr %i, align 4 + %lt = icmp slt i32 %2, 10 + br i1 %lt, label %loop.body, label %loop.exit + +loop.body: ; preds = %loop.cond + call void (ptr, ...) @printf(ptr @.str.6) + %3 = load i32, ptr %i, align 4 + %eq = icmp eq i32 %3, 1 + br i1 %eq, label %if.then, label %if.exit + +if.then: ; preds = %loop.body + %4 = load i32, ptr %i, align 4 + call void (ptr, ...) @printf(ptr @.str.7, i32 %4, i32 %1) + %5 = load i32, ptr %i, align 4 + call void (ptr, ...) @printf(ptr @.str.8, i32 %5, i32 %1) + br label %loop.inc + +if.exit: ; preds = %loop.body + call void (ptr, ...) @printf(ptr @.str.9) + %6 = load i32, ptr %i, align 4 + %eq1 = icmp eq i32 %6, 2 + br i1 %eq1, label %if.then2, label %if.exit3 + +if.then2: ; preds = %if.exit + %gt = icmp sgt i32 %1, 0 + br i1 %gt, label %cond.lhs, label %cond.rhs + +cond.lhs: ; preds = %if.then2 + store i64 ptrtoint (ptr @"test.Abc$FOO" to i64), ptr %reterr, align 8 + br label %err_retblock + +cond.rhs: ; preds = %if.then2 + br label %cond.phi + +cond.phi: ; preds = %cond.rhs + %7 = load i32, ptr %i, align 4 + call void (ptr, ...) @printf(ptr @.str.10, i32 %7, i32 %1) + %8 = load i32, ptr %i, align 4 + call void (ptr, ...) @printf(ptr @.str.11, i32 %8, i32 %1) + store i32 0, ptr %0, align 4 + ret i64 0 + +err_retblock: ; preds = %cond.lhs + %9 = load i32, ptr %i, align 4 + call void (ptr, ...) @printf(ptr @.str.12, i32 %9, i32 %1) + %10 = load i32, ptr %i, align 4 + call void (ptr, ...) @printf(ptr @.str.13, i32 %10, i32 %1) + %11 = load i64, ptr %reterr, align 8 + ret i64 %11 + +if.exit3: ; preds = %if.exit + %12 = load i32, ptr %i, align 4 + call void (ptr, ...) @printf(ptr @.str.14, i32 %12, i32 %1) + %13 = load i32, ptr %i, align 4 + call void (ptr, ...) @printf(ptr @.str.15, i32 %13, i32 %1) + br label %loop.inc + +loop.inc: ; preds = %if.exit3, %if.then + %14 = load i32, ptr %i, align 4 + %add = add i32 %14, 1 + store i32 %add, ptr %i, align 4 + br label %loop.cond + +loop.exit: ; preds = %loop.cond + store i32 0, ptr %0, align 4 + ret i64 0 +}