From 0ef99c23a838fd3f9a40a70faf637e6af1b69567 Mon Sep 17 00:00:00 2001 From: Christoffer Lerno Date: Tue, 4 Mar 2025 02:18:24 +0100 Subject: [PATCH] Allow swizzling assign, eg. `abc.xz += { 5, 10 };` --- releasenotes.md | 1 + src/compiler/codegen_internal.h | 6 +- src/compiler/compiler_internal.h | 1 + src/compiler/llvm_codegen_expr.c | 71 +++++++++++++++---- src/compiler/sema_expr.c | 27 +++++-- .../vector/vector_conversion_scalar.c3 | 2 +- test/unit/regression/vector_ops.c3 | 28 ++++++++ 7 files changed, 116 insertions(+), 20 deletions(-) diff --git a/releasenotes.md b/releasenotes.md index f0afd793b..9e7cd9b11 100644 --- a/releasenotes.md +++ b/releasenotes.md @@ -16,6 +16,7 @@ - Remove `.allocator = allocator` syntax for functions. - Remove `@operator(construct)`. - Removal of "any-switch". +- Allow swizzling assign, eg. `abc.xz += { 5, 10 };` ### Fixes - Fix address sanitizer to work on MachO targets (e.g. MacOS). diff --git a/src/compiler/codegen_internal.h b/src/compiler/codegen_internal.h index 49c8b4a0e..5f8523000 100644 --- a/src/compiler/codegen_internal.h +++ b/src/compiler/codegen_internal.h @@ -113,10 +113,10 @@ UNUSED static inline bool abi_type_is_promotable_integer_or_bool(AbiType type) return false; } -static inline bool expr_is_vector_index(Expr *expr) +static inline bool expr_is_vector_index_or_swizzle(Expr *expr) { - return expr->expr_kind == EXPR_SUBSCRIPT - && type_lowering(exprtype(expr->subscript_expr.expr))->type_kind == TYPE_VECTOR; + return (expr->expr_kind == EXPR_SUBSCRIPT && type_lowering(exprtype(expr->subscript_expr.expr))->type_kind == TYPE_VECTOR) + || (expr->expr_kind == EXPR_SWIZZLE && type_lowering(exprtype(expr->swizzle_expr.parent))->type_kind == TYPE_VECTOR); } const char *codegen_create_asm(Ast *ast); diff --git a/src/compiler/compiler_internal.h b/src/compiler/compiler_internal.h index c2d341b72..7878b138c 100644 --- a/src/compiler/compiler_internal.h +++ b/src/compiler/compiler_internal.h @@ -1066,6 +1066,7 @@ typedef struct typedef struct { ExprId parent; + bool is_overlapping; const char *swizzle; } ExprSwizzle; diff --git a/src/compiler/llvm_codegen_expr.c b/src/compiler/llvm_codegen_expr.c index 02bf2992d..01434e812 100644 --- a/src/compiler/llvm_codegen_expr.c +++ b/src/compiler/llvm_codegen_expr.c @@ -16,6 +16,7 @@ static LLVMValueRef llvm_emit_dynamic_search(GenContext *c, LLVMValueRef type_id static inline void llvm_emit_bitassign_array(GenContext *c, LLVMValueRef result, BEValue parent, Decl *parent_decl, Decl *member); static inline void llvm_emit_builtin_access(GenContext *c, BEValue *be_value, Expr *expr); static inline void llvm_emit_const_initialize_reference(GenContext *c, BEValue *ref, Expr *expr); +static void llvm_emit_swizzle_from_value(GenContext *c, LLVMValueRef vector_value, BEValue *value, Expr *expr); static inline void llvm_emit_optional(GenContext *c, BEValue *be_value, Expr *expr); static inline void llvm_emit_inc_dec_change(GenContext *c, BEValue *addr, BEValue *after, BEValue *before, Expr *expr, int diff, @@ -4343,17 +4344,58 @@ static void llvm_emit_vector_assign_expr(GenContext *c, BEValue *be_value, Expr Expr *left = exprptr(expr->binary_expr.left); BinaryOp binary_op = expr->binary_expr.operator; BEValue addr; - BEValue index; + bool is_swizzle = left->expr_kind == EXPR_SWIZZLE; + + if (left->expr_kind == EXPR_SWIZZLE) + { + // Emit the variable + llvm_emit_exprid(c, &addr, left->swizzle_expr.parent); + } + else + { + // Emit the variable + llvm_emit_exprid(c, &addr, left->subscript_expr.expr); + } // Emit the variable - llvm_emit_exprid(c, &addr, left->subscript_expr.expr); llvm_value_addr(c, &addr); - LLVMValueRef vector_value = llvm_load_value_store(c, &addr); + LLVMValueRef vector_value = llvm_load_value_store(c, &addr); + if (is_swizzle) + { + if (addr.type->array.base == type_bool) vector_value = llvm_emit_trunc_bool(c, vector_value); + if (binary_op > BINARYOP_ASSIGN) + { + BEValue lhs; + llvm_emit_swizzle_from_value(c, llvm_load_value_store(c, &addr), &lhs, left); + BinaryOp base_op = binaryop_assign_base_op(binary_op); + ASSERT(base_op != BINARYOP_ERROR); + llvm_value_rvalue(c, &lhs); + llvm_emit_binary(c, be_value, expr, &lhs, base_op); + } + else + { + llvm_emit_expr(c, be_value, exprptr(expr->binary_expr.right)); + } + llvm_value_rvalue(c, be_value); + const char *sw_ptr = left->swizzle_expr.swizzle; + unsigned vec_len = be_value->type->array.len; + LLVMValueRef result = be_value->value; + for (unsigned i = 0; i < vec_len; i++) + { + int index = (swizzle[(int)sw_ptr[i]] - 1) & 0xF; + LLVMValueRef val = llvm_emit_extract_value(c, result, i); + vector_value = llvm_emit_insert_value(c, vector_value, val, index); + } + llvm_value_set(be_value, vector_value, addr.type); + llvm_store(c, &addr, be_value); + llvm_value_set(be_value, result, expr->type); + return; + } // Emit the index + BEValue index; llvm_emit_exprid(c, &index, left->subscript_expr.index.expr); LLVMValueRef index_val = llvm_load_value_store(c, &index); - if (binary_op > BINARYOP_ASSIGN) { BinaryOp base_op = binaryop_assign_base_op(binary_op); @@ -4367,15 +4409,17 @@ static void llvm_emit_vector_assign_expr(GenContext *c, BEValue *be_value, Expr llvm_emit_expr(c, be_value, exprptr(expr->binary_expr.right)); } - LLVMValueRef new_value = LLVMBuildInsertElement(c->builder, vector_value, llvm_load_value_store(c, be_value), index_val, "elemset"); + LLVMValueRef new_value = LLVMBuildInsertElement(c->builder, vector_value, llvm_load_value_store(c, be_value), + index_val, "elemset"); llvm_store_raw(c, &addr, new_value); + } static void llvm_emit_binary_expr(GenContext *c, BEValue *be_value, Expr *expr) { BinaryOp binary_op = expr->binary_expr.operator; // Vector assign is handled separately. - if (binary_op >= BINARYOP_ASSIGN && expr_is_vector_index(exprptr(expr->binary_expr.left))) + if (binary_op >= BINARYOP_ASSIGN && expr_is_vector_index_or_swizzle(exprptr(expr->binary_expr.left))) { llvm_emit_vector_assign_expr(c, be_value, expr); return; @@ -6675,12 +6719,8 @@ static void llmv_emit_test_hook(GenContext *c, BEValue *value, Expr *expr) llvm_value_set_address_abi_aligned(value, get_global, expr->type); } - -static void llvm_emit_swizzle(GenContext *c, BEValue *value, Expr *expr) +static void llvm_emit_swizzle_from_value(GenContext *c, LLVMValueRef vector_value, BEValue *value, Expr *expr) { - llvm_emit_exprid(c, value, expr->swizzle_expr.parent); - llvm_value_rvalue(c, value); - LLVMValueRef parent = value->value; LLVMTypeRef result_type = llvm_get_type(c, expr->type); unsigned vec_len = LLVMGetVectorSize(result_type); LLVMValueRef mask_val[4]; @@ -6691,10 +6731,17 @@ static void llvm_emit_swizzle(GenContext *c, BEValue *value, Expr *expr) int index = (swizzle[(int)sw_ptr[i]] - 1) & 0xF; mask_val[i] = llvm_const_int(c, type_uint, index); } - LLVMValueRef res = LLVMBuildShuffleVector(c->builder, parent, LLVMGetUndef(LLVMTypeOf(parent)), LLVMConstVector(mask_val, vec_len), sw_ptr); + LLVMValueRef res = LLVMBuildShuffleVector(c->builder, vector_value, LLVMGetUndef(LLVMTypeOf(vector_value)), LLVMConstVector(mask_val, vec_len), sw_ptr); llvm_value_set(value, res, expr->type); } +static void llvm_emit_swizzle(GenContext *c, BEValue *value, Expr *expr) +{ + llvm_emit_exprid(c, value, expr->swizzle_expr.parent); + llvm_value_rvalue(c, value); + llvm_emit_swizzle_from_value(c, value->value, value, expr); +} + static void llvm_emit_default_arg(GenContext *c, BEValue *value, Expr *expr) { if (llvm_use_debug(c)) diff --git a/src/compiler/sema_expr.c b/src/compiler/sema_expr.c index 954edbd13..773995406 100644 --- a/src/compiler/sema_expr.c +++ b/src/compiler/sema_expr.c @@ -481,7 +481,11 @@ static bool sema_binary_is_expr_lvalue(SemaContext *context, Expr *top_expr, Exp UNREACHABLE case EXPR_SWIZZLE: if (failed_ref) goto FAILED_REF; - RETURN_SEMA_ERROR(expr, "You cannot use swizzling to assign to multiple elements, use element-wise assign instead."); + if (expr->swizzle_expr.is_overlapping) + { + RETURN_SEMA_ERROR(expr, "You cannot use swizzling to assign to multiple elements if they are overlapping"); + } + return sema_binary_is_expr_lvalue(context, top_expr, exprptr(expr->swizzle_expr.parent), failed_ref); case EXPR_LAMBDA: if (failed_ref) goto FAILED_REF; RETURN_SEMA_ERROR(expr, "This expression is a value and cannot be assigned to."); @@ -4892,6 +4896,7 @@ static inline bool sema_expr_analyse_swizzle(SemaContext *context, Expr *expr, E if (is_lvalue) check = CHECK_VALUE; ASSERT_SPAN(expr, len > 0); int index; + bool is_overlapping = false; for (unsigned i = 0; i < len; i++) { char val = swizzle[(int)kw[i]] - 1; @@ -4907,6 +4912,18 @@ static inline bool sema_expr_analyse_swizzle(SemaContext *context, Expr *expr, E { RETURN_SEMA_ERROR(expr, "Mixing [xyzw] and [rgba] is not permitted, you will need to select one of them."); } + if (!is_overlapping) + { + for (int j = 0; j < i; j++) + { + char prev = swizzle[(int)kw[j]] - 1; + if (val == prev) + { + is_overlapping = true; + break; + } + } + } } index &= 0xF; if (len == 1) @@ -4941,7 +4958,8 @@ static inline bool sema_expr_analyse_swizzle(SemaContext *context, Expr *expr, E } Type *result = type_get_vector(indexed_type, len); expr->expr_kind = EXPR_SWIZZLE; - expr->swizzle_expr = (ExprSwizzle) { exprid(parent), kw }; + expr->swizzle_expr = (ExprSwizzle) { .parent = exprid(parent), .swizzle = kw, .is_overlapping = is_overlapping }; + expr->type = result; return true; } @@ -6045,10 +6063,11 @@ static bool sema_expr_analyse_op_assign(SemaContext *context, Expr *expr, Expr * Type *no_fail = type_no_optional(left->type); Type *flat = type_flatten(no_fail); - // 3. If this is only defined for ints (*%, ^= |= &= %=) verify that this is an int. + // 3. If this is only defined for ints (^= |= &= %=) verify that this is an int. if (int_only && !type_flat_is_intlike(flat)) { - if (is_bit_op && flat->type_kind == TYPE_BITSTRUCT) goto BITSTRUCT_OK; + if (is_bit_op && (flat->type_kind == TYPE_BITSTRUCT || flat == type_bool || type_flat_is_bool_vector(flat))) goto BITSTRUCT_OK; + RETURN_SEMA_ERROR(left, "Expected an integer here, not a value of type %s.", type_quoted_error_string(left->type)); } diff --git a/test/test_suite/vector/vector_conversion_scalar.c3 b/test/test_suite/vector/vector_conversion_scalar.c3 index 4dc84189c..b60038bfb 100644 --- a/test/test_suite/vector/vector_conversion_scalar.c3 +++ b/test/test_suite/vector/vector_conversion_scalar.c3 @@ -4,7 +4,7 @@ fn void main() { int[<2>] y = 1; y[..] = 3; - y.xy = 3; // #error: cannot use swizzling + y.xxy = 3; // #error: cannot use swizzling y *= 2; y = 3; // #error: explicit cast test2(3); // #error: explicit cast diff --git a/test/unit/regression/vector_ops.c3 b/test/unit/regression/vector_ops.c3 index 59aa1133f..ab6333588 100644 --- a/test/unit/regression/vector_ops.c3 +++ b/test/unit/regression/vector_ops.c3 @@ -10,6 +10,34 @@ fn void test_int_mod() @test assert((int[<2>]){ 10, 99 } / { 3, 5 } == { 3, 19 }); } +fn void test_swizzle_assign() @test +{ + int[<3>] abc; + abc.xy = { 3, 4 }; + assert(abc == { 3, 4, 0 }); + abc.xz += { 2, 7 }; + assert(abc == { 5, 4, 7 }); + abc.xy += { 4, -100 }; + assert(abc == { 9, -96, 7 }); + abc.xz = { 0, 0 }; + assert(abc == { 0, -96, 0}); +} + +fn void test_swizzle_assign_bool() @test +{ + bool[<3>] abc; + abc.xy = { true, true }; + assert(abc == { true, true, false }); + abc.yz ^= { true, true }; + assert(abc == { true, false, true }); + abc ^= true; + assert(abc == { false, true, false }); + abc.xy ^= true; + assert(abc == { true, false, false }); + assert((abc.yx ^= true) == { true, false }); + assert(abc == { false, true, false }); +} + fn void test_conv() @test { float[<4>] y = { 1, 2, 3, 4 };