- $$MASK_TO_INT and $$INT_TO_MASK to create bool masks from integers and back.

- Fix bug when creating bool vectors in certain cases.
This commit is contained in:
Christoffer Lerno
2025-12-25 20:55:11 +01:00
parent 18b246c577
commit f3b71ed7eb
13 changed files with 193 additions and 30 deletions

View File

@@ -26,17 +26,19 @@ Checks: >
# Turn all the warnings from the checks above into errors.
WarningsAsErrors: "*"
CheckOptions:
- { key: readability-function-cognitive-complexity.Threshold, value: 100 }
- { key: readability-identifier-naming.StructCase, value: CamelCase }
- { key: readability-identifier-naming.FunctionCase, value: lower_case }
- { key: readability-identifier-naming.VariableCase, value: lower_case }
- { key: readability-identifier-naming.MacroDefinitionCase, value: UPPER_CASE }
- { key: readability-identifier-naming.EnumConstantCase, value: UPPER_CASE }
- { key: readability-identifier-naming.ConstexprVariableCase, value: CamelCase }
- { key: readability-identifier-naming.ConstexprVariablePrefix, value: k }
- { key: readability-identifier-naming.GlobalConstantCase, value: CamelCase }
- { key: readability-identifier-naming.GlobalConstantPrefix, value: k }
- { key: readability-identifier-naming.StaticConstantCase, value: CamelCase }
- { key: readability-identifier-naming.StaticConstantPrefix, value: k }
- { key: readability-function-cognitive-complexity.Threshold, value: 100 }
- { key: readability-identifier-naming.StructCase, value: CamelCase }
- { key: readability-identifier-naming.FunctionCase, value: lower_case }
- { key: readability-identifier-naming.VariableCase, value: lower_case }
- { key: readability-identifier-naming.MacroDefinitionCase, value: UPPER_CASE }
- { key: readability-identifier-naming.EnumConstantCase, value: UPPER_CASE }
- { key: readability-identifier-naming.ConstexprVariableCase, value: CamelCase }
- { key: readability-identifier-naming.ConstexprVariablePrefix, value: k }
- { key: readability-identifier-naming.GlobalConstantCase, value: CamelCase }
- { key: readability-identifier-naming.GlobalConstantPrefix, value: k }
- { key: readability-identifier-naming.StaticConstantCase, value: CamelCase }
- { key: readability-identifier-naming.StaticConstantPrefix, value: k }
- { key: readability-identifier-naming.MinimumParameterNameLength, value: 0 }
MinimumParameterNameLength: 0

View File

@@ -462,6 +462,7 @@ macro swizzle(v, ...) @builtin
return $$swizzle(v, $vasplat);
}
<*
Shuffle two vectors by a common index from arranging the vectors sequentially in memory

View File

@@ -3,6 +3,20 @@
module std::math::vector;
import std::math;
<*
@require $Type.kindof == VECTOR &&& $kindof(($Type){}[0]) == BOOL : "Expected a bool vector"
@require $kindof(mask).is_int() : "Expected an integer mask"
*>
macro bool[<*>] mask_from_int($Type, mask)
{
return $$int_to_mask(mask, $Type.len);
}
macro bool[<*>].mask_to_int(self)
{
return $$mask_to_int(self);
}
macro double double[<*>].sq_magnitude(self) => self.dot(self);
macro float float[<*>].sq_magnitude(self) => self.dot(self);

View File

@@ -9,6 +9,7 @@
- Fixed bug where constants would get modified when slicing them. #2660
- Support for NetBSD.
- Testing for the presence of methods at the top level is prohibited previous to method registration.
- `$$MASK_TO_INT` and `$$INT_TO_MASK` to create bool masks from integers and back.
### Fixes
- Regression with npot vector in struct triggering an assert #2219.
@@ -30,6 +31,7 @@
- Incorrect rounding for decimals in formatter in some cases. #2657
- Incorrectly using LLVMStructType when emitting dynamic functions on MachO #2666
- FixedThreadPool join did not work correctly.
- Fix bug when creating bool vectors in certain cases.
### Stdlib changes
- Add `ThreadPool` join function to wait for all threads to finish in the pool without destroying the threads.

View File

@@ -464,11 +464,13 @@ typedef enum
BUILTIN_FSHR,
BUILTIN_GATHER,
BUILTIN_GET_ROUNDING_MODE,
BUILTIN_INT_TO_MASK,
BUILTIN_LOG,
BUILTIN_LOG10,
BUILTIN_LOG2,
BUILTIN_MATRIX_MUL,
BUILTIN_MATRIX_TRANSPOSE,
BUILTIN_MASK_TO_INT,
BUILTIN_MASKED_LOAD,
BUILTIN_MASKED_STORE,
BUILTIN_MAX,

View File

@@ -595,6 +595,35 @@ static void llvm_emit_gather(GenContext *c, BEValue *be_value, Expr *expr)
llvm_value_set(be_value, result, expr->type);
}
static void llvm_emit_mask_to_int(GenContext *c, BEValue *be_value, Expr *expr)
{
Expr **args = expr->call_expr.arguments;
ASSERT(vec_size(args) == 1);
LLVMValueRef val = llvm_emit_expr_to_rvalue(c, args[0]);
LLVMTypeRef mask_type = LLVMTypeOf(val);
unsigned bits = LLVMGetVectorSize(mask_type);
val = LLVMBuildBitCast(c->builder, val, LLVMIntTypeInContext(c->context, bits), "");
unsigned target_bits = next_highest_power_of_2(bits);
if (target_bits < 8) target_bits = 8;
if (target_bits < bits) val = LLVMBuildZExt(c->builder, val, llvm_get_type(c, expr->type), "");
llvm_value_set(be_value, val, expr->type);
}
static void llvm_emit_int_to_mask(GenContext *c, BEValue *be_value, Expr *expr)
{
Expr **args = expr->call_expr.arguments;
ASSERT(vec_size(args) == 2);
LLVMValueRef val = llvm_emit_expr_to_rvalue(c, args[0]);
unsigned bits = (unsigned)args[1]->const_expr.ixx.i.low;
unsigned int_len = type_bit_size(args[0]->type);
if (int_len > bits)
{
val = LLVMBuildTrunc(c->builder, val, LLVMIntTypeInContext(c->context, bits), "");
}
val = LLVMBuildBitCast(c->builder, val, LLVMVectorType(c->bool_type, bits), "");
llvm_value_set(be_value, val, expr->type);
}
static void llvm_emit_masked_store(GenContext *c, BEValue *be_value, Expr *expr)
{
Expr **args = expr->call_expr.arguments;
@@ -915,6 +944,12 @@ void llvm_emit_builtin_call(GenContext *c, BEValue *result_value, Expr *expr)
case BUILTIN_SCATTER:
llvm_emit_scatter(c, result_value, expr);
return;
case BUILTIN_MASK_TO_INT:
llvm_emit_mask_to_int(c, result_value, expr);
return;
case BUILTIN_INT_TO_MASK:
llvm_emit_int_to_mask(c, result_value, expr);
return;
case BUILTIN_MASKED_STORE:
llvm_emit_masked_store(c, result_value, expr);
return;

View File

@@ -6357,8 +6357,7 @@ static inline void llvm_emit_vector_initializer_list(GenContext *c, BEValue *val
FOREACH_IDX(i, Expr *, element, elements)
{
llvm_emit_expr(c, &val, element);
llvm_value_rvalue(c, &val);
vec_value = llvm_update_vector(c, vec_value, val.value, (ArrayIndex)i);
vec_value = llvm_update_vector(c, vec_value, llvm_load_value_store(c, &val), (ArrayIndex)i);
}
}
else
@@ -6382,18 +6381,18 @@ static inline void llvm_emit_vector_initializer_list(GenContext *c, BEValue *val
ASSERT(vec_size(designator->designator_expr.path) == 1);
DesignatorElement *element = designator->designator_expr.path[0];
llvm_emit_expr(c, &val, designator->designator_expr.value);
llvm_value_rvalue(c, &val);
LLVMValueRef value = llvm_load_value_store(c, &val);
switch (element->kind)
{
case DESIGNATOR_ARRAY:
{
vec_value = llvm_update_vector(c, vec_value, val.value, element->index);
vec_value = llvm_update_vector(c, vec_value, value, element->index);
break;
}
case DESIGNATOR_RANGE:
for (ArrayIndex idx = element->index; idx <= element->index_end; idx++)
{
vec_value = llvm_update_vector(c, vec_value, val.value, idx);
vec_value = llvm_update_vector(c, vec_value, value, idx);
}
break;
case DESIGNATOR_FIELD:

View File

@@ -1068,6 +1068,44 @@ bool sema_expr_analyse_builtin_call(SemaContext *context, Expr *expr)
rtype = type_void;
break;
}
case BUILTIN_INT_TO_MASK:
{
ASSERT(arg_count == 2);
if (!sema_check_builtin_args(context, args, (BuiltinArg[]) {BA_INTEGER, BA_INTEGER }, 2)) return false;
Expr *len = args[1];
if (!sema_cast_const(len) || !expr_is_const_int(len))
{
RETURN_SEMA_ERROR(len, "Expected constant integer for the vector length.");
}
Type *type = type_flatten(args[0]->type);
int size = (int)type_bit_size(type);
if (int_icomp(len->const_expr.ixx, size, BINARYOP_GT))
{
RETURN_SEMA_ERROR(args[1], "The vector length (%s) cannot be greater than the bit width of the integer (%d).",
int_to_str(len->const_expr.ixx, 10, false), size);
}
int bits = (int)len->const_expr.ixx.i.low;
if (!bits)
{
RETURN_SEMA_ERROR(args[1], "The vector length cannot be zero");
}
rtype = type_get_vector(type_bool, TYPE_VECTOR, bits);
break;
}
case BUILTIN_MASK_TO_INT:
{
ASSERT(arg_count == 1);
if (!sema_check_builtin_args(context, args, (BuiltinArg[]) {BA_BOOLVEC }, 1)) return false;
Type *vec = type_flatten(args[0]->type);
ArraySize len = vec->array.len;
if (len > 128)
{
RETURN_SEMA_ERROR(args[0], "Masks must be 128 or fewer bits to convert them to an integer.");
}
if (len < 8) len = 8;
rtype = type_int_unsigned_by_bitsize(next_highest_power_of_2(len));
break;
}
case BUILTIN_MASKED_LOAD:
{
ASSERT(arg_count == 4);
@@ -1376,6 +1414,7 @@ static inline int builtin_expected_args(BuiltinFunction func)
case BUILTIN_LOG:
case BUILTIN_LRINT:
case BUILTIN_LROUND:
case BUILTIN_MASK_TO_INT:
case BUILTIN_NEARBYINT:
case BUILTIN_POPCOUNT:
case BUILTIN_REDUCE_ADD:
@@ -1394,15 +1433,16 @@ static inline int builtin_expected_args(BuiltinFunction func)
case BUILTIN_SIN:
case BUILTIN_SQRT:
case BUILTIN_STR_HASH:
case BUILTIN_STR_UPPER:
case BUILTIN_STR_LOWER:
case BUILTIN_STR_SNAKECASE:
case BUILTIN_STR_PASCALCASE:
case BUILTIN_STR_SNAKECASE:
case BUILTIN_STR_UPPER:
case BUILTIN_TRUNC:
case BUILTIN_VOLATILE_LOAD:
case BUILTIN_WASM_MEMORY_SIZE:
return 1;
case BUILTIN_STR_FIND:
case BUILTIN_VOLATILE_STORE:
case BUILTIN_ANY_MAKE:
case BUILTIN_COPYSIGN:
case BUILTIN_EXACT_ADD:
case BUILTIN_EXACT_DIV:
@@ -1410,6 +1450,7 @@ static inline int builtin_expected_args(BuiltinFunction func)
case BUILTIN_EXACT_MUL:
case BUILTIN_EXACT_SUB:
case BUILTIN_EXPECT:
case BUILTIN_INT_TO_MASK:
case BUILTIN_MAX:
case BUILTIN_MIN:
case BUILTIN_POW:
@@ -1417,19 +1458,18 @@ static inline int builtin_expected_args(BuiltinFunction func)
case BUILTIN_REDUCE_FADD:
case BUILTIN_REDUCE_FMUL:
case BUILTIN_SAT_ADD:
case BUILTIN_SAT_MUL:
case BUILTIN_SAT_SHL:
case BUILTIN_SAT_SUB:
case BUILTIN_SAT_MUL:
case BUILTIN_VOLATILE_STORE:
case BUILTIN_VECCOMPNE:
case BUILTIN_VECCOMPLT:
case BUILTIN_VECCOMPLE:
case BUILTIN_STR_FIND:
case BUILTIN_UNALIGNED_LOAD:
case BUILTIN_VECCOMPEQ:
case BUILTIN_VECCOMPGE:
case BUILTIN_VECCOMPGT:
case BUILTIN_VECCOMPEQ:
case BUILTIN_VECCOMPLE:
case BUILTIN_VECCOMPLT:
case BUILTIN_VECCOMPNE:
case BUILTIN_WASM_MEMORY_GROW:
case BUILTIN_ANY_MAKE:
case BUILTIN_UNALIGNED_LOAD:
return 2;
case BUILTIN_EXPECT_WITH_PROBABILITY:
case BUILTIN_FMA:

View File

@@ -235,9 +235,11 @@ void symtab_init(uint32_t capacity)
builtin_list[BUILTIN_FSHR] = KW_DEF("fshr");
builtin_list[BUILTIN_GATHER] = KW_DEF("gather");
builtin_list[BUILTIN_GET_ROUNDING_MODE] = KW_DEF("get_rounding_mode");
builtin_list[BUILTIN_INT_TO_MASK] = KW_DEF("int_to_mask");
builtin_list[BUILTIN_LOG] = KW_DEF("log");
builtin_list[BUILTIN_LOG2] = KW_DEF("log2");
builtin_list[BUILTIN_LOG10] = KW_DEF("log10");
builtin_list[BUILTIN_MASK_TO_INT] = KW_DEF("mask_to_int");
builtin_list[BUILTIN_MASKED_LOAD] = KW_DEF("masked_load");
builtin_list[BUILTIN_MASKED_STORE] = KW_DEF("masked_store");
builtin_list[BUILTIN_MATRIX_MUL] = KW_DEF("matrix_mul");

View File

@@ -0,0 +1,38 @@
// #target: macos-x64
module test;
fn void test()
{
bool[<8>] x = $$int_to_mask(0b1111000, 8);
bool[<32>] y = $$int_to_mask(0b1111000, 32);
}
fn void test2()
{
char x = $$mask_to_int((bool[<8>]){ false, false, false, false, true, true, true, false });
short x2 = $$mask_to_int((bool[<15>]){});
uint128 y = $$mask_to_int((bool[<128>]){});
}
/* #expect: test.ll
define void @test.test() #0 {
entry:
%x = alloca <8 x i8>, align 8
%y = alloca <32 x i8>, align 32
%0 = sext <8 x i1> bitcast (<1 x i8> splat (i8 120) to <8 x i1>) to <8 x i8>
store <8 x i8> %0, ptr %x, align 8
%1 = sext <32 x i1> bitcast (<1 x i32> splat (i32 120) to <32 x i1>) to <32 x i8>
store <32 x i8> %1, ptr %y, align 32
ret void
}
define void @test.test2() #0 {
entry:
%x = alloca i8, align 1
%x2 = alloca i16, align 2
%y = alloca i128, align 16
store i8 bitcast (<8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 true, i1 true, i1 true, i1 false> to i8), ptr %x, align 1
store i16 0, ptr %x2, align 2
store i128 0, ptr %y, align 16
ret void
}

View File

@@ -0,0 +1,18 @@
fn void test1()
{
$$int_to_mask(1u, 32);
$$int_to_mask(1u, 33); // #error: The vector length (33) cannot be greater than the bit width of the integer (32)
}
fn void test2()
{
$$int_to_mask(1ULL, 128);
$$int_to_mask(1ULL, 129); // #error: The vector length (129) cannot be greater than the bit width of the integer (128)
}
fn void test3()
{
$$mask_to_int((bool[<128>]){});
$$mask_to_int((bool[<129>]){}); // #error: Masks must be 128 or fewer bits to convert them to an integer
}

View File

@@ -1,6 +1,6 @@
module test2;
import test3;
struct Bar @if($defined(Foo.b)) // #error: "There might be a method 'b' for 'Foo', but methods for the type have not yet been completely registered, so this yields an error.
struct Bar @if($defined(Foo.b)) // #error: There might be a method 'b' for 'Foo', but methods for the type have not yet been completely registered, so this yields an error
{
int a;
}

View File

@@ -0,0 +1,10 @@
module vector_mask @test;
import std::io, std::math;
fn void to_from_mask()
{
int x = (bool[<9>]){ true, false, true, false, false, false, false, false, true }.mask_to_int();
test::eq(x, 0b100000101);
bool[<10>] mask = vector::mask_from_int(bool[<10>], x);
test::eq(mask, (bool[<10>]){ true, false, true, false, false, false, false, false, true, false });
}