mirror of
https://github.com/c3lang/c3c.git
synced 2026-02-27 03:51:18 +00:00
Add $$matrix_mul and $$matrix_transpose builtins.
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
- Limit vector max size, default is 4096 bits, but may be increased using --max-vector-size.
|
||||
- Allow the use of `has_tagof` on builtin types.
|
||||
- `@jump` now included in `--list-attributes` #2155.
|
||||
- Add `$$matrix_mul` and `$$matrix_transpose` builtins.
|
||||
|
||||
### Fixes
|
||||
- Assert triggered when casting from `int[2]` to `uint[2]` #2115
|
||||
|
||||
@@ -459,6 +459,8 @@ typedef enum
|
||||
BUILTIN_LOG,
|
||||
BUILTIN_LOG10,
|
||||
BUILTIN_LOG2,
|
||||
BUILTIN_MATRIX_MUL,
|
||||
BUILTIN_MATRIX_TRANSPOSE,
|
||||
BUILTIN_MASKED_LOAD,
|
||||
BUILTIN_MASKED_STORE,
|
||||
BUILTIN_MAX,
|
||||
|
||||
@@ -824,6 +824,8 @@ static void llvm_codegen_setup()
|
||||
intrinsic_id.masked_store = lookup_intrinsic("llvm.masked.store");
|
||||
intrinsic_id.maximum = lookup_intrinsic("llvm.maximum");
|
||||
intrinsic_id.maxnum = lookup_intrinsic("llvm.maxnum");
|
||||
intrinsic_id.matrix_multiply = lookup_intrinsic("llvm.matrix.multiply");
|
||||
intrinsic_id.matrix_transpose = lookup_intrinsic("llvm.matrix.transpose");
|
||||
intrinsic_id.memcpy = lookup_intrinsic("llvm.memcpy");
|
||||
intrinsic_id.memcpy_inline = lookup_intrinsic("llvm.memcpy.inline");
|
||||
intrinsic_id.memmove = lookup_intrinsic("llvm.memmove");
|
||||
|
||||
@@ -526,6 +526,34 @@ void llvm_emit_simple_builtin(GenContext *c, BEValue *be_value, Expr *expr, unsi
|
||||
llvm_value_set(be_value, result, expr->type);
|
||||
}
|
||||
|
||||
void llvm_emit_matrix_multiply(GenContext *c, BEValue *be_value, Expr *expr)
|
||||
{
|
||||
Expr **args = expr->call_expr.arguments;
|
||||
unsigned count = vec_size(args);
|
||||
ASSERT(count == 5);
|
||||
LLVMValueRef arg_slots[5];
|
||||
llvm_emit_intrinsic_args(c, args, arg_slots, count);
|
||||
LLVMTypeRef type = LLVMTypeOf(arg_slots[0]);
|
||||
LLVMTypeRef result_type = llvm_get_type(c, expr->type);
|
||||
LLVMTypeRef call_type[3] = { result_type, type, LLVMTypeOf(arg_slots[1]) };
|
||||
LLVMValueRef result = llvm_emit_call_intrinsic(c, intrinsic_id.matrix_multiply, call_type, 3, arg_slots, count);
|
||||
llvm_value_set(be_value, result, expr->type);
|
||||
}
|
||||
|
||||
void llvm_emit_matrix_transpose(GenContext *c, BEValue *be_value, Expr *expr)
|
||||
{
|
||||
Expr **args = expr->call_expr.arguments;
|
||||
unsigned count = vec_size(args);
|
||||
ASSERT(count == 3);
|
||||
LLVMValueRef arg_slots[3];
|
||||
llvm_emit_intrinsic_args(c, args, arg_slots, count);
|
||||
LLVMTypeRef type = LLVMTypeOf(arg_slots[0]);
|
||||
LLVMTypeRef result_type = llvm_get_type(c, expr->type);
|
||||
LLVMTypeRef call_type[3] = { result_type, type };
|
||||
LLVMValueRef result = llvm_emit_call_intrinsic(c, intrinsic_id.matrix_transpose, call_type, 2, arg_slots, count);
|
||||
llvm_value_set(be_value, result, expr->type);
|
||||
}
|
||||
|
||||
static void llvm_emit_masked_load(GenContext *c, BEValue *be_value, Expr *expr)
|
||||
{
|
||||
Expr **args = expr->call_expr.arguments;
|
||||
@@ -788,6 +816,12 @@ void llvm_emit_builtin_call(GenContext *c, BEValue *result_value, Expr *expr)
|
||||
case BUILTIN_MEMSET_INLINE:
|
||||
llvm_emit_memset_builtin(c, intrinsic_id.memset_inline, result_value, expr);
|
||||
return;
|
||||
case BUILTIN_MATRIX_MUL:
|
||||
llvm_emit_matrix_multiply(c, result_value, expr);
|
||||
return;
|
||||
case BUILTIN_MATRIX_TRANSPOSE:
|
||||
llvm_emit_simple_builtin(c, result_value, expr, intrinsic_id.matrix_transpose);
|
||||
return;
|
||||
case BUILTIN_SYSCLOCK:
|
||||
llvm_value_set(result_value, llvm_emit_call_intrinsic(c, intrinsic_id.readcyclecounter, NULL, 0, NULL, 0), expr->type);
|
||||
return;
|
||||
|
||||
@@ -199,6 +199,8 @@ typedef struct
|
||||
unsigned masked_store;
|
||||
unsigned maximum;
|
||||
unsigned maxnum;
|
||||
unsigned matrix_multiply;
|
||||
unsigned matrix_transpose;
|
||||
unsigned memcpy;
|
||||
unsigned memcpy_inline;
|
||||
unsigned memmove;
|
||||
|
||||
@@ -704,6 +704,46 @@ bool sema_expr_analyse_builtin_call(SemaContext *context, Expr *expr)
|
||||
if (!sema_check_builtin_args(context, args, (BuiltinArg[]) {BA_INTLIKE}, 1)) return false;
|
||||
rtype = args[0]->type;
|
||||
break;
|
||||
case BUILTIN_MATRIX_TRANSPOSE:
|
||||
{
|
||||
ASSERT(arg_count == 3);
|
||||
if (!sema_check_builtin_args(context, args, (BuiltinArg[]) {BA_VEC, BA_INTEGER, BA_INTEGER}, 3)) return false;
|
||||
if (!sema_check_builtin_args_const(context, &args[1], 2)) return false;
|
||||
ArraySize vec_len = type_flatten(args[0]->type)->array.len;
|
||||
Int sum = int_mul(args[1]->const_expr.ixx, args[2]->const_expr.ixx);
|
||||
if (!int_icomp(sum, vec_len, BINARYOP_EQ))
|
||||
{
|
||||
RETURN_SEMA_ERROR(args[1], "Expected row * col to equal %d.", vec_len);
|
||||
}
|
||||
rtype = args[0]->type;
|
||||
break;
|
||||
}
|
||||
case BUILTIN_MATRIX_MUL:
|
||||
ASSERT(arg_count == 5);
|
||||
if (!sema_check_builtin_args(context, args, (BuiltinArg[]) {BA_VEC, BA_VEC, BA_INTEGER, BA_INTEGER, BA_INTEGER }, 5)) return false;
|
||||
if (!sema_check_builtin_args_const(context, &args[2], 3)) return false;
|
||||
Type *flat1 = type_flatten(args[0]->type);
|
||||
Type *flat2 = type_flatten(args[1]->type);
|
||||
if (flat1->array.base != flat2->array.base)
|
||||
{
|
||||
RETURN_SEMA_ERROR(args[1], "Expected both matrices to be of the same type.");
|
||||
}
|
||||
ArraySize vec_len1 = flat1->array.len;
|
||||
ArraySize vec_len2 = flat2->array.len;
|
||||
Int128 sum = i128_mult(args[2]->const_expr.ixx.i, args[3]->const_expr.ixx.i);
|
||||
args[2]->type = type_int;
|
||||
args[3]->type = type_int;
|
||||
if (sum.high != 0 || sum.low != vec_len1)
|
||||
{
|
||||
RETURN_SEMA_ERROR(args[3], "Expected outer row * inner to equal %d.", vec_len1);
|
||||
}
|
||||
sum = i128_mult(args[3]->const_expr.ixx.i, args[4]->const_expr.ixx.i);
|
||||
if (sum.high != 0 || sum.low != vec_len2)
|
||||
{
|
||||
RETURN_SEMA_ERROR(args[4], "Expected inner * outer col to equal %d.", vec_len2);
|
||||
}
|
||||
rtype = type_get_vector(flat1->array.base, i128_mult(args[2]->const_expr.ixx.i, args[4]->const_expr.ixx.i).low);
|
||||
break;
|
||||
case BUILTIN_SAT_SHL:
|
||||
case BUILTIN_SAT_SUB:
|
||||
case BUILTIN_SAT_ADD:
|
||||
@@ -1258,6 +1298,7 @@ static inline int builtin_expected_args(BuiltinFunction func)
|
||||
case BUILTIN_ATOMIC_LOAD:
|
||||
case BUILTIN_UNALIGNED_STORE:
|
||||
case BUILTIN_SELECT:
|
||||
case BUILTIN_MATRIX_TRANSPOSE:
|
||||
return 3;
|
||||
case BUILTIN_ATOMIC_STORE:
|
||||
case BUILTIN_MASKED_STORE:
|
||||
@@ -1276,6 +1317,7 @@ static inline int builtin_expected_args(BuiltinFunction func)
|
||||
case BUILTIN_ATOMIC_FETCH_MIN:
|
||||
case BUILTIN_ATOMIC_FETCH_SUB:
|
||||
case BUILTIN_ATOMIC_FETCH_DEC_WRAP:
|
||||
case BUILTIN_MATRIX_MUL:
|
||||
return 5;
|
||||
case BUILTIN_MEMCOPY:
|
||||
case BUILTIN_MEMCOPY_INLINE:
|
||||
|
||||
@@ -236,6 +236,8 @@ void symtab_init(uint32_t capacity)
|
||||
builtin_list[BUILTIN_LOG10] = KW_DEF("log10");
|
||||
builtin_list[BUILTIN_MASKED_LOAD] = KW_DEF("masked_load");
|
||||
builtin_list[BUILTIN_MASKED_STORE] = KW_DEF("masked_store");
|
||||
builtin_list[BUILTIN_MATRIX_MUL] = KW_DEF("matrix_mul");
|
||||
builtin_list[BUILTIN_MATRIX_TRANSPOSE] = KW_DEF("matrix_transpose");
|
||||
builtin_list[BUILTIN_MEMCOPY] = KW_DEF("memcpy");
|
||||
builtin_list[BUILTIN_MEMCOPY_INLINE] = KW_DEF("memcpy_inline");
|
||||
builtin_list[BUILTIN_MEMMOVE] = KW_DEF("memmove");
|
||||
|
||||
38
test/test_suite/builtins/matrix_builtin.c3t
Normal file
38
test/test_suite/builtins/matrix_builtin.c3t
Normal file
@@ -0,0 +1,38 @@
|
||||
// #target: macos-x64
|
||||
module test;
|
||||
fn int main()
|
||||
{
|
||||
int[<4>] x = { 1, 2, 3, 4 };
|
||||
int[<4>] z = $$matrix_mul(x, x, 2, 2, 2);
|
||||
int[<2>] a = { 1, 2 };
|
||||
int[<3>] b = { 1, 2, 3 };
|
||||
int[<6>] c = $$matrix_mul(a, b, 2, 1, 3);
|
||||
c = $$matrix_transpose(c, 2, 3);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* #expect: test.ll
|
||||
|
||||
define i32 @main() #0 {
|
||||
entry:
|
||||
%x = alloca <4 x i32>, align 16
|
||||
%z = alloca <4 x i32>, align 16
|
||||
%a = alloca <2 x i32>, align 8
|
||||
%b = alloca <3 x i32>, align 16
|
||||
%c = alloca <6 x i32>, align 32
|
||||
store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %x, align 16
|
||||
%0 = load <4 x i32>, ptr %x, align 16
|
||||
%1 = load <4 x i32>, ptr %x, align 16
|
||||
%2 = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> %0, <4 x i32> %1, i32 2, i32 2, i32 2)
|
||||
store <4 x i32> %2, ptr %z, align 16
|
||||
store <2 x i32> <i32 1, i32 2>, ptr %a, align 8
|
||||
store <3 x i32> <i32 1, i32 2, i32 3>, ptr %b, align 16
|
||||
%3 = load <2 x i32>, ptr %a, align 8
|
||||
%4 = load <3 x i32>, ptr %b, align 16
|
||||
%5 = call <6 x i32> @llvm.matrix.multiply.v6i32.v2i32.v3i32(<2 x i32> %3, <3 x i32> %4, i32 2, i32 1, i32 3)
|
||||
store <6 x i32> %5, ptr %c, align 32
|
||||
%6 = load <6 x i32>, ptr %c, align 32
|
||||
%7 = call <6 x i32> @llvm.matrix.transpose.v6i32(<6 x i32> %6, i32 2, i32 3)
|
||||
store <6 x i32> %7, ptr %c, align 32
|
||||
ret i32 0
|
||||
}
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "llvm-c/Target.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
#include "llvm/Passes/PassBuilder.h"
|
||||
#include "llvm/Passes/StandardInstrumentations.h"
|
||||
@@ -155,6 +156,8 @@ bool llvm_run_passes(LLVMModuleRef m, LLVMTargetMachineRef tm,
|
||||
llvm::FunctionAnalysisManager FAM;
|
||||
llvm::CGSCCAnalysisManager CGAM;
|
||||
llvm::ModuleAnalysisManager MAM;
|
||||
|
||||
|
||||
PB.registerLoopAnalyses(LAM);
|
||||
PB.registerFunctionAnalyses(FAM);
|
||||
PB.registerCGSCCAnalyses(CGAM);
|
||||
@@ -196,6 +199,7 @@ bool llvm_run_passes(LLVMModuleRef m, LLVMTargetMachineRef tm,
|
||||
#else
|
||||
llvm::ModulePassManager MPM = PB.buildPerModuleDefaultPipeline(level, false);
|
||||
#endif
|
||||
MPM.addPass(llvm::createModuleToFunctionPassAdaptor(llvm::LowerMatrixIntrinsicsPass(false)));
|
||||
if (passes->should_verify)
|
||||
{
|
||||
MPM.addPass(llvm::VerifierPass());
|
||||
|
||||
Reference in New Issue
Block a user