diff --git a/src/compiler/compiler_internal.h b/src/compiler/compiler_internal.h index 060973549..8d3396c95 100644 --- a/src/compiler/compiler_internal.h +++ b/src/compiler/compiler_internal.h @@ -364,6 +364,7 @@ typedef struct VarDecl_ VarDeclKind kind : 8; bool constant : 1; bool unwrap : 1; + bool shadow : 1; bool vararg : 1; bool is_static : 1; bool is_threadlocal : 1; diff --git a/src/compiler/sema_name_resolution.c b/src/compiler/sema_name_resolution.c index e7ff73607..daa7fee97 100644 --- a/src/compiler/sema_name_resolution.c +++ b/src/compiler/sema_name_resolution.c @@ -452,6 +452,7 @@ bool sema_add_local(Context *context, Decl *decl) decl->module = context->module; // Ignore synthetic locals. if (decl->name_token.index == NO_TOKEN_ID.index) return true; + if (decl->decl_kind == DECL_VAR && decl->var.shadow) goto ADD_VAR; Decl *other = sema_resolve_normal_symbol(context, decl->name_token, NULL, false); assert(!other || other->module); if (other && other->module == context->module) @@ -461,6 +462,7 @@ bool sema_add_local(Context *context, Decl *decl) decl_poison(other); return false; } +ADD_VAR:; Decl ***vars = &context->active_function_for_analysis->func_decl.annotations->vars; unsigned num_vars = vec_size(*vars); if (num_vars == MAX_LOCALS - 1) diff --git a/src/compiler/sema_stmts.c b/src/compiler/sema_stmts.c index 016e89aa7..dff4440be 100644 --- a/src/compiler/sema_stmts.c +++ b/src/compiler/sema_stmts.c @@ -1858,10 +1858,10 @@ static inline bool sema_check_value_case(Context *context, Type *switch_type, As return true; } -static bool sema_analyse_switch_body(Context *context, Ast *statement, SourceSpan expr_span, Type *switch_type, Ast **cases) +static bool sema_analyse_switch_body(Context *context, Ast *statement, SourceSpan expr_span, Type *switch_type, Ast **cases, Decl *switch_decl) { bool use_type_id = false; - if (!type_is_comparable(switch_type)) + if (!type_is_comparable(switch_type) && switch_type != type_any) { sema_error_range(expr_span, "You cannot test '%s' for equality, and only values that supports '==' for comparison can be used in a switch.", type_to_error_string(switch_type)); return false; @@ -1876,6 +1876,7 @@ static bool sema_analyse_switch_body(Context *context, Ast *statement, SourceSpa unsigned case_count = vec_size(cases); bool success = true; bool max_ranged = false; + bool type_switch = switch_type == type_typeid; for (unsigned i = 0; i < case_count; i++) { Ast *stmt = cases[i]; @@ -1884,7 +1885,7 @@ static bool sema_analyse_switch_body(Context *context, Ast *statement, SourceSpa switch (stmt->ast_kind) { case AST_CASE_STMT: - if (switch_type->type_kind == TYPE_TYPEID) + if (type_switch) { if (!sema_check_type_case(context, switch_type, stmt, cases, i)) { @@ -1926,6 +1927,20 @@ static bool sema_analyse_switch_body(Context *context, Ast *statement, SourceSpa Ast *next = (i < case_count - 1) ? cases[i + 1] : NULL; PUSH_NEXT(next, statement); Ast *body = stmt->case_stmt.body; + if (stmt->ast_kind == AST_CASE_STMT && body && type_switch && switch_decl && stmt->case_stmt.expr->expr_kind == EXPR_CONST) + { + Type *type = type_get_ptr(stmt->case_stmt.expr->const_expr.typeid); + Decl *alias = decl_new_var(switch_decl->name_token, + type_info_new_base(type, stmt->case_stmt.expr->span), + VARDECL_LOCAL, VISIBLE_LOCAL); + Expr *ident_converted = expr_variable(switch_decl); + if (!cast(ident_converted, type)) return false; + alias->var.init_expr = ident_converted; + alias->var.shadow = true; + Ast *decl_ast = new_ast(AST_DECLARE_STMT, alias->span); + decl_ast->declare_stmt = alias; + vec_insert_first(body->compound_stmt.stmts, decl_ast); + } success = success && (!body || sema_analyse_compound_statement_no_scope(context, body)); POP_BREAK(); POP_NEXT(); @@ -2062,10 +2077,29 @@ static bool sema_analyse_switch_stmt(Context *context, Ast *statement) Expr *cond = statement->switch_stmt.cond; Type *switch_type; + Decl *last_decl = NULL; if (statement->ast_kind == AST_SWITCH_STMT) { if (!sema_analyse_cond(context, cond, false, false)) return false; - switch_type = VECLAST(cond->cond_expr)->type->canonical; + Expr *last = VECLAST(cond->cond_expr); + switch_type = last->type->canonical; + if (switch_type == type_any) + { + if (last->expr_kind == EXPR_DECL) + { + last_decl = last->decl_expr; + } + else if (last->expr_kind == EXPR_IDENTIFIER) + { + last_decl = last->identifier_expr.decl; + } + Expr *inner = expr_copy(last); + last->type = type_typeid; + last->expr_kind = EXPR_TYPEOFANY; + last->inner_expr = inner; + switch_type = type_typeid; + cond->type = type_typeid; + } } else { @@ -2075,7 +2109,7 @@ static bool sema_analyse_switch_stmt(Context *context, Ast *statement) statement->switch_stmt.defer = context->active_scope.defer_last; if (!sema_analyse_switch_body(context, statement, cond ? cond->span : statement->span, switch_type->canonical, - statement->switch_stmt.cases)) + statement->switch_stmt.cases, last_decl)) { return SCOPE_POP_ERROR(); } diff --git a/src/utils/lib.h b/src/utils/lib.h index a4a71c7fd..e5d7c3c9b 100644 --- a/src/utils/lib.h +++ b/src/utils/lib.h @@ -387,7 +387,6 @@ static inline void vec_resize(void *vec, uint32_t new_size) VHeader_ *header = vec; header[-1].size = new_size; } - static inline void vec_pop(void *vec) { assert(vec); @@ -441,6 +440,14 @@ static inline void* expand_(void *vec, size_t element_size) (vec_)[vec_size(vec_) - 1] = value_; \ } while (0) +#define vec_insert_first(vec_, value_) do { \ + void *__temp = expand_((vec_), sizeof(*(vec_))); \ + (vec_) = __temp; \ + unsigned __xsize = vec_size(vec_); \ + for (unsigned __x = __xsize - 1; __x > 0; __x--) (vec_)[__x] = (vec_)[__x - 1]; \ + (vec_)[0] = value_; \ + } while (0) + #if IS_GCC || IS_CLANG #define VECLAST(_vec) ({ unsigned _size = vec_size(_vec); _size ? (_vec)[_size - 1] : NULL; }) #else diff --git a/test/test_suite/variant/variant_switch.c3t b/test/test_suite/variant/variant_switch.c3t new file mode 100644 index 000000000..8a671af13 --- /dev/null +++ b/test/test_suite/variant/variant_switch.c3t @@ -0,0 +1,33 @@ +// #target: x64-darwin +module foo; + +extern fn void printf(char*, ...); + +fn void test(variant z) +{ + switch (z) + { + case int: + printf("int: %d\n", *z); + *z = 3; + case double: + printf("double %f\n", *z); + default: + printf("Unknown type.\n"); + } + if (z.typeid == int.typeid) + { + printf("int: %d\n", *(int*)(z)); + } +} +fn int main() +{ + test(&&123.0); + test(&&1); + test(&&true); + return 0; +} + +/* expect: foo.ll + +foekf \ No newline at end of file