Compare @compact structs.

This commit is contained in:
Christoffer Lerno
2024-07-12 23:54:07 +02:00
parent ca4b782912
commit 3f45ed14b9
4 changed files with 159 additions and 27 deletions

View File

@@ -33,6 +33,7 @@ static inline void llvm_emit_try_unwrap(GenContext *c, BEValue *value, Expr *exp
static inline void llvm_emit_vector_initializer_list(GenContext *c, BEValue *value, Expr *expr);
static inline void llvm_extract_bitvalue_from_array(GenContext *c, BEValue *be_value, Decl *member, Decl *parent_decl);
static inline void llvm_emit_type_from_any(GenContext *c, BEValue *be_value);
static inline void llvm_emit_memcmp(GenContext *c, BEValue *be_value, LLVMValueRef ptr, LLVMValueRef other_ptr, BinaryOp binary_op, AlignSize lhs_align, AlignSize rhs_align, ByteSize size);
static void llvm_convert_vector_comparison(GenContext *c, BEValue *be_value, LLVMValueRef val, Type *vector_type,
bool is_equals);
static void llvm_emit_any_pointer(GenContext *c, BEValue *any, BEValue *pointer);
@@ -3553,6 +3554,16 @@ static void llvm_emit_any_comparison(GenContext *c, BEValue *result, BEValue *lh
llvm_value_set(result, res, type_bool);
}
static void llvm_emit_struct_comparison(GenContext *c, BEValue *result, BEValue *lhs, BEValue *rhs, BinaryOp binary_op)
{
llvm_value_fold_optional(c, lhs);
llvm_value_fold_optional(c, rhs);
llvm_value_addr(c, lhs);
llvm_value_addr(c, rhs);
llvm_emit_memcmp(c, result, lhs->value, rhs->value, binary_op, lhs->alignment, rhs->alignment,
type_size(lhs->type));
}
static inline LLVMValueRef llvm_emit_mult_int(GenContext *c, Type *type, LLVMValueRef left, LLVMValueRef right, SourceSpan loc)
{
if (active_target.feature.trap_on_wrap)
@@ -3671,6 +3682,112 @@ INLINE bool should_inline_array_comp(ArraySize len, Type *base_type_lowered)
return len <= 16;
}
}
static void llvm_emit_memcmp_inline(GenContext *c, BEValue *be_value, LLVMValueRef lhs,
LLVMValueRef rhs, ByteSize element_size,
AlignSize lhs_align, AlignSize rhs_align, int len, bool want_match)
{
assert(element_size <= platform_target.width_register / 8);
lhs_align = type_min_alignment(element_size, lhs_align);
rhs_align = type_min_alignment(element_size, rhs_align);
LLVMTypeRef element_type = LLVMIntTypeInContext(c->context, element_size * 8);
LLVMBasicBlockRef exit = llvm_basic_block_new(c, "array_cmp_exit");
LLVMBasicBlockRef loop_begin = llvm_basic_block_new(c, "array_loop_start");
LLVMBasicBlockRef comparison = llvm_basic_block_new(c, "array_loop_comparison");
LLVMBasicBlockRef comparison_phi;
LLVMBasicBlockRef loop_begin_phi;
LLVMValueRef len_val = llvm_const_int(c, type_usz, len);
LLVMValueRef one = llvm_const_int(c, type_usz, 1);
BEValue index_var;
llvm_value_set_address_abi_aligned(&index_var, llvm_emit_alloca_aligned(c, type_usz, "cmp.idx"), type_usz);
llvm_store_raw(c, &index_var, llvm_get_zero(c, type_usz));
llvm_emit_br(c, loop_begin);
llvm_emit_block(c, loop_begin);
AlignSize align_lhs;
BEValue lhs_v;
BEValue index_copy = index_var;
llvm_value_rvalue(c, &index_copy);
LLVMValueRef index_val = index_copy.value;
LLVMValueRef lhs_ptr = llvm_emit_pointer_inbounds_gep_raw(c, element_type, lhs, index_val);
LLVMValueRef rhs_ptr = llvm_emit_pointer_inbounds_gep_raw(c, element_type, rhs, index_val);
LLVMValueRef lhs_value = llvm_load(c, element_type, lhs_ptr, lhs_align, "lhs");
LLVMValueRef rhs_value = llvm_load(c, element_type, rhs_ptr, rhs_align, "rhs");
LLVMValueRef comp_val = LLVMBuildICmp(c->builder, LLVMIntEQ, lhs_value, rhs_value, "cmp");
loop_begin_phi = c->current_block;
llvm_emit_cond_br_raw(c, comp_val, comparison, exit);
llvm_emit_block(c, comparison);
LLVMValueRef new_index = LLVMBuildAdd(c->builder, index_copy.value, one, "inc");
llvm_store_raw(c, &index_var, new_index);
BEValue comp;
llvm_emit_int_comp_raw(c, &comp, type_usz, type_usz, new_index, len_val, BINARYOP_LT);
comparison_phi = c->current_block;
llvm_emit_cond_br(c, &comp, loop_begin, exit);
llvm_emit_block(c, exit);
LLVMValueRef success = LLVMConstInt(c->bool_type, want_match ? 1 : 0, false);
LLVMValueRef failure = LLVMConstInt(c->bool_type, want_match ? 0 : 1, false);
llvm_new_phi(c, be_value, "array_cmp_phi", type_bool, success, comparison_phi, failure, loop_begin_phi);
}
static void llvm_emit_memcmp_unrolled(GenContext *c, BEValue *be_value, LLVMValueRef lhs_ptr,
LLVMValueRef rhs_ptr, ByteSize element_size,
AlignSize lhs_align, AlignSize rhs_align, int len, bool want_match)
{
assert(len < 17);
assert(element_size <= platform_target.width_register / 8);
LLVMTypeRef element_type = LLVMIntTypeInContext(c->context, element_size * 8);
LLVMBasicBlockRef blocks[17];
LLVMValueRef value_block[17];
LLVMBasicBlockRef ok_block = llvm_basic_block_new(c, "match");
LLVMBasicBlockRef exit_block = llvm_basic_block_new(c, "exit");
LLVMValueRef success = LLVMConstInt(c->bool_type, want_match ? 1 : 0, false);
LLVMValueRef failure = LLVMConstInt(c->bool_type, want_match ? 0 : 1, false);
LLVMValueRef one = llvm_const_int(c, type_usz, 1);
for (unsigned i = 0; i < len; i++)
{
value_block[i] = failure;
if (i > 0)
{
lhs_ptr = llvm_emit_pointer_inbounds_gep_raw(c, element_type, lhs_ptr, one);
rhs_ptr = llvm_emit_pointer_inbounds_gep_raw(c, element_type, rhs_ptr, one);
}
AlignSize lhs_align_current = type_min_alignment(lhs_align + i * element_size, lhs_align);
AlignSize rhs_align_current = type_min_alignment(rhs_align + i * element_size, rhs_align);
LLVMValueRef lhs_value = llvm_load(c, element_type, lhs_ptr, lhs_align_current, "lhs");
LLVMValueRef rhs_value = llvm_load(c, element_type, rhs_ptr, rhs_align_current, "rhs");
LLVMValueRef comp = LLVMBuildICmp(c->builder, LLVMIntEQ, lhs_value, rhs_value, "cmp");
blocks[i] = c->current_block;
LLVMBasicBlockRef block = ok_block;
block = i < len - 1 ? llvm_basic_block_new(c, "next_check") : block;
llvm_emit_cond_br_raw(c, comp, block, exit_block);
llvm_emit_block(c, block);
}
llvm_emit_br(c, exit_block);
llvm_emit_block(c, exit_block);
value_block[len] = success;
blocks[len] = ok_block;
LLVMValueRef phi = LLVMBuildPhi(c->builder, c->bool_type, "memcmp_phi");
LLVMAddIncoming(phi, value_block, blocks, len + 1);
llvm_value_set(be_value, phi, type_bool);
}
static inline void llvm_emit_memcmp(GenContext *c, BEValue *be_value, LLVMValueRef ptr, LLVMValueRef other_ptr, BinaryOp binary_op, AlignSize lhs_align, AlignSize rhs_align, ByteSize size)
{
ByteSize element_size = lhs_align > platform_target.width_register / 8 ? platform_target.width_register / 8 : lhs_align;
if (element_size > rhs_align) element_size = rhs_align;
if (element_size > size) element_size = size;
ByteSize repeats = size / element_size;
assert(size % element_size == 0 && "Expected size padded to alignment");
if (repeats <= MEMCMP_INLINE_REGS)
{
llvm_emit_memcmp_unrolled(c, be_value, ptr, other_ptr, element_size, lhs_align, rhs_align, repeats, binary_op == BINARYOP_EQ);
return;
}
llvm_emit_memcmp_inline(c, be_value, ptr, other_ptr, element_size, lhs_align, rhs_align, repeats, binary_op == BINARYOP_EQ);
}
static void llvm_emit_array_comp(GenContext *c, BEValue *be_value, BEValue *lhs, BEValue *rhs, BinaryOp binary_op)
{
bool want_match = binary_op == BINARYOP_EQ;
@@ -3857,10 +3974,12 @@ void llvm_emit_comp(GenContext *c, BEValue *result, BEValue *lhs, BEValue *rhs,
llvm_emit_any_comparison(c, result, lhs, rhs, binary_op);
return;
case LOWERED_TYPES:
case TYPE_STRUCT:
case TYPE_UNION:
case TYPE_FLEXIBLE_ARRAY:
UNREACHABLE
case TYPE_STRUCT:
case TYPE_UNION:
llvm_emit_struct_comparison(c, result, lhs, rhs, binary_op);
return;
case TYPE_SLICE:
llvm_emit_slice_comp(c, result, lhs, rhs, binary_op);
return;
@@ -4065,6 +4184,19 @@ void llvm_emit_bitstruct_binary_op(GenContext *c, BEValue *be_value, BEValue *lh
llvm_value_set_address(be_value, store, lhs->type, lhs->alignment);
}
INLINE void llvm_fold_for_compare(GenContext *c, BEValue *be_value)
{
switch (be_value->type->type_kind)
{
case TYPE_ARRAY:
case TYPE_STRUCT:
case TYPE_UNION:
break;
default:
llvm_value_rvalue(c, be_value);
break;
}
}
void llvm_emit_binary(GenContext *c, BEValue *be_value, Expr *expr, BEValue *lhs_loaded, BinaryOp binary_op)
{
// foo ?? bar
@@ -4096,13 +4228,12 @@ void llvm_emit_binary(GenContext *c, BEValue *be_value, Expr *expr, BEValue *lhs
llvm_emit_expr(c, &lhs, exprptr(expr->binary_expr.left));
}
// We need the rvalue.
if (lhs.type->type_kind != TYPE_ARRAY) llvm_value_rvalue(c, &lhs);
llvm_fold_for_compare(c, &lhs);
// Evaluate rhs
BEValue rhs;
llvm_emit_expr(c, &rhs, exprptr(expr->binary_expr.right));
if (rhs.type->type_kind != TYPE_ARRAY) llvm_value_rvalue(c, &rhs);
llvm_fold_for_compare(c, &rhs);
EMIT_LOC(c, expr);
// Comparison <=>
if (binary_op >= BINARYOP_GT && binary_op <= BINARYOP_EQ)
@@ -6707,8 +6838,7 @@ static inline void llvm_emit_builtin_access(GenContext *c, BEValue *be_value, Ex
llvm_value_set(be_value, llvm_get_zero(c, type_usz), type_usz);
return;
}
Type *inner_type = type_flatten(inner->type);
assert(inner_type->type_kind == TYPE_FAULTTYPE);
assert(type_flatten(inner->type)->type_kind == TYPE_FAULTTYPE);
llvm_value_rvalue(c, be_value);
BEValue zero;
LLVMBasicBlockRef exit_block = llvm_basic_block_new(c, "faultordinal_exit");
@@ -6729,6 +6859,7 @@ static inline void llvm_emit_builtin_access(GenContext *c, BEValue *be_value, Ex
case ACCESS_FAULTNAME:
{
Type *inner_type = type_no_optional(inner->type)->canonical;
(void)inner_type;
assert(inner_type->type_kind == TYPE_FAULTTYPE || inner_type->type_kind == TYPE_ANYFAULT);
llvm_value_rvalue(c, be_value);
LLVMValueRef val = llvm_emit_alloca_aligned(c, type_chars, "faultname_zero");