From 3a1bba19afb3e76e3c2faa7f70bc006ac96056b4 Mon Sep 17 00:00:00 2001 From: Christoffer Lerno Date: Thu, 9 Jan 2025 22:32:59 +0100 Subject: [PATCH] Allow test runners to take String[] arguments. --- lib/std/core/private/main_stub.c3 | 14 ++ lib/std/core/runtime.c3 | 326 ------------------------------ lib/std/core/runtime_benchmark.c3 | 177 ++++++++++++++++ lib/std/core/runtime_test.c3 | 157 ++++++++++++++ releasenotes.md | 1 + src/compiler/compiler_internal.h | 2 - src/compiler/llvm_codegen.c | 58 +----- src/compiler/sema_decls.c | 65 +++++- src/compiler/sema_internal.h | 2 + src/compiler/semantic_analyser.c | 37 +++- 10 files changed, 446 insertions(+), 393 deletions(-) create mode 100644 lib/std/core/runtime_benchmark.c3 create mode 100644 lib/std/core/runtime_test.c3 diff --git a/lib/std/core/private/main_stub.c3 b/lib/std/core/private/main_stub.c3 index 1c7e50b24..f64ded9e2 100644 --- a/lib/std/core/private/main_stub.c3 +++ b/lib/std/core/private/main_stub.c3 @@ -47,6 +47,13 @@ macro int @main_to_int_main_args(#m, int argc, char** argv) return #m(list); } +macro int @_main_runner(#m, int argc, char** argv) +{ + String[] list = args_to_strings(argc, argv); + defer free(list.ptr); + return #m(list) ? 0 : 1; +} + macro int @main_to_void_main_args(#m, int argc, char** argv) { String[] list = args_to_strings(argc, argv); @@ -157,6 +164,13 @@ macro int @wmain_to_int_main_args(#m, int argc, Char16** argv) return #m(args); } +macro int @_wmain_runner(#m, int argc, Char16** argv) +{ + String[] args = wargs_strings(argc, argv); + defer release_wargs(args); + return #m(args) ? 0 : 1; +} + macro int @wmain_to_void_main_args(#m, int argc, Char16** argv) { String[] args = wargs_strings(argc, argv); diff --git a/lib/std/core/runtime.c3 b/lib/std/core/runtime.c3 index df0f01489..65852328a 100644 --- a/lib/std/core/runtime.c3 +++ b/lib/std/core/runtime.c3 @@ -22,332 +22,6 @@ struct SliceRaw usz len; } -def BenchmarkFn = fn void!() @if($$OLD_TEST); -def BenchmarkFn = fn void() @if(!$$OLD_TEST); - -struct BenchmarkUnit -{ - String name; - BenchmarkFn func; -} - -fn BenchmarkUnit[] benchmark_collection_create(Allocator allocator = allocator::heap()) -{ - BenchmarkFn[] fns = $$BENCHMARK_FNS; - String[] names = $$BENCHMARK_NAMES; - BenchmarkUnit[] benchmarks = allocator::alloc_array(allocator, BenchmarkUnit, names.len); - foreach (i, benchmark : fns) - { - benchmarks[i] = { names[i], fns[i] }; - } - return benchmarks; -} - -const DEFAULT_BENCHMARK_WARMUP_ITERATIONS = 3; -const DEFAULT_BENCHMARK_MAX_ITERATIONS = 10000; - -uint benchmark_warmup_iterations @private = DEFAULT_BENCHMARK_WARMUP_ITERATIONS; -uint benchmark_max_iterations @private = DEFAULT_BENCHMARK_MAX_ITERATIONS; - -fn void set_benchmark_warmup_iterations(uint value) @builtin -{ - benchmark_warmup_iterations = value; -} - -fn void set_benchmark_max_iterations(uint value) @builtin -{ - assert(value > 0); - benchmark_max_iterations = value; -} - -fn bool run_benchmarks(BenchmarkUnit[] benchmarks) @if($$OLD_TEST) -{ - int benchmarks_passed = 0; - int benchmark_count = benchmarks.len; - usz max_name; - - foreach (&unit : benchmarks) - { - if (max_name < unit.name.len) max_name = unit.name.len; - } - - usz len = max_name + 9; - - DString name = dstring::temp_with_capacity(64); - name.append_repeat('-', len / 2); - name.append(" BENCHMARKS "); - name.append_repeat('-', len - len / 2); - - io::printn(name); - - name.clear(); - - long sys_clock_started; - long sys_clock_finished; - long sys_clocks; - Clock clock; - anyfault err; - - foreach(unit : benchmarks) - { - defer name.clear(); - name.appendf("Benchmarking %s ", unit.name); - name.append_repeat('.', max_name - unit.name.len + 2); - io::printf("%s ", name.str_view()); - - for (uint i = 0; i < benchmark_warmup_iterations; i++) - { - err = @catch(unit.func()) @inline; - @volatile_load(err); - } - - clock = std::time::clock::now(); - sys_clock_started = $$sysclock(); - - for (uint i = 0; i < benchmark_max_iterations; i++) - { - err = @catch(unit.func()) @inline; - @volatile_load(err); - } - - sys_clock_finished = $$sysclock(); - NanoDuration nano_seconds = clock.mark(); - sys_clocks = sys_clock_finished - sys_clock_started; - - if (err) - { - io::printfn("[failed] Failed due to: %s", err); - continue; - } - - io::printfn("[ok] %.2f ns, %.2f CPU's clocks", (float)nano_seconds / benchmark_max_iterations, (float)sys_clocks / benchmark_max_iterations); - benchmarks_passed++; - } - - io::printfn("\n%d benchmark%s run.\n", benchmark_count, benchmark_count > 1 ? "s" : ""); - io::printfn("Benchmarks Result: %s. %d passed, %d failed.", - benchmarks_passed < benchmark_count ? "FAILED" : "ok", - benchmarks_passed, - benchmark_count - benchmarks_passed); - - return benchmark_count == benchmarks_passed; -} - -fn bool run_benchmarks(BenchmarkUnit[] benchmarks) @if(!$$OLD_TEST) -{ - usz max_name; - - foreach (&unit : benchmarks) - { - if (max_name < unit.name.len) max_name = unit.name.len; - } - - usz len = max_name + 9; - - DString name = dstring::temp_with_capacity(64); - name.append_repeat('-', len / 2); - name.append(" BENCHMARKS "); - name.append_repeat('-', len - len / 2); - - io::printn(name); - - name.clear(); - - long sys_clock_started; - long sys_clock_finished; - long sys_clocks; - Clock clock; - - foreach(unit : benchmarks) - { - defer name.clear(); - name.appendf("Benchmarking %s ", unit.name); - name.append_repeat('.', max_name - unit.name.len + 2); - io::printf("%s ", name.str_view()); - - for (uint i = 0; i < benchmark_warmup_iterations; i++) - { - unit.func() @inline; - } - - clock = std::time::clock::now(); - sys_clock_started = $$sysclock(); - - for (uint i = 0; i < benchmark_max_iterations; i++) - { - unit.func() @inline; - } - - sys_clock_finished = $$sysclock(); - NanoDuration nano_seconds = clock.mark(); - sys_clocks = sys_clock_finished - sys_clock_started; - - io::printfn("[COMPLETE] %.2f ns, %.2f CPU's clocks", (float)nano_seconds / benchmark_max_iterations, (float)sys_clocks / benchmark_max_iterations); - } - - io::printfn("\n%d benchmark%s run.\n", benchmarks.len, benchmarks.len > 1 ? "s" : ""); - return true; -} - -fn bool default_benchmark_runner() -{ - @pool() - { - return run_benchmarks(benchmark_collection_create(allocator::temp())); - }; -} - -def TestFn = fn void!() @if($$OLD_TEST); -def TestFn = fn void() @if(!$$OLD_TEST); - -struct TestUnit -{ - String name; - TestFn func; -} - -fn TestUnit[] test_collection_create(Allocator allocator = allocator::heap()) -{ - TestFn[] fns = $$TEST_FNS; - String[] names = $$TEST_NAMES; - TestUnit[] tests = allocator::alloc_array(allocator, TestUnit, names.len); - foreach (i, test : fns) - { - tests[i] = { names[i], fns[i] }; - } - return tests; -} - -struct TestContext -{ - JmpBuf buf; -} - -// Sort the tests by their name in ascending order. -fn int cmp_test_unit(TestUnit a, TestUnit b) -{ - usz an = a.name.len; - usz bn = b.name.len; - if (an > bn) @swap(a, b); - foreach (i, ac : a.name) - { - char bc = b.name[i]; - if (ac != bc) return an > bn ? bc - ac : ac - bc; - } - return (int)(an - bn); -} - -TestContext* test_context @private; - -fn void test_panic(String message, String file, String function, uint line) -{ - io::printn("[error]"); - io::print("\n Error: "); - io::print(message); - io::printn(); - io::printfn(" - in %s %s:%s.\n", function, file, line); - libc::longjmp(&test_context.buf, 1); -} - -fn bool run_tests(TestUnit[] tests) @if($$OLD_TEST) -{ - usz max_name; - foreach (&unit : tests) - { - if (max_name < unit.name.len) max_name = unit.name.len; - } - quicksort(tests, &cmp_test_unit); - - TestContext context; - test_context = &context; - - PanicFn old_panic = builtin::panic; - defer builtin::panic = old_panic; - builtin::panic = &test_panic; - int tests_passed = 0; - int test_count = tests.len; - DString name = dstring::temp_with_capacity(64); - usz len = max_name + 9; - name.append_repeat('-', len / 2); - name.append(" TESTS "); - name.append_repeat('-', len - len / 2); - io::printn(name); - name.clear(); - foreach(unit : tests) - { - defer name.clear(); - name.appendf("Testing %s ", unit.name); - name.append_repeat('.', max_name - unit.name.len + 2); - io::printf("%s ", name.str_view()); - (void)io::stdout().flush(); - if (libc::setjmp(&context.buf) == 0) - { - if (catch err = unit.func()) - { - io::printfn("[failed] Failed due to: %s", err); - continue; - } - io::printn("[ok]"); - tests_passed++; - } - } - io::printfn("\n%d test%s run.\n", test_count, test_count > 1 ? "s" : ""); - io::printfn("Test Result: %s. %d passed, %d failed.", - tests_passed < test_count ? "FAILED" : "ok", tests_passed, test_count - tests_passed); - return test_count == tests_passed; -} - -fn bool run_tests(TestUnit[] tests) @if(!$$OLD_TEST) -{ - usz max_name; - foreach (&unit : tests) - { - if (max_name < unit.name.len) max_name = unit.name.len; - } - quicksort(tests, &cmp_test_unit); - - TestContext context; - test_context = &context; - - PanicFn old_panic = builtin::panic; - defer builtin::panic = old_panic; - builtin::panic = &test_panic; - int tests_passed = 0; - int test_count = tests.len; - DString name = dstring::temp_with_capacity(64); - usz len = max_name + 9; - name.append_repeat('-', len / 2); - name.append(" TESTS "); - name.append_repeat('-', len - len / 2); - io::printn(name); - name.clear(); - foreach(unit : tests) - { - defer name.clear(); - name.appendf("Testing %s ", unit.name); - name.append_repeat('.', max_name - unit.name.len + 2); - io::printf("%s ", name.str_view()); - (void)io::stdout().flush(); - if (libc::setjmp(&context.buf) == 0) - { - unit.func(); - io::printn("[ok]"); - tests_passed++; - } - } - io::printfn("\n%d test%s run.\n", test_count, test_count > 1 ? "s" : ""); - io::printfn("Test Result: %s. %d passed, %d failed.", - tests_passed < test_count ? "FAILED" : "ok", tests_passed, test_count - tests_passed); - return test_count == tests_passed; -} - -fn bool default_test_runner() -{ - @pool() - { - return run_tests(test_collection_create(allocator::temp())); - }; -} module std::core::runtime @if(WASM_NOLIBC); diff --git a/lib/std/core/runtime_benchmark.c3 b/lib/std/core/runtime_benchmark.c3 new file mode 100644 index 000000000..d54efd3c7 --- /dev/null +++ b/lib/std/core/runtime_benchmark.c3 @@ -0,0 +1,177 @@ +module std::core::runtime; +import libc, std::time, std::io, std::sort; + +def BenchmarkFn = fn void!() @if($$OLD_TEST); +def BenchmarkFn = fn void() @if(!$$OLD_TEST); + +struct BenchmarkUnit +{ + String name; + BenchmarkFn func; +} + +fn BenchmarkUnit[] benchmark_collection_create(Allocator allocator = allocator::heap()) +{ + BenchmarkFn[] fns = $$BENCHMARK_FNS; + String[] names = $$BENCHMARK_NAMES; + BenchmarkUnit[] benchmarks = allocator::alloc_array(allocator, BenchmarkUnit, names.len); + foreach (i, benchmark : fns) + { + benchmarks[i] = { names[i], fns[i] }; + } + return benchmarks; +} + +const DEFAULT_BENCHMARK_WARMUP_ITERATIONS = 3; +const DEFAULT_BENCHMARK_MAX_ITERATIONS = 10000; + +uint benchmark_warmup_iterations @private = DEFAULT_BENCHMARK_WARMUP_ITERATIONS; +uint benchmark_max_iterations @private = DEFAULT_BENCHMARK_MAX_ITERATIONS; + +fn void set_benchmark_warmup_iterations(uint value) @builtin +{ + benchmark_warmup_iterations = value; +} + +fn void set_benchmark_max_iterations(uint value) @builtin +{ + assert(value > 0); + benchmark_max_iterations = value; +} + +fn bool run_benchmarks(BenchmarkUnit[] benchmarks) @if($$OLD_TEST) +{ + int benchmarks_passed = 0; + int benchmark_count = benchmarks.len; + usz max_name; + + foreach (&unit : benchmarks) + { + if (max_name < unit.name.len) max_name = unit.name.len; + } + + usz len = max_name + 9; + + DString name = dstring::temp_with_capacity(64); + name.append_repeat('-', len / 2); + name.append(" BENCHMARKS "); + name.append_repeat('-', len - len / 2); + + io::printn(name); + + name.clear(); + + long sys_clock_started; + long sys_clock_finished; + long sys_clocks; + Clock clock; + anyfault err; + + foreach(unit : benchmarks) + { + defer name.clear(); + name.appendf("Benchmarking %s ", unit.name); + name.append_repeat('.', max_name - unit.name.len + 2); + io::printf("%s ", name.str_view()); + + for (uint i = 0; i < benchmark_warmup_iterations; i++) + { + err = @catch(unit.func()) @inline; + @volatile_load(err); + } + + clock = std::time::clock::now(); + sys_clock_started = $$sysclock(); + + for (uint i = 0; i < benchmark_max_iterations; i++) + { + err = @catch(unit.func()) @inline; + @volatile_load(err); + } + + sys_clock_finished = $$sysclock(); + NanoDuration nano_seconds = clock.mark(); + sys_clocks = sys_clock_finished - sys_clock_started; + + if (err) + { + io::printfn("[failed] Failed due to: %s", err); + continue; + } + + io::printfn("[ok] %.2f ns, %.2f CPU's clocks", (float)nano_seconds / benchmark_max_iterations, (float)sys_clocks / benchmark_max_iterations); + benchmarks_passed++; + } + + io::printfn("\n%d benchmark%s run.\n", benchmark_count, benchmark_count > 1 ? "s" : ""); + io::printfn("Benchmarks Result: %s. %d passed, %d failed.", + benchmarks_passed < benchmark_count ? "FAILED" : "ok", + benchmarks_passed, + benchmark_count - benchmarks_passed); + + return benchmark_count == benchmarks_passed; +} + +fn bool run_benchmarks(BenchmarkUnit[] benchmarks) @if(!$$OLD_TEST) +{ + usz max_name; + + foreach (&unit : benchmarks) + { + if (max_name < unit.name.len) max_name = unit.name.len; + } + + usz len = max_name + 9; + + DString name = dstring::temp_with_capacity(64); + name.append_repeat('-', len / 2); + name.append(" BENCHMARKS "); + name.append_repeat('-', len - len / 2); + + io::printn(name); + + name.clear(); + + long sys_clock_started; + long sys_clock_finished; + long sys_clocks; + Clock clock; + + foreach(unit : benchmarks) + { + defer name.clear(); + name.appendf("Benchmarking %s ", unit.name); + name.append_repeat('.', max_name - unit.name.len + 2); + io::printf("%s ", name.str_view()); + + for (uint i = 0; i < benchmark_warmup_iterations; i++) + { + unit.func() @inline; + } + + clock = std::time::clock::now(); + sys_clock_started = $$sysclock(); + + for (uint i = 0; i < benchmark_max_iterations; i++) + { + unit.func() @inline; + } + + sys_clock_finished = $$sysclock(); + NanoDuration nano_seconds = clock.mark(); + sys_clocks = sys_clock_finished - sys_clock_started; + + io::printfn("[COMPLETE] %.2f ns, %.2f CPU's clocks", (float)nano_seconds / benchmark_max_iterations, (float)sys_clocks / benchmark_max_iterations); + } + + io::printfn("\n%d benchmark%s run.\n", benchmarks.len, benchmarks.len > 1 ? "s" : ""); + return true; +} + +fn bool default_benchmark_runner(String[] args) +{ + @pool() + { + return run_benchmarks(benchmark_collection_create(allocator::temp())); + }; +} diff --git a/lib/std/core/runtime_test.c3 b/lib/std/core/runtime_test.c3 new file mode 100644 index 000000000..622a17e52 --- /dev/null +++ b/lib/std/core/runtime_test.c3 @@ -0,0 +1,157 @@ +// Copyright (c) 2025 Christoffer Lerno. All rights reserved. +// Use of this source code is governed by the MIT license +// a copy of which can be found in the LICENSE_STDLIB file. +module std::core::runtime; +import libc, std::time, std::io, std::sort; + +def TestFn = fn void!() @if($$OLD_TEST); +def TestFn = fn void() @if(!$$OLD_TEST); + +struct TestUnit +{ + String name; + TestFn func; +} + +fn TestUnit[] test_collection_create(Allocator allocator = allocator::heap()) +{ + TestFn[] fns = $$TEST_FNS; + String[] names = $$TEST_NAMES; + TestUnit[] tests = allocator::alloc_array(allocator, TestUnit, names.len); + foreach (i, test : fns) + { + tests[i] = { names[i], fns[i] }; + } + return tests; +} + +struct TestContext +{ + JmpBuf buf; +} + +// Sort the tests by their name in ascending order. +fn int cmp_test_unit(TestUnit a, TestUnit b) +{ + usz an = a.name.len; + usz bn = b.name.len; + if (an > bn) @swap(a, b); + foreach (i, ac : a.name) + { + char bc = b.name[i]; + if (ac != bc) return an > bn ? bc - ac : ac - bc; + } + return (int)(an - bn); +} + +TestContext* test_context @private; + +fn void test_panic(String message, String file, String function, uint line) +{ + io::printn("[error]"); + io::print("\n Error: "); + io::print(message); + io::printn(); + io::printfn(" - in %s %s:%s.\n", function, file, line); + libc::longjmp(&test_context.buf, 1); +} + +fn bool run_tests(TestUnit[] tests) @if($$OLD_TEST) +{ + usz max_name; + foreach (&unit : tests) + { + if (max_name < unit.name.len) max_name = unit.name.len; + } + quicksort(tests, &cmp_test_unit); + + TestContext context; + test_context = &context; + + PanicFn old_panic = builtin::panic; + defer builtin::panic = old_panic; + builtin::panic = &test_panic; + int tests_passed = 0; + int test_count = tests.len; + DString name = dstring::temp_with_capacity(64); + usz len = max_name + 9; + name.append_repeat('-', len / 2); + name.append(" TESTS "); + name.append_repeat('-', len - len / 2); + io::printn(name); + name.clear(); + foreach(unit : tests) + { + defer name.clear(); + name.appendf("Testing %s ", unit.name); + name.append_repeat('.', max_name - unit.name.len + 2); + io::printf("%s ", name.str_view()); + (void)io::stdout().flush(); + if (libc::setjmp(&context.buf) == 0) + { + if (catch err = unit.func()) + { + io::printfn("[failed] Failed due to: %s", err); + continue; + } + io::printn("[ok]"); + tests_passed++; + } + } + io::printfn("\n%d test%s run.\n", test_count, test_count > 1 ? "s" : ""); + io::printfn("Test Result: %s. %d passed, %d failed.", + tests_passed < test_count ? "FAILED" : "ok", tests_passed, test_count - tests_passed); + return test_count == tests_passed; +} + +fn bool run_tests(TestUnit[] tests) @if(!$$OLD_TEST) +{ + usz max_name; + foreach (&unit : tests) + { + if (max_name < unit.name.len) max_name = unit.name.len; + } + quicksort(tests, &cmp_test_unit); + + TestContext context; + test_context = &context; + + PanicFn old_panic = builtin::panic; + defer builtin::panic = old_panic; + builtin::panic = &test_panic; + int tests_passed = 0; + int test_count = tests.len; + DString name = dstring::temp_with_capacity(64); + usz len = max_name + 9; + name.append_repeat('-', len / 2); + name.append(" TESTS "); + name.append_repeat('-', len - len / 2); + io::printn(name); + name.clear(); + foreach(unit : tests) + { + defer name.clear(); + name.appendf("Testing %s ", unit.name); + name.append_repeat('.', max_name - unit.name.len + 2); + io::printf("%s ", name.str_view()); + (void)io::stdout().flush(); + if (libc::setjmp(&context.buf) == 0) + { + unit.func(); + io::printn("[ok]"); + tests_passed++; + } + } + io::printfn("\n%d test%s run.\n", test_count, test_count > 1 ? "s" : ""); + io::printfn("Test Result: %s. %d passed, %d failed.", + tests_passed < test_count ? "FAILED" : "ok", tests_passed, test_count - tests_passed); + return test_count == tests_passed; +} + +fn bool default_test_runner(String[] args) +{ + @pool() + { + return run_tests(test_collection_create(allocator::temp())); + }; +} diff --git a/releasenotes.md b/releasenotes.md index d15efc1ba..99e124a60 100644 --- a/releasenotes.md +++ b/releasenotes.md @@ -24,6 +24,7 @@ - Deprecated '&' macro arguments. - Deprecate `fn void! main() type main functions. - Deprecate old `void!` @benchmark and @test functions. +- Allow test runners to take String[] arguments. ### Fixes - Fix case trying to initialize a `char[*]*` from a String. diff --git a/src/compiler/compiler_internal.h b/src/compiler/compiler_internal.h index 5397c17a5..ac86af145 100644 --- a/src/compiler/compiler_internal.h +++ b/src/compiler/compiler_internal.h @@ -1863,8 +1863,6 @@ typedef struct Decl *panicf; Decl *io_error_file_not_found; Decl *main; - Decl *test_func; - Decl *benchmark_func; Decl *decl_stack[MAX_GLOBAL_DECL_STACK]; Decl **decl_stack_bottom; Decl **decl_stack_top; diff --git a/src/compiler/llvm_codegen.c b/src/compiler/llvm_codegen.c index 6d50d4d3c..3341ebb48 100644 --- a/src/compiler/llvm_codegen.c +++ b/src/compiler/llvm_codegen.c @@ -1343,28 +1343,6 @@ LLVMValueRef llvm_get_ref(GenContext *c, Decl *decl) UNREACHABLE } -static void llvm_gen_test_main(GenContext *c) -{ - Decl *test_runner = compiler.context.test_func; - if (!test_runner) - { - error_exit("No test runner found."); - } - ASSERT0(!compiler.context.main && "Main should not be set if a test main is generated."); - compiler.context.main = test_runner; - LLVMTypeRef cint = llvm_get_type(c, type_cint); - LLVMTypeRef main_type = LLVMFunctionType(cint, NULL, 0, true); - LLVMTypeRef runner_type = LLVMFunctionType(c->byte_type, NULL, 0, true); - LLVMValueRef func = LLVMAddFunction(c->module, kw_main, main_type); - scratch_buffer_set_extern_decl_name(test_runner, true); - LLVMValueRef other_func = LLVMAddFunction(c->module, scratch_buffer_to_string(), runner_type); - LLVMBuilderRef builder = llvm_create_function_entry(c, func, NULL); - LLVMValueRef val = LLVMBuildCall2(builder, runner_type, other_func, NULL, 0, ""); - val = LLVMBuildSelect(builder, LLVMBuildTrunc(builder, val, c->bool_type, ""), - LLVMConstNull(cint), LLVMConstInt(cint, 1, false), ""); - LLVMBuildRet(builder, val); - LLVMDisposeBuilder(builder); -} INLINE GenContext *llvm_gen_tests(Module** modules, unsigned module_count, LLVMContextRef shared_context) { @@ -1431,37 +1409,10 @@ INLINE GenContext *llvm_gen_tests(Module** modules, unsigned module_count, LLVMC LLVMSetGlobalConstant(decl_list, 1); LLVMSetInitializer(decl_list, llvm_emit_aggregate_two(c, decls_array_type, decl_ref, count)); - if (compiler.build.type == TARGET_TYPE_TEST) - { - llvm_gen_test_main(c); - } - compiler.build.debug_info = actual_debug_info; return c; } -static void llvm_gen_benchmark_main(GenContext *c) -{ - Decl *benchmark_runner = compiler.context.benchmark_func; - if (!benchmark_runner) - { - error_exit("No benchmark runner found."); - } - ASSERT0(!compiler.context.main && "Main should not be set if a benchmark main is generated."); - compiler.context.main = benchmark_runner; - LLVMTypeRef cint = llvm_get_type(c, type_cint); - LLVMTypeRef main_type = LLVMFunctionType(cint, NULL, 0, true); - LLVMTypeRef runner_type = LLVMFunctionType(c->byte_type, NULL, 0, true); - LLVMValueRef func = LLVMAddFunction(c->module, kw_main, main_type); - scratch_buffer_set_extern_decl_name(benchmark_runner, true); - LLVMValueRef other_func = LLVMAddFunction(c->module, scratch_buffer_to_string(), runner_type); - LLVMBuilderRef builder = llvm_create_function_entry(c, func, NULL); - LLVMValueRef val = LLVMBuildCall2(builder, runner_type, other_func, NULL, 0, ""); - val = LLVMBuildSelect(builder, LLVMBuildTrunc(builder, val, c->bool_type, ""), - LLVMConstNull(cint), LLVMConstInt(cint, 1, false), ""); - LLVMBuildRet(builder, val); - LLVMDisposeBuilder(builder); -} INLINE GenContext *llvm_gen_benchmarks(Module** modules, unsigned module_count, LLVMContextRef shared_context) { @@ -1527,11 +1478,6 @@ INLINE GenContext *llvm_gen_benchmarks(Module** modules, unsigned module_count, LLVMSetGlobalConstant(decl_list, 1); LLVMSetInitializer(decl_list, llvm_emit_aggregate_two(c, decls_array_type, decl_ref, count)); - if (compiler.build.type == TARGET_TYPE_BENCHMARK) - { - llvm_gen_benchmark_main(c); - } - compiler.build.debug_info = actual_debug_info; return c; } @@ -1678,7 +1624,7 @@ static GenContext *llvm_gen_module(Module *module, LLVMContextRef shared_context llvm_emit_function_decl(gen_context, func); } - if (compiler.build.type != TARGET_TYPE_TEST && compiler.build.type != TARGET_TYPE_BENCHMARK && unit->main_function && unit->main_function->is_synthetic) + if (unit->main_function && unit->main_function->is_synthetic) { has_elements = true; llvm_emit_function_decl(gen_context, unit->main_function); @@ -1726,7 +1672,7 @@ static GenContext *llvm_gen_module(Module *module, LLVMContextRef shared_context llvm_emit_function_body(gen_context, func); } - if (compiler.build.type != TARGET_TYPE_TEST && compiler.build.type != TARGET_TYPE_BENCHMARK && unit->main_function && unit->main_function->is_synthetic) + if (unit->main_function && unit->main_function->is_synthetic) { has_elements = true; llvm_emit_function_body(gen_context, unit->main_function); diff --git a/src/compiler/sema_decls.c b/src/compiler/sema_decls.c index 5d53009eb..c51d3b9b9 100755 --- a/src/compiler/sema_decls.c +++ b/src/compiler/sema_decls.c @@ -3266,6 +3266,69 @@ static inline MainType sema_find_main_type(SemaContext *context, Signature *sig, } +Decl *sema_create_runner_main(SemaContext *context, Decl *decl) +{ + bool is_win32 = compiler.platform.os == OS_TYPE_WIN32; + Decl *function = decl_new(DECL_FUNC, NULL, decl->span); + function->is_export = true; + function->has_extname = true; + function->extname = kw_mainstub; + function->name = kw_mainstub; + function->unit = decl->unit; + + // Pick wWinMain, main or wmain + Decl *params[4] = { NULL, NULL, NULL, NULL }; + int param_count; + if (is_win32) + { + function->extname = kw_wmain; + params[0] = decl_new_generated_var(type_cint, VARDECL_PARAM, decl->span); + params[1] = decl_new_generated_var(type_get_ptr(type_get_ptr(type_ushort)), VARDECL_PARAM, decl->span); + param_count = 2; + } + else + { + function->extname = kw_main; + params[0] = decl_new_generated_var(type_cint, VARDECL_PARAM, decl->span); + params[1] = decl_new_generated_var(type_get_ptr(type_get_ptr(type_char)), VARDECL_PARAM, decl->span); + param_count = 2; + } + + function->has_extname = true; + function->func_decl.signature.rtype = type_infoid(type_info_new_base(type_cint, decl->span)); + function->func_decl.signature.vararg_index = param_count; + Decl **main_params = NULL; + for (int i = 0; i < param_count; i++) vec_add(main_params, params[i]); + function->func_decl.signature.params = main_params; + Ast *body = new_ast(AST_COMPOUND_STMT, decl->span); + AstId *next = &body->compound_stmt.first_stmt; + Ast *ret_stmt = new_ast(AST_RETURN_STMT, decl->span); + const char *kw_main_invoker = symtab_preset(is_win32 ? "@_wmain_runner" : "@_main_runner", TOKEN_AT_IDENT); + Decl *d = sema_find_symbol(context, kw_main_invoker); + if (!d) + { + SEMA_ERROR(decl, "Missing main forwarding function '%s'.", kw_main_invoker); + return poisoned_decl; + } + Expr *invoker = expr_new(EXPR_IDENTIFIER, decl->span); + expr_resolve_ident(invoker, d); + Expr *call = expr_new(EXPR_CALL, decl->span); + Expr *fn_ref = expr_variable(decl); + vec_add(call->call_expr.arguments, fn_ref); + for (int i = 0; i < param_count; i++) + { + Expr *arg = expr_variable(params[i]); + vec_add(call->call_expr.arguments, arg); + } + call->call_expr.function = exprid(invoker); + for (int i = 0; i < param_count; i++) params[i]->resolve_status = RESOLVE_NOT_DONE; + ast_append(&next, ret_stmt); + ret_stmt->return_stmt.expr = call; + function->func_decl.body = astid(body); + function->is_synthetic = true; + return function; +} + static inline Decl *sema_create_synthetic_main(SemaContext *context, Decl *decl, MainType main, bool int_return, bool err_return, bool is_winmain, bool is_wmain) { Decl *function = decl_new(DECL_FUNC, NULL, decl->span); @@ -3377,7 +3440,7 @@ static inline Decl *sema_create_synthetic_main(SemaContext *context, Decl *decl, default: UNREACHABLE; } -NEXT:; + NEXT:; const char *kw_main_invoker = symtab_preset(main_invoker, TOKEN_AT_IDENT); Decl *d = sema_find_symbol(context, kw_main_invoker); if (!d) diff --git a/src/compiler/sema_internal.h b/src/compiler/sema_internal.h index 82912332f..1dd5b2877 100644 --- a/src/compiler/sema_internal.h +++ b/src/compiler/sema_internal.h @@ -72,6 +72,8 @@ bool sema_analyse_function_body(SemaContext *context, Decl *func); bool sema_analyse_contracts(SemaContext *context, AstId doc, AstId **asserts, SourceSpan span, bool *has_ensures); void sema_append_contract_asserts(AstId assert_first, Ast* compound_stmt); +Decl *sema_create_runner_main(SemaContext *context, Decl *decl); + void sema_analyse_pass_top(Module *module); void sema_analyse_pass_module_hierarchy(Module *module); void sema_analysis_pass_process_imports(Module *module); diff --git a/src/compiler/semantic_analyser.c b/src/compiler/semantic_analyser.c index 5d17c5f70..aa01bfe72 100644 --- a/src/compiler/semantic_analyser.c +++ b/src/compiler/semantic_analyser.c @@ -286,6 +286,20 @@ static void sema_analyze_to_stage(AnalysisStage stage) halt_on_error(); } +static bool setup_main_runner(Decl *run_function) +{ + SemaContext context; + sema_context_init(&context, run_function->unit); + Decl *main = sema_create_runner_main(&context, run_function); + if (!decl_ok(main)) return false; + if (!sema_analyse_decl(&context, main)) return false; + if (!sema_analyse_function_body(&context, main)) return false; + sema_context_destroy(&context); + compiler.context.main = main; + main->unit->main_function = main; + main->no_strip = true; + return true; +} static void assign_panicfn(void) { if (compiler.build.feature.panic_level == PANIC_OFF || (!compiler.build.panicfn && no_stdlib())) @@ -354,7 +368,7 @@ static void assign_testfn(void) if (!compiler.build.testing) return; if (!compiler.build.testfn && no_stdlib()) { - compiler.context.test_func = NULL; + error_exit("No test function could be found."); return; } const char *testfn = compiler.build.testfn ? compiler.build.testfn : "std::core::runtime::default_test_runner"; @@ -374,12 +388,17 @@ static void assign_testfn(void) { error_exit("'%s::%s' is not a function.", path->module, ident); } - if (!type_func_match(type_get_func_ptr(decl->type->canonical), type_bool, 0)) + if (!type_func_match(type_get_func_ptr(decl->type->canonical), type_bool, 1, type_get_slice(type_string))) { - error_exit("Expected test runner to have the signature fn void()."); + error_exit("Expected test runner to have the signature fn bool(String[])."); } - compiler.context.test_func = decl; decl->no_strip = true; + if (compiler.build.type != TARGET_TYPE_TEST) return; + + if (!setup_main_runner(decl)) + { + error_exit("Failed to set up test runner."); + } } static void assign_benchfn(void) @@ -387,7 +406,6 @@ static void assign_benchfn(void) if (!compiler.build.benchmarking) return; if (!compiler.build.benchfn && no_stdlib()) { - compiler.context.benchmark_func = NULL; return; } const char *testfn = compiler.build.benchfn ? compiler.build.benchfn : "std::core::runtime::default_benchmark_runner"; @@ -407,12 +425,15 @@ static void assign_benchfn(void) { error_exit("'%s::%s' is not a function.", path->module, ident); } - if (!type_func_match(type_get_func_ptr(decl->type->canonical), type_bool, 0)) + if (!type_func_match(type_get_func_ptr(decl->type->canonical), type_bool, 1, type_get_slice(type_string))) { - error_exit("Expected benchmark function to have the signature fn void()."); + error_exit("Expected benchmark function to have the signature fn bool(String[] args)."); } - compiler.context.benchmark_func = decl; decl->no_strip = true; + if (!setup_main_runner(decl)) + { + error_exit("Failed to set up benchmark runner."); + } } /**