From 966e8107f86a15de833b6ab83d513165ec99e87f Mon Sep 17 00:00:00 2001 From: Christoffer Lerno Date: Tue, 27 May 2025 00:50:16 +0200 Subject: [PATCH] Add `$$matrix_mul` and `$$matrix_transpose` builtins. --- releasenotes.md | 1 + src/compiler/enums.h | 2 + src/compiler/llvm_codegen.c | 2 + src/compiler/llvm_codegen_builtins.c | 34 +++++++++++++++++ src/compiler/llvm_codegen_internal.h | 2 + src/compiler/sema_builtins.c | 42 +++++++++++++++++++++ src/compiler/symtab.c | 2 + test/test_suite/builtins/matrix_builtin.c3t | 38 +++++++++++++++++++ wrapper/src/wrapper.cpp | 4 ++ 9 files changed, 127 insertions(+) create mode 100644 test/test_suite/builtins/matrix_builtin.c3t diff --git a/releasenotes.md b/releasenotes.md index 5bf78c95d..c30b42f72 100644 --- a/releasenotes.md +++ b/releasenotes.md @@ -17,6 +17,7 @@ - Limit vector max size, default is 4096 bits, but may be increased using --max-vector-size. - Allow the use of `has_tagof` on builtin types. - `@jump` now included in `--list-attributes` #2155. +- Add `$$matrix_mul` and `$$matrix_transpose` builtins. ### Fixes - Assert triggered when casting from `int[2]` to `uint[2]` #2115 diff --git a/src/compiler/enums.h b/src/compiler/enums.h index 67f9dc30c..d15ad23cd 100644 --- a/src/compiler/enums.h +++ b/src/compiler/enums.h @@ -459,6 +459,8 @@ typedef enum BUILTIN_LOG, BUILTIN_LOG10, BUILTIN_LOG2, + BUILTIN_MATRIX_MUL, + BUILTIN_MATRIX_TRANSPOSE, BUILTIN_MASKED_LOAD, BUILTIN_MASKED_STORE, BUILTIN_MAX, diff --git a/src/compiler/llvm_codegen.c b/src/compiler/llvm_codegen.c index c755bc590..47e8b6599 100644 --- a/src/compiler/llvm_codegen.c +++ b/src/compiler/llvm_codegen.c @@ -824,6 +824,8 @@ static void llvm_codegen_setup() intrinsic_id.masked_store = lookup_intrinsic("llvm.masked.store"); intrinsic_id.maximum = lookup_intrinsic("llvm.maximum"); intrinsic_id.maxnum = lookup_intrinsic("llvm.maxnum"); + intrinsic_id.matrix_multiply = lookup_intrinsic("llvm.matrix.multiply"); + intrinsic_id.matrix_transpose = lookup_intrinsic("llvm.matrix.transpose"); intrinsic_id.memcpy = lookup_intrinsic("llvm.memcpy"); intrinsic_id.memcpy_inline = lookup_intrinsic("llvm.memcpy.inline"); intrinsic_id.memmove = lookup_intrinsic("llvm.memmove"); diff --git a/src/compiler/llvm_codegen_builtins.c b/src/compiler/llvm_codegen_builtins.c index 7b261ad1b..8953421c6 100644 --- a/src/compiler/llvm_codegen_builtins.c +++ b/src/compiler/llvm_codegen_builtins.c @@ -526,6 +526,34 @@ void llvm_emit_simple_builtin(GenContext *c, BEValue *be_value, Expr *expr, unsi llvm_value_set(be_value, result, expr->type); } +void llvm_emit_matrix_multiply(GenContext *c, BEValue *be_value, Expr *expr) +{ + Expr **args = expr->call_expr.arguments; + unsigned count = vec_size(args); + ASSERT(count == 5); + LLVMValueRef arg_slots[5]; + llvm_emit_intrinsic_args(c, args, arg_slots, count); + LLVMTypeRef type = LLVMTypeOf(arg_slots[0]); + LLVMTypeRef result_type = llvm_get_type(c, expr->type); + LLVMTypeRef call_type[3] = { result_type, type, LLVMTypeOf(arg_slots[1]) }; + LLVMValueRef result = llvm_emit_call_intrinsic(c, intrinsic_id.matrix_multiply, call_type, 3, arg_slots, count); + llvm_value_set(be_value, result, expr->type); +} + +void llvm_emit_matrix_transpose(GenContext *c, BEValue *be_value, Expr *expr) +{ + Expr **args = expr->call_expr.arguments; + unsigned count = vec_size(args); + ASSERT(count == 3); + LLVMValueRef arg_slots[3]; + llvm_emit_intrinsic_args(c, args, arg_slots, count); + LLVMTypeRef type = LLVMTypeOf(arg_slots[0]); + LLVMTypeRef result_type = llvm_get_type(c, expr->type); + LLVMTypeRef call_type[3] = { result_type, type }; + LLVMValueRef result = llvm_emit_call_intrinsic(c, intrinsic_id.matrix_transpose, call_type, 2, arg_slots, count); + llvm_value_set(be_value, result, expr->type); +} + static void llvm_emit_masked_load(GenContext *c, BEValue *be_value, Expr *expr) { Expr **args = expr->call_expr.arguments; @@ -788,6 +816,12 @@ void llvm_emit_builtin_call(GenContext *c, BEValue *result_value, Expr *expr) case BUILTIN_MEMSET_INLINE: llvm_emit_memset_builtin(c, intrinsic_id.memset_inline, result_value, expr); return; + case BUILTIN_MATRIX_MUL: + llvm_emit_matrix_multiply(c, result_value, expr); + return; + case BUILTIN_MATRIX_TRANSPOSE: + llvm_emit_simple_builtin(c, result_value, expr, intrinsic_id.matrix_transpose); + return; case BUILTIN_SYSCLOCK: llvm_value_set(result_value, llvm_emit_call_intrinsic(c, intrinsic_id.readcyclecounter, NULL, 0, NULL, 0), expr->type); return; diff --git a/src/compiler/llvm_codegen_internal.h b/src/compiler/llvm_codegen_internal.h index 536e5f5d2..57e2e3993 100644 --- a/src/compiler/llvm_codegen_internal.h +++ b/src/compiler/llvm_codegen_internal.h @@ -199,6 +199,8 @@ typedef struct unsigned masked_store; unsigned maximum; unsigned maxnum; + unsigned matrix_multiply; + unsigned matrix_transpose; unsigned memcpy; unsigned memcpy_inline; unsigned memmove; diff --git a/src/compiler/sema_builtins.c b/src/compiler/sema_builtins.c index d41885be6..f09fb7686 100644 --- a/src/compiler/sema_builtins.c +++ b/src/compiler/sema_builtins.c @@ -704,6 +704,46 @@ bool sema_expr_analyse_builtin_call(SemaContext *context, Expr *expr) if (!sema_check_builtin_args(context, args, (BuiltinArg[]) {BA_INTLIKE}, 1)) return false; rtype = args[0]->type; break; + case BUILTIN_MATRIX_TRANSPOSE: + { + ASSERT(arg_count == 3); + if (!sema_check_builtin_args(context, args, (BuiltinArg[]) {BA_VEC, BA_INTEGER, BA_INTEGER}, 3)) return false; + if (!sema_check_builtin_args_const(context, &args[1], 2)) return false; + ArraySize vec_len = type_flatten(args[0]->type)->array.len; + Int sum = int_mul(args[1]->const_expr.ixx, args[2]->const_expr.ixx); + if (!int_icomp(sum, vec_len, BINARYOP_EQ)) + { + RETURN_SEMA_ERROR(args[1], "Expected row * col to equal %d.", vec_len); + } + rtype = args[0]->type; + break; + } + case BUILTIN_MATRIX_MUL: + ASSERT(arg_count == 5); + if (!sema_check_builtin_args(context, args, (BuiltinArg[]) {BA_VEC, BA_VEC, BA_INTEGER, BA_INTEGER, BA_INTEGER }, 5)) return false; + if (!sema_check_builtin_args_const(context, &args[2], 3)) return false; + Type *flat1 = type_flatten(args[0]->type); + Type *flat2 = type_flatten(args[1]->type); + if (flat1->array.base != flat2->array.base) + { + RETURN_SEMA_ERROR(args[1], "Expected both matrices to be of the same type."); + } + ArraySize vec_len1 = flat1->array.len; + ArraySize vec_len2 = flat2->array.len; + Int128 sum = i128_mult(args[2]->const_expr.ixx.i, args[3]->const_expr.ixx.i); + args[2]->type = type_int; + args[3]->type = type_int; + if (sum.high != 0 || sum.low != vec_len1) + { + RETURN_SEMA_ERROR(args[3], "Expected outer row * inner to equal %d.", vec_len1); + } + sum = i128_mult(args[3]->const_expr.ixx.i, args[4]->const_expr.ixx.i); + if (sum.high != 0 || sum.low != vec_len2) + { + RETURN_SEMA_ERROR(args[4], "Expected inner * outer col to equal %d.", vec_len2); + } + rtype = type_get_vector(flat1->array.base, i128_mult(args[2]->const_expr.ixx.i, args[4]->const_expr.ixx.i).low); + break; case BUILTIN_SAT_SHL: case BUILTIN_SAT_SUB: case BUILTIN_SAT_ADD: @@ -1258,6 +1298,7 @@ static inline int builtin_expected_args(BuiltinFunction func) case BUILTIN_ATOMIC_LOAD: case BUILTIN_UNALIGNED_STORE: case BUILTIN_SELECT: + case BUILTIN_MATRIX_TRANSPOSE: return 3; case BUILTIN_ATOMIC_STORE: case BUILTIN_MASKED_STORE: @@ -1276,6 +1317,7 @@ static inline int builtin_expected_args(BuiltinFunction func) case BUILTIN_ATOMIC_FETCH_MIN: case BUILTIN_ATOMIC_FETCH_SUB: case BUILTIN_ATOMIC_FETCH_DEC_WRAP: + case BUILTIN_MATRIX_MUL: return 5; case BUILTIN_MEMCOPY: case BUILTIN_MEMCOPY_INLINE: diff --git a/src/compiler/symtab.c b/src/compiler/symtab.c index 1464ebff8..2bdd8918b 100644 --- a/src/compiler/symtab.c +++ b/src/compiler/symtab.c @@ -236,6 +236,8 @@ void symtab_init(uint32_t capacity) builtin_list[BUILTIN_LOG10] = KW_DEF("log10"); builtin_list[BUILTIN_MASKED_LOAD] = KW_DEF("masked_load"); builtin_list[BUILTIN_MASKED_STORE] = KW_DEF("masked_store"); + builtin_list[BUILTIN_MATRIX_MUL] = KW_DEF("matrix_mul"); + builtin_list[BUILTIN_MATRIX_TRANSPOSE] = KW_DEF("matrix_transpose"); builtin_list[BUILTIN_MEMCOPY] = KW_DEF("memcpy"); builtin_list[BUILTIN_MEMCOPY_INLINE] = KW_DEF("memcpy_inline"); builtin_list[BUILTIN_MEMMOVE] = KW_DEF("memmove"); diff --git a/test/test_suite/builtins/matrix_builtin.c3t b/test/test_suite/builtins/matrix_builtin.c3t new file mode 100644 index 000000000..165e4fc37 --- /dev/null +++ b/test/test_suite/builtins/matrix_builtin.c3t @@ -0,0 +1,38 @@ +// #target: macos-x64 +module test; +fn int main() +{ + int[<4>] x = { 1, 2, 3, 4 }; + int[<4>] z = $$matrix_mul(x, x, 2, 2, 2); + int[<2>] a = { 1, 2 }; + int[<3>] b = { 1, 2, 3 }; + int[<6>] c = $$matrix_mul(a, b, 2, 1, 3); + c = $$matrix_transpose(c, 2, 3); + return 0; +} + +/* #expect: test.ll + +define i32 @main() #0 { +entry: + %x = alloca <4 x i32>, align 16 + %z = alloca <4 x i32>, align 16 + %a = alloca <2 x i32>, align 8 + %b = alloca <3 x i32>, align 16 + %c = alloca <6 x i32>, align 32 + store <4 x i32> , ptr %x, align 16 + %0 = load <4 x i32>, ptr %x, align 16 + %1 = load <4 x i32>, ptr %x, align 16 + %2 = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> %0, <4 x i32> %1, i32 2, i32 2, i32 2) + store <4 x i32> %2, ptr %z, align 16 + store <2 x i32> , ptr %a, align 8 + store <3 x i32> , ptr %b, align 16 + %3 = load <2 x i32>, ptr %a, align 8 + %4 = load <3 x i32>, ptr %b, align 16 + %5 = call <6 x i32> @llvm.matrix.multiply.v6i32.v2i32.v3i32(<2 x i32> %3, <3 x i32> %4, i32 2, i32 1, i32 3) + store <6 x i32> %5, ptr %c, align 32 + %6 = load <6 x i32>, ptr %c, align 32 + %7 = call <6 x i32> @llvm.matrix.transpose.v6i32(<6 x i32> %6, i32 2, i32 3) + store <6 x i32> %7, ptr %c, align 32 + ret i32 0 +} \ No newline at end of file diff --git a/wrapper/src/wrapper.cpp b/wrapper/src/wrapper.cpp index 2b635652d..95ed02fcc 100644 --- a/wrapper/src/wrapper.cpp +++ b/wrapper/src/wrapper.cpp @@ -15,6 +15,7 @@ #include "llvm-c/Target.h" #include "llvm/IR/Verifier.h" #include "llvm/IR/Module.h" +#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" #include "llvm/IR/Function.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/Passes/StandardInstrumentations.h" @@ -155,6 +156,8 @@ bool llvm_run_passes(LLVMModuleRef m, LLVMTargetMachineRef tm, llvm::FunctionAnalysisManager FAM; llvm::CGSCCAnalysisManager CGAM; llvm::ModuleAnalysisManager MAM; + + PB.registerLoopAnalyses(LAM); PB.registerFunctionAnalyses(FAM); PB.registerCGSCCAnalyses(CGAM); @@ -196,6 +199,7 @@ bool llvm_run_passes(LLVMModuleRef m, LLVMTargetMachineRef tm, #else llvm::ModulePassManager MPM = PB.buildPerModuleDefaultPipeline(level, false); #endif + MPM.addPass(llvm::createModuleToFunctionPassAdaptor(llvm::LowerMatrixIntrinsicsPass(false))); if (passes->should_verify) { MPM.addPass(llvm::VerifierPass());