From 49d13c23bbd0df2ddce83abbbdf0897dea33acc7 Mon Sep 17 00:00:00 2001 From: Christoffer Lerno Date: Fri, 10 Sep 2021 15:46:02 +0200 Subject: [PATCH] Fix issue with grouped expressions in macros. Adding spectral norml code example. --- resources/examples/spectralnorm.c3 | 67 ++++++++++++++++++++++++++++ resources/lib/std/array.c3 | 4 +- resources/lib/std/mem.c3 | 4 +- src/compiler/copying.c | 3 +- src/compiler/llvm_codegen.c | 12 ++--- src/compiler/llvm_codegen_internal.h | 2 +- 6 files changed, 80 insertions(+), 12 deletions(-) create mode 100644 resources/examples/spectralnorm.c3 diff --git a/resources/examples/spectralnorm.c3 b/resources/examples/spectralnorm.c3 new file mode 100644 index 000000000..14a8f756d --- /dev/null +++ b/resources/examples/spectralnorm.c3 @@ -0,0 +1,67 @@ +module spectralnorm; +import std::mem; +import std::array; + +extern func int atoi(char *s); +extern func int printf(char *s, ...); +extern func double sqrt(double); + +double[] temparr; + +func double eval_A(int i, int j) +{ + return 1.0 / ((i + j) * (i + j + 1) / 2 + i + 1); +} + +func void eval_A_times_u(double[] u, double[] au) +{ + foreach (i, &val : au) + { + *val = 0; + foreach (j, uval : u) + { + *val += eval_A((int)(i), (int)(j)) * uval; + } + } +} + +func void eval_At_times_u(double[] u, double[] au) +{ + foreach (i, &val : au) + { + *val = 0; + foreach (j, uval : u) + { + *val += eval_A((int)(j), (int)(i)) * uval; + } + } +} + +func void eval_AtA_times_u(double[] u, double[] atau) @noinline +{ + eval_A_times_u(u, temparr); + eval_At_times_u(temparr, atau); +} + +func int main(int argc, char **argv) +{ + int n = (argc == 2) ? atoi(argv[1]) : 2000; + temparr = @array::make(double, n); + double[] u = @array::make(double, n); + double[] v = @array::make(double, n); + foreach(&uval : u) *uval = 1; + for (int i = 0; i < 10; i++) + { + eval_AtA_times_u(u, v); + eval_AtA_times_u(v, u); + } + double vBv; + double vv; + foreach (i, vval : v) + { + vBv += u[i] * vval; + vv += vval * vval; + } + printf("%0.9f\n", sqrt(vBv / vv)); + return 0; +} diff --git a/resources/lib/std/array.c3 b/resources/lib/std/array.c3 index f588ee0d4..74b858931 100644 --- a/resources/lib/std/array.c3 +++ b/resources/lib/std/array.c3 @@ -4,13 +4,13 @@ import std::mem; macro make($Type, usize elements) { assert(elements > 0); - $Type* ptr = mem::alloc($Type.sizeof, elements); + $Type* ptr = mem::alloc($sizeof($Type), elements); return ptr[0..(elements - 1)]; } macro make_zero($Type, usize elements) { assert(elements > 0); - $Type* ptr = mem::calloc($Type.sizeof, elements); + $Type* ptr = mem::calloc($sizeof($Type), elements); return ptr[0..(elements - 1)]; } diff --git a/resources/lib/std/mem.c3 b/resources/lib/std/mem.c3 index 910bead58..221846caf 100644 --- a/resources/lib/std/mem.c3 +++ b/resources/lib/std/mem.c3 @@ -113,9 +113,9 @@ macro malloc($Type) return ($Type*)(mem::alloc($sizeof($Type))); } -func void* alloc(usize size, usize elements = 1) @inline +func void* alloc(usize size, usize count = 1) @inline { - return _malloc(size * elements); + return _malloc(size * count); } func void* calloc(usize size, usize elements = 1) @inline diff --git a/src/compiler/copying.c b/src/compiler/copying.c index 4d7d5dc6e..41f2b99d4 100644 --- a/src/compiler/copying.c +++ b/src/compiler/copying.c @@ -200,7 +200,7 @@ Expr *copy_expr(Expr *source_expr) MACRO_COPY_EXPR(expr->subscript_expr.index); return expr; case EXPR_GROUP: - MACRO_COPY_EXPR(expr->group_expr->group_expr); + MACRO_COPY_EXPR(expr->group_expr); return expr; case EXPR_ACCESS: MACRO_COPY_EXPR(expr->access_expr.parent); @@ -512,7 +512,6 @@ Decl *copy_decl(Decl *decl) break; case DECL_BITSTRUCT: TODO - break; case DECL_LABEL: TODO break; diff --git a/src/compiler/llvm_codegen.c b/src/compiler/llvm_codegen.c index 4b64c70f6..f4de9f089 100644 --- a/src/compiler/llvm_codegen.c +++ b/src/compiler/llvm_codegen.c @@ -467,14 +467,15 @@ void gencontext_print_llvm_ir(GenContext *context) } -LLVMValueRef llvm_emit_alloca(GenContext *context, LLVMTypeRef type, unsigned alignment, const char *name) +LLVMValueRef llvm_emit_alloca(GenContext *c, LLVMTypeRef type, unsigned alignment, const char *name) { assert(alignment > 0); - LLVMBasicBlockRef current_block = LLVMGetInsertBlock(context->builder); - LLVMPositionBuilderBefore(context->builder, context->alloca_point); - LLVMValueRef alloca = LLVMBuildAlloca(context->builder, type, name); + LLVMBasicBlockRef current_block = LLVMGetInsertBlock(c->builder); + LLVMPositionBuilderBefore(c->builder, c->alloca_point); + assert(LLVMGetTypeContext(type) == c->context); + LLVMValueRef alloca = LLVMBuildAlloca(c->builder, type, name); llvm_set_alignment(alloca, alignment); - LLVMPositionBuilderAtEnd(context->builder, current_block); + LLVMPositionBuilderAtEnd(c->builder, current_block); return alloca; } @@ -1217,6 +1218,7 @@ void llvm_emit_memcpy_to_decl(GenContext *c, Decl *decl, LLVMValueRef source, un LLVMValueRef llvm_emit_load_aligned(GenContext *c, LLVMTypeRef type, LLVMValueRef pointer, AlignSize alignment, const char *name) { LLVMValueRef value = LLVMBuildLoad2(c->builder, type, pointer, name); + assert(LLVMGetTypeContext(type) == c->context); llvm_set_alignment(value, alignment ? alignment : llvm_abi_alignment(c, type)); return value; } diff --git a/src/compiler/llvm_codegen_internal.h b/src/compiler/llvm_codegen_internal.h index a24aef5ac..954095de3 100644 --- a/src/compiler/llvm_codegen_internal.h +++ b/src/compiler/llvm_codegen_internal.h @@ -213,7 +213,7 @@ LLVMBasicBlockRef llvm_basic_block_new(GenContext *c, const char *name); static inline LLVMValueRef llvm_const_int(GenContext *c, Type *type, uint64_t val); LLVMValueRef llvm_emit_const_padding(GenContext *c, ByteSize size); LLVMTypeRef llvm_const_padding_type(GenContext *c, ByteSize size); -LLVMValueRef llvm_emit_alloca(GenContext *context, LLVMTypeRef type, unsigned alignment, const char *name); +LLVMValueRef llvm_emit_alloca(GenContext *c, LLVMTypeRef type, unsigned alignment, const char *name); LLVMValueRef llvm_emit_alloca_aligned(GenContext *c, Type *type, const char *name); BEValue llvm_emit_assign_expr(GenContext *context, BEValue *ref, Expr *expr, LLVMValueRef failable); static inline LLVMValueRef llvm_emit_bitcast(GenContext *context, LLVMValueRef value, Type *type);