diff --git a/src/compiler/ast.c b/src/compiler/ast.c index 97054f4d0..791d65a54 100644 --- a/src/compiler/ast.c +++ b/src/compiler/ast.c @@ -225,6 +225,8 @@ bool expr_is_pure(Expr *expr) return true; case EXPR_BITASSIGN: return false; + case EXPR_VARIANTSWITCH: + return false; case EXPR_BINARY: if (expr->binary_expr.operator >= BINARYOP_ASSIGN) return false; return expr_is_pure(expr->binary_expr.right) && expr_is_pure(expr->binary_expr.left); diff --git a/src/compiler/compiler_internal.h b/src/compiler/compiler_internal.h index 63db8610a..857dec0a7 100644 --- a/src/compiler/compiler_internal.h +++ b/src/compiler/compiler_internal.h @@ -920,6 +920,22 @@ typedef struct Token identifier; BuiltinFunction builtin; } ExprBuiltin; + +typedef struct +{ + bool is_assign : 1; + bool is_deref : 1; + union + { + struct + { + TokenId new_ident; + Expr *variant_expr; + }; + Decl *variable; + }; +} ExprVariantSwitch; + struct Expr_ { ExprKind expr_kind : 8; @@ -927,6 +943,7 @@ struct Expr_ SourceSpan span; Type *type; union { + ExprVariantSwitch variant_switch; ExprLen len_expr; ExprCast cast_expr; TypeInfo *type_expr; diff --git a/src/compiler/copying.c b/src/compiler/copying.c index bd44ba96e..12cee3922 100644 --- a/src/compiler/copying.c +++ b/src/compiler/copying.c @@ -70,6 +70,7 @@ Expr *copy_expr(Expr *source_expr) switch (source_expr->expr_kind) { case EXPR_MACRO_BODY_EXPANSION: + case EXPR_VARIANTSWITCH: UNREACHABLE case EXPR_FLATPATH: case EXPR_UNDEF: diff --git a/src/compiler/enums.h b/src/compiler/enums.h index 0fb8b9311..86b80056f 100644 --- a/src/compiler/enums.h +++ b/src/compiler/enums.h @@ -183,6 +183,7 @@ typedef enum EXPR_COMPOUND_LITERAL, EXPR_CONST, EXPR_CONST_IDENTIFIER, + EXPR_CT_CALL, EXPR_CT_IDENT, EXPR_COND, EXPR_DECL, @@ -219,7 +220,7 @@ typedef enum EXPR_TYPEINFO, EXPR_UNARY, EXPR_UNDEF, - EXPR_CT_CALL, + EXPR_VARIANTSWITCH, EXPR_NOP, } ExprKind; diff --git a/src/compiler/linker.c b/src/compiler/linker.c index 9085d9ec8..febd1b098 100644 --- a/src/compiler/linker.c +++ b/src/compiler/linker.c @@ -3,14 +3,16 @@ #include // for LLVM_VERSION_STRING #ifdef PLATFORM_WINDOWS + #include "utils/find_msvc.h" + #endif -extern bool llvm_link_elf(const char **args, int arg_count, const char** error_string); -extern bool llvm_link_macho(const char **args, int arg_count, const char** error_string); -extern bool llvm_link_coff(const char **args, int arg_count, const char** error_string); -extern bool llvm_link_wasm(const char **args, int arg_count, const char** error_string); -extern bool llvm_link_mingw(const char **args, int arg_count, const char** error_string); +extern bool llvm_link_elf(const char **args, int arg_count, const char **error_string); +extern bool llvm_link_macho(const char **args, int arg_count, const char **error_string); +extern bool llvm_link_coff(const char **args, int arg_count, const char **error_string); +extern bool llvm_link_wasm(const char **args, int arg_count, const char **error_string); +extern bool llvm_link_mingw(const char **args, int arg_count, const char **error_string); static void add_files(const char ***args, const char **files_to_link, unsigned file_count) { @@ -37,15 +39,16 @@ static void prepare_msys2_linker_flags(const char ***args, const char **files_to add_arg("-m"); add_arg("i386pep"); add_arg("-Bdynamic"); - add_arg(join_strings((const char *[]){root, "\\x86_64-w64-mingw32\\lib\\crt2.o"}, 2)); - add_arg(join_strings((const char *[]){root, "\\x86_64-w64-mingw32\\lib\\crtbegin.o"}, 2)); - add_arg(join_strings((const char *[]){"-L", root, "\\x86_64-w64-mingw32\\lib"}, 3)); - add_arg(join_strings((const char *[]){"-L", root, "\\lib"}, 3)); - add_arg(join_strings((const char *[]){"-L", root, "\\x86_64-w64-mingw32\\sys-root\\mingw\\lib"}, 3)); - add_arg(join_strings((const char *[]){"-L", root, "\\lib\\clang\\", LLVM_VERSION_STRING, "\\lib\\windows"}, 5)); + add_arg(join_strings((const char *[]){ root, "\\x86_64-w64-mingw32\\lib\\crt2.o" }, 2)); + add_arg(join_strings((const char *[]){ root, "\\x86_64-w64-mingw32\\lib\\crtbegin.o" }, 2)); + add_arg(join_strings((const char *[]){ "-L", root, "\\x86_64-w64-mingw32\\lib" }, 3)); + add_arg(join_strings((const char *[]){ "-L", root, "\\lib" }, 3)); + add_arg(join_strings((const char *[]){ "-L", root, "\\x86_64-w64-mingw32\\sys-root\\mingw\\lib" }, 3)); + add_arg(join_strings((const char *[]){ "-L", root, "\\lib\\clang\\", LLVM_VERSION_STRING, "\\lib\\windows" }, 5)); add_files(args, files_to_link, file_count); add_arg("-lmingw32"); - add_arg(join_strings((const char *[]){root, "\\lib\\clang\\", LLVM_VERSION_STRING, "\\lib\\windows\\libclang_rt.builtins-x86_64.a"}, 4)); + add_arg(join_strings((const char *[]){ root, "\\lib\\clang\\", LLVM_VERSION_STRING, + "\\lib\\windows\\libclang_rt.builtins-x86_64.a" }, 4)); add_arg("-lunwind"); add_arg("-lmoldname"); add_arg("-lmingwex"); @@ -55,13 +58,14 @@ static void prepare_msys2_linker_flags(const char ***args, const char **files_to add_arg("-luser32"); add_arg("-lkernel32"); add_arg("-lmingw32"); - add_arg(join_strings((const char *[]){root, "\\lib\\clang\\", LLVM_VERSION_STRING, "\\lib\\windows\\libclang_rt.builtins-x86_64.a"}, 4)); + add_arg(join_strings((const char *[]){ root, "\\lib\\clang\\", LLVM_VERSION_STRING, + "\\lib\\windows\\libclang_rt.builtins-x86_64.a" }, 4)); add_arg("-lunwind"); add_arg("-lmoldname"); add_arg("-lmingwex"); add_arg("-lmsvcrt"); add_arg("-lkernel32"); - add_arg(join_strings((const char *[]){root, "\\x86_64-w64-mingw32\\lib\\crtend.o"}, 2)); + add_arg(join_strings((const char *[]){ root, "\\x86_64-w64-mingw32\\lib\\crtend.o" }, 2)); #undef add_arg } @@ -76,8 +80,8 @@ static bool link_exe(const char *output_file, const char **files_to_link, unsign else { #endif - vec_add(args, "-o"); - vec_add(args, output_file); + vec_add(args, "-o"); + vec_add(args, output_file); #ifdef _MSC_VER } #endif @@ -87,7 +91,7 @@ static bool link_exe(const char *output_file, const char **files_to_link, unsign } const char *error = NULL; // This isn't used in most cases, but its contents should get freed after linking. - WindowsLinkPathsUTF8 windows_paths = {0}; + WindowsLinkPathsUTF8 windows_paths = { 0 }; switch (platform_target.os) { @@ -129,7 +133,7 @@ static bool link_exe(const char *output_file, const char **files_to_link, unsign { return false; } - } + } break; case OS_TYPE_MACOSX: add_files(&args, files_to_link, file_count); @@ -216,7 +220,14 @@ static bool link_exe(const char *output_file, const char **files_to_link, unsign switch (platform_target.object_format) { case OBJ_FORMAT_COFF: - success = (platform_target.x64.is_mingw64 ? llvm_link_mingw : llvm_link_coff)(args, (int)vec_size(args), &error); + if (platform_target.x64.is_mingw64) + { + success = llvm_link_mingw(args, (int)vec_size(args), &error); + } + else + { + success = llvm_link_coff(args, (int)vec_size(args), &error); + } // This is only defined if compiling with MSVC #ifdef _MSC_VER if (windows_paths.windows_sdk_um_library_path) { @@ -320,17 +331,17 @@ void platform_linker(const char *output_file, const char **files, unsigned file_ printf("Program linked to executable '%s'.\n", output_file); } -void platform_compiler(const char **files, unsigned file_count, const char* flags) +void platform_compiler(const char **files, unsigned file_count, const char *flags) { const char **parts = NULL; vec_add(parts, active_target.cc); const bool pie_set = - flags != NULL && - ( strstr(flags, "-fno-PIE") || // This is a weird case, but probably don't set PIE if - strstr(flags, "-fno-pie") || // it is being set in user defined cflags. - strstr(flags, "-fpie") || - strstr(flags, "-fPIE") ); // strcasestr is apparently nonstandard >:( + flags != NULL && + (strstr(flags, "-fno-PIE") || // This is a weird case, but probably don't set PIE if + strstr(flags, "-fno-pie") || // it is being set in user defined cflags. + strstr(flags, "-fpie") || + strstr(flags, "-fPIE")); // strcasestr is apparently nonstandard >:( if (!pie_set) { switch (platform_target.pie) @@ -349,7 +360,7 @@ void platform_compiler(const char **files, unsigned file_count, const char* flag break; } } - + vec_add(parts, "-c"); if (flags) vec_add(parts, flags); for (unsigned i = 0; i < file_count; i++) diff --git a/src/compiler/llvm_codegen_expr.c b/src/compiler/llvm_codegen_expr.c index e6d131a38..d4ed6a435 100644 --- a/src/compiler/llvm_codegen_expr.c +++ b/src/compiler/llvm_codegen_expr.c @@ -5167,6 +5167,7 @@ void llvm_emit_expr(GenContext *c, BEValue *value, Expr *expr) case EXPR_PLACEHOLDER: case EXPR_CT_CALL: case EXPR_FLATPATH: + case EXPR_VARIANTSWITCH: UNREACHABLE case EXPR_TRY_UNWRAP_CHAIN: llvm_emit_try_unwrap_chain(c, value, expr); @@ -5187,6 +5188,7 @@ void llvm_emit_expr(GenContext *c, BEValue *value, Expr *expr) TODO case EXPR_DECL: llvm_emit_local_decl(c, expr->decl_expr); + llvm_value_set_decl_address(value, expr->decl_expr); return; case EXPR_SLICE_ASSIGN: llvm_emit_slice_assign(c, value, expr); diff --git a/src/compiler/parse_expr.c b/src/compiler/parse_expr.c index 07a40c4d9..a4c359f76 100644 --- a/src/compiler/parse_expr.c +++ b/src/compiler/parse_expr.c @@ -161,7 +161,7 @@ static inline Expr *parse_try_unwrap(Context *context) static inline Expr *parse_try_unwrap_chain(Context *context) { Expr **unwraps = NULL; -ASSIGN_EXPR_ELSE(Expr *first_unwrap , parse_try_unwrap(context), poisoned_expr); + ASSIGN_EXPR_ELSE(Expr *first_unwrap , parse_try_unwrap(context), poisoned_expr); vec_add(unwraps, first_unwrap); while (try_consume(context, TOKEN_AND)) { diff --git a/src/compiler/sema_casts.c b/src/compiler/sema_casts.c index 0e2ce49fa..8814b274e 100644 --- a/src/compiler/sema_casts.c +++ b/src/compiler/sema_casts.c @@ -803,6 +803,7 @@ Expr *recursive_may_narrow_float(Expr *expr, Type *type) case EXPR_SUBSCRIPT_ADDR: case EXPR_TYPEOFANY: case EXPR_PTR: + case EXPR_VARIANTSWITCH: UNREACHABLE case EXPR_POST_UNARY: return recursive_may_narrow_float(expr->unary_expr.expr, type); @@ -956,6 +957,7 @@ Expr *recursive_may_narrow_int(Expr *expr, Type *type) case EXPR_SUBSCRIPT_ADDR: case EXPR_TYPEOFANY: case EXPR_PTR: + case EXPR_VARIANTSWITCH: UNREACHABLE case EXPR_POST_UNARY: return recursive_may_narrow_int(expr->unary_expr.expr, type); diff --git a/src/compiler/sema_expr.c b/src/compiler/sema_expr.c index 7155fc450..b6d6f9398 100644 --- a/src/compiler/sema_expr.c +++ b/src/compiler/sema_expr.c @@ -316,6 +316,8 @@ bool expr_is_constant_eval(Expr *expr, ConstantEvalKind eval_kind) case EXPR_ACCESS: expr = expr->access_expr.parent; goto RETRY; + case EXPR_VARIANTSWITCH: + return false; case EXPR_BITASSIGN: return false; case EXPR_BINARY: @@ -6720,6 +6722,7 @@ static inline bool sema_analyse_expr_dispatch(Context *context, Expr *expr) case EXPR_TRY_UNWRAP: case EXPR_CATCH_UNWRAP: case EXPR_PTR: + case EXPR_VARIANTSWITCH: UNREACHABLE case EXPR_DECL: if (!sema_analyse_var_decl(context, expr->decl_expr, true)) return false; diff --git a/src/compiler/sema_stmts.c b/src/compiler/sema_stmts.c index dff4440be..e50ea9e79 100644 --- a/src/compiler/sema_stmts.c +++ b/src/compiler/sema_stmts.c @@ -8,6 +8,13 @@ static bool sema_analyse_compound_stmt(Context *context, Ast *statement); +typedef enum +{ + COND_TYPE_UNWRAP_BOOL, + COND_TYPE_UNWRAP, + COND_TYPE_EVALTYPE_VALUE, +} CondType; + static void sema_unwrappable_from_catch_in_else(Context *c, Expr *cond) { assert(cond->expr_kind == EXPR_COND); @@ -276,18 +283,25 @@ static inline bool sema_analyse_try_unwrap(Context *context, Expr *expr) expr->resolve_status = RESOLVE_DONE; return true; } -static inline bool sema_analyse_try_unwrap_chain(Context *context, Expr *expr) + + +static inline bool sema_analyse_try_unwrap_chain(Context *context, Expr *expr, CondType cond_type) { + assert(cond_type == COND_TYPE_UNWRAP_BOOL || cond_type == COND_TYPE_UNWRAP); + assert(expr->expr_kind == EXPR_TRY_UNWRAP_CHAIN); + Expr **chain = expr->try_unwrap_chain_expr; + unsigned elements = vec_size(chain); + VECEACH(expr->try_unwrap_chain_expr, i) { - Expr *chain = expr->try_unwrap_chain_expr[i]; - if (chain->expr_kind == EXPR_TRY_UNWRAP) + Expr *chain_element = chain[i]; + if (chain_element->expr_kind == EXPR_TRY_UNWRAP) { - if (!sema_analyse_try_unwrap(context, chain)) return false; + if (!sema_analyse_try_unwrap(context, chain_element)) return false; continue; } - if (!sema_analyse_cond_expr(context, chain)) return false; + if (!sema_analyse_cond_expr(context, chain_element)) return false; } expr->type = type_bool; expr->resolve_status = RESOLVE_DONE; @@ -413,27 +427,81 @@ static void sema_remove_unwraps_from_try(Context *c, Expr *cond) } } -static inline bool sema_analyse_last_cond(Context *context, Expr *expr, bool may_unwrap) +static inline bool sema_analyse_last_cond(Context *context, Expr *expr, CondType cond_type) { switch (expr->expr_kind) { case EXPR_TRY_UNWRAP_CHAIN: - if (!may_unwrap) + if (cond_type != COND_TYPE_UNWRAP_BOOL && cond_type != COND_TYPE_UNWRAP) { SEMA_ERROR(expr, "Try unwrapping is only allowed inside of a 'while' or 'if' conditional."); return false; } - return sema_analyse_try_unwrap_chain(context, expr); + return sema_analyse_try_unwrap_chain(context, expr, cond_type); case EXPR_CATCH_UNWRAP: - if (!may_unwrap) + if (cond_type != COND_TYPE_UNWRAP_BOOL && cond_type != COND_TYPE_UNWRAP) { SEMA_ERROR(expr, "Catch unwrapping is only allowed inside of a 'while' or 'if' conditional, maybe catch(...) will do what you need?"); return false; } return sema_analyse_catch_unwrap(context, expr); default: - return sema_analyse_expr(context, expr); + break; } + + if (cond_type != COND_TYPE_EVALTYPE_VALUE) goto NORMAL_EXPR; + + // Now we're analysing the last expression in a switch. + // Case 1: switch (var = variant_expr) + if (expr->expr_kind == EXPR_BINARY && expr->binary_expr.operator == BINARYOP_ASSIGN) + { + // No variable on the lhs? Then it can't be a variant unwrap. + Expr *left = expr->binary_expr.left; + if (left->resolve_status == RESOLVE_DONE || left->expr_kind != EXPR_IDENTIFIER || left->identifier_expr.path) goto NORMAL_EXPR; + + // Does the identifier exist in the parent scope? + // then again it can't be a variant unwrap. + Decl *decl_for_ident = sema_resolve_normal_symbol(context, left->identifier_expr.identifier, NULL, false); + if (decl_for_ident) goto NORMAL_EXPR; + + Expr *right = expr->binary_expr.right; + bool is_deref = right->expr_kind == EXPR_UNARY && right->unary_expr.operator == UNARYOP_DEREF; + if (is_deref) right = right->unary_expr.expr; + if (!sema_analyse_expr_rhs(context, NULL, right, false)) return false; + if (right->type == type_get_ptr(type_any) && is_deref) + { + is_deref = false; + right = expr->binary_expr.right; + if (!sema_analyse_expr_rhs(context, NULL, right, false)) return false; + } + if (right->type != type_any) goto NORMAL_EXPR; + // Found an expansion here + expr->expr_kind = EXPR_VARIANTSWITCH; + expr->variant_switch.new_ident = left->identifier_expr.identifier; + expr->variant_switch.variant_expr = right; + expr->variant_switch.is_deref = is_deref; + expr->variant_switch.is_assign = true; + expr->resolve_status = RESOLVE_DONE; + expr->type = type_typeid; + return true; + } + if (!sema_analyse_expr(context, expr)) return false; + if (expr->type != type_any) return true; + if (expr->expr_kind == EXPR_IDENTIFIER) + { + Decl *decl = expr->identifier_expr.decl; + expr->expr_kind = EXPR_VARIANTSWITCH; + expr->variant_switch.is_deref = false; + expr->variant_switch.is_assign = false; + expr->variant_switch.variable = decl; + expr->type = type_typeid; + expr->resolve_status = RESOLVE_DONE; + return true; + } + return true; + +NORMAL_EXPR: + return sema_analyse_expr(context, expr); } /** * An decl-expr-list is a list of a mixture of declarations and expressions. @@ -443,7 +511,7 @@ static inline bool sema_analyse_last_cond(Context *context, Expr *expr, bool may * * In this case the final value is 4.0 and the type is float. */ -static inline bool sema_analyse_cond_list(Context *context, Expr *expr, bool may_unwrap) +static inline bool sema_analyse_cond_list(Context *context, Expr *expr, CondType cond_type) { assert(expr->expr_kind == EXPR_COND); @@ -463,13 +531,14 @@ static inline bool sema_analyse_cond_list(Context *context, Expr *expr, bool may if (!sema_analyse_expr(context, dexprs[i])) return false; } - if (!sema_analyse_last_cond(context, dexprs[entries - 1], may_unwrap)) return false; + if (!sema_analyse_last_cond(context, dexprs[entries - 1], cond_type)) return false; expr->type = dexprs[entries - 1]->type; expr->resolve_status = RESOLVE_DONE; return true; } + /** * Analyse a conditional expression: * @@ -482,12 +551,13 @@ static inline bool sema_analyse_cond_list(Context *context, Expr *expr, bool may * @param cast_to_bool if the result is to be cast to bool after * @return true if it passes analysis. */ -static inline bool sema_analyse_cond(Context *context, Expr *expr, bool cast_to_bool, bool may_unwrap) +static inline bool sema_analyse_cond(Context *context, Expr *expr, CondType cond_type) { + bool cast_to_bool = cond_type == COND_TYPE_UNWRAP_BOOL; assert(expr->expr_kind == EXPR_COND && "Conditional expressions should always be of type EXPR_DECL_LIST"); // 1. Analyse the declaration list. - if (!sema_analyse_cond_list(context, expr, may_unwrap)) return false; + if (!sema_analyse_cond_list(context, expr, cond_type)) return false; // 2. If we get "void", either through a void call or an empty list, // signal that. @@ -568,7 +638,7 @@ static inline bool sema_analyse_while_stmt(Context *context, Ast *statement) SCOPE_START_WITH_LABEL(statement->while_stmt.flow.label) // 2. Analyze the condition - if (!sema_analyse_cond(context, cond, true, true)) + if (!sema_analyse_cond(context, cond, COND_TYPE_UNWRAP_BOOL)) { // 2a. In case of error, pop context and exit. return SCOPE_POP_ERROR(); @@ -833,7 +903,7 @@ static inline bool sema_analyse_for_stmt(Context *context, Ast *statement) Expr *cond = statement->for_stmt.cond; if (cond->expr_kind == EXPR_COND) { - success = sema_analyse_cond(context, cond, true, true); + success = sema_analyse_cond(context, cond, COND_TYPE_UNWRAP_BOOL); } else { @@ -1371,8 +1441,9 @@ static inline bool sema_analyse_if_stmt(Context *context, Ast *statement) Expr *cond = statement->if_stmt.cond; SCOPE_OUTER_START - bool cast_to_bool = statement->if_stmt.then_body->ast_kind != AST_IF_CATCH_SWITCH_STMT; - success = sema_analyse_cond(context, cond, cast_to_bool, true); + CondType cond_type = statement->if_stmt.then_body->ast_kind == AST_IF_CATCH_SWITCH_STMT + ? COND_TYPE_UNWRAP : COND_TYPE_UNWRAP_BOOL; + success = sema_analyse_cond(context, cond, cond_type); Ast *then = statement->if_stmt.then_body; bool then_has_braces = then->ast_kind == AST_COMPOUND_STMT || then->ast_kind == AST_IF_CATCH_SWITCH_STMT; @@ -1858,10 +1929,10 @@ static inline bool sema_check_value_case(Context *context, Type *switch_type, As return true; } -static bool sema_analyse_switch_body(Context *context, Ast *statement, SourceSpan expr_span, Type *switch_type, Ast **cases, Decl *switch_decl) +static bool sema_analyse_switch_body(Context *context, Ast *statement, SourceSpan expr_span, Type *switch_type, Ast **cases, ExprVariantSwitch *variant, Decl *var_holder) { bool use_type_id = false; - if (!type_is_comparable(switch_type) && switch_type != type_any) + if (!type_is_comparable(switch_type)) { sema_error_range(expr_span, "You cannot test '%s' for equality, and only values that supports '==' for comparison can be used in a switch.", type_to_error_string(switch_type)); return false; @@ -1927,19 +1998,41 @@ static bool sema_analyse_switch_body(Context *context, Ast *statement, SourceSpa Ast *next = (i < case_count - 1) ? cases[i + 1] : NULL; PUSH_NEXT(next, statement); Ast *body = stmt->case_stmt.body; - if (stmt->ast_kind == AST_CASE_STMT && body && type_switch && switch_decl && stmt->case_stmt.expr->expr_kind == EXPR_CONST) + if (stmt->ast_kind == AST_CASE_STMT && body && type_switch && var_holder && stmt->case_stmt.expr->expr_kind == EXPR_CONST) { - Type *type = type_get_ptr(stmt->case_stmt.expr->const_expr.typeid); - Decl *alias = decl_new_var(switch_decl->name_token, - type_info_new_base(type, stmt->case_stmt.expr->span), - VARDECL_LOCAL, VISIBLE_LOCAL); - Expr *ident_converted = expr_variable(switch_decl); - if (!cast(ident_converted, type)) return false; - alias->var.init_expr = ident_converted; - alias->var.shadow = true; - Ast *decl_ast = new_ast(AST_DECLARE_STMT, alias->span); - decl_ast->declare_stmt = alias; - vec_insert_first(body->compound_stmt.stmts, decl_ast); + if (variant->is_assign) + { + Type *real_type = type_get_ptr(stmt->case_stmt.expr->const_expr.typeid); + TokenId name = variant->new_ident; + Decl *new_var = decl_new_var(name, + type_info_new_base(variant->is_deref + ? real_type->pointer : real_type, source_span_from_token_id(name)), + VARDECL_LOCAL, VISIBLE_LOCAL); + Expr *var_result = expr_variable(var_holder); + if (!cast(var_result, real_type)) return false; + if (variant->is_deref) + { + expr_insert_deref(var_result); + } + new_var->var.init_expr = var_result; + Ast *decl_ast = new_ast(AST_DECLARE_STMT, new_var->span); + decl_ast->declare_stmt = new_var; + vec_insert_first(body->compound_stmt.stmts, decl_ast); + } + else + { + Type *type = type_get_ptr(stmt->case_stmt.expr->const_expr.typeid); + Decl *alias = decl_new_var(var_holder->name_token, + type_info_new_base(type, stmt->case_stmt.expr->span), + VARDECL_LOCAL, VISIBLE_LOCAL); + Expr *ident_converted = expr_variable(var_holder); + if (!cast(ident_converted, type)) return false; + alias->var.init_expr = ident_converted; + alias->var.shadow = true; + Ast *decl_ast = new_ast(AST_DECLARE_STMT, alias->span); + decl_ast->declare_stmt = alias; + vec_insert_first(body->compound_stmt.stmts, decl_ast); + } } success = success && (!body || sema_analyse_compound_statement_no_scope(context, body)); POP_BREAK(); @@ -2077,29 +2170,42 @@ static bool sema_analyse_switch_stmt(Context *context, Ast *statement) Expr *cond = statement->switch_stmt.cond; Type *switch_type; - Decl *last_decl = NULL; + ExprVariantSwitch var_switch; + Decl *variant_decl = NULL; if (statement->ast_kind == AST_SWITCH_STMT) { - if (!sema_analyse_cond(context, cond, false, false)) return false; + if (!sema_analyse_cond(context, cond, COND_TYPE_EVALTYPE_VALUE)) return false; Expr *last = VECLAST(cond->cond_expr); switch_type = last->type->canonical; - if (switch_type == type_any) + if (last->expr_kind == EXPR_VARIANTSWITCH) { - if (last->expr_kind == EXPR_DECL) + var_switch = last->variant_switch; + + + Expr *inner; + if (var_switch.is_assign) { - last_decl = last->decl_expr; + inner = expr_new(EXPR_DECL, last->span); + variant_decl = decl_new_generated_var(".variant", type_any, VARDECL_LOCAL, last->span); + variant_decl->var.init_expr = var_switch.variant_expr; + inner->decl_expr = variant_decl; + if (!sema_analyse_expr(context, inner)) return false; } - else if (last->expr_kind == EXPR_IDENTIFIER) + else { - last_decl = last->identifier_expr.decl; + inner = expr_new(EXPR_IDENTIFIER, last->span); + variant_decl = var_switch.variable; + inner->identifier_expr.decl = variant_decl; + inner->type = type_any; + inner->resolve_status = RESOLVE_DONE; } - Expr *inner = expr_copy(last); last->type = type_typeid; last->expr_kind = EXPR_TYPEOFANY; last->inner_expr = inner; switch_type = type_typeid; cond->type = type_typeid; } + } else { @@ -2109,7 +2215,7 @@ static bool sema_analyse_switch_stmt(Context *context, Ast *statement) statement->switch_stmt.defer = context->active_scope.defer_last; if (!sema_analyse_switch_body(context, statement, cond ? cond->span : statement->span, switch_type->canonical, - statement->switch_stmt.cases, last_decl)) + statement->switch_stmt.cases, variant_decl ? &var_switch : NULL, variant_decl)) { return SCOPE_POP_ERROR(); } @@ -2216,7 +2322,7 @@ bool sema_analyse_assert_stmt(Context *context, Ast *statement) } if (expr->expr_kind == EXPR_TRY_UNWRAP_CHAIN) { - if (!sema_analyse_try_unwrap_chain(context, expr)) return false; + if (!sema_analyse_try_unwrap_chain(context, expr, COND_TYPE_UNWRAP_BOOL)) return false; } else { diff --git a/test/test_suite/variant/variant_assign.c3t b/test/test_suite/variant/variant_assign.c3t new file mode 100644 index 000000000..212ae47cf --- /dev/null +++ b/test/test_suite/variant/variant_assign.c3t @@ -0,0 +1,360 @@ +// #target: x64-darwin +module foo; + +extern fn void printf(char*, ...); + +fn void test(variant z) +{ + switch (z) + { + case int: + printf("int: %d\n", *z); + case double: + printf("double %f\n", *z); + default: + printf("Unknown type.\n"); + } +} +fn void test2(variant y) +{ + switch (z = y) + { + case int: + y = &&12; + printf("int: %d\n", *z); + case double: + printf("double %f\n", *z); + default: + printf("Unknown type.\n"); + } +} + +fn void test3(variant y) +{ + switch (z = *y) + { + case int: + printf("int: %d\n", z); + case double: + printf("double %f\n", z); + default: + printf("Unknown type.\n"); + } +} + +fn int main() +{ + test(&&123.0); + test(&&1); + test(&&true); + test2(&&123.5); + test2(&&1); + test2(&&true); + test3(&&124.0); + test3(&&2); + test3(&&true); + return 0; +} + +/* #expect: foo.ll + +define void @foo.test(i64 %0, i8* %1) #0 { +entry: + %z = alloca %variant, align 8 + %switch = alloca i64, align 8 + %z1 = alloca i32*, align 8 + %z4 = alloca double*, align 8 + %pair = bitcast %variant* %z to { i64, i8* }* + %2 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %pair, i32 0, i32 0 + store i64 %0, i64* %2, align 8 + %3 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %pair, i32 0, i32 1 + store i8* %1, i8** %3, align 8 + %4 = getelementptr inbounds %variant, %variant* %z, i32 0, i32 1 + %5 = load i64, i64* %4, align 8 + store i64 %5, i64* %switch, align 8 + br label %switch.entry + +switch.entry: ; preds = %entry + %6 = load i64, i64* %switch, align 8 + %eq = icmp eq i64 5, %6 + br i1 %eq, label %switch.case, label %next_if + +switch.case: ; preds = %switch.entry + %7 = getelementptr inbounds %variant, %variant* %z, i32 0, i32 0 + %8 = bitcast i8** %7 to i32** + %9 = load i32*, i32** %8, align 8 + store i32* %9, i32** %z1, align 8 + %10 = load i32*, i32** %z1, align 8 + %11 = load i32, i32* %10, align 8 + call void (i8*, ...) @printf(i8* getelementptr inbounds ([9 x i8], [9 x i8]* @.str, i32 0, i32 0), i32 %11) + br label %switch.exit + +next_if: ; preds = %switch.entry + %eq2 = icmp eq i64 15, %6 + br i1 %eq2, label %switch.case3, label %next_if5 + +switch.case3: ; preds = %next_if + %12 = getelementptr inbounds %variant, %variant* %z, i32 0, i32 0 + %13 = bitcast i8** %12 to double** + %14 = load double*, double** %13, align 8 + store double* %14, double** %z4, align 8 + %15 = load double*, double** %z4, align 8 + %16 = load double, double* %15, align 8 + call void (i8*, ...) @printf(i8* getelementptr inbounds ([11 x i8], [11 x i8]* @.str.1, i32 0, i32 0), double %16) + br label %switch.exit + +next_if5: ; preds = %next_if + br label %switch.default + +switch.default: ; preds = %next_if5 + call void (i8*, ...) @printf(i8* getelementptr inbounds ([15 x i8], [15 x i8]* @.str.2, i32 0, i32 0)) + br label %switch.exit + +switch.exit: ; preds = %switch.default, %switch.case3, %switch.case + ret void +} + +define void @foo.test2(i64 %0, i8* %1) #0 { +entry: + %y = alloca %variant, align 8 + %.variant = alloca %variant, align 8 + %switch = alloca i64, align 8 + %z = alloca i32*, align 8 + %taddr = alloca i32, align 4 + %z3 = alloca double*, align 8 + %pair = bitcast %variant* %y to { i64, i8* }* + %2 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %pair, i32 0, i32 0 + store i64 %0, i64* %2, align 8 + %3 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %pair, i32 0, i32 1 + store i8* %1, i8** %3, align 8 + %4 = bitcast %variant* %.variant to i8* + %5 = bitcast %variant* %y to i8* + call void @llvm.memcpy.p0i8.p0i8.i32(i8* align 8 %4, i8* align 8 %5, i32 16, i1 false) + %6 = getelementptr inbounds %variant, %variant* %.variant, i32 0, i32 1 + %7 = load i64, i64* %6, align 8 + store i64 %7, i64* %switch, align 8 + br label %switch.entry + +switch.entry: ; preds = %entry + %8 = load i64, i64* %switch, align 8 + %eq = icmp eq i64 5, %8 + br i1 %eq, label %switch.case, label %next_if + +switch.case: ; preds = %switch.entry + %9 = getelementptr inbounds %variant, %variant* %.variant, i32 0, i32 0 + %10 = bitcast i8** %9 to i32** + %11 = load i32*, i32** %10, align 8 + store i32* %11, i32** %z, align 8 + store i32 12, i32* %taddr, align 4 + %12 = bitcast i32* %taddr to i8* + %13 = insertvalue %variant undef, i8* %12, 0 + %14 = insertvalue %variant %13, i64 5, 1 + store %variant %14, %variant* %y, align 8 + %15 = load i32*, i32** %z, align 8 + %16 = load i32, i32* %15, align 8 + call void (i8*, ...) @printf(i8* getelementptr inbounds ([9 x i8], [9 x i8]* @.str.3, i32 0, i32 0), i32 %16) + br label %switch.exit + +next_if: ; preds = %switch.entry + %eq1 = icmp eq i64 15, %8 + br i1 %eq1, label %switch.case2, label %next_if4 + +switch.case2: ; preds = %next_if + %17 = getelementptr inbounds %variant, %variant* %.variant, i32 0, i32 0 + %18 = bitcast i8** %17 to double** + %19 = load double*, double** %18, align 8 + store double* %19, double** %z3, align 8 + %20 = load double*, double** %z3, align 8 + %21 = load double, double* %20, align 8 + call void (i8*, ...) @printf(i8* getelementptr inbounds ([11 x i8], [11 x i8]* @.str.4, i32 0, i32 0), double %21) + br label %switch.exit + +next_if4: ; preds = %next_if + br label %switch.default + +switch.default: ; preds = %next_if4 + call void (i8*, ...) @printf(i8* getelementptr inbounds ([15 x i8], [15 x i8]* @.str.5, i32 0, i32 0)) + br label %switch.exit + +switch.exit: ; preds = %switch.default, %switch.case2, %switch.case + ret void +} + +define void @foo.test3(i64 %0, i8* %1) #0 { +entry: + %y = alloca %variant, align 8 + %.variant = alloca %variant, align 8 + %switch = alloca i64, align 8 + %z = alloca i32, align 4 + %z3 = alloca double, align 8 + %pair = bitcast %variant* %y to { i64, i8* }* + %2 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %pair, i32 0, i32 0 + store i64 %0, i64* %2, align 8 + %3 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %pair, i32 0, i32 1 + store i8* %1, i8** %3, align 8 + %4 = bitcast %variant* %.variant to i8* + %5 = bitcast %variant* %y to i8* + call void @llvm.memcpy.p0i8.p0i8.i32(i8* align 8 %4, i8* align 8 %5, i32 16, i1 false) + %6 = getelementptr inbounds %variant, %variant* %.variant, i32 0, i32 1 + %7 = load i64, i64* %6, align 8 + store i64 %7, i64* %switch, align 8 + br label %switch.entry + +switch.entry: ; preds = %entry + %8 = load i64, i64* %switch, align 8 + %eq = icmp eq i64 5, %8 + br i1 %eq, label %switch.case, label %next_if + +switch.case: ; preds = %switch.entry + %9 = getelementptr inbounds %variant, %variant* %.variant, i32 0, i32 0 + %10 = bitcast i8** %9 to i32** + %11 = load i32*, i32** %10, align 8 + %12 = load i32, i32* %11, align 8 + store i32 %12, i32* %z, align 4 + %13 = load i32, i32* %z, align 4 + call void (i8*, ...) @printf(i8* getelementptr inbounds ([9 x i8], [9 x i8]* @.str.6, i32 0, i32 0), i32 %13) + br label %switch.exit + +next_if: ; preds = %switch.entry + %eq1 = icmp eq i64 15, %8 + br i1 %eq1, label %switch.case2, label %next_if4 + +switch.case2: ; preds = %next_if + %14 = getelementptr inbounds %variant, %variant* %.variant, i32 0, i32 0 + %15 = bitcast i8** %14 to double** + %16 = load double*, double** %15, align 8 + %17 = load double, double* %16, align 8 + store double %17, double* %z3, align 8 + %18 = load double, double* %z3, align 8 + call void (i8*, ...) @printf(i8* getelementptr inbounds ([11 x i8], [11 x i8]* @.str.7, i32 0, i32 0), double %18) + br label %switch.exit + +next_if4: ; preds = %next_if + br label %switch.default + +switch.default: ; preds = %next_if4 + call void (i8*, ...) @printf(i8* getelementptr inbounds ([15 x i8], [15 x i8]* @.str.8, i32 0, i32 0)) + br label %switch.exit + +switch.exit: ; preds = %switch.default, %switch.case2, %switch.case + ret void +} + +define i32 @main() #0 { +entry: + %taddr = alloca double, align 8 + %taddr1 = alloca %variant, align 8 + %taddr2 = alloca i32, align 4 + %taddr3 = alloca %variant, align 8 + %taddr6 = alloca i8, align 1 + %taddr7 = alloca %variant, align 8 + %taddr10 = alloca double, align 8 + %taddr11 = alloca %variant, align 8 + %taddr14 = alloca i32, align 4 + %taddr15 = alloca %variant, align 8 + %taddr18 = alloca i8, align 1 + %taddr19 = alloca %variant, align 8 + %taddr22 = alloca double, align 8 + %taddr23 = alloca %variant, align 8 + %taddr26 = alloca i32, align 4 + %taddr27 = alloca %variant, align 8 + %taddr30 = alloca i8, align 1 + %taddr31 = alloca %variant, align 8 + store double 1.230000e+02, double* %taddr, align 8 + %0 = bitcast double* %taddr to i8* + %1 = insertvalue %variant undef, i8* %0, 0 + %2 = insertvalue %variant %1, i64 15, 1 + store %variant %2, %variant* %taddr1, align 8 + %3 = bitcast %variant* %taddr1 to { i64, i8* }* + %4 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %3, i32 0, i32 0 + %lo = load i64, i64* %4, align 8 + %5 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %3, i32 0, i32 1 + %hi = load i8*, i8** %5, align 8 + call void @foo.test(i64 %lo, i8* %hi) + store i32 1, i32* %taddr2, align 4 + %6 = bitcast i32* %taddr2 to i8* + %7 = insertvalue %variant undef, i8* %6, 0 + %8 = insertvalue %variant %7, i64 5, 1 + store %variant %8, %variant* %taddr3, align 8 + %9 = bitcast %variant* %taddr3 to { i64, i8* }* + %10 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %9, i32 0, i32 0 + %lo4 = load i64, i64* %10, align 8 + %11 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %9, i32 0, i32 1 + %hi5 = load i8*, i8** %11, align 8 + call void @foo.test(i64 %lo4, i8* %hi5) + store i8 1, i8* %taddr6, align 1 + %12 = insertvalue %variant undef, i8* %taddr6, 0 + %13 = insertvalue %variant %12, i64 2, 1 + store %variant %13, %variant* %taddr7, align 8 + %14 = bitcast %variant* %taddr7 to { i64, i8* }* + %15 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %14, i32 0, i32 0 + %lo8 = load i64, i64* %15, align 8 + %16 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %14, i32 0, i32 1 + %hi9 = load i8*, i8** %16, align 8 + call void @foo.test(i64 %lo8, i8* %hi9) + store double 1.235000e+02, double* %taddr10, align 8 + %17 = bitcast double* %taddr10 to i8* + %18 = insertvalue %variant undef, i8* %17, 0 + %19 = insertvalue %variant %18, i64 15, 1 + store %variant %19, %variant* %taddr11, align 8 + %20 = bitcast %variant* %taddr11 to { i64, i8* }* + %21 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %20, i32 0, i32 0 + %lo12 = load i64, i64* %21, align 8 + %22 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %20, i32 0, i32 1 + %hi13 = load i8*, i8** %22, align 8 + call void @foo.test2(i64 %lo12, i8* %hi13) + store i32 1, i32* %taddr14, align 4 + %23 = bitcast i32* %taddr14 to i8* + %24 = insertvalue %variant undef, i8* %23, 0 + %25 = insertvalue %variant %24, i64 5, 1 + store %variant %25, %variant* %taddr15, align 8 + %26 = bitcast %variant* %taddr15 to { i64, i8* }* + %27 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %26, i32 0, i32 0 + %lo16 = load i64, i64* %27, align 8 + %28 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %26, i32 0, i32 1 + %hi17 = load i8*, i8** %28, align 8 + call void @foo.test2(i64 %lo16, i8* %hi17) + store i8 1, i8* %taddr18, align 1 + %29 = insertvalue %variant undef, i8* %taddr18, 0 + %30 = insertvalue %variant %29, i64 2, 1 + store %variant %30, %variant* %taddr19, align 8 + %31 = bitcast %variant* %taddr19 to { i64, i8* }* + %32 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %31, i32 0, i32 0 + %lo20 = load i64, i64* %32, align 8 + %33 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %31, i32 0, i32 1 + %hi21 = load i8*, i8** %33, align 8 + call void @foo.test2(i64 %lo20, i8* %hi21) + store double 1.240000e+02, double* %taddr22, align 8 + %34 = bitcast double* %taddr22 to i8* + %35 = insertvalue %variant undef, i8* %34, 0 + %36 = insertvalue %variant %35, i64 15, 1 + store %variant %36, %variant* %taddr23, align 8 + %37 = bitcast %variant* %taddr23 to { i64, i8* }* + %38 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %37, i32 0, i32 0 + %lo24 = load i64, i64* %38, align 8 + %39 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %37, i32 0, i32 1 + %hi25 = load i8*, i8** %39, align 8 + call void @foo.test3(i64 %lo24, i8* %hi25) + store i32 2, i32* %taddr26, align 4 + %40 = bitcast i32* %taddr26 to i8* + %41 = insertvalue %variant undef, i8* %40, 0 + %42 = insertvalue %variant %41, i64 5, 1 + store %variant %42, %variant* %taddr27, align 8 + %43 = bitcast %variant* %taddr27 to { i64, i8* }* + %44 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %43, i32 0, i32 0 + %lo28 = load i64, i64* %44, align 8 + %45 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %43, i32 0, i32 1 + %hi29 = load i8*, i8** %45, align 8 + call void @foo.test3(i64 %lo28, i8* %hi29) + store i8 1, i8* %taddr30, align 1 + %46 = insertvalue %variant undef, i8* %taddr30, 0 + %47 = insertvalue %variant %46, i64 2, 1 + store %variant %47, %variant* %taddr31, align 8 + %48 = bitcast %variant* %taddr31 to { i64, i8* }* + %49 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %48, i32 0, i32 0 + %lo32 = load i64, i64* %49, align 8 + %50 = getelementptr inbounds { i64, i8* }, { i64, i8* }* %48, i32 0, i32 1 + %hi33 = load i8*, i8** %50, align 8 + call void @foo.test3(i64 %lo32, i8* %hi33) + ret i32 0 +}