From 3f45ed14b9b1662c85198a797a9ccaba3da3b3ea Mon Sep 17 00:00:00 2001 From: Christoffer Lerno Date: Fri, 12 Jul 2024 23:54:07 +0200 Subject: [PATCH] Compare @compact structs. --- src/compiler/compiler_internal.h | 10 ++- src/compiler/llvm_codegen_expr.c | 145 +++++++++++++++++++++++++++++-- src/compiler/sema_decls.c | 26 +++--- src/compiler/types.c | 5 +- 4 files changed, 159 insertions(+), 27 deletions(-) diff --git a/src/compiler/compiler_internal.h b/src/compiler/compiler_internal.h index e3b587821..d20efc912 100644 --- a/src/compiler/compiler_internal.h +++ b/src/compiler/compiler_internal.h @@ -18,6 +18,7 @@ typedef uint64_t ByteSize; typedef uint32_t TypeSize; typedef int32_t IndexDiff; typedef int32_t ArrayIndex; +typedef uint16_t StructIndex; typedef uint32_t AlignSize; typedef int32_t ScopeId; typedef uint32_t ArraySize; @@ -34,9 +35,9 @@ typedef uint16_t SectionId; #define INITIAL_SYMBOL_MAP 0x10000 #define INITIAL_GENERIC_SYMBOL_MAP 0x1000 #define MAX_MACRO_ITERATIONS 0xFFFFFF -#define MAX_PARAMS 127 +#define MAX_PARAMS 255 #define MAX_BITSTRUCT 0x1000 -#define MAX_MEMBERS ((ArrayIndex)(((uint64_t)2) << 28)) +#define MAX_MEMBERS ((StructIndex)1) << 15 #define MAX_ALIGNMENT ((ArrayIndex)(((uint64_t)2) << 28)) #define MAX_PRIORITY 0xFFFF #define MAX_TYPE_SIZE UINT32_MAX @@ -44,6 +45,7 @@ typedef uint16_t SectionId; #define MAX_ASM_INSTRUCTION_PARAMS 6 #define CLOBBER_FLAG_ELEMENTS 4 #define MAX_CLOBBER_FLAGS (64 * CLOBBER_FLAG_ELEMENTS) +#define MEMCMP_INLINE_REGS 8 extern const char *project_default_keys[][2]; extern const int project_default_keys_count; @@ -427,9 +429,9 @@ typedef struct typedef struct { TypeSize size; - ArrayIndex union_rep; Decl **members; - Decl *padded_decl; + DeclId padded_decl_id; + StructIndex union_rep; AlignSize padding : 16; } StructDecl; diff --git a/src/compiler/llvm_codegen_expr.c b/src/compiler/llvm_codegen_expr.c index 3a9103eaf..6b3f9c660 100644 --- a/src/compiler/llvm_codegen_expr.c +++ b/src/compiler/llvm_codegen_expr.c @@ -33,6 +33,7 @@ static inline void llvm_emit_try_unwrap(GenContext *c, BEValue *value, Expr *exp static inline void llvm_emit_vector_initializer_list(GenContext *c, BEValue *value, Expr *expr); static inline void llvm_extract_bitvalue_from_array(GenContext *c, BEValue *be_value, Decl *member, Decl *parent_decl); static inline void llvm_emit_type_from_any(GenContext *c, BEValue *be_value); +static inline void llvm_emit_memcmp(GenContext *c, BEValue *be_value, LLVMValueRef ptr, LLVMValueRef other_ptr, BinaryOp binary_op, AlignSize lhs_align, AlignSize rhs_align, ByteSize size); static void llvm_convert_vector_comparison(GenContext *c, BEValue *be_value, LLVMValueRef val, Type *vector_type, bool is_equals); static void llvm_emit_any_pointer(GenContext *c, BEValue *any, BEValue *pointer); @@ -3553,6 +3554,16 @@ static void llvm_emit_any_comparison(GenContext *c, BEValue *result, BEValue *lh llvm_value_set(result, res, type_bool); } +static void llvm_emit_struct_comparison(GenContext *c, BEValue *result, BEValue *lhs, BEValue *rhs, BinaryOp binary_op) +{ + llvm_value_fold_optional(c, lhs); + llvm_value_fold_optional(c, rhs); + llvm_value_addr(c, lhs); + llvm_value_addr(c, rhs); + llvm_emit_memcmp(c, result, lhs->value, rhs->value, binary_op, lhs->alignment, rhs->alignment, + type_size(lhs->type)); +} + static inline LLVMValueRef llvm_emit_mult_int(GenContext *c, Type *type, LLVMValueRef left, LLVMValueRef right, SourceSpan loc) { if (active_target.feature.trap_on_wrap) @@ -3671,6 +3682,112 @@ INLINE bool should_inline_array_comp(ArraySize len, Type *base_type_lowered) return len <= 16; } } + +static void llvm_emit_memcmp_inline(GenContext *c, BEValue *be_value, LLVMValueRef lhs, + LLVMValueRef rhs, ByteSize element_size, + AlignSize lhs_align, AlignSize rhs_align, int len, bool want_match) +{ + assert(element_size <= platform_target.width_register / 8); + lhs_align = type_min_alignment(element_size, lhs_align); + rhs_align = type_min_alignment(element_size, rhs_align); + LLVMTypeRef element_type = LLVMIntTypeInContext(c->context, element_size * 8); + LLVMBasicBlockRef exit = llvm_basic_block_new(c, "array_cmp_exit"); + LLVMBasicBlockRef loop_begin = llvm_basic_block_new(c, "array_loop_start"); + LLVMBasicBlockRef comparison = llvm_basic_block_new(c, "array_loop_comparison"); + LLVMBasicBlockRef comparison_phi; + LLVMBasicBlockRef loop_begin_phi; + LLVMValueRef len_val = llvm_const_int(c, type_usz, len); + LLVMValueRef one = llvm_const_int(c, type_usz, 1); + BEValue index_var; + llvm_value_set_address_abi_aligned(&index_var, llvm_emit_alloca_aligned(c, type_usz, "cmp.idx"), type_usz); + llvm_store_raw(c, &index_var, llvm_get_zero(c, type_usz)); + + llvm_emit_br(c, loop_begin); + llvm_emit_block(c, loop_begin); + + AlignSize align_lhs; + BEValue lhs_v; + BEValue index_copy = index_var; + llvm_value_rvalue(c, &index_copy); + + LLVMValueRef index_val = index_copy.value; + LLVMValueRef lhs_ptr = llvm_emit_pointer_inbounds_gep_raw(c, element_type, lhs, index_val); + LLVMValueRef rhs_ptr = llvm_emit_pointer_inbounds_gep_raw(c, element_type, rhs, index_val); + LLVMValueRef lhs_value = llvm_load(c, element_type, lhs_ptr, lhs_align, "lhs"); + LLVMValueRef rhs_value = llvm_load(c, element_type, rhs_ptr, rhs_align, "rhs"); + LLVMValueRef comp_val = LLVMBuildICmp(c->builder, LLVMIntEQ, lhs_value, rhs_value, "cmp"); + loop_begin_phi = c->current_block; + llvm_emit_cond_br_raw(c, comp_val, comparison, exit); + llvm_emit_block(c, comparison); + + LLVMValueRef new_index = LLVMBuildAdd(c->builder, index_copy.value, one, "inc"); + llvm_store_raw(c, &index_var, new_index); + BEValue comp; + llvm_emit_int_comp_raw(c, &comp, type_usz, type_usz, new_index, len_val, BINARYOP_LT); + comparison_phi = c->current_block; + llvm_emit_cond_br(c, &comp, loop_begin, exit); + llvm_emit_block(c, exit); + LLVMValueRef success = LLVMConstInt(c->bool_type, want_match ? 1 : 0, false); + LLVMValueRef failure = LLVMConstInt(c->bool_type, want_match ? 0 : 1, false); + llvm_new_phi(c, be_value, "array_cmp_phi", type_bool, success, comparison_phi, failure, loop_begin_phi); + +} +static void llvm_emit_memcmp_unrolled(GenContext *c, BEValue *be_value, LLVMValueRef lhs_ptr, + LLVMValueRef rhs_ptr, ByteSize element_size, + AlignSize lhs_align, AlignSize rhs_align, int len, bool want_match) +{ + assert(len < 17); + assert(element_size <= platform_target.width_register / 8); + LLVMTypeRef element_type = LLVMIntTypeInContext(c->context, element_size * 8); + LLVMBasicBlockRef blocks[17]; + LLVMValueRef value_block[17]; + LLVMBasicBlockRef ok_block = llvm_basic_block_new(c, "match"); + LLVMBasicBlockRef exit_block = llvm_basic_block_new(c, "exit"); + LLVMValueRef success = LLVMConstInt(c->bool_type, want_match ? 1 : 0, false); + LLVMValueRef failure = LLVMConstInt(c->bool_type, want_match ? 0 : 1, false); + LLVMValueRef one = llvm_const_int(c, type_usz, 1); + for (unsigned i = 0; i < len; i++) + { + value_block[i] = failure; + if (i > 0) + { + lhs_ptr = llvm_emit_pointer_inbounds_gep_raw(c, element_type, lhs_ptr, one); + rhs_ptr = llvm_emit_pointer_inbounds_gep_raw(c, element_type, rhs_ptr, one); + } + AlignSize lhs_align_current = type_min_alignment(lhs_align + i * element_size, lhs_align); + AlignSize rhs_align_current = type_min_alignment(rhs_align + i * element_size, rhs_align); + LLVMValueRef lhs_value = llvm_load(c, element_type, lhs_ptr, lhs_align_current, "lhs"); + LLVMValueRef rhs_value = llvm_load(c, element_type, rhs_ptr, rhs_align_current, "rhs"); + LLVMValueRef comp = LLVMBuildICmp(c->builder, LLVMIntEQ, lhs_value, rhs_value, "cmp"); + blocks[i] = c->current_block; + LLVMBasicBlockRef block = ok_block; + block = i < len - 1 ? llvm_basic_block_new(c, "next_check") : block; + llvm_emit_cond_br_raw(c, comp, block, exit_block); + llvm_emit_block(c, block); + } + llvm_emit_br(c, exit_block); + llvm_emit_block(c, exit_block); + value_block[len] = success; + blocks[len] = ok_block; + LLVMValueRef phi = LLVMBuildPhi(c->builder, c->bool_type, "memcmp_phi"); + LLVMAddIncoming(phi, value_block, blocks, len + 1); + llvm_value_set(be_value, phi, type_bool); +} +static inline void llvm_emit_memcmp(GenContext *c, BEValue *be_value, LLVMValueRef ptr, LLVMValueRef other_ptr, BinaryOp binary_op, AlignSize lhs_align, AlignSize rhs_align, ByteSize size) +{ + + ByteSize element_size = lhs_align > platform_target.width_register / 8 ? platform_target.width_register / 8 : lhs_align; + if (element_size > rhs_align) element_size = rhs_align; + if (element_size > size) element_size = size; + ByteSize repeats = size / element_size; + assert(size % element_size == 0 && "Expected size padded to alignment"); + if (repeats <= MEMCMP_INLINE_REGS) + { + llvm_emit_memcmp_unrolled(c, be_value, ptr, other_ptr, element_size, lhs_align, rhs_align, repeats, binary_op == BINARYOP_EQ); + return; + } + llvm_emit_memcmp_inline(c, be_value, ptr, other_ptr, element_size, lhs_align, rhs_align, repeats, binary_op == BINARYOP_EQ); +} static void llvm_emit_array_comp(GenContext *c, BEValue *be_value, BEValue *lhs, BEValue *rhs, BinaryOp binary_op) { bool want_match = binary_op == BINARYOP_EQ; @@ -3857,10 +3974,12 @@ void llvm_emit_comp(GenContext *c, BEValue *result, BEValue *lhs, BEValue *rhs, llvm_emit_any_comparison(c, result, lhs, rhs, binary_op); return; case LOWERED_TYPES: - case TYPE_STRUCT: - case TYPE_UNION: case TYPE_FLEXIBLE_ARRAY: UNREACHABLE + case TYPE_STRUCT: + case TYPE_UNION: + llvm_emit_struct_comparison(c, result, lhs, rhs, binary_op); + return; case TYPE_SLICE: llvm_emit_slice_comp(c, result, lhs, rhs, binary_op); return; @@ -4065,6 +4184,19 @@ void llvm_emit_bitstruct_binary_op(GenContext *c, BEValue *be_value, BEValue *lh llvm_value_set_address(be_value, store, lhs->type, lhs->alignment); } +INLINE void llvm_fold_for_compare(GenContext *c, BEValue *be_value) +{ + switch (be_value->type->type_kind) + { + case TYPE_ARRAY: + case TYPE_STRUCT: + case TYPE_UNION: + break; + default: + llvm_value_rvalue(c, be_value); + break; + } +} void llvm_emit_binary(GenContext *c, BEValue *be_value, Expr *expr, BEValue *lhs_loaded, BinaryOp binary_op) { // foo ?? bar @@ -4096,13 +4228,12 @@ void llvm_emit_binary(GenContext *c, BEValue *be_value, Expr *expr, BEValue *lhs llvm_emit_expr(c, &lhs, exprptr(expr->binary_expr.left)); } // We need the rvalue. - if (lhs.type->type_kind != TYPE_ARRAY) llvm_value_rvalue(c, &lhs); + llvm_fold_for_compare(c, &lhs); // Evaluate rhs BEValue rhs; llvm_emit_expr(c, &rhs, exprptr(expr->binary_expr.right)); - if (rhs.type->type_kind != TYPE_ARRAY) llvm_value_rvalue(c, &rhs); - + llvm_fold_for_compare(c, &rhs); EMIT_LOC(c, expr); // Comparison <=> if (binary_op >= BINARYOP_GT && binary_op <= BINARYOP_EQ) @@ -6707,8 +6838,7 @@ static inline void llvm_emit_builtin_access(GenContext *c, BEValue *be_value, Ex llvm_value_set(be_value, llvm_get_zero(c, type_usz), type_usz); return; } - Type *inner_type = type_flatten(inner->type); - assert(inner_type->type_kind == TYPE_FAULTTYPE); + assert(type_flatten(inner->type)->type_kind == TYPE_FAULTTYPE); llvm_value_rvalue(c, be_value); BEValue zero; LLVMBasicBlockRef exit_block = llvm_basic_block_new(c, "faultordinal_exit"); @@ -6729,6 +6859,7 @@ static inline void llvm_emit_builtin_access(GenContext *c, BEValue *be_value, Ex case ACCESS_FAULTNAME: { Type *inner_type = type_no_optional(inner->type)->canonical; + (void)inner_type; assert(inner_type->type_kind == TYPE_FAULTTYPE || inner_type->type_kind == TYPE_ANYFAULT); llvm_value_rvalue(c, be_value); LLVMValueRef val = llvm_emit_alloca_aligned(c, type_chars, "faultname_zero"); diff --git a/src/compiler/sema_decls.c b/src/compiler/sema_decls.c index 966ac2b00..876e6e619 100644 --- a/src/compiler/sema_decls.c +++ b/src/compiler/sema_decls.c @@ -275,18 +275,16 @@ static inline bool sema_analyse_struct_member(SemaContext *context, Decl *parent static inline bool sema_check_struct_holes(SemaContext *context, Decl *decl, Decl *member, Type *member_type) { - if (member_type->type_kind == TYPE_STRUCT || member_type->type_kind == TYPE_UNION) + member_type = type_flatten(member_type); + if (member_type->type_kind != TYPE_STRUCT && member_type->type_kind != TYPE_UNION) return true; + if (!member_type->decl->strukt.padded_decl_id) return true; + if (!decl->strukt.padded_decl_id) decl->strukt.padded_decl_id = member_type->decl->strukt.padded_decl_id; + if (decl->attr_compact) { - if (member_type->decl->strukt.padded_decl) - { - if (!decl->strukt.padded_decl) decl->strukt.padded_decl = member_type->decl->strukt.padded_decl; - if (decl->attr_compact) - { - SEMA_ERROR(member, "This member has holes."); - SEMA_NOTE(member_type->decl->strukt.padded_decl, "Padding would be added for this type."); - return false; - } - } + SEMA_ERROR(member, "%s has padding and can't be used as the type of '%s', because members of a `@compact` type must all have zero padding.", type_quoted_error_string(member_type), member->name); + SEMA_NOTE(declptr(member_type->decl->strukt.padded_decl_id), "The first padded field in %s is here.", + type_quoted_error_string(member_type)); + return false; } return true; } @@ -592,7 +590,7 @@ static bool sema_analyse_struct_members(SemaContext *context, Decl *decl) if (align_offset - offset != 0) { - if (!decl->strukt.padded_decl) decl->strukt.padded_decl = member; + if (!decl->strukt.padded_decl_id) decl->strukt.padded_decl_id = declid(member); if (decl->attr_nopadding || member->attr_nopadding) { RETURN_SEMA_ERROR(member, "%d bytes of padding would be added to align this member.", align_offset - offset); @@ -649,10 +647,10 @@ static bool sema_analyse_struct_members(SemaContext *context, Decl *decl) if (size != offset) { - if (!decl->strukt.padded_decl) decl->strukt.padded_decl = decl; + if (!decl->strukt.padded_decl_id) decl->strukt.padded_decl_id = declid(decl); if (decl->attr_nopadding) { - RETURN_SEMA_ERROR(decl, "%d bytes of padding would be added to the end this struct.", size - offset); + RETURN_SEMA_ERROR(decl, "%d bytes of padding would be added to the end this struct which is not allowed with `@nopadding` and `@compact`.", size - offset); } } diff --git a/src/compiler/types.c b/src/compiler/types.c index 06107836c..e81ffb78f 100644 --- a/src/compiler/types.c +++ b/src/compiler/types.c @@ -540,12 +540,13 @@ bool type_is_comparable(Type *type) case TYPE_POISONED: UNREACHABLE case TYPE_VOID: - case TYPE_UNION: - case TYPE_STRUCT: case TYPE_FLEXIBLE_ARRAY: case TYPE_OPTIONAL: case TYPE_MEMBER: return false; + case TYPE_UNION: + case TYPE_STRUCT: + return type->decl->attr_compact; case TYPE_BITSTRUCT: type = type->decl->bitstruct.base_type->type; goto RETRY;