Add $$matrix_mul and $$matrix_transpose builtins.

This commit is contained in:
Christoffer Lerno
2025-05-27 00:50:16 +02:00
parent 61a4dcc807
commit 966e8107f8
9 changed files with 127 additions and 0 deletions

View File

@@ -17,6 +17,7 @@
- Limit vector max size, default is 4096 bits, but may be increased using --max-vector-size.
- Allow the use of `has_tagof` on builtin types.
- `@jump` now included in `--list-attributes` #2155.
- Add `$$matrix_mul` and `$$matrix_transpose` builtins.
### Fixes
- Assert triggered when casting from `int[2]` to `uint[2]` #2115

View File

@@ -459,6 +459,8 @@ typedef enum
BUILTIN_LOG,
BUILTIN_LOG10,
BUILTIN_LOG2,
BUILTIN_MATRIX_MUL,
BUILTIN_MATRIX_TRANSPOSE,
BUILTIN_MASKED_LOAD,
BUILTIN_MASKED_STORE,
BUILTIN_MAX,

View File

@@ -824,6 +824,8 @@ static void llvm_codegen_setup()
intrinsic_id.masked_store = lookup_intrinsic("llvm.masked.store");
intrinsic_id.maximum = lookup_intrinsic("llvm.maximum");
intrinsic_id.maxnum = lookup_intrinsic("llvm.maxnum");
intrinsic_id.matrix_multiply = lookup_intrinsic("llvm.matrix.multiply");
intrinsic_id.matrix_transpose = lookup_intrinsic("llvm.matrix.transpose");
intrinsic_id.memcpy = lookup_intrinsic("llvm.memcpy");
intrinsic_id.memcpy_inline = lookup_intrinsic("llvm.memcpy.inline");
intrinsic_id.memmove = lookup_intrinsic("llvm.memmove");

View File

@@ -526,6 +526,34 @@ void llvm_emit_simple_builtin(GenContext *c, BEValue *be_value, Expr *expr, unsi
llvm_value_set(be_value, result, expr->type);
}
void llvm_emit_matrix_multiply(GenContext *c, BEValue *be_value, Expr *expr)
{
Expr **args = expr->call_expr.arguments;
unsigned count = vec_size(args);
ASSERT(count == 5);
LLVMValueRef arg_slots[5];
llvm_emit_intrinsic_args(c, args, arg_slots, count);
LLVMTypeRef type = LLVMTypeOf(arg_slots[0]);
LLVMTypeRef result_type = llvm_get_type(c, expr->type);
LLVMTypeRef call_type[3] = { result_type, type, LLVMTypeOf(arg_slots[1]) };
LLVMValueRef result = llvm_emit_call_intrinsic(c, intrinsic_id.matrix_multiply, call_type, 3, arg_slots, count);
llvm_value_set(be_value, result, expr->type);
}
void llvm_emit_matrix_transpose(GenContext *c, BEValue *be_value, Expr *expr)
{
Expr **args = expr->call_expr.arguments;
unsigned count = vec_size(args);
ASSERT(count == 3);
LLVMValueRef arg_slots[3];
llvm_emit_intrinsic_args(c, args, arg_slots, count);
LLVMTypeRef type = LLVMTypeOf(arg_slots[0]);
LLVMTypeRef result_type = llvm_get_type(c, expr->type);
LLVMTypeRef call_type[3] = { result_type, type };
LLVMValueRef result = llvm_emit_call_intrinsic(c, intrinsic_id.matrix_transpose, call_type, 2, arg_slots, count);
llvm_value_set(be_value, result, expr->type);
}
static void llvm_emit_masked_load(GenContext *c, BEValue *be_value, Expr *expr)
{
Expr **args = expr->call_expr.arguments;
@@ -788,6 +816,12 @@ void llvm_emit_builtin_call(GenContext *c, BEValue *result_value, Expr *expr)
case BUILTIN_MEMSET_INLINE:
llvm_emit_memset_builtin(c, intrinsic_id.memset_inline, result_value, expr);
return;
case BUILTIN_MATRIX_MUL:
llvm_emit_matrix_multiply(c, result_value, expr);
return;
case BUILTIN_MATRIX_TRANSPOSE:
llvm_emit_simple_builtin(c, result_value, expr, intrinsic_id.matrix_transpose);
return;
case BUILTIN_SYSCLOCK:
llvm_value_set(result_value, llvm_emit_call_intrinsic(c, intrinsic_id.readcyclecounter, NULL, 0, NULL, 0), expr->type);
return;

View File

@@ -199,6 +199,8 @@ typedef struct
unsigned masked_store;
unsigned maximum;
unsigned maxnum;
unsigned matrix_multiply;
unsigned matrix_transpose;
unsigned memcpy;
unsigned memcpy_inline;
unsigned memmove;

View File

@@ -704,6 +704,46 @@ bool sema_expr_analyse_builtin_call(SemaContext *context, Expr *expr)
if (!sema_check_builtin_args(context, args, (BuiltinArg[]) {BA_INTLIKE}, 1)) return false;
rtype = args[0]->type;
break;
case BUILTIN_MATRIX_TRANSPOSE:
{
ASSERT(arg_count == 3);
if (!sema_check_builtin_args(context, args, (BuiltinArg[]) {BA_VEC, BA_INTEGER, BA_INTEGER}, 3)) return false;
if (!sema_check_builtin_args_const(context, &args[1], 2)) return false;
ArraySize vec_len = type_flatten(args[0]->type)->array.len;
Int sum = int_mul(args[1]->const_expr.ixx, args[2]->const_expr.ixx);
if (!int_icomp(sum, vec_len, BINARYOP_EQ))
{
RETURN_SEMA_ERROR(args[1], "Expected row * col to equal %d.", vec_len);
}
rtype = args[0]->type;
break;
}
case BUILTIN_MATRIX_MUL:
ASSERT(arg_count == 5);
if (!sema_check_builtin_args(context, args, (BuiltinArg[]) {BA_VEC, BA_VEC, BA_INTEGER, BA_INTEGER, BA_INTEGER }, 5)) return false;
if (!sema_check_builtin_args_const(context, &args[2], 3)) return false;
Type *flat1 = type_flatten(args[0]->type);
Type *flat2 = type_flatten(args[1]->type);
if (flat1->array.base != flat2->array.base)
{
RETURN_SEMA_ERROR(args[1], "Expected both matrices to be of the same type.");
}
ArraySize vec_len1 = flat1->array.len;
ArraySize vec_len2 = flat2->array.len;
Int128 sum = i128_mult(args[2]->const_expr.ixx.i, args[3]->const_expr.ixx.i);
args[2]->type = type_int;
args[3]->type = type_int;
if (sum.high != 0 || sum.low != vec_len1)
{
RETURN_SEMA_ERROR(args[3], "Expected outer row * inner to equal %d.", vec_len1);
}
sum = i128_mult(args[3]->const_expr.ixx.i, args[4]->const_expr.ixx.i);
if (sum.high != 0 || sum.low != vec_len2)
{
RETURN_SEMA_ERROR(args[4], "Expected inner * outer col to equal %d.", vec_len2);
}
rtype = type_get_vector(flat1->array.base, i128_mult(args[2]->const_expr.ixx.i, args[4]->const_expr.ixx.i).low);
break;
case BUILTIN_SAT_SHL:
case BUILTIN_SAT_SUB:
case BUILTIN_SAT_ADD:
@@ -1258,6 +1298,7 @@ static inline int builtin_expected_args(BuiltinFunction func)
case BUILTIN_ATOMIC_LOAD:
case BUILTIN_UNALIGNED_STORE:
case BUILTIN_SELECT:
case BUILTIN_MATRIX_TRANSPOSE:
return 3;
case BUILTIN_ATOMIC_STORE:
case BUILTIN_MASKED_STORE:
@@ -1276,6 +1317,7 @@ static inline int builtin_expected_args(BuiltinFunction func)
case BUILTIN_ATOMIC_FETCH_MIN:
case BUILTIN_ATOMIC_FETCH_SUB:
case BUILTIN_ATOMIC_FETCH_DEC_WRAP:
case BUILTIN_MATRIX_MUL:
return 5;
case BUILTIN_MEMCOPY:
case BUILTIN_MEMCOPY_INLINE:

View File

@@ -236,6 +236,8 @@ void symtab_init(uint32_t capacity)
builtin_list[BUILTIN_LOG10] = KW_DEF("log10");
builtin_list[BUILTIN_MASKED_LOAD] = KW_DEF("masked_load");
builtin_list[BUILTIN_MASKED_STORE] = KW_DEF("masked_store");
builtin_list[BUILTIN_MATRIX_MUL] = KW_DEF("matrix_mul");
builtin_list[BUILTIN_MATRIX_TRANSPOSE] = KW_DEF("matrix_transpose");
builtin_list[BUILTIN_MEMCOPY] = KW_DEF("memcpy");
builtin_list[BUILTIN_MEMCOPY_INLINE] = KW_DEF("memcpy_inline");
builtin_list[BUILTIN_MEMMOVE] = KW_DEF("memmove");

View File

@@ -0,0 +1,38 @@
// #target: macos-x64
module test;
fn int main()
{
int[<4>] x = { 1, 2, 3, 4 };
int[<4>] z = $$matrix_mul(x, x, 2, 2, 2);
int[<2>] a = { 1, 2 };
int[<3>] b = { 1, 2, 3 };
int[<6>] c = $$matrix_mul(a, b, 2, 1, 3);
c = $$matrix_transpose(c, 2, 3);
return 0;
}
/* #expect: test.ll
define i32 @main() #0 {
entry:
%x = alloca <4 x i32>, align 16
%z = alloca <4 x i32>, align 16
%a = alloca <2 x i32>, align 8
%b = alloca <3 x i32>, align 16
%c = alloca <6 x i32>, align 32
store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %x, align 16
%0 = load <4 x i32>, ptr %x, align 16
%1 = load <4 x i32>, ptr %x, align 16
%2 = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> %0, <4 x i32> %1, i32 2, i32 2, i32 2)
store <4 x i32> %2, ptr %z, align 16
store <2 x i32> <i32 1, i32 2>, ptr %a, align 8
store <3 x i32> <i32 1, i32 2, i32 3>, ptr %b, align 16
%3 = load <2 x i32>, ptr %a, align 8
%4 = load <3 x i32>, ptr %b, align 16
%5 = call <6 x i32> @llvm.matrix.multiply.v6i32.v2i32.v3i32(<2 x i32> %3, <3 x i32> %4, i32 2, i32 1, i32 3)
store <6 x i32> %5, ptr %c, align 32
%6 = load <6 x i32>, ptr %c, align 32
%7 = call <6 x i32> @llvm.matrix.transpose.v6i32(<6 x i32> %6, i32 2, i32 3)
store <6 x i32> %7, ptr %c, align 32
ret i32 0
}

View File

@@ -15,6 +15,7 @@
#include "llvm-c/Target.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IR/Module.h"
#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
#include "llvm/IR/Function.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/StandardInstrumentations.h"
@@ -155,6 +156,8 @@ bool llvm_run_passes(LLVMModuleRef m, LLVMTargetMachineRef tm,
llvm::FunctionAnalysisManager FAM;
llvm::CGSCCAnalysisManager CGAM;
llvm::ModuleAnalysisManager MAM;
PB.registerLoopAnalyses(LAM);
PB.registerFunctionAnalyses(FAM);
PB.registerCGSCCAnalyses(CGAM);
@@ -196,6 +199,7 @@ bool llvm_run_passes(LLVMModuleRef m, LLVMTargetMachineRef tm,
#else
llvm::ModulePassManager MPM = PB.buildPerModuleDefaultPipeline(level, false);
#endif
MPM.addPass(llvm::createModuleToFunctionPassAdaptor(llvm::LowerMatrixIntrinsicsPass(false)));
if (passes->should_verify)
{
MPM.addPass(llvm::VerifierPass());