From 373001fd1237d4dcd3188dba529c3afaaec03e61 Mon Sep 17 00:00:00 2001 From: Christoffer Lerno Date: Sun, 3 May 2020 02:04:13 +0200 Subject: [PATCH] Error -> errset (temporarily). Catch / throw now works, but it will not yet correctly handle defer. --- missing.txt | 2 +- resources/testfragments/super_simple.c3 | 132 +++++++- src/compiler/ast.c | 8 +- src/compiler/casts.c | 109 ++++++- src/compiler/compiler_internal.h | 104 +++++- src/compiler/enums.h | 6 +- src/compiler/llvm_codegen.c | 49 ++- src/compiler/llvm_codegen_expr.c | 412 +++++++++++++++++++++--- src/compiler/llvm_codegen_function.c | 25 +- src/compiler/llvm_codegen_internal.h | 28 +- src/compiler/llvm_codegen_stmt.c | 72 +++-- src/compiler/llvm_codegen_type.c | 35 +- src/compiler/parse_expr.c | 13 + src/compiler/parse_stmt.c | 24 +- src/compiler/parser.c | 13 +- src/compiler/sema_decls.c | 17 +- src/compiler/sema_expr.c | 75 ++++- src/compiler/sema_stmts.c | 277 +++++++++++++--- src/compiler/sema_types.c | 22 ++ src/compiler/tokens.c | 2 + src/compiler/types.c | 36 ++- src/utils/lib.h | 5 + src/utils/stringutils.c | 11 + 23 files changed, 1268 insertions(+), 209 deletions(-) diff --git a/missing.txt b/missing.txt index c5d9868de..a6065f84b 100644 --- a/missing.txt +++ b/missing.txt @@ -67,7 +67,7 @@ Things missing: * Error handling - Error unions - Catch/try -- Function return channel +- Try goto/return/break/continue/throw * Enum - Values: min, max, array diff --git a/resources/testfragments/super_simple.c3 b/resources/testfragments/super_simple.c3 index 5b6c99ece..2e9fcd983 100644 --- a/resources/testfragments/super_simple.c3 +++ b/resources/testfragments/super_simple.c3 @@ -64,18 +64,17 @@ enum EnumTestOverflowAfter -/* -error Error +errset Error { BLURB, NO_SUCH_FILE, } -error OtherError +errset OtherError { FOO_BAR } -*/ + enum Inf { A, @@ -608,11 +607,47 @@ struct WithArray int[4] x; } -error Err +errset Err { TEST_ERR1 } +errset Err2 +{ + TEST_ERR2 +} + +func int testThrow1() throws Err +{ + throw Err.TEST_ERR1; +} + +func int testThrow2() throws Err, Err2 +{ + throw Err.TEST_ERR1; +} + +func int testThrow3() throws +{ + throw Err.TEST_ERR1; +} + +func int testThrow4(int x) throws +{ + throw x > 4 ? Err.TEST_ERR1 : Err2.TEST_ERR2; +} + +func int testThrow5(int x) throws +{ + throw x > 4 ? Err.TEST_ERR1 : Err.TEST_ERR1; +} + +func int testThrow6(int x) throws Err, Err2 +{ + if (x > 4) throw Err.TEST_ERR1; + throw Err2.TEST_ERR2; +} + func int testThrow(int x) throws Err { @@ -620,12 +655,98 @@ func int testThrow(int x) throws Err return x * x; } +func int testThrowAny(int x) throws +{ + if (x < 0) throw Err.TEST_ERR1; + return x * x; +} + +func int oekt() throws +{ + int x = try testThrow(-3); + return x; +} + +func void testErrorMulti() +{ + + printf("Test error multi\n"); + int z = try oekt(); + catch (Err e) + { + printf("Expected particular error.\n"); + } + catch (error e) + { + printf("Unexpected any error error.\n"); + } + printf("End\n"); +} + +func void throwAOrB(int i) throws Err, Err2 +{ + printf("AB\n"); + if (i == 1) throw Err.TEST_ERR1; + printf("B\n"); + if (i == 2) throw Err2.TEST_ERR2; + printf("None\n"); +} + func void testErrors() { + int x = try testThrow(20) else 1; printf("Value was %d, expected 400.\n", x); + + x = try testThrow(-1) else 20; printf("Value was %d, expected 20.\n", x); + + printf("Begin\n"); + int y = try testThrow(-1); + + printf("Value was %d, expected 9.\n", y); + + printf("Didn't expect this one.\n"); + catch (error e) + { + printf("Expected this catch.\n"); + } + y = try testThrow(-1); + + catch (Err e) + { + printf("Particular error.\n"); + } + testErrorMulti(); + + try throwAOrB(1); + catch (Err e) + { + printf("A1\n"); + } + catch (Err2 e) + { + printf("A2\n"); + } + catch (error e) + { + printf("Wut\n"); + } + try throwAOrB(2); + catch (Err e) + { + printf("B1\n"); + } + catch (Err2 e) + { + printf("B2\n"); + } + catch (error e) + { + printf("Wut\n"); + } + printf("End of errors\n"); } func void testArray() @@ -651,6 +772,7 @@ func void testArray() { printf("x[%d] = %d\n", i, x[i]); } + } func void testDefer() { diff --git a/src/compiler/ast.c b/src/compiler/ast.c index 1dfa02a81..a48bf4044 100644 --- a/src/compiler/ast.c +++ b/src/compiler/ast.c @@ -829,7 +829,10 @@ void fprint_decl_recursive(FILE *file, Decl *decl, int indent) case DECL_ATTRIBUTE: TODO case DECL_THROWS: - TODO + fprintf_indented(file, indent, "(throws"); + fprint_type_info_recursive(file, decl->throws, indent + 1); + fprint_endparen(file, indent); + break;; } } @@ -1033,9 +1036,6 @@ static void fprint_ast_recursive(FILE *file, Ast *ast, int indent) fprintf(file, "(throw\n"); fprint_expr_recursive(file, ast->throw_stmt.throw_value, indent + 1); break; - case AST_TRY_STMT: - TODO - break; case AST_VOLATILE_STMT: TODO break; diff --git a/src/compiler/casts.c b/src/compiler/casts.c index 53a8e4542..93ba24b42 100644 --- a/src/compiler/casts.c +++ b/src/compiler/casts.c @@ -313,6 +313,49 @@ bool ixxen(Expr *left, Type *canonical, Type *type, CastType cast_type) return ixxxi(left, canonical, type, cast_type); } +/** + * Convert from compile time int to error value + */ +bool ixxer(Expr *left, Type *canonical, Type *type, CastType cast_type) +{ + // Assigning zero = no value is always ok. + if (cast_type == CAST_TYPE_IMPLICIT) EXIT_T_MISMATCH(); + if (cast_type == CAST_TYPE_OPTIONAL_IMPLICIT) return true; + + if (left->expr_kind == EXPR_CONST) + { + if (bigint_cmp_zero(&left->const_expr.i) != CMP_GT) + { + SEMA_ERROR(left, "Cannot cast '%s' to an error value.", expr_const_to_error_string(&left->const_expr)); + return false; + } + BigInt comp; + bigint_init_unsigned(&comp, vec_size(canonical->decl->error.error_constants)); + if (bigint_cmp(&left->const_expr.i, &comp) == CMP_GT) + { + SEMA_ERROR(left, "Cannot cast '%s' to a valid '%s' error value.", expr_const_to_error_string(&left->const_expr), canonical->decl->name); + return false; + } + left->type = type; + return true; + } + assert(canonical->type_kind == TYPE_ERROR); + + if (!ixxxi(left, type_error_base->canonical, type_error_base, cast_type)) return false; + + insert_cast(left, CAST_XIERR, canonical); + + return true; +} + +/** + * Convert from compile time int to error union + */ +bool ixxeu(Expr *left, Type *type) +{ + UNREACHABLE +} + /** * Cast signed int -> signed int * @return true if this is a widening, an explicit cast or if it is an implicit assign add @@ -561,6 +604,12 @@ bool erxi(Expr* left, Type *from, Type *canonical, Type *type, CastType cast_typ TODO } +bool ereu(Expr *left) +{ + insert_cast(left, CAST_ERREU, type_error_union); + return true; +} + bool vava(Expr* left, Type *from, Type *canonical, Type *type, CastType cast_type) { TODO @@ -581,6 +630,57 @@ bool usui(Expr* left, Type *from, Type *canonical, Type *type, CastType cast_typ TODO } +bool euxi(Expr *left, Type *canonical, Type *type, CastType cast_type) +{ + if (cast_type == CAST_TYPE_OPTIONAL_IMPLICIT) return true; + if (cast_type == CAST_TYPE_IMPLICIT) + { + SEMA_ERROR(left, "Cannot implictly cast an error to '%s'.", type_to_error_string(type)); + return false; + } + left->type = type; + return true; +} + +bool xieu(Expr *left, Type *canonical, Type *type, CastType cast_type) +{ + TODO +} + +bool xierr(Expr *left, Type *canonical, Type *type, CastType cast_type) +{ + TODO +} + +/** + * Convert error union to error. This is always a required cast. + * @return false if an error was reported. + */ +bool euer(Expr *left, Type *canonical, Type *type, CastType cast_type) +{ + TODO + if (cast_type == CAST_TYPE_OPTIONAL_IMPLICIT) return true; + if (cast_type == CAST_TYPE_IMPLICIT) + { + SEMA_ERROR(left, "Cannot implictly cast an error union back to '%s'.", type_to_error_string(type)); + return false; + } + insert_cast(left, CAST_EUERR, canonical); + return true; +} + + +/** + * Convert error union to error. This is always a required cast. + * @return false if an error was reported. + */ +bool eubool(Expr *left, Type *canonical, Type *type, CastType cast_type) +{ + TODO + insert_cast(left, CAST_EUBOOL, canonical); + return true; +} + bool ptva(Expr* left, Type *from, Type *canonical, Type *type, CastType cast_type) { TODO @@ -697,7 +797,10 @@ bool cast(Expr *expr, Type *to_type, CastType cast_type) if (type_is_float(canonical)) return bofp(expr, canonical, to_type, cast_type); break; case TYPE_ERROR_UNION: - TODO + if (to_type->type_kind == TYPE_BOOL) return eubool(expr, canonical, to_type, cast_type); + if (type_is_integer(canonical)) return euxi(expr, canonical, to_type, cast_type); + if (to_type->type_kind == TYPE_ERROR) return euer(expr, canonical, to_type, cast_type); + break; case TYPE_IXX: // Compile time integers may convert into ints, floats, bools if (type_is_integer(canonical)) return ixxxi(expr, canonical, to_type, cast_type); @@ -705,6 +808,7 @@ bool cast(Expr *expr, Type *to_type, CastType cast_type) if (canonical == type_bool) return ixxbo(expr, to_type); if (canonical->type_kind == TYPE_POINTER) return xipt(expr, from_type, canonical, to_type, cast_type); if (canonical->type_kind == TYPE_ENUM) return ixxen(expr, canonical, to_type, cast_type); + if (canonical->type_kind == TYPE_ERROR) return ixxer(expr, canonical, to_type, cast_type); break; case TYPE_I8: case TYPE_I16: @@ -715,6 +819,8 @@ bool cast(Expr *expr, Type *to_type, CastType cast_type) if (type_is_float(canonical)) return sifp(expr, canonical, to_type); if (canonical == type_bool) return xibo(expr, canonical, to_type, cast_type); if (canonical->type_kind == TYPE_POINTER) return xipt(expr, from_type, canonical, to_type, cast_type); + if (canonical->type_kind == TYPE_ERROR_UNION) return xieu(expr, canonical, to_type, cast_type); + if (canonical->type_kind == TYPE_ERROR) return xierr(expr, canonical, to_type, cast_type); break; case TYPE_U8: case TYPE_U16: @@ -745,6 +851,7 @@ bool cast(Expr *expr, Type *to_type, CastType cast_type) break; case TYPE_ERROR: if (type_is_integer(canonical)) return erxi(expr, from_type, canonical, to_type, cast_type); + if (canonical == type_error_union) return ereu(expr); break; case TYPE_FUNC: if (type_is_integer(canonical)) return ptxi(expr, canonical, to_type, cast_type); diff --git a/src/compiler/compiler_internal.h b/src/compiler/compiler_internal.h index 1f2fcb8bd..f2c6508b9 100644 --- a/src/compiler/compiler_internal.h +++ b/src/compiler/compiler_internal.h @@ -234,6 +234,7 @@ typedef struct typedef struct { Decl **error_constants; + void *start_value; } ErrorDecl; typedef struct @@ -298,8 +299,9 @@ typedef struct typedef enum { ERROR_RETURN_NONE = 0, - ERROR_RETURN_PARAM = 1, - ERROR_RETURN_RETURN = 2, + ERROR_RETURN_ONE = 1, + ERROR_RETURN_MANY = 2, + ERROR_RETURN_ANY = 3, } ErrorReturn; typedef struct _FunctionSignature { @@ -308,7 +310,7 @@ typedef struct _FunctionSignature bool has_default : 1; bool throw_any : 1; bool return_param : 1; - ErrorReturn error_return : 3; + ErrorReturn error_return : 4; TypeInfo *rtype; Decl** params; Decl** throws; @@ -425,10 +427,27 @@ typedef struct _Decl }; } Decl; +typedef enum +{ + TRY_EXPR_ELSE_EXPR, + TRY_EXPR_ELSE_JUMP, + TRY_STMT, +} TryType; + typedef struct { - Expr *expr; - Expr *else_expr; + TryType type; + union + { + Expr *expr; + Ast *stmt; + }; + union + { + Expr *else_expr; + Ast *else_stmt; + }; + void *jump_target; } ExprTry; typedef struct @@ -474,12 +493,39 @@ typedef struct } ExprPostUnary; +typedef enum +{ + CATCH_TRY_ELSE, + CATCH_REGULAR, + CATCH_RETURN_ANY, + CATCH_RETURN_MANY, + CATCH_RETURN_ONE +} CatchKind; + +typedef struct +{ + CatchKind kind; + union + { + Expr *try_else; + Ast *catch; + Decl *error; + }; +} CatchInfo; + +typedef struct +{ + bool is_completely_handled; + DeferList defers; + CatchInfo *catches; +} ThrowInfo; typedef struct { bool is_struct_function : 1; Expr *function; Expr **arguments; + ThrowInfo *throw_info; } ExprCall; typedef struct @@ -695,6 +741,7 @@ typedef struct _AstCatchStmt { Decl *error_param; struct _Ast *body; + void *block; } AstCatchStmt; typedef struct _AstCtIfStmt @@ -751,7 +798,6 @@ typedef struct typedef struct { Expr *throw_value; - DeferList defers; } AstThrowStmt; typedef struct @@ -786,7 +832,6 @@ typedef struct _Ast Expr *expr_stmt; AstThrowStmt throw_stmt; Ast *volatile_stmt; - Ast *try_stmt; AstLabelStmt label_stmt; AstReturnStmt return_stmt; AstWhileStmt while_stmt; @@ -838,7 +883,7 @@ typedef struct _DynamicScope { ScopeFlags flags; ScopeFlags flags_created; - unsigned errors; + unsigned throws; Decl **local_decl_start; DeferList defers; ExitType exit; @@ -857,6 +902,26 @@ typedef struct Token next_tok; } Lexer; +typedef enum +{ + THROW_TYPE_CALL_ANY, + THROW_TYPE_CALL_THROW_MANY, + THROW_TYPE_CALL_THROW_ONE, +} ThrowType; + +typedef struct +{ + SourceRange span; + ThrowType kind : 4; + ThrowInfo *throw_info; + // The error type of the throw. + union + { + Type *throw; + Decl **throws; + }; +} Throw; + typedef struct _Context { BuildTarget *target; @@ -896,7 +961,8 @@ typedef struct _Context // Error handling struct { - Decl **errors; + Ast **throw; + Throw *error_calls; int try_nesting; }; Type *rtype; @@ -952,7 +1018,8 @@ extern Type *type_byte, *type_ushort, *type_uint, *type_ulong, *type_usize; extern Type *type_compint, *type_compfloat; extern Type *type_c_short, *type_c_int, *type_c_long, *type_c_longlong; extern Type *type_c_ushort, *type_c_uint, *type_c_ulong, *type_c_ulonglong; -extern Type *type_typeid, *type_error, *type_error_union; +extern Type *type_typeid, *type_error_union, *type_error_base; + extern const char *main_name; @@ -1081,6 +1148,7 @@ void diag_reset(void); void diag_error_range(SourceRange span, const char *message, ...); void diag_verror_range(SourceRange span, const char *message, va_list args); + #define EXPR_NEW_EXPR(_kind, _expr) expr_new(_kind, _expr->span) #define EXPR_NEW_TOKEN(_kind, _tok) expr_new(_kind, _tok.span) Expr *expr_new(ExprKind kind, SourceRange start); @@ -1176,6 +1244,20 @@ void *target_data_layout(); void *target_machine(); void *target_target(); +bool throw_completely_caught(Decl *throw, CatchInfo *catches); +static inline Throw throw_new_single(SourceRange range, ThrowType type, ThrowInfo *info, Type *throw) +{ + return (Throw) { .kind = type, .span = range, .throw_info = info, .throw = throw }; +} +static inline Throw throw_new_union(SourceRange range, ThrowType type, ThrowInfo *info) +{ + return (Throw) { .kind = type, .span = range, .throw_info = info, .throw = type_error_union }; +} +static inline Throw throw_new_multiple(SourceRange range, ThrowInfo *info, Decl **throws) +{ + return (Throw) { .kind = THROW_TYPE_CALL_THROW_MANY, .span = range, .throw_info = info, .throws = throws }; +} + #define TOKEN_MAX_LENGTH 0xFFFF #define TOK2VARSTR(_token) _token.span.length, _token.start bool token_is_type(TokenType type); @@ -1214,7 +1296,7 @@ static inline Type *type_reduced(Type *type) { Type *canonical = type->canonical; if (canonical->type_kind == TYPE_ENUM) return canonical->decl->enums.type_info->type->canonical; - if (canonical->type_kind == TYPE_ERROR) return type_error->canonical; + if (canonical->type_kind == TYPE_ERROR) return type_error_base; return canonical; } diff --git a/src/compiler/enums.h b/src/compiler/enums.h index d99193c23..571ba5bd6 100644 --- a/src/compiler/enums.h +++ b/src/compiler/enums.h @@ -98,7 +98,6 @@ typedef enum AST_DECL_EXPR_LIST, AST_SWITCH_STMT, AST_THROW_STMT, - AST_TRY_STMT, AST_NEXT_STMT, AST_VOLATILE_STMT, AST_WHILE_STMT, @@ -115,6 +114,10 @@ typedef enum typedef enum { CAST_ERROR, + CAST_ERREU, + CAST_EUERR, + CAST_EUBOOL, + CAST_XIERR, CAST_PTRPTR, CAST_PTRXI, CAST_VARRPTR, @@ -453,6 +456,7 @@ typedef enum TOKEN_VOLATILE, TOKEN_WHILE, TOKEN_TYPEOF, + TOKEN_ERRSET, TOKEN_CT_CASE, // $case TOKEN_CT_DEFAULT, // $default diff --git a/src/compiler/llvm_codegen.c b/src/compiler/llvm_codegen.c index 8292ec42e..db8dcdcb2 100644 --- a/src/compiler/llvm_codegen.c +++ b/src/compiler/llvm_codegen.c @@ -278,6 +278,47 @@ void gencontext_emit_struct_decl(GenContext *context, Decl *decl) } } +static inline uint32_t upper_power_of_two(uint32_t v) +{ + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} + +void gencontext_emit_error_decl(GenContext *context, Decl *decl) +{ + unsigned slots = vec_size(decl->error.error_constants) + 1; + LLVMTypeRef reserved_type = LLVMArrayType(llvm_type(type_char), slots); + char *buffer = strcat_arena(decl->external_name, "_DOMAIN"); + LLVMValueRef global_name = LLVMAddGlobal(context->module, reserved_type, buffer); + LLVMSetLinkage(global_name, LLVMInternalLinkage); + LLVMSetGlobalConstant(global_name, 1); + LLVMSetInitializer(global_name, llvm_int(type_char, 1)); + decl->error.start_value = global_name; + uint32_t min_align = upper_power_of_two(slots); + uint32_t pointer_align = type_abi_alignment(type_voidptr); + LLVMSetAlignment(global_name, pointer_align > min_align ? pointer_align : min_align); + switch (decl->visibility) + { + case VISIBLE_MODULE: + LLVMSetVisibility(global_name, LLVMProtectedVisibility); + break; + case VISIBLE_PUBLIC: + LLVMSetVisibility(global_name, LLVMDefaultVisibility); + break; + case VISIBLE_EXTERN: + case VISIBLE_LOCAL: + LLVMSetVisibility(global_name, LLVMHiddenVisibility); + break; + } +} + + static void gencontext_emit_decl(GenContext *context, Decl *decl) { switch (decl->decl_kind) @@ -303,7 +344,7 @@ static void gencontext_emit_decl(GenContext *context, Decl *decl) // TODO break; case DECL_ERROR: - // TODO + UNREACHABLE; break;; case DECL_ERROR_CONSTANT: //TODO @@ -340,6 +381,10 @@ void llvm_codegen(Context *context) { gencontext_emit_decl(&gen_context, context->types[i]); } + VECEACH(context->error_types, i) + { + gencontext_emit_error_decl(&gen_context, context->error_types[i]); + } VECEACH(context->functions, i) { Decl *decl = context->functions[i]; @@ -354,7 +399,7 @@ void llvm_codegen(Context *context) LLVMPassManagerBuilderSetOptLevel(pass_manager_builder, build_options.optimization_level); LLVMPassManagerBuilderSetSizeLevel(pass_manager_builder, build_options.size_optimization_level); LLVMPassManagerBuilderSetDisableUnrollLoops(pass_manager_builder, build_options.optimization_level == OPTIMIZATION_NONE); - LLVMPassManagerBuilderUseInlinerWithThreshold(pass_manager_builder, get_inlining_threshold()); + LLVMPassManagerBuilderUseInlinerWithThreshold(pass_manager_builder, 0); //get_inlining_threshold()); LLVMPassManagerRef pass_manager = LLVMCreatePassManager(); LLVMPassManagerRef function_pass_manager = LLVMCreateFunctionPassManagerForModule(gen_context.module); LLVMAddAnalysisPasses(target_machine(), pass_manager); diff --git a/src/compiler/llvm_codegen_expr.c b/src/compiler/llvm_codegen_expr.c index 30ead7207..e46a146a9 100644 --- a/src/compiler/llvm_codegen_expr.c +++ b/src/compiler/llvm_codegen_expr.c @@ -6,14 +6,6 @@ #include "compiler_internal.h" #include "bigint.h" -static inline LLVMValueRef gencontext_emit_const_int(GenContext *context, Type *type, uint64_t val) -{ - type = type->canonical; - assert(type_is_any_integer(type) || type->type_kind == TYPE_BOOL); - return LLVMConstInt(llvm_type(type), val, type_is_signed_integer(type)); -} - -#define llvm_int(_type, _val) gencontext_emit_const_int(context, _type, _val) static inline LLVMValueRef gencontext_emit_add_int(GenContext *context, Type *type, bool use_mod, LLVMValueRef left, LLVMValueRef right) { @@ -216,10 +208,21 @@ LLVMValueRef gencontext_emit_address(GenContext *context, Expr *expr) UNREACHABLE } +static inline LLVMValueRef gencontext_emit_error_cast(GenContext *context, LLVMValueRef value, Type *type) +{ + LLVMValueRef global = type->decl->error.start_value; + LLVMValueRef val = LLVMBuildBitCast(context->builder, global, llvm_type(type_usize), ""); + LLVMValueRef extend = LLVMBuildZExtOrBitCast(context->builder, value, llvm_type(type_usize), ""); + return LLVMBuildAdd(context->builder, val, extend, ""); +} + LLVMValueRef gencontext_emit_cast(GenContext *context, CastKind cast_kind, LLVMValueRef value, Type *type, Type *target_type) { switch (cast_kind) { + case CAST_XIERR: + // TODO Insert zero check. + return value; case CAST_ERROR: UNREACHABLE case CAST_PTRPTR: @@ -232,6 +235,12 @@ LLVMValueRef gencontext_emit_cast(GenContext *context, CastKind cast_kind, LLVMV TODO case CAST_STRPTR: TODO + case CAST_ERREU: + return gencontext_emit_error_cast(context, value, target_type); + case CAST_EUERR: + TODO + case CAST_EUBOOL: + return LLVMBuildICmp(context->builder, LLVMIntNE, value, llvm_int(type_usize, 0), "eubool"); case CAST_PTRBOOL: return LLVMBuildICmp(context->builder, LLVMIntNE, value, LLVMConstPointerNull(llvm_type(type->canonical->pointer)), "ptrbool"); case CAST_BOOLINT: @@ -779,6 +788,16 @@ static LLVMValueRef gencontext_emit_binary(GenContext *context, Expr *expr, LLVM UNREACHABLE } +LLVMBasicBlockRef gencontext_get_try_target(GenContext *context, Expr *try_expr) +{ + if (!try_expr->try_expr.jump_target) + { + try_expr->try_expr.jump_target = gencontext_create_free_block(context, "tryjump"); + } + return try_expr->try_expr.jump_target; +} + + LLVMValueRef gencontext_emit_post_unary_expr(GenContext *context, Expr *expr) { return gencontext_emit_post_inc_dec(context, expr->post_expr.expr, expr->post_expr.operator == POSTUNARYOP_INC ? 1 : -1, false); @@ -794,22 +813,20 @@ LLVMValueRef gencontext_emit_try_expr(GenContext *context, Expr *expr) { if (expr->try_expr.else_expr) { - LLVMBasicBlockRef catch_block = gencontext_create_free_block(context, "catchblock"); + LLVMBasicBlockRef else_block = gencontext_get_try_target(context, expr); LLVMBasicBlockRef after_catch = gencontext_create_free_block(context, "aftercatch"); LLVMValueRef res = gencontext_emit_alloca(context, llvm_type(expr->try_expr.else_expr->type), "catch"); - gencontext_push_catch(context, NULL, catch_block); LLVMValueRef normal_res = gencontext_emit_expr(context, expr->try_expr.expr); - gencontext_pop_catch(context); LLVMBuildStore(context->builder, normal_res, res); gencontext_emit_br(context, after_catch); - gencontext_emit_block(context, catch_block); + gencontext_emit_block(context, else_block); LLVMValueRef catch_value = gencontext_emit_expr(context, expr->try_expr.else_expr); LLVMBuildStore(context->builder, catch_value, res); gencontext_emit_br(context, after_catch); gencontext_emit_block(context, after_catch); return gencontext_emit_load(context, expr->try_expr.else_expr->type, res); } - TODO + return gencontext_emit_expr(context, expr->try_expr.expr); } static LLVMValueRef gencontext_emit_binary_expr(GenContext *context, Expr *expr) @@ -930,14 +947,34 @@ LLVMValueRef gencontext_emit_const_expr(GenContext *context, Expr *expr) 0)); return global_name; } + case TYPE_ERROR: + return llvm_int(type_error_base, expr->const_expr.error_constant->error_constant.value); default: UNREACHABLE } } - -static inline void gencontext_emit_throw_branch(GenContext *context, LLVMValueRef value, Type *error_type) +/* +static inline void gencontext_emit_throw_union_branch(GenContext *context, CatchInfo *catches, LLVMValueRef value) { LLVMBasicBlockRef after_block = gencontext_create_free_block(context, "throwafter"); + + type_error_union + value = LLVMBuildExtractValue(context->builder, value, 1, ""); + LLVMValueRef comparison = LLVMBuildICmp(context->builder, LLVMIntNE, llvm_int(error_type, 0), value, "checkerr"); + + VECEACH(catches, i) + { + LLVMBasicBlockRef ret_err_block = gencontext_create_free_block(context, "ret_err"); + gencontext_emit_cond_br(context, comparison, ret_err_block, after_block); + gencontext_emit_block(context, ret_err_block); + + if (catches->kind == CATCH_TRY_ELSE) + { + + } + } + + if (error_type == type_error_union) size_t catch_index = context->catch_stack_index; while (catch_index > 0) { @@ -946,42 +983,345 @@ static inline void gencontext_emit_throw_branch(GenContext *context, LLVMValueRe Catch *current_catch = &context->catch_stack[catch_index]; if (!current_catch->decl) { - - LLVMValueRef comparison = LLVMBuildICmp(context->builder, LLVMIntNE, llvm_int(error_type, 0), value, "checkerr"); + LLVMValueRef zero = llvm_int(type_reduced(type_error), 0); + // TODO emit defers. + if (error_type == type_error_union) + { + } + LLVMValueRef comparison = LLVMBuildICmp(context->builder, LLVMIntNE, zero, value, "checkerr"); gencontext_emit_cond_br(context, comparison, current_catch->catch_block, after_block); gencontext_emit_block(context, after_block); return; } } - TODO + // Fix defers gencontext_emit_defer(context, ast->throw_stmt.defers.start, ast->throw_stmt.defers.end); + + if (error_type == type_error_union) + { + gencontext_emit_load(context, type_error_union, context->error_out); + TODO + } + LLVMBuildRet(context->builder, value); + context->current_block = NULL; + context->current_block_is_target = false; + gencontext_emit_block(context, after_block); } +*/ + + +static inline bool gencontext_emit_throw_branch_for_single_throw_catch(GenContext *context, Decl *throw, LLVMValueRef value, Ast *catch, bool single_throw) +{ + gencontext_generate_catch_block_if_needed(context, catch); + + // Catch any is simple. + if (catch->catch_stmt.error_param->type == type_error_union) + { + // If this was a single throw, then we do a cast first. + if (single_throw) + { + value = gencontext_emit_cast(context, CAST_ERREU, value, type_error_union, throw->type); + } + // Store the value and jump to the catch. + LLVMBuildStore(context->builder, value, catch->catch_stmt.error_param->var.backend_ref); + gencontext_emit_br(context, catch->catch_stmt.block); + return true; + } + + // If we catch a single, the single throw is easy, it *must* be the same type. + if (single_throw) + { + assert(catch->catch_stmt.error_param->type->decl == throw); + // Store and jump. + LLVMBuildStore(context->builder, value, catch->catch_stmt.error_param->var.backend_ref); + gencontext_emit_br(context, catch->catch_stmt.block); + return true; + } + + // Here instead we more than one throw and we're catching a single error type + LLVMValueRef offset = LLVMBuildBitCast(context->builder, catch->catch_stmt.error_param->type->decl->error.start_value, llvm_type(type_error_union), ""); + LLVMValueRef check = LLVMBuildAnd(context->builder, offset, value, ""); + LLVMValueRef match = LLVMBuildICmp(context->builder, LLVMIntEQ, offset, check, ""); + LLVMBasicBlockRef catch_block_set = gencontext_create_free_block(context, "setval"); + LLVMBasicBlockRef continue_block = gencontext_create_free_block(context, ""); + gencontext_emit_cond_br(context, match, catch_block_set, continue_block); + value = LLVMBuildAnd(context->builder, LLVMBuildNeg(context->builder, offset, ""), value, ""); + LLVMBuildStore(context->builder, value, catch->catch_stmt.error_param->var.backend_ref); + gencontext_emit_br(context, catch->catch_stmt.block); + gencontext_emit_block(context, continue_block); + return false; +} + +static inline bool gencontext_emit_throw_branch_for_single_throw(GenContext *context, Decl *throw, LLVMValueRef value, CatchInfo *catch_infos, bool single_throw, bool union_return) +{ + TODO + /* + VECEACH(catch_infos, i) + { + CatchInfo *info = &catch_infos[i]; + switch (info->kind) + { + case CATCH_REGULAR: + if (gencontext_emit_throw_branch_for_single_throw_catch(context, throw, value, info->catch, single_throw)) + { + // Catching all should only happen on the last branch. + assert(i == vec_size(catch_infos)); + return true; + } + break;; + case CATCH_TRY_ELSE: + // Try else should only happen on the last branch. + assert(i == vec_size(catch_infos)); + gencontext_emit_br(context, info->try_else->try_expr.jump_target); + return true; + } + }*/ + + // If this is not a single throw, then we need to check first. + if (!single_throw) + { + LLVMValueRef offset = LLVMBuildBitCast(context->builder, throw->error.start_value, llvm_type(type_error_union), ""); + LLVMValueRef check = LLVMBuildAnd(context->builder, offset, value, ""); + LLVMValueRef match = LLVMBuildICmp(context->builder, LLVMIntEQ, offset, check, ""); + LLVMBasicBlockRef catch_block_set = gencontext_create_free_block(context, "setval"); + LLVMBasicBlockRef continue_block = gencontext_create_free_block(context, ""); + gencontext_emit_cond_br(context, match, catch_block_set, continue_block); + value = LLVMBuildAnd(context->builder, LLVMBuildNeg(context->builder, offset, ""), value, ""); + //LLVMBuildStore(context->builder, value, catch->catch_stmt.error_param->var.backend_ref); + //gencontext_emit_br(context, catch->catch_stmt.block); + gencontext_emit_block(context, continue_block); + + } + // Case (1) the error return is a normal return, in that case we send the unadorned type. +/* if (union_return) + { + + + assert(error_type == type_error_base); + gencontext_emit_cond_br(context, comparison, return_block, after_block); + gencontext_emit_block(context, return_block); + gencontext_emit_return_value(context, value); + gencontext_emit_block(context, after_block); + return; + } + + assert(context->cur_func_decl->func.function_signature.error_return == ERROR_RETURN_UNION); + if (error_type == type_error_base) + { + value = gencontext_emit_cast(context, CAST_EREU, value, cuf, error_type) + + gencontext_emit_cast_expr + + + return false; + */ + return false; +} + + +static inline void gencontext_emit_throw_branch(GenContext *context, LLVMValueRef value, Decl** errors, ThrowInfo *throw_info, ErrorReturn error_return) +{ + Type *call_error_type; + switch (error_return) + { + case ERROR_RETURN_NONE: + // If there is no error return, this is a no-op. + return; + case ERROR_RETURN_ONE: + call_error_type = type_error_base; + break; + case ERROR_RETURN_MANY: + case ERROR_RETURN_ANY: + call_error_type = type_error_union; + break;; + } + + LLVMBasicBlockRef after_block = gencontext_create_free_block(context, "throwafter"); + LLVMValueRef comparison = LLVMBuildICmp(context->builder, LLVMIntNE, llvm_int(call_error_type, 0), value, "checkerr"); + + unsigned catch_count = vec_size(throw_info->catches); + assert(throw_info->is_completely_handled && catch_count); + + // Special handling for a single catch. + if (catch_count == 1) + { + CatchInfo *catch = &throw_info->catches[0]; + switch (catch->kind) + { + case CATCH_RETURN_ONE: + { + LLVMBasicBlockRef else_block = gencontext_create_free_block(context, "erret_one"); + gencontext_emit_cond_br(context, comparison, else_block, after_block); + gencontext_emit_block(context, else_block); + gencontext_emit_return_value(context, value); + gencontext_emit_block(context, after_block); + return; + } + case CATCH_RETURN_MANY: + { + TODO // Check type + LLVMBasicBlockRef else_block = gencontext_create_free_block(context, "erret_many"); + gencontext_emit_cond_br(context, comparison, else_block, after_block); + gencontext_emit_block(context, else_block); + if (call_error_type == type_error_base) + { + value = gencontext_emit_cast(context, CAST_ERREU, value, type_error_union, call_error_type); + } + gencontext_emit_return_value(context, value); + gencontext_emit_block(context, after_block); + return; + } + case CATCH_TRY_ELSE: + { + LLVMBasicBlockRef else_block = gencontext_get_try_target(context, catch->try_else); + gencontext_emit_cond_br(context, comparison, else_block, after_block); + gencontext_emit_block(context, after_block); + return; + } + case CATCH_RETURN_ANY: + { + LLVMBasicBlockRef else_block = gencontext_create_free_block(context, "erret_any"); + gencontext_emit_cond_br(context, comparison, else_block, after_block); + gencontext_emit_block(context, else_block); + if (call_error_type == type_error_base) + { + value = gencontext_emit_cast(context, CAST_ERREU, value, type_error_union, errors[0]->type); + } + gencontext_emit_return_value(context, value); + gencontext_emit_block(context, after_block); + return; + } + case CATCH_REGULAR: + { + gencontext_generate_catch_block_if_needed(context, catch->catch); + LLVMBasicBlockRef else_block = gencontext_create_free_block(context, "catch_regular"); + gencontext_emit_cond_br(context, comparison, else_block, after_block); + gencontext_emit_block(context, else_block); + Decl *error_param = catch->catch->catch_stmt.error_param; + if (call_error_type == type_error_base) + { + if (error_param->type == type_error_union) + { + value = gencontext_emit_cast(context, CAST_ERREU, value, type_error_union, errors[0]->type); + } + } + LLVMBuildStore(context->builder, value, error_param->var.backend_ref); + gencontext_emit_br(context, catch->catch->catch_stmt.block); + gencontext_emit_block(context, after_block); + return; + } + } + UNREACHABLE + } + // Here we handle multiple branches. + + LLVMBasicBlockRef err_handling_block = gencontext_create_free_block(context, "errhandlingblock"); + gencontext_emit_cond_br(context, comparison, err_handling_block, after_block); + gencontext_emit_block(context, err_handling_block); + err_handling_block = NULL; + + assert(error_return != ERROR_RETURN_ONE && "Should already be handled."); + + // Below here we can assume we're handling error unions. + VECEACH(throw_info->catches, i) + { + if (err_handling_block == NULL) + { + err_handling_block = gencontext_create_free_block(context, "thrownext"); + } + CatchInfo *catch = &throw_info->catches[i]; + switch (catch->kind) + { + case CATCH_RETURN_ONE: + { + // This is always the last catch, so we can assume that we have the correct error + // type already. + LLVMValueRef offset = LLVMBuildBitCast(context->builder, catch->error->error.start_value, llvm_type(type_error_union), ""); + LLVMValueRef negated = LLVMBuildNeg(context->builder, offset, ""); + LLVMValueRef final_value = LLVMBuildAnd(context->builder, negated, value, ""); + gencontext_emit_return_value(context, final_value); + gencontext_emit_block(context, after_block); + assert(i == vec_size(throw_info->catches) - 1); + return; + } + case CATCH_RETURN_MANY: + case CATCH_RETURN_ANY: + { + // This is simple, just return our value. + gencontext_emit_return_value(context, value); + gencontext_emit_block(context, after_block); + assert(i == vec_size(throw_info->catches) - 1); + return; + } + case CATCH_TRY_ELSE: + { + // This should be the last catch. + LLVMBasicBlockRef else_block = gencontext_get_try_target(context, catch->try_else); + gencontext_emit_br(context, else_block); + gencontext_emit_block(context, after_block); + assert(i == vec_size(throw_info->catches) - 1); + return; + } + case CATCH_REGULAR: + { + Decl *param = catch->catch->catch_stmt.error_param; + gencontext_generate_catch_block_if_needed(context, catch->catch); + Decl *error_param = catch->catch->catch_stmt.error_param; + + // The wildcard catch is always the last one. + if (param->type == type_error_union) + { + // Store the value, then jump + LLVMBuildStore(context->builder, value, error_param->var.backend_ref); + gencontext_emit_br(context, catch->catch->catch_stmt.block); + gencontext_emit_block(context, after_block); + assert(i == vec_size(throw_info->catches) - 1); + return; + } + + // Here we have a normal catch. + + // Find the offset. + LLVMValueRef offset = LLVMBuildBitCast(context->builder, param->type->decl->error.start_value, llvm_type(type_error_union), ""); + + // wrapping(value - offset) < entries + 1 – this handles both cases since wrapping will make + // values below the offset big. + LLVMValueRef comp_value = LLVMBuildSub(context->builder, value, offset, ""); + LLVMValueRef entries_value = llvm_int(type_error_union, vec_size(param->type->decl->error.error_constants) + 1); + LLVMValueRef match = LLVMBuildICmp(context->builder, LLVMIntULT, comp_value, entries_value, "matcherr"); + + LLVMBasicBlockRef match_block = gencontext_create_free_block(context, "match"); + gencontext_emit_cond_br(context, match, match_block, err_handling_block); + gencontext_emit_block(context, match_block); + + LLVMBuildStore(context->builder, comp_value, error_param->var.backend_ref); + gencontext_emit_br(context, catch->catch->catch_stmt.block); + gencontext_emit_block(context, err_handling_block); + err_handling_block = NULL; + break; + } + } + } + gencontext_emit_br(context, after_block); + gencontext_emit_block(context, after_block); +} + LLVMValueRef gencontext_emit_call_expr(GenContext *context, Expr *expr) { size_t args = vec_size(expr->call_expr.arguments); Decl *function_decl = expr->call_expr.function->identifier_expr.decl; FunctionSignature *signature = &function_decl->func.function_signature; LLVMValueRef return_param = NULL; - LLVMValueRef error_param = NULL; if (signature->return_param) { return_param = gencontext_emit_alloca(context, llvm_type(signature->rtype->type), "returnparam"); args++; } - if (signature->error_return == ERROR_RETURN_PARAM) - { - error_param = gencontext_emit_alloca(context, llvm_type(type_error_union), "errorparam"); - args++; - } LLVMValueRef *values = args ? malloc_arena(args * sizeof(LLVMValueRef)) : NULL; unsigned param_index = 0; if (return_param) { values[param_index++] = return_param; } - if (error_param) - { - values[param_index++] = error_param; - } VECEACH(expr->call_expr.arguments, i) { values[param_index++] = gencontext_emit_expr(context, expr->call_expr.arguments[i]); @@ -992,19 +1332,9 @@ LLVMValueRef gencontext_emit_call_expr(GenContext *context, Expr *expr) LLVMValueRef func = function->func.backend_value; LLVMTypeRef func_type = llvm_type(function->type); LLVMValueRef call = LLVMBuildCall2(context->builder, func_type, func, values, args, "call"); - if (signature->error_return) - { - if (error_param) - { - LLVMValueRef maybe_error = gencontext_emit_load(context, type_error_union, error_param); - TODO // Incorrect, must get subset if this is 128 bits - gencontext_emit_throw_branch(context, maybe_error, type_error_union); - } - else - { - gencontext_emit_throw_branch(context, call, type_reduced(type_error)); - } - } + + gencontext_emit_throw_branch(context, call, function->func.function_signature.throws, expr->call_expr.throw_info, signature->error_return); + // If we used a return param, then load that info here. if (return_param) { diff --git a/src/compiler/llvm_codegen_function.c b/src/compiler/llvm_codegen_function.c index 99c92383c..7ea406e66 100644 --- a/src/compiler/llvm_codegen_function.c +++ b/src/compiler/llvm_codegen_function.c @@ -84,14 +84,20 @@ static inline void gencontext_emit_parameter(GenContext *context, Decl *decl, un void gencontext_emit_implicit_return(GenContext *context) { - if (context->cur_func_decl->func.function_signature.error_return == ERROR_RETURN_RETURN) + switch (context->cur_func_decl->func.function_signature.error_return) { - LLVMBuildRet(context->builder, LLVMConstInt(llvm_type(type_ulong), 0, false)); - } - else - { - LLVMBuildRetVoid(context->builder); + case ERROR_RETURN_NONE: + LLVMBuildRetVoid(context->builder); + return; + case ERROR_RETURN_ANY: + case ERROR_RETURN_MANY: + LLVMBuildRet(context->builder, llvm_int(type_usize, 0)); + return; + case ERROR_RETURN_ONE: + LLVMBuildRet(context->builder, llvm_int(type_error_base, 0)); + return; } + UNREACHABLE } void gencontext_emit_function_body(GenContext *context, Decl *decl) @@ -117,6 +123,7 @@ void gencontext_emit_function_body(GenContext *context, Decl *decl) FunctionSignature *signature = &decl->func.function_signature; int arg = 0; + if (signature->return_param) { context->return_out = LLVMGetParam(context->function, arg++); @@ -126,12 +133,6 @@ void gencontext_emit_function_body(GenContext *context, Decl *decl) context->return_out = NULL; } - if (signature->error_return == ERROR_RETURN_PARAM) - { - context->error_out = gencontext_emit_alloca(context, llvm_type(type_error_union), "errorval"); - LLVMBuildStore(context->builder, LLVMGetParam(context->function, arg++), context->error_out); - } - // Generate LLVMValueRef's for all parameters, so we can use them as local vars in code VECEACH(decl->func.function_signature.params, i) { diff --git a/src/compiler/llvm_codegen_internal.h b/src/compiler/llvm_codegen_internal.h index 988afa20e..d4a960988 100644 --- a/src/compiler/llvm_codegen_internal.h +++ b/src/compiler/llvm_codegen_internal.h @@ -45,7 +45,6 @@ typedef struct } DebugContext; #define BREAK_STACK_MAX 256 -#define CATCH_STACK_MAX 256 typedef struct { @@ -67,12 +66,10 @@ typedef struct Ast **defer_stack; DebugContext debug; Context *ast_context; - Catch catch_stack[CATCH_STACK_MAX]; - size_t catch_stack_index; BreakContinue break_continue_stack[BREAK_STACK_MAX]; size_t break_continue_stack_index; LLVMValueRef return_out; - LLVMValueRef error_out; + LLVMBasicBlockRef error_exit_block; LLVMBasicBlockRef expr_block_exit; bool current_block_is_target : 1; bool did_call_stack_save : 1; @@ -111,12 +108,13 @@ void gencontext_end_module(GenContext *context); void gencontext_add_attribute(GenContext context, unsigned attribute_id, LLVMValueRef value_to_add_attribute_to); void gencontext_emit_stmt(GenContext *context, Ast *ast); -void gencontext_push_catch(GenContext *context, Decl *error_type, LLVMBasicBlockRef catch_block); -void gencontext_pop_catch(GenContext *context); + +void gencontext_generate_catch_block_if_needed(GenContext *context, Ast *ast); LLVMValueRef gencontext_emit_call_intrinsic(GenContext *context, unsigned intrinsic_id, LLVMTypeRef *types, LLVMValueRef *values, unsigned arg_count); void gencontext_emit_panic_on_true(GenContext *context, LLVMValueRef value, const char *panic_name); void gencontext_emit_defer(GenContext *context, Ast *defer_start, Ast *defer_end); +LLVMBasicBlockRef gencontext_get_try_target(GenContext *context, Expr *try_expr); LLVMValueRef gencontext_emit_expr(GenContext *context, Expr *expr); LLVMValueRef gencontext_emit_assign_expr(GenContext *context, LLVMValueRef ref, Expr *expr); LLVMMetadataRef gencontext_get_debug_type(GenContext *context, Type *type); @@ -134,7 +132,6 @@ static inline LLVMBasicBlockRef gencontext_create_free_block(GenContext *context { return LLVMCreateBasicBlockInContext(context->context, name); } - void gencontext_emit_function_body(GenContext *context, Decl *decl); void gencontext_emit_implicit_return(GenContext *context); void gencontext_emit_function_decl(GenContext *context, Decl *decl); @@ -146,7 +143,12 @@ static inline LLVMValueRef gencontext_emit_load(GenContext *context, Type *type, assert(gencontext_get_llvm_type(context, type) == LLVMGetElementType(LLVMTypeOf(value))); return LLVMBuildLoad2(context->builder, gencontext_get_llvm_type(context, type), value, ""); } - +static inline void gencontext_emit_return_value(GenContext *context, LLVMValueRef value) +{ + LLVMBuildRet(context->builder, value); + context->current_block = NULL; + context->current_block_is_target = false; +} LLVMValueRef gencontext_emit_cast(GenContext *context, CastKind cast_kind, LLVMValueRef value, Type *type, Type *target_type); static inline bool gencontext_func_pass_return_by_param(GenContext *context, Type *first_param_type) { return false; }; static inline bool gencontext_func_pass_param_by_reference(GenContext *context, Type *param_type) { return false; } @@ -219,3 +221,13 @@ static inline LLVMCallConv llvm_call_convention_from_call(CallABI abi) #define llvm_type(type) gencontext_get_llvm_type(context, type) #define DEBUG_TYPE(type) gencontext_get_debug_type(context, type) + +static inline LLVMValueRef gencontext_emit_const_int(GenContext *context, Type *type, uint64_t val) +{ + type = type->canonical; + if (type == type_error_union) type = type_usize->canonical; + assert(type_is_any_integer(type) || type->type_kind == TYPE_BOOL); + return LLVMConstInt(llvm_type(type), val, type_is_signed_integer(type)); +} + +#define llvm_int(_type, _val) gencontext_emit_const_int(context, _type, _val) diff --git a/src/compiler/llvm_codegen_stmt.c b/src/compiler/llvm_codegen_stmt.c index 8d0579c4a..6f6a54d4e 100644 --- a/src/compiler/llvm_codegen_stmt.c +++ b/src/compiler/llvm_codegen_stmt.c @@ -146,13 +146,22 @@ static inline void gencontext_emit_throw(GenContext *context, Ast *ast) // Ensure we are on a branch that is non empty. if (!gencontext_check_block_branch_emit(context)) return; - gencontext_emit_defer(context, ast->throw_stmt.defers.start, ast->throw_stmt.defers.end); - // TODO handle throw if simply a jump - LLVMBuildRet(context->builder, LLVMConstInt(llvm_type(type_ulong), 10 + ast->throw_stmt.throw_value->identifier_expr.decl->error_constant.value, false)); + // TODO defer +// gencontext_emit_defer(context, ast->throw_stmt.defers.start, ast->throw_stmt.defers.end); - context->current_block = NULL; - LLVMBasicBlockRef post_ret_block = gencontext_create_free_block(context, "ret"); - gencontext_emit_block(context, post_ret_block); + LLVMValueRef error_val = gencontext_emit_expr(context, ast->throw_stmt.throw_value); + + // In the case that the throw actually contains a single error, but the function is throwing an error union, + // we must insert a conversion. + if (context->cur_func_decl->func.function_signature.error_return != ERROR_RETURN_ONE && + ast->throw_stmt.throw_value->type->type_kind == TYPE_ERROR) + { + error_val = gencontext_emit_cast(context, CAST_ERREU, error_val, type_error_union, ast->throw_stmt.throw_value->type); + } + + gencontext_emit_return_value(context, error_val); + LLVMBasicBlockRef post_throw_block = gencontext_create_free_block(context, "throw"); + gencontext_emit_block(context, post_throw_block); } @@ -196,23 +205,6 @@ void gencontext_emit_if(GenContext *context, Ast *ast) } -void gencontext_push_catch(GenContext *context, Decl *error_type, LLVMBasicBlockRef catch_block) -{ - size_t index = context->catch_stack_index++; - if (index == CATCH_STACK_MAX - 1) - { - error_exit("Exhausted catch stack - exceeded %d entries.", CATCH_STACK_MAX); - } - context->catch_stack[index].decl = error_type; - context->catch_stack[index].catch_block = catch_block; -} - -void gencontext_pop_catch(GenContext *context) -{ - assert(context->catch_stack_index > 0); - context->catch_stack_index--; -} - static void gencontext_push_break_continue(GenContext *context, LLVMBasicBlockRef break_block, LLVMBasicBlockRef continue_block, LLVMBasicBlockRef next_block) { @@ -596,6 +588,35 @@ void gencontext_emit_scoped_stmt(GenContext *context, Ast *ast) gencontext_emit_defer(context, ast->scoped_stmt.defers.start, ast->scoped_stmt.defers.end); } +void gencontext_generate_catch_block_if_needed(GenContext *context, Ast *ast) +{ + LLVMBasicBlockRef block = ast->catch_stmt.block; + if (block) return; + block = gencontext_create_free_block(context, "catchblock"); + ast->catch_stmt.block = block; + LLVMTypeRef type; + if (ast->catch_stmt.error_param->type == type_error_union) + { + type = llvm_type(type_error_union); + } + else + { + type = llvm_type(type_error_base); + } + ast->catch_stmt.error_param->var.backend_ref = gencontext_emit_alloca(context, type, ""); +} + +void gencontext_emit_catch_stmt(GenContext *context, Ast *ast) +{ + gencontext_generate_catch_block_if_needed(context, ast); + LLVMBasicBlockRef after_catch = gencontext_create_free_block(context, "after_catch"); + gencontext_emit_br(context, after_catch); + gencontext_emit_block(context, ast->catch_stmt.block); + gencontext_emit_stmt(context, ast->catch_stmt.body); + gencontext_emit_br(context, after_catch); + gencontext_emit_block(context, after_catch); +} + void gencontext_emit_panic_on_true(GenContext *context, LLVMValueRef value, const char *panic_name) { LLVMBasicBlockRef panic_block = gencontext_create_free_block(context, "panic"); @@ -660,9 +681,8 @@ void gencontext_emit_stmt(GenContext *context, Ast *ast) case AST_NOP_STMT: break; case AST_CATCH_STMT: - case AST_TRY_STMT: - // Should have been lowered. - UNREACHABLE + gencontext_emit_catch_stmt(context, ast); + break; case AST_THROW_STMT: gencontext_emit_throw(context, ast); break; diff --git a/src/compiler/llvm_codegen_type.c b/src/compiler/llvm_codegen_type.c index f7266d02a..4d8fd94fd 100644 --- a/src/compiler/llvm_codegen_type.c +++ b/src/compiler/llvm_codegen_type.c @@ -125,7 +125,6 @@ LLVMTypeRef llvm_func_type(LLVMContextRef context, Type *type) FunctionSignature *signature = type->func.signature; unsigned parameters = vec_size(signature->params); if (signature->return_param) parameters++; - if (signature->error_return == ERROR_RETURN_PARAM) parameters++; if (parameters) { params = malloc_arena(sizeof(LLVMTypeRef) * parameters); @@ -134,23 +133,26 @@ LLVMTypeRef llvm_func_type(LLVMContextRef context, Type *type) { params[index++] = llvm_get_type(context, type_get_ptr(signature->rtype->type)); } - if (signature->error_return == ERROR_RETURN_PARAM) - { - params[index++] = llvm_get_type(context, type_get_ptr(type_error_union)); - } VECEACH(signature->params, i) { params[index++] = llvm_get_type(context, signature->params[i]->type->canonical); } } LLVMTypeRef ret_type; - if (signature->error_return == ERROR_RETURN_RETURN) + switch (signature->error_return) { - ret_type = llvm_get_type(context, type_ulong); - } - else - { - ret_type = signature->return_param ? llvm_get_type(context, type_void) : llvm_get_type(context, type->func.signature->rtype->type); + case ERROR_RETURN_ANY: + case ERROR_RETURN_MANY: + ret_type = llvm_get_type(context, type_error_union); + break; + case ERROR_RETURN_ONE: + ret_type = llvm_get_type(context, type_error_base); + break;; + case ERROR_RETURN_NONE: + ret_type = signature->return_param ? llvm_get_type(context, type_void) : llvm_get_type(context, type->func.signature->rtype->type); + break; + default: + UNREACHABLE } LLVMTypeRef functype = LLVMFunctionType(ret_type, params, parameters, signature->variadic); return functype; @@ -167,21 +169,16 @@ LLVMTypeRef llvm_get_type(LLVMContextRef context, Type *type) switch (type->type_kind) { case TYPE_POISONED: - case TYPE_ERROR: - UNREACHABLE; case TYPE_META_TYPE: return type->backend_type = LLVMIntTypeInContext(context, type->builtin.bitsize); + case TYPE_ERROR: + return type->backend_type = llvm_get_type(context, type_error_base); case TYPE_TYPEDEF: return type->backend_type = llvm_get_type(context, type->canonical); case TYPE_ENUM: return type->backend_type = llvm_get_type(context, type->decl->enums.type_info->type->canonical); case TYPE_ERROR_UNION: - { - LLVMTypeRef types[2]; - types[0] = llvm_get_type(context, type_typeid->canonical); - types[1] = llvm_get_type(context, type_error->canonical); - return type->backend_type = LLVMStructType(types, 2, false); - } + return type->backend_type = LLVMIntTypeInContext(context, type->builtin.bitsize); case TYPE_STRUCT: case TYPE_UNION: return type->backend_type = llvm_type_from_decl(context, type->decl); diff --git a/src/compiler/parse_expr.c b/src/compiler/parse_expr.c index 8b1de94a1..d13f17cc6 100644 --- a/src/compiler/parse_expr.c +++ b/src/compiler/parse_expr.c @@ -219,6 +219,7 @@ static Expr *parse_ternary_expr(Context *context, Expr *left_side) Expr *false_expr = TRY_EXPR_OR(parse_precedence(context, PREC_TERNARY + 1), poisoned_expr); expr_ternary->ternary_expr.else_expr = false_expr; + RANGE_EXTEND_PREV(expr_ternary); return expr_ternary; } @@ -399,8 +400,20 @@ static Expr *parse_try_expr(Context *context, Expr *left) Expr *try_expr = EXPR_NEW_TOKEN(EXPR_TRY, context->tok); advance_and_verify(context, TOKEN_TRY); try_expr->try_expr.expr = TRY_EXPR_OR(parse_precedence(context, PREC_TRY + 1), poisoned_expr); + try_expr->try_expr.type = TRY_EXPR_ELSE_EXPR; if (try_consume(context, TOKEN_ELSE)) { + switch (context->tok.type) + { + case TOKEN_RETURN: + case TOKEN_BREAK: + case TOKEN_CONTINUE: + case TOKEN_THROW: + try_expr->try_expr.type = TRY_EXPR_ELSE_JUMP; + TODO + default: + break; + } try_expr->try_expr.else_expr = TRY_EXPR_OR(parse_precedence(context, PREC_ASSIGNMENT), poisoned_expr); } return try_expr; diff --git a/src/compiler/parse_stmt.c b/src/compiler/parse_stmt.c index 411d85de8..adf2d41aa 100644 --- a/src/compiler/parse_stmt.c +++ b/src/compiler/parse_stmt.c @@ -89,12 +89,18 @@ static inline Ast* parse_catch_stmt(Context *context) CONSUME_OR(TOKEN_LPAREN, poisoned_ast); TypeInfo *type = NULL; - if (!try_consume(context, TOKEN_ERROR_TYPE)) + if (context->tok.type == TOKEN_ERROR_TYPE) + { + type = type_info_new_base(type_error_union, context->tok.span); + advance(context); + } + else { type = TRY_TYPE_OR(parse_type(context), poisoned_ast); } EXPECT_IDENT_FOR_OR("error parameter", poisoned_ast); - Decl *decl = decl_new_var(context->tok, type, VARDECL_PARAM, VISIBLE_LOCAL); + Decl *decl = decl_new_var(context->tok, type, VARDECL_LOCAL, VISIBLE_LOCAL); + advance(context); catch_stmt->catch_stmt.error_param = decl; CONSUME_OR(TOKEN_RPAREN, poisoned_ast); @@ -635,6 +641,7 @@ Ast *parse_stmt(Context *context) case TOKEN_TYPEID: case TOKEN_CT_TYPE_IDENT: case TOKEN_TYPE_IDENT: + case TOKEN_ERROR_TYPE: if (context->next_tok.type == TOKEN_DOT || context->next_tok.type == TOKEN_LBRACE) { return parse_expr_stmt(context); @@ -680,12 +687,15 @@ Ast *parse_stmt(Context *context) case TOKEN_TRY: if (is_valid_try_statement(context->next_tok.type)) { - Token token = context->tok; + Expr *try_expr = EXPR_NEW_TOKEN(EXPR_TRY, context->tok); advance(context); Ast *stmt = TRY_AST(parse_stmt(context)); - Ast *try_ast = AST_NEW_TOKEN(AST_TRY_STMT, token); - try_ast->try_stmt = stmt; - return try_ast; + try_expr->try_expr.type = TRY_STMT; + try_expr->try_expr.stmt = stmt; + RANGE_EXTEND_PREV(try_expr); + Ast *ast = AST_NEW(AST_EXPR_STMT, try_expr->span); + ast->expr_stmt = try_expr; + return ast; } return parse_expr_stmt(context); case TOKEN_CONTINUE: @@ -784,7 +794,6 @@ Ast *parse_stmt(Context *context) case TOKEN_AS: case TOKEN_ELSE: case TOKEN_ENUM: - case TOKEN_ERROR_TYPE: case TOKEN_FUNC: case TOKEN_GENERIC: case TOKEN_IMPORT: @@ -794,6 +803,7 @@ Ast *parse_stmt(Context *context) case TOKEN_EXTERN: case TOKEN_STRUCT: case TOKEN_THROWS: + case TOKEN_ERRSET: case TOKEN_TYPEDEF: case TOKEN_UNION: case TOKEN_UNTIL: diff --git a/src/compiler/parser.c b/src/compiler/parser.c index f4f1b48d5..0d235ad54 100644 --- a/src/compiler/parser.c +++ b/src/compiler/parser.c @@ -187,11 +187,10 @@ static void recover_top_level(Context *context) case TOKEN_FUNC: case TOKEN_CONST: case TOKEN_TYPEDEF: - case TOKEN_ERROR_TYPE: case TOKEN_STRUCT: case TOKEN_IMPORT: case TOKEN_UNION: - case TOKEN_ENUM: + case TOKEN_ERRSET: case TOKEN_MACRO: case TOKEN_EXTERN: return; @@ -333,6 +332,9 @@ static inline TypeInfo *parse_base_type(Context *context) type_info = type_info_new(TYPE_INFO_IDENTIFIER, context->tok.span); type_info->unresolved.name_loc = context->tok; break; + case TOKEN_ERROR_TYPE: + type_found = type_error_union; + break; case TOKEN_VOID: type_found = type_void; break; @@ -719,6 +721,7 @@ bool parse_type_or_expr(Context *context, Expr **expr_ptr, TypeInfo **type_ptr) case TOKEN_C_ULONGLONG: case TOKEN_TYPE_IDENT: case TOKEN_CT_TYPE_IDENT: + case TOKEN_ERROR_TYPE: *type_ptr = parse_type(context); return parse_type_or_expr_after_type(context, expr_ptr, type_ptr); return true; @@ -1177,7 +1180,7 @@ static AttributeDomains TOKEN_TO_ATTR[TOKEN_EOF + 1] = { [TOKEN_UNION] = ATTR_UNION, [TOKEN_CONST] = ATTR_CONST, [TOKEN_TYPEDEF] = ATTR_TYPEDEF, - [TOKEN_ERROR_TYPE] = ATTR_ERROR, + [TOKEN_ERRSET] = ATTR_ERROR, }; /** @@ -1333,7 +1336,7 @@ static inline Decl *parse_macro_declaration(Context *context, Visibility visibil */ static inline Decl *parse_error_declaration(Context *context, Visibility visibility) { - advance_and_verify(context, TOKEN_ERROR_TYPE); + advance_and_verify(context, TOKEN_ERRSET); Decl *error_decl = decl_new_with_type(context->tok, DECL_ERROR, visibility); @@ -1680,7 +1683,7 @@ static inline Decl *parse_top_level(Context *context) return parse_macro_declaration(context, visibility); case TOKEN_ENUM: return parse_enum_declaration(context, visibility); - case TOKEN_ERROR_TYPE: + case TOKEN_ERRSET: return parse_error_declaration(context, visibility); case TOKEN_TYPEDEF: return parse_typedef_declaration(context, visibility); diff --git a/src/compiler/sema_decls.c b/src/compiler/sema_decls.c index 8dfa8b01a..33f4e6576 100644 --- a/src/compiler/sema_decls.c +++ b/src/compiler/sema_decls.c @@ -31,6 +31,7 @@ static inline bool sema_analyse_error(Context *context __unused, Decl *decl) break; } } + constant->type = decl->type; constant->error_constant.value = i + 1; constant->resolve_status = RESOLVE_DONE; } @@ -274,25 +275,23 @@ static inline Type *sema_analyse_function_signature(Context *context, FunctionSi } unsigned error_types = vec_size(signature->throws); - if (signature->throw_any || error_types > 1) + ErrorReturn error_return = ERROR_RETURN_NONE; + if (signature->throw_any) { - signature->error_return = ERROR_RETURN_PARAM; + error_return = ERROR_RETURN_ANY; } - else if (error_types == 1) + else if (error_types) { - signature->error_return = ERROR_RETURN_RETURN; - } - else - { - signature->error_return = ERROR_RETURN_NONE; + error_return = error_types > 1 ? ERROR_RETURN_MANY : ERROR_RETURN_ONE; } + signature->error_return = error_return; Type *return_type = signature->rtype->type->canonical; signature->return_param = false; if (return_type->type_kind != TYPE_VOID) { // TODO fix this number with ABI compatibility - if (signature->error_return == ERROR_RETURN_RETURN || type_size(return_type) > 8 * 2) + if (signature->error_return != ERROR_RETURN_NONE || type_size(return_type) > 8 * 2) { signature->return_param = true; } diff --git a/src/compiler/sema_expr.c b/src/compiler/sema_expr.c index 06473b9dc..b04e77723 100644 --- a/src/compiler/sema_expr.c +++ b/src/compiler/sema_expr.c @@ -157,7 +157,8 @@ static inline bool sema_expr_analyse_error_constant(Expr *expr, const char *name assert(error_constant->resolve_status == RESOLVE_DONE); expr->type = decl->type; expr->expr_kind = EXPR_CONST; - expr_const_set_int(&expr->const_expr, error_constant->error_constant.value, type_error->canonical->type_kind); + expr->const_expr.error_constant = error_constant; + expr->const_expr.kind = TYPE_ERROR; return true; } } @@ -252,11 +253,14 @@ static inline int find_index_of_named_parameter(Decl** func_params, Expr *expr) return -1; } + static inline bool sema_expr_analyse_func_call(Context *context, Type *to, Expr *expr, Decl *decl) { Expr **args = expr->call_expr.arguments; FunctionSignature *signature = &decl->func.function_signature; Decl **func_params = signature->params; + expr->call_expr.throw_info = CALLOCS(ThrowInfo); + expr->call_expr.throw_info->defers = context->current_scope->defers; unsigned error_params = signature->throw_any || signature->throws; if (error_params) { @@ -265,6 +269,22 @@ static inline bool sema_expr_analyse_func_call(Context *context, Type *to, Expr SEMA_ERROR(expr, "Function '%s' throws errors, this call must be prefixed 'try'.", decl->name); return false; } + // Add errors to current context. + if (signature->throw_any) + { + vec_add(context->error_calls, throw_new_union(expr->span, THROW_TYPE_CALL_ANY, expr->call_expr.throw_info)); + } + else + { + if (vec_size(signature->throws) == 1) + { + vec_add(context->error_calls, throw_new_single(expr->span, THROW_TYPE_CALL_THROW_ONE, expr->call_expr.throw_info, signature->throws[0]->type)); + } + else + { + vec_add(context->error_calls, throw_new_multiple(expr->span, expr->call_expr.throw_info, signature->throws)); + } + } } unsigned func_param_count = vec_size(func_params); unsigned num_args = vec_size(args); @@ -2210,15 +2230,63 @@ static inline bool sema_expr_analyse_post_unary(Context *context, Type *to, Expr static inline bool sema_expr_analyse_try(Context *context, Type *to, Expr *expr) { context->try_nesting++; + // Duplicates code in try for statements.. :( + unsigned prev_throws = vec_size(context->error_calls); bool success = sema_analyse_expr(context, to, expr->try_expr.expr); context->try_nesting--; if (!success) return false; + unsigned new_throws = vec_size(context->error_calls); + if (new_throws == prev_throws) + { + if (expr->try_expr.type == TRY_STMT) + { + SEMA_ERROR(expr->try_expr.stmt, "No error to 'try' in the statement that follows, please remove the 'try'."); + } + else + { + SEMA_ERROR(expr->try_expr.expr, "No error to 'try' in the expression, please remove the 'try'."); + } + return false; + } expr->type = expr->try_expr.expr->type; + bool found = false; + for (unsigned i = prev_throws; i < new_throws; i++) + { + // At least one uncaught error found! + if (!context->error_calls[i].throw_info->is_completely_handled) + { + found = true; + break; + } + } if (expr->try_expr.else_expr) { + CatchInfo info = { .kind = CATCH_TRY_ELSE, .try_else = expr }; + // Absorb all errors. + for (unsigned i = prev_throws; i < new_throws; i++) + { + Throw *throw = &context->error_calls[i]; + // Skip handled errors + if (throw[i].throw_info->is_completely_handled) continue; + throw->throw_info->is_completely_handled = true; + vec_add(throw->throw_info->catches, info); + } + // Resize to remove the throws from consideration. + vec_resize(context->error_calls, prev_throws); if (!sema_analyse_expr(context, to, expr->try_expr.else_expr)) return false; } - // TODO Check errors! + if (!found) + { + if (expr->try_expr.type == TRY_STMT) + { + SEMA_ERROR(expr->try_expr.stmt, "All errors in the following statement was caught, please remove the 'try'."); + } + else + { + SEMA_ERROR(expr->try_expr.expr, "All errors in the expression was caught, please remove the 'try'."); + } + return false; + } return true; } @@ -2514,9 +2582,6 @@ static Ast *ast_copy_from_macro(Context *context, Expr *macro, Ast *source) case AST_THROW_STMT: MACRO_COPY_EXPR(ast->throw_stmt.throw_value); return ast; - case AST_TRY_STMT: - MACRO_COPY_AST(ast->try_stmt); - return ast; case AST_NEXT_STMT: TODO return ast; diff --git a/src/compiler/sema_stmts.c b/src/compiler/sema_stmts.c index 536922b32..1c4f61afe 100644 --- a/src/compiler/sema_stmts.c +++ b/src/compiler/sema_stmts.c @@ -440,9 +440,119 @@ static inline bool sema_analyse_label(Context *context, Ast *statement) } -static bool sema_analyse_catch_stmt(Context *context __unused, Ast *statement __unused) +static inline bool throw_completely_handled_call_throw_many(Throw *throw) { - TODO + assert(throw->kind == THROW_TYPE_CALL_THROW_MANY && "Only for throw many"); + assert(!throw->throw_info->is_completely_handled && "Expected unhandled"); + Decl **throws = throw->throws; + CatchInfo *catched = throw->throw_info->catches; + unsigned catches = 0; + unsigned throw_count = vec_size(throws); + for (unsigned i = 0; i < throw_count; i++) + { + Decl *throw_decl = throws[i]; + if (throw_completely_caught(throw_decl, catched)) + { + catches++; + } + } + return catches == throw_count; +} + +/** + * Handle the catch statement + * @return true if error checking succeeds. + */ +static bool sema_analyse_catch_stmt(Context *context, Ast *statement) +{ + unsigned throws = vec_size(context->error_calls); + + // 1. If no errors are found we don't have a try to match. + if (throws <= context->current_scope->throws) + { + SEMA_ERROR(statement, "Unexpected 'catch' without a matching 'try'."); + return false; + } + + // 2. Let's check that we haven't caught all errors. + bool found = false; + for (unsigned i = context->current_scope->throws; i < throws; i++) + { + if (!context->error_calls[i].throw_info->is_completely_handled) + { + found = true; + break;; + } + } + // IMPROVE: Suppress for macro? + if (!found) + { + SEMA_ERROR(statement, "All errors are already caught, so this catch will not handle any errors."); + return false; + } + + // 3. Resolve variable + Decl *variable = statement->catch_stmt.error_param; + assert(variable->var.kind == VARDECL_LOCAL); + if (!sema_resolve_type_info(context, variable->var.type_info)) return false; + variable->type = variable->var.type_info->type; + Type *error_type = variable->type->canonical; + + CatchInfo catch = { .kind = CATCH_REGULAR, .catch = statement }; + // 4. Absorb all errors in case of a type error union. + if (error_type == type_error_union) + { + for (unsigned i = context->current_scope->throws; i < throws; i++) + { + Throw *throw = &context->error_calls[i]; + // Skip handled errors + if (throw->throw_info->is_completely_handled) continue; + vec_add(throw->throw_info->catches, catch); + throw->throw_info->is_completely_handled = true; + } + // Resize to remove the throws from consideration. + vec_resize(context->error_calls, context->current_scope->throws); + } + else + { + // 5. Otherwise, go through the list of errors and null the errors matching the current type. + for (unsigned i = context->current_scope->throws; i < throws; i++) + { + Throw *throw = &context->error_calls[i]; + // Skip handled errors + if (throw->throw_info->is_completely_handled) continue; + + switch (throw->kind) + { + case THROW_TYPE_CALL_ANY: + vec_add(throw->throw_info->catches, catch); + // An error union can never be completely handled. + break; + case THROW_TYPE_CALL_THROW_ONE: + // If there is no match, ignore. + if (throw->throw != error_type) continue; + // Otherwise add and set to completely handled. + vec_add(throw->throw_info->catches, catch); + throw->throw_info->is_completely_handled = true; + break; + case THROW_TYPE_CALL_THROW_MANY: + // The most complex situation, add and handle below. + vec_add(throw->throw_info->catches, catch); + throw->throw_info->is_completely_handled = throw_completely_handled_call_throw_many(throw); + break; + } + } + } + context_push_scope(context); + // Push the error variable + if (!sema_add_local(context, variable)) goto ERR_END_SCOPE; + if (!sema_analyse_statement(context, statement->catch_stmt.body)) goto ERR_END_SCOPE; + context_pop_scope(context); + return true; + + ERR_END_SCOPE: + context_pop_scope(context); + return false; } static bool sema_analyse_asm_stmt(Context *context __unused, Ast *statement __unused) @@ -719,29 +829,6 @@ static bool sema_analyse_switch_stmt(Context *context, Ast *statement) return success; } -static bool sema_analyse_try_stmt(Context *context, Ast *statement) -{ - context->try_nesting++; - unsigned errors = vec_size(context->errors); - if (!sema_analyse_statement(context, statement->try_stmt)) - { - context->try_nesting--; - return false; - } - unsigned new_errors = vec_size(context->errors); - if (new_errors == errors) - { - SEMA_ERROR(statement, "No error to 'try' in the statement that follows, please remove the 'try'."); - return false; - } - for (unsigned i = errors; i < new_errors; i++) - { - // At least one uncaught error found! - if (context->errors[i]) return true; - } - SEMA_ERROR(statement, "All errors in the following statement was caught, please remove the 'try'."); - return false; -} static bool sema_analyse_throw_stmt(Context *context, Ast *statement) { @@ -749,18 +836,44 @@ static bool sema_analyse_throw_stmt(Context *context, Ast *statement) UPDATE_EXIT(EXIT_THROW); if (!sema_analyse_expr(context, NULL, throw_value)) return false; Type *type = throw_value->type->canonical; - if (type->type_kind != TYPE_ERROR) + if (type->type_kind != TYPE_ERROR && type->type_kind != TYPE_ERROR_UNION) { SEMA_ERROR(throw_value, "Only 'error' types can be thrown, this is a '%s'.", type->name); return false; } - if (!context->try_nesting && !context->active_function_for_analysis->func.function_signature.error_return) + FunctionSignature *sig = &context->active_function_for_analysis->func.function_signature; + if (sig->error_return == ERROR_RETURN_NONE) { - // TODO check error type - SEMA_ERROR(statement, "This 'throw' is not handled, please add a 'throws %s' clause to the function signature or use try-catch.", type->name); + SEMA_ERROR(statement, "This throw requires that the function adds 'throws %s' to its declaration.", type->name); return false; } - vec_add(context->errors, type->decl); + + // Check if the error is actually in the list. + if (sig->error_return == ERROR_RETURN_MANY || sig->error_return == ERROR_RETURN_ONE) + { + bool found = false; + VECEACH(sig->throws, i) + { + if (sig->throws[i]->type == type) + { + found = true; + } + } + if (!found) + { + if (type != type_error_union) + { + SEMA_ERROR(statement->throw_stmt.throw_value, "'%s' must be added to the list of errors after 'throws'.", type->name); + } + else + { + SEMA_ERROR(statement, "This throw requires the function to use a wildcard 'throws' without types.", type->name); + } + return false; + } + } + + vec_add(context->throw, statement); return true; } @@ -833,8 +946,6 @@ static inline bool sema_analyse_statement_inner(Context *context, Ast *statement return sema_analyse_switch_stmt(context, statement); case AST_THROW_STMT: return sema_analyse_throw_stmt(context, statement); - case AST_TRY_STMT: - return sema_analyse_try_stmt(context, statement); case AST_NEXT_STMT: return sema_analyse_next_stmt(context, statement); case AST_VOLATILE_STMT: @@ -883,16 +994,75 @@ static inline void defer_list_walk_to_common_depth(Ast **defer_stmt, int this_de } } +static inline bool throw_add_error_return_catch(Throw *throw, Decl **func_throws) +{ + assert(throw->kind != THROW_TYPE_CALL_ANY); + Decl **throws; + unsigned throw_count; + if (throw->kind == THROW_TYPE_CALL_THROW_MANY) + { + throws = throw->throws; + throw_count = vec_size(throws); + } + else + { + throws = &throw->throw->decl; + throw_count = 1; + } + unsigned func_throw_count = vec_size(func_throws); + assert(func_throw_count); + bool catch_added = false; + for (unsigned i = 0; i < func_throw_count; i++) + { + Decl *func_throw = func_throws[i]; + for (unsigned j = 0; j < throw_count; j++) + { + if (throws[j] == func_throw->type->decl) + { + // If the throw was already caught, ignore it. + if (throw_completely_caught(throws[j], throw->throw_info->catches)) continue; + + // One of the throws was caught + if (func_throw_count > 1) + { + CatchInfo info = { .kind = CATCH_RETURN_MANY, .error = func_throw->type->decl }; + vec_add(throw->throw_info->catches, info); + } + else + { + CatchInfo info = { .kind = CATCH_RETURN_ONE, .error = func_throw->type->decl }; + vec_add(throw->throw_info->catches, info); + } + // If we only have one count, then we're done! + if (throw_count == 1) + { + throw->throw_info->is_completely_handled = true; + return true; + } + // Otherwise we simply continue. + } + } + } + // If we have already caught some, then we might have completely caught all throws. + if (throw_count > 1 && catch_added) + { + throw->throw_info->is_completely_handled = throw_completely_handled_call_throw_many(throw); + } + return catch_added; +} + bool sema_analyse_function_body(Context *context, Decl *func) { + FunctionSignature *signature = &func->func.function_signature; context->active_function_for_analysis = func; - context->rtype = func->func.function_signature.rtype->type; + context->rtype = signature->rtype->type; context->current_scope = &context->scopes[0]; // Clean out the current scope. memset(context->current_scope, 0, sizeof(*context->current_scope)); // Clear try handling - vec_resize(context->errors, 0); + vec_resize(context->throw, 0); + vec_resize(context->error_calls, 0); // Clear returns vec_resize(context->returns, 0); context->try_nesting = 0; @@ -904,7 +1074,7 @@ bool sema_analyse_function_body(Context *context, Decl *func) context->in_volatile_section = 0; func->func.annotations = CALLOCS(*func->func.annotations); context_push_scope(context); - Decl **params = func->func.function_signature.params; + Decl **params = signature->params; assert(context->current_scope == &context->scopes[1]); VECEACH(params, i) { @@ -914,7 +1084,7 @@ bool sema_analyse_function_body(Context *context, Decl *func) assert(context->current_scope == &context->scopes[1]); if (context->current_scope->exit != EXIT_RETURN && context->current_scope->exit != EXIT_THROW && context->current_scope->exit != EXIT_GOTO) { - if (func->func.function_signature.rtype->type->canonical != type_void) + if (signature->rtype->type->canonical != type_void) { // IMPROVE better pointer to end. SEMA_ERROR(func, "Missing return statement at the end of the function."); @@ -922,6 +1092,7 @@ bool sema_analyse_function_body(Context *context, Decl *func) } } + VECEACH(context->gotos, i) { Ast *goto_stmt = context->gotos[i]; @@ -966,6 +1137,40 @@ bool sema_analyse_function_body(Context *context, Decl *func) current = current->defer_stmt.prev_defer; } } + bool error_was_useful = vec_size(context->throw) > 0; + VECEACH(context->error_calls, i) + { + Throw *throw = &context->error_calls[i]; + if (throw->throw_info->is_completely_handled) continue; + + switch (signature->error_return) + { + case ERROR_RETURN_NONE: + // Nothing to do, will result in error. + break; + case ERROR_RETURN_ANY: + // Any return, then any throw is ok, add + // an implicit catch. + vec_add(throw->throw_info->catches, (CatchInfo) { .kind = CATCH_RETURN_ANY }); + throw->throw_info->is_completely_handled = true; + error_was_useful = true; + continue; + case ERROR_RETURN_MANY: + case ERROR_RETURN_ONE: + // Try to add a catch. + if (throw_add_error_return_catch(throw, signature->throws)) + { + error_was_useful = true; + } + break; + } + // If it's fully catched, then fine. + if (throw->throw_info->is_completely_handled) continue; + // Otherwise error. + SEMA_ERROR(throw, "The errors returned by the call must be completely caught in a catch or else the function current must be declared to throw."); + return false; + } + func->func.labels = context->labels; context_pop_scope(context); context->current_scope = NULL; diff --git a/src/compiler/sema_types.c b/src/compiler/sema_types.c index 71ceb7d14..5522aec91 100644 --- a/src/compiler/sema_types.c +++ b/src/compiler/sema_types.c @@ -18,6 +18,28 @@ static inline bool sema_resolve_ptr_type(Context *context, TypeInfo *type_info) return true; } +bool throw_completely_caught(Decl *throw, CatchInfo *catches) +{ + VECEACH(catches, i) + { + CatchInfo *catch_info = &catches[i]; + switch (catch_info->kind) + { + case CATCH_REGULAR: + if (throw == catch_info->catch->catch_stmt.error_param->type->decl) return true; + break; + case CATCH_TRY_ELSE: + case CATCH_RETURN_ANY: + return true; + case CATCH_RETURN_MANY: + case CATCH_RETURN_ONE: + if (throw == catch_info->error) return true; + break; + } + } + return false; +} + static inline bool sema_resolve_array_type(Context *context, TypeInfo *type) { diff --git a/src/compiler/tokens.c b/src/compiler/tokens.c index 51ccf5f89..0a94f3773 100644 --- a/src/compiler/tokens.c +++ b/src/compiler/tokens.c @@ -200,6 +200,8 @@ const char *token_type_to_string(TokenType type) return "extern"; case TOKEN_ERROR_TYPE: return "error"; + case TOKEN_ERRSET: + return "errset"; case TOKEN_FALSE: return "false"; case TOKEN_FOR: diff --git a/src/compiler/types.c b/src/compiler/types.c index 17307dd26..9c4c29627 100644 --- a/src/compiler/types.c +++ b/src/compiler/types.c @@ -10,8 +10,7 @@ static Type t_f32, t_f64, t_fxx; static Type t_usz, t_isz; static Type t_cus, t_cui, t_cul, t_cull; static Type t_cs, t_ci, t_cl, t_cll; -static Type t_voidstar, t_typeid; -static Type t_err, t_error_union; +static Type t_voidstar, t_typeid, t_error_union; Type *type_bool = &t_u1; Type *type_void = &t_u0; @@ -19,8 +18,6 @@ Type *type_string = &t_str; Type *type_voidptr = &t_voidstar; Type *type_float = &t_f32; Type *type_double = &t_f64; -Type *type_error = &t_err; -Type *type_error_union = &t_error_union; Type *type_typeid = &t_typeid; Type *type_char = &t_i8; Type *type_short = &t_i16; @@ -42,6 +39,13 @@ Type *type_c_ushort = &t_cus; Type *type_c_uint = &t_cui; Type *type_c_ulong = &t_cul; Type *type_c_ulonglong = &t_cull; +Type *type_error_union = &t_error_union; +Type *type_error_base = &t_ci; + +static unsigned size_subarray; +static unsigned alignment_subarray; +unsigned size_error_code; +unsigned alignment_error_code; #define PTR_OFFSET 0 #define VAR_ARRAY_OFFSET 1 @@ -116,7 +120,7 @@ const char *type_to_error_string(Type *type) asprintf(&buffer, "%s[:]", type_to_error_string(type->array.base)); return buffer; case TYPE_ERROR_UNION: - TODO + return "error"; } UNREACHABLE } @@ -197,7 +201,7 @@ size_t type_size(Type *canonical) case TYPE_ENUM: return canonical->decl->enums.type_info->type->canonical->builtin.bytesize; case TYPE_ERROR: - return type_error->canonical->builtin.bytesize; + return alignment_error_code; case TYPE_STRUCT: case TYPE_UNION: return canonical->decl->strukt.size; @@ -212,13 +216,12 @@ size_t type_size(Type *canonical) case TYPE_POINTER: case TYPE_VARARRAY: case TYPE_STRING: + case TYPE_ERROR_UNION: return t_usz.canonical->builtin.bytesize; case TYPE_ARRAY: return type_size(canonical->array.base) * canonical->array.len; case TYPE_SUBARRAY: - TODO - case TYPE_ERROR_UNION: - TODO + return size_subarray; } UNREACHABLE } @@ -235,7 +238,7 @@ unsigned int type_abi_alignment(Type *canonical) case TYPE_ENUM: return canonical->decl->enums.type_info->type->canonical->builtin.abi_alignment; case TYPE_ERROR: - return type_error->canonical->builtin.abi_alignment; + return alignment_error_code; case TYPE_STRUCT: case TYPE_UNION: return canonical->decl->strukt.abi_alignment; @@ -243,6 +246,7 @@ unsigned int type_abi_alignment(Type *canonical) case TYPE_BOOL: case ALL_INTS: case ALL_FLOATS: + case TYPE_ERROR_UNION: return canonical->builtin.abi_alignment; case TYPE_FUNC: case TYPE_POINTER: @@ -252,9 +256,7 @@ unsigned int type_abi_alignment(Type *canonical) case TYPE_ARRAY: return type_abi_alignment(canonical->array.base); case TYPE_SUBARRAY: - TODO - case TYPE_ERROR_UNION: - TODO + return alignment_subarray; } UNREACHABLE } @@ -456,11 +458,12 @@ type_create(#_name, &_shortname, _type, _bits, target->align_ ## _align, target- type_create_alias("c_short", &t_cs, type_signed_int_by_bitsize(target->width_c_short)); type_create_alias("c_int", &t_ci, type_signed_int_by_bitsize(target->width_c_int)); - // TODO fix error size - type_create_alias("error", &t_err, type_signed_int_by_bitsize(target->width_c_int)); type_create_alias("c_long", &t_cl, type_signed_int_by_bitsize(target->width_c_long)); type_create_alias("c_longlong", &t_cll, type_signed_int_by_bitsize(target->width_c_long_long)); + alignment_subarray = MAX(type_abi_alignment(&t_voidstar), type_abi_alignment(t_usz.canonical)); + size_subarray = alignment_subarray * 2; + type_create("error", &t_error_union, TYPE_ERROR_UNION, target->width_pointer, target->align_pointer, target->align_pref_pointer); } /** @@ -670,7 +673,8 @@ Type *type_find_max_type(Type *type, Type *other) // some way? return NULL; case TYPE_ERROR: - TODO + if (other->type_kind == TYPE_ERROR) return type_error_union; + return NULL; case TYPE_FUNC: case TYPE_UNION: case TYPE_ERROR_UNION: diff --git a/src/utils/lib.h b/src/utils/lib.h index c86aabbf7..ebd1d0833 100644 --- a/src/utils/lib.h +++ b/src/utils/lib.h @@ -303,5 +303,10 @@ static inline bool is_all_lower(const char* string) #define __printflike(x, y) #endif +char *strcat_arena(const char *a, const char *b); char *strformat(const char *var, ...) __printflike(1, 2); +#define MAX(_a, _b) ({ \ + typeof(_a) __a__ = (_a); \ + typeof(_b) __b__ = (_b); \ + __a__ > __b__ ? __a__ : __b__; }) \ No newline at end of file diff --git a/src/utils/stringutils.c b/src/utils/stringutils.c index 5162857e7..255b31d26 100644 --- a/src/utils/stringutils.c +++ b/src/utils/stringutils.c @@ -21,4 +21,15 @@ char *strformat(const char *var, ...) va_end(list); assert(len == new_len); return buffer; +} + +char *strcat_arena(const char *a, const char *b) +{ + unsigned a_len = strlen(a); + unsigned b_len = strlen(b); + char *buffer = malloc_arena(a_len + b_len + 1); + memcpy(buffer, a, a_len); + memcpy(buffer + a_len, b, b_len); + buffer[a_len + b_len] = '\0'; + return buffer; } \ No newline at end of file