Allow test runners to take String[] arguments.

This commit is contained in:
Christoffer Lerno
2025-01-09 22:32:59 +01:00
parent 0857363470
commit 3a1bba19af
10 changed files with 446 additions and 393 deletions

View File

@@ -47,6 +47,13 @@ macro int @main_to_int_main_args(#m, int argc, char** argv)
return #m(list);
}
macro int @_main_runner(#m, int argc, char** argv)
{
String[] list = args_to_strings(argc, argv);
defer free(list.ptr);
return #m(list) ? 0 : 1;
}
macro int @main_to_void_main_args(#m, int argc, char** argv)
{
String[] list = args_to_strings(argc, argv);
@@ -157,6 +164,13 @@ macro int @wmain_to_int_main_args(#m, int argc, Char16** argv)
return #m(args);
}
macro int @_wmain_runner(#m, int argc, Char16** argv)
{
String[] args = wargs_strings(argc, argv);
defer release_wargs(args);
return #m(args) ? 0 : 1;
}
macro int @wmain_to_void_main_args(#m, int argc, Char16** argv)
{
String[] args = wargs_strings(argc, argv);

View File

@@ -22,332 +22,6 @@ struct SliceRaw
usz len;
}
def BenchmarkFn = fn void!() @if($$OLD_TEST);
def BenchmarkFn = fn void() @if(!$$OLD_TEST);
struct BenchmarkUnit
{
String name;
BenchmarkFn func;
}
fn BenchmarkUnit[] benchmark_collection_create(Allocator allocator = allocator::heap())
{
BenchmarkFn[] fns = $$BENCHMARK_FNS;
String[] names = $$BENCHMARK_NAMES;
BenchmarkUnit[] benchmarks = allocator::alloc_array(allocator, BenchmarkUnit, names.len);
foreach (i, benchmark : fns)
{
benchmarks[i] = { names[i], fns[i] };
}
return benchmarks;
}
const DEFAULT_BENCHMARK_WARMUP_ITERATIONS = 3;
const DEFAULT_BENCHMARK_MAX_ITERATIONS = 10000;
uint benchmark_warmup_iterations @private = DEFAULT_BENCHMARK_WARMUP_ITERATIONS;
uint benchmark_max_iterations @private = DEFAULT_BENCHMARK_MAX_ITERATIONS;
fn void set_benchmark_warmup_iterations(uint value) @builtin
{
benchmark_warmup_iterations = value;
}
fn void set_benchmark_max_iterations(uint value) @builtin
{
assert(value > 0);
benchmark_max_iterations = value;
}
fn bool run_benchmarks(BenchmarkUnit[] benchmarks) @if($$OLD_TEST)
{
int benchmarks_passed = 0;
int benchmark_count = benchmarks.len;
usz max_name;
foreach (&unit : benchmarks)
{
if (max_name < unit.name.len) max_name = unit.name.len;
}
usz len = max_name + 9;
DString name = dstring::temp_with_capacity(64);
name.append_repeat('-', len / 2);
name.append(" BENCHMARKS ");
name.append_repeat('-', len - len / 2);
io::printn(name);
name.clear();
long sys_clock_started;
long sys_clock_finished;
long sys_clocks;
Clock clock;
anyfault err;
foreach(unit : benchmarks)
{
defer name.clear();
name.appendf("Benchmarking %s ", unit.name);
name.append_repeat('.', max_name - unit.name.len + 2);
io::printf("%s ", name.str_view());
for (uint i = 0; i < benchmark_warmup_iterations; i++)
{
err = @catch(unit.func()) @inline;
@volatile_load(err);
}
clock = std::time::clock::now();
sys_clock_started = $$sysclock();
for (uint i = 0; i < benchmark_max_iterations; i++)
{
err = @catch(unit.func()) @inline;
@volatile_load(err);
}
sys_clock_finished = $$sysclock();
NanoDuration nano_seconds = clock.mark();
sys_clocks = sys_clock_finished - sys_clock_started;
if (err)
{
io::printfn("[failed] Failed due to: %s", err);
continue;
}
io::printfn("[ok] %.2f ns, %.2f CPU's clocks", (float)nano_seconds / benchmark_max_iterations, (float)sys_clocks / benchmark_max_iterations);
benchmarks_passed++;
}
io::printfn("\n%d benchmark%s run.\n", benchmark_count, benchmark_count > 1 ? "s" : "");
io::printfn("Benchmarks Result: %s. %d passed, %d failed.",
benchmarks_passed < benchmark_count ? "FAILED" : "ok",
benchmarks_passed,
benchmark_count - benchmarks_passed);
return benchmark_count == benchmarks_passed;
}
fn bool run_benchmarks(BenchmarkUnit[] benchmarks) @if(!$$OLD_TEST)
{
usz max_name;
foreach (&unit : benchmarks)
{
if (max_name < unit.name.len) max_name = unit.name.len;
}
usz len = max_name + 9;
DString name = dstring::temp_with_capacity(64);
name.append_repeat('-', len / 2);
name.append(" BENCHMARKS ");
name.append_repeat('-', len - len / 2);
io::printn(name);
name.clear();
long sys_clock_started;
long sys_clock_finished;
long sys_clocks;
Clock clock;
foreach(unit : benchmarks)
{
defer name.clear();
name.appendf("Benchmarking %s ", unit.name);
name.append_repeat('.', max_name - unit.name.len + 2);
io::printf("%s ", name.str_view());
for (uint i = 0; i < benchmark_warmup_iterations; i++)
{
unit.func() @inline;
}
clock = std::time::clock::now();
sys_clock_started = $$sysclock();
for (uint i = 0; i < benchmark_max_iterations; i++)
{
unit.func() @inline;
}
sys_clock_finished = $$sysclock();
NanoDuration nano_seconds = clock.mark();
sys_clocks = sys_clock_finished - sys_clock_started;
io::printfn("[COMPLETE] %.2f ns, %.2f CPU's clocks", (float)nano_seconds / benchmark_max_iterations, (float)sys_clocks / benchmark_max_iterations);
}
io::printfn("\n%d benchmark%s run.\n", benchmarks.len, benchmarks.len > 1 ? "s" : "");
return true;
}
fn bool default_benchmark_runner()
{
@pool()
{
return run_benchmarks(benchmark_collection_create(allocator::temp()));
};
}
def TestFn = fn void!() @if($$OLD_TEST);
def TestFn = fn void() @if(!$$OLD_TEST);
struct TestUnit
{
String name;
TestFn func;
}
fn TestUnit[] test_collection_create(Allocator allocator = allocator::heap())
{
TestFn[] fns = $$TEST_FNS;
String[] names = $$TEST_NAMES;
TestUnit[] tests = allocator::alloc_array(allocator, TestUnit, names.len);
foreach (i, test : fns)
{
tests[i] = { names[i], fns[i] };
}
return tests;
}
struct TestContext
{
JmpBuf buf;
}
// Sort the tests by their name in ascending order.
fn int cmp_test_unit(TestUnit a, TestUnit b)
{
usz an = a.name.len;
usz bn = b.name.len;
if (an > bn) @swap(a, b);
foreach (i, ac : a.name)
{
char bc = b.name[i];
if (ac != bc) return an > bn ? bc - ac : ac - bc;
}
return (int)(an - bn);
}
TestContext* test_context @private;
fn void test_panic(String message, String file, String function, uint line)
{
io::printn("[error]");
io::print("\n Error: ");
io::print(message);
io::printn();
io::printfn(" - in %s %s:%s.\n", function, file, line);
libc::longjmp(&test_context.buf, 1);
}
fn bool run_tests(TestUnit[] tests) @if($$OLD_TEST)
{
usz max_name;
foreach (&unit : tests)
{
if (max_name < unit.name.len) max_name = unit.name.len;
}
quicksort(tests, &cmp_test_unit);
TestContext context;
test_context = &context;
PanicFn old_panic = builtin::panic;
defer builtin::panic = old_panic;
builtin::panic = &test_panic;
int tests_passed = 0;
int test_count = tests.len;
DString name = dstring::temp_with_capacity(64);
usz len = max_name + 9;
name.append_repeat('-', len / 2);
name.append(" TESTS ");
name.append_repeat('-', len - len / 2);
io::printn(name);
name.clear();
foreach(unit : tests)
{
defer name.clear();
name.appendf("Testing %s ", unit.name);
name.append_repeat('.', max_name - unit.name.len + 2);
io::printf("%s ", name.str_view());
(void)io::stdout().flush();
if (libc::setjmp(&context.buf) == 0)
{
if (catch err = unit.func())
{
io::printfn("[failed] Failed due to: %s", err);
continue;
}
io::printn("[ok]");
tests_passed++;
}
}
io::printfn("\n%d test%s run.\n", test_count, test_count > 1 ? "s" : "");
io::printfn("Test Result: %s. %d passed, %d failed.",
tests_passed < test_count ? "FAILED" : "ok", tests_passed, test_count - tests_passed);
return test_count == tests_passed;
}
fn bool run_tests(TestUnit[] tests) @if(!$$OLD_TEST)
{
usz max_name;
foreach (&unit : tests)
{
if (max_name < unit.name.len) max_name = unit.name.len;
}
quicksort(tests, &cmp_test_unit);
TestContext context;
test_context = &context;
PanicFn old_panic = builtin::panic;
defer builtin::panic = old_panic;
builtin::panic = &test_panic;
int tests_passed = 0;
int test_count = tests.len;
DString name = dstring::temp_with_capacity(64);
usz len = max_name + 9;
name.append_repeat('-', len / 2);
name.append(" TESTS ");
name.append_repeat('-', len - len / 2);
io::printn(name);
name.clear();
foreach(unit : tests)
{
defer name.clear();
name.appendf("Testing %s ", unit.name);
name.append_repeat('.', max_name - unit.name.len + 2);
io::printf("%s ", name.str_view());
(void)io::stdout().flush();
if (libc::setjmp(&context.buf) == 0)
{
unit.func();
io::printn("[ok]");
tests_passed++;
}
}
io::printfn("\n%d test%s run.\n", test_count, test_count > 1 ? "s" : "");
io::printfn("Test Result: %s. %d passed, %d failed.",
tests_passed < test_count ? "FAILED" : "ok", tests_passed, test_count - tests_passed);
return test_count == tests_passed;
}
fn bool default_test_runner()
{
@pool()
{
return run_tests(test_collection_create(allocator::temp()));
};
}
module std::core::runtime @if(WASM_NOLIBC);

View File

@@ -0,0 +1,177 @@
module std::core::runtime;
import libc, std::time, std::io, std::sort;
def BenchmarkFn = fn void!() @if($$OLD_TEST);
def BenchmarkFn = fn void() @if(!$$OLD_TEST);
struct BenchmarkUnit
{
String name;
BenchmarkFn func;
}
fn BenchmarkUnit[] benchmark_collection_create(Allocator allocator = allocator::heap())
{
BenchmarkFn[] fns = $$BENCHMARK_FNS;
String[] names = $$BENCHMARK_NAMES;
BenchmarkUnit[] benchmarks = allocator::alloc_array(allocator, BenchmarkUnit, names.len);
foreach (i, benchmark : fns)
{
benchmarks[i] = { names[i], fns[i] };
}
return benchmarks;
}
const DEFAULT_BENCHMARK_WARMUP_ITERATIONS = 3;
const DEFAULT_BENCHMARK_MAX_ITERATIONS = 10000;
uint benchmark_warmup_iterations @private = DEFAULT_BENCHMARK_WARMUP_ITERATIONS;
uint benchmark_max_iterations @private = DEFAULT_BENCHMARK_MAX_ITERATIONS;
fn void set_benchmark_warmup_iterations(uint value) @builtin
{
benchmark_warmup_iterations = value;
}
fn void set_benchmark_max_iterations(uint value) @builtin
{
assert(value > 0);
benchmark_max_iterations = value;
}
fn bool run_benchmarks(BenchmarkUnit[] benchmarks) @if($$OLD_TEST)
{
int benchmarks_passed = 0;
int benchmark_count = benchmarks.len;
usz max_name;
foreach (&unit : benchmarks)
{
if (max_name < unit.name.len) max_name = unit.name.len;
}
usz len = max_name + 9;
DString name = dstring::temp_with_capacity(64);
name.append_repeat('-', len / 2);
name.append(" BENCHMARKS ");
name.append_repeat('-', len - len / 2);
io::printn(name);
name.clear();
long sys_clock_started;
long sys_clock_finished;
long sys_clocks;
Clock clock;
anyfault err;
foreach(unit : benchmarks)
{
defer name.clear();
name.appendf("Benchmarking %s ", unit.name);
name.append_repeat('.', max_name - unit.name.len + 2);
io::printf("%s ", name.str_view());
for (uint i = 0; i < benchmark_warmup_iterations; i++)
{
err = @catch(unit.func()) @inline;
@volatile_load(err);
}
clock = std::time::clock::now();
sys_clock_started = $$sysclock();
for (uint i = 0; i < benchmark_max_iterations; i++)
{
err = @catch(unit.func()) @inline;
@volatile_load(err);
}
sys_clock_finished = $$sysclock();
NanoDuration nano_seconds = clock.mark();
sys_clocks = sys_clock_finished - sys_clock_started;
if (err)
{
io::printfn("[failed] Failed due to: %s", err);
continue;
}
io::printfn("[ok] %.2f ns, %.2f CPU's clocks", (float)nano_seconds / benchmark_max_iterations, (float)sys_clocks / benchmark_max_iterations);
benchmarks_passed++;
}
io::printfn("\n%d benchmark%s run.\n", benchmark_count, benchmark_count > 1 ? "s" : "");
io::printfn("Benchmarks Result: %s. %d passed, %d failed.",
benchmarks_passed < benchmark_count ? "FAILED" : "ok",
benchmarks_passed,
benchmark_count - benchmarks_passed);
return benchmark_count == benchmarks_passed;
}
fn bool run_benchmarks(BenchmarkUnit[] benchmarks) @if(!$$OLD_TEST)
{
usz max_name;
foreach (&unit : benchmarks)
{
if (max_name < unit.name.len) max_name = unit.name.len;
}
usz len = max_name + 9;
DString name = dstring::temp_with_capacity(64);
name.append_repeat('-', len / 2);
name.append(" BENCHMARKS ");
name.append_repeat('-', len - len / 2);
io::printn(name);
name.clear();
long sys_clock_started;
long sys_clock_finished;
long sys_clocks;
Clock clock;
foreach(unit : benchmarks)
{
defer name.clear();
name.appendf("Benchmarking %s ", unit.name);
name.append_repeat('.', max_name - unit.name.len + 2);
io::printf("%s ", name.str_view());
for (uint i = 0; i < benchmark_warmup_iterations; i++)
{
unit.func() @inline;
}
clock = std::time::clock::now();
sys_clock_started = $$sysclock();
for (uint i = 0; i < benchmark_max_iterations; i++)
{
unit.func() @inline;
}
sys_clock_finished = $$sysclock();
NanoDuration nano_seconds = clock.mark();
sys_clocks = sys_clock_finished - sys_clock_started;
io::printfn("[COMPLETE] %.2f ns, %.2f CPU's clocks", (float)nano_seconds / benchmark_max_iterations, (float)sys_clocks / benchmark_max_iterations);
}
io::printfn("\n%d benchmark%s run.\n", benchmarks.len, benchmarks.len > 1 ? "s" : "");
return true;
}
fn bool default_benchmark_runner(String[] args)
{
@pool()
{
return run_benchmarks(benchmark_collection_create(allocator::temp()));
};
}

View File

@@ -0,0 +1,157 @@
// Copyright (c) 2025 Christoffer Lerno. All rights reserved.
// Use of this source code is governed by the MIT license
// a copy of which can be found in the LICENSE_STDLIB file.
module std::core::runtime;
import libc, std::time, std::io, std::sort;
def TestFn = fn void!() @if($$OLD_TEST);
def TestFn = fn void() @if(!$$OLD_TEST);
struct TestUnit
{
String name;
TestFn func;
}
fn TestUnit[] test_collection_create(Allocator allocator = allocator::heap())
{
TestFn[] fns = $$TEST_FNS;
String[] names = $$TEST_NAMES;
TestUnit[] tests = allocator::alloc_array(allocator, TestUnit, names.len);
foreach (i, test : fns)
{
tests[i] = { names[i], fns[i] };
}
return tests;
}
struct TestContext
{
JmpBuf buf;
}
// Sort the tests by their name in ascending order.
fn int cmp_test_unit(TestUnit a, TestUnit b)
{
usz an = a.name.len;
usz bn = b.name.len;
if (an > bn) @swap(a, b);
foreach (i, ac : a.name)
{
char bc = b.name[i];
if (ac != bc) return an > bn ? bc - ac : ac - bc;
}
return (int)(an - bn);
}
TestContext* test_context @private;
fn void test_panic(String message, String file, String function, uint line)
{
io::printn("[error]");
io::print("\n Error: ");
io::print(message);
io::printn();
io::printfn(" - in %s %s:%s.\n", function, file, line);
libc::longjmp(&test_context.buf, 1);
}
fn bool run_tests(TestUnit[] tests) @if($$OLD_TEST)
{
usz max_name;
foreach (&unit : tests)
{
if (max_name < unit.name.len) max_name = unit.name.len;
}
quicksort(tests, &cmp_test_unit);
TestContext context;
test_context = &context;
PanicFn old_panic = builtin::panic;
defer builtin::panic = old_panic;
builtin::panic = &test_panic;
int tests_passed = 0;
int test_count = tests.len;
DString name = dstring::temp_with_capacity(64);
usz len = max_name + 9;
name.append_repeat('-', len / 2);
name.append(" TESTS ");
name.append_repeat('-', len - len / 2);
io::printn(name);
name.clear();
foreach(unit : tests)
{
defer name.clear();
name.appendf("Testing %s ", unit.name);
name.append_repeat('.', max_name - unit.name.len + 2);
io::printf("%s ", name.str_view());
(void)io::stdout().flush();
if (libc::setjmp(&context.buf) == 0)
{
if (catch err = unit.func())
{
io::printfn("[failed] Failed due to: %s", err);
continue;
}
io::printn("[ok]");
tests_passed++;
}
}
io::printfn("\n%d test%s run.\n", test_count, test_count > 1 ? "s" : "");
io::printfn("Test Result: %s. %d passed, %d failed.",
tests_passed < test_count ? "FAILED" : "ok", tests_passed, test_count - tests_passed);
return test_count == tests_passed;
}
fn bool run_tests(TestUnit[] tests) @if(!$$OLD_TEST)
{
usz max_name;
foreach (&unit : tests)
{
if (max_name < unit.name.len) max_name = unit.name.len;
}
quicksort(tests, &cmp_test_unit);
TestContext context;
test_context = &context;
PanicFn old_panic = builtin::panic;
defer builtin::panic = old_panic;
builtin::panic = &test_panic;
int tests_passed = 0;
int test_count = tests.len;
DString name = dstring::temp_with_capacity(64);
usz len = max_name + 9;
name.append_repeat('-', len / 2);
name.append(" TESTS ");
name.append_repeat('-', len - len / 2);
io::printn(name);
name.clear();
foreach(unit : tests)
{
defer name.clear();
name.appendf("Testing %s ", unit.name);
name.append_repeat('.', max_name - unit.name.len + 2);
io::printf("%s ", name.str_view());
(void)io::stdout().flush();
if (libc::setjmp(&context.buf) == 0)
{
unit.func();
io::printn("[ok]");
tests_passed++;
}
}
io::printfn("\n%d test%s run.\n", test_count, test_count > 1 ? "s" : "");
io::printfn("Test Result: %s. %d passed, %d failed.",
tests_passed < test_count ? "FAILED" : "ok", tests_passed, test_count - tests_passed);
return test_count == tests_passed;
}
fn bool default_test_runner(String[] args)
{
@pool()
{
return run_tests(test_collection_create(allocator::temp()));
};
}

View File

@@ -24,6 +24,7 @@
- Deprecated '&' macro arguments.
- Deprecate `fn void! main() type main functions.
- Deprecate old `void!` @benchmark and @test functions.
- Allow test runners to take String[] arguments.
### Fixes
- Fix case trying to initialize a `char[*]*` from a String.

View File

@@ -1863,8 +1863,6 @@ typedef struct
Decl *panicf;
Decl *io_error_file_not_found;
Decl *main;
Decl *test_func;
Decl *benchmark_func;
Decl *decl_stack[MAX_GLOBAL_DECL_STACK];
Decl **decl_stack_bottom;
Decl **decl_stack_top;

View File

@@ -1343,28 +1343,6 @@ LLVMValueRef llvm_get_ref(GenContext *c, Decl *decl)
UNREACHABLE
}
static void llvm_gen_test_main(GenContext *c)
{
Decl *test_runner = compiler.context.test_func;
if (!test_runner)
{
error_exit("No test runner found.");
}
ASSERT0(!compiler.context.main && "Main should not be set if a test main is generated.");
compiler.context.main = test_runner;
LLVMTypeRef cint = llvm_get_type(c, type_cint);
LLVMTypeRef main_type = LLVMFunctionType(cint, NULL, 0, true);
LLVMTypeRef runner_type = LLVMFunctionType(c->byte_type, NULL, 0, true);
LLVMValueRef func = LLVMAddFunction(c->module, kw_main, main_type);
scratch_buffer_set_extern_decl_name(test_runner, true);
LLVMValueRef other_func = LLVMAddFunction(c->module, scratch_buffer_to_string(), runner_type);
LLVMBuilderRef builder = llvm_create_function_entry(c, func, NULL);
LLVMValueRef val = LLVMBuildCall2(builder, runner_type, other_func, NULL, 0, "");
val = LLVMBuildSelect(builder, LLVMBuildTrunc(builder, val, c->bool_type, ""),
LLVMConstNull(cint), LLVMConstInt(cint, 1, false), "");
LLVMBuildRet(builder, val);
LLVMDisposeBuilder(builder);
}
INLINE GenContext *llvm_gen_tests(Module** modules, unsigned module_count, LLVMContextRef shared_context)
{
@@ -1431,37 +1409,10 @@ INLINE GenContext *llvm_gen_tests(Module** modules, unsigned module_count, LLVMC
LLVMSetGlobalConstant(decl_list, 1);
LLVMSetInitializer(decl_list, llvm_emit_aggregate_two(c, decls_array_type, decl_ref, count));
if (compiler.build.type == TARGET_TYPE_TEST)
{
llvm_gen_test_main(c);
}
compiler.build.debug_info = actual_debug_info;
return c;
}
static void llvm_gen_benchmark_main(GenContext *c)
{
Decl *benchmark_runner = compiler.context.benchmark_func;
if (!benchmark_runner)
{
error_exit("No benchmark runner found.");
}
ASSERT0(!compiler.context.main && "Main should not be set if a benchmark main is generated.");
compiler.context.main = benchmark_runner;
LLVMTypeRef cint = llvm_get_type(c, type_cint);
LLVMTypeRef main_type = LLVMFunctionType(cint, NULL, 0, true);
LLVMTypeRef runner_type = LLVMFunctionType(c->byte_type, NULL, 0, true);
LLVMValueRef func = LLVMAddFunction(c->module, kw_main, main_type);
scratch_buffer_set_extern_decl_name(benchmark_runner, true);
LLVMValueRef other_func = LLVMAddFunction(c->module, scratch_buffer_to_string(), runner_type);
LLVMBuilderRef builder = llvm_create_function_entry(c, func, NULL);
LLVMValueRef val = LLVMBuildCall2(builder, runner_type, other_func, NULL, 0, "");
val = LLVMBuildSelect(builder, LLVMBuildTrunc(builder, val, c->bool_type, ""),
LLVMConstNull(cint), LLVMConstInt(cint, 1, false), "");
LLVMBuildRet(builder, val);
LLVMDisposeBuilder(builder);
}
INLINE GenContext *llvm_gen_benchmarks(Module** modules, unsigned module_count, LLVMContextRef shared_context)
{
@@ -1527,11 +1478,6 @@ INLINE GenContext *llvm_gen_benchmarks(Module** modules, unsigned module_count,
LLVMSetGlobalConstant(decl_list, 1);
LLVMSetInitializer(decl_list, llvm_emit_aggregate_two(c, decls_array_type, decl_ref, count));
if (compiler.build.type == TARGET_TYPE_BENCHMARK)
{
llvm_gen_benchmark_main(c);
}
compiler.build.debug_info = actual_debug_info;
return c;
}
@@ -1678,7 +1624,7 @@ static GenContext *llvm_gen_module(Module *module, LLVMContextRef shared_context
llvm_emit_function_decl(gen_context, func);
}
if (compiler.build.type != TARGET_TYPE_TEST && compiler.build.type != TARGET_TYPE_BENCHMARK && unit->main_function && unit->main_function->is_synthetic)
if (unit->main_function && unit->main_function->is_synthetic)
{
has_elements = true;
llvm_emit_function_decl(gen_context, unit->main_function);
@@ -1726,7 +1672,7 @@ static GenContext *llvm_gen_module(Module *module, LLVMContextRef shared_context
llvm_emit_function_body(gen_context, func);
}
if (compiler.build.type != TARGET_TYPE_TEST && compiler.build.type != TARGET_TYPE_BENCHMARK && unit->main_function && unit->main_function->is_synthetic)
if (unit->main_function && unit->main_function->is_synthetic)
{
has_elements = true;
llvm_emit_function_body(gen_context, unit->main_function);

View File

@@ -3266,6 +3266,69 @@ static inline MainType sema_find_main_type(SemaContext *context, Signature *sig,
}
Decl *sema_create_runner_main(SemaContext *context, Decl *decl)
{
bool is_win32 = compiler.platform.os == OS_TYPE_WIN32;
Decl *function = decl_new(DECL_FUNC, NULL, decl->span);
function->is_export = true;
function->has_extname = true;
function->extname = kw_mainstub;
function->name = kw_mainstub;
function->unit = decl->unit;
// Pick wWinMain, main or wmain
Decl *params[4] = { NULL, NULL, NULL, NULL };
int param_count;
if (is_win32)
{
function->extname = kw_wmain;
params[0] = decl_new_generated_var(type_cint, VARDECL_PARAM, decl->span);
params[1] = decl_new_generated_var(type_get_ptr(type_get_ptr(type_ushort)), VARDECL_PARAM, decl->span);
param_count = 2;
}
else
{
function->extname = kw_main;
params[0] = decl_new_generated_var(type_cint, VARDECL_PARAM, decl->span);
params[1] = decl_new_generated_var(type_get_ptr(type_get_ptr(type_char)), VARDECL_PARAM, decl->span);
param_count = 2;
}
function->has_extname = true;
function->func_decl.signature.rtype = type_infoid(type_info_new_base(type_cint, decl->span));
function->func_decl.signature.vararg_index = param_count;
Decl **main_params = NULL;
for (int i = 0; i < param_count; i++) vec_add(main_params, params[i]);
function->func_decl.signature.params = main_params;
Ast *body = new_ast(AST_COMPOUND_STMT, decl->span);
AstId *next = &body->compound_stmt.first_stmt;
Ast *ret_stmt = new_ast(AST_RETURN_STMT, decl->span);
const char *kw_main_invoker = symtab_preset(is_win32 ? "@_wmain_runner" : "@_main_runner", TOKEN_AT_IDENT);
Decl *d = sema_find_symbol(context, kw_main_invoker);
if (!d)
{
SEMA_ERROR(decl, "Missing main forwarding function '%s'.", kw_main_invoker);
return poisoned_decl;
}
Expr *invoker = expr_new(EXPR_IDENTIFIER, decl->span);
expr_resolve_ident(invoker, d);
Expr *call = expr_new(EXPR_CALL, decl->span);
Expr *fn_ref = expr_variable(decl);
vec_add(call->call_expr.arguments, fn_ref);
for (int i = 0; i < param_count; i++)
{
Expr *arg = expr_variable(params[i]);
vec_add(call->call_expr.arguments, arg);
}
call->call_expr.function = exprid(invoker);
for (int i = 0; i < param_count; i++) params[i]->resolve_status = RESOLVE_NOT_DONE;
ast_append(&next, ret_stmt);
ret_stmt->return_stmt.expr = call;
function->func_decl.body = astid(body);
function->is_synthetic = true;
return function;
}
static inline Decl *sema_create_synthetic_main(SemaContext *context, Decl *decl, MainType main, bool int_return, bool err_return, bool is_winmain, bool is_wmain)
{
Decl *function = decl_new(DECL_FUNC, NULL, decl->span);
@@ -3377,7 +3440,7 @@ static inline Decl *sema_create_synthetic_main(SemaContext *context, Decl *decl,
default:
UNREACHABLE;
}
NEXT:;
NEXT:;
const char *kw_main_invoker = symtab_preset(main_invoker, TOKEN_AT_IDENT);
Decl *d = sema_find_symbol(context, kw_main_invoker);
if (!d)

View File

@@ -72,6 +72,8 @@ bool sema_analyse_function_body(SemaContext *context, Decl *func);
bool sema_analyse_contracts(SemaContext *context, AstId doc, AstId **asserts, SourceSpan span, bool *has_ensures);
void sema_append_contract_asserts(AstId assert_first, Ast* compound_stmt);
Decl *sema_create_runner_main(SemaContext *context, Decl *decl);
void sema_analyse_pass_top(Module *module);
void sema_analyse_pass_module_hierarchy(Module *module);
void sema_analysis_pass_process_imports(Module *module);

View File

@@ -286,6 +286,20 @@ static void sema_analyze_to_stage(AnalysisStage stage)
halt_on_error();
}
static bool setup_main_runner(Decl *run_function)
{
SemaContext context;
sema_context_init(&context, run_function->unit);
Decl *main = sema_create_runner_main(&context, run_function);
if (!decl_ok(main)) return false;
if (!sema_analyse_decl(&context, main)) return false;
if (!sema_analyse_function_body(&context, main)) return false;
sema_context_destroy(&context);
compiler.context.main = main;
main->unit->main_function = main;
main->no_strip = true;
return true;
}
static void assign_panicfn(void)
{
if (compiler.build.feature.panic_level == PANIC_OFF || (!compiler.build.panicfn && no_stdlib()))
@@ -354,7 +368,7 @@ static void assign_testfn(void)
if (!compiler.build.testing) return;
if (!compiler.build.testfn && no_stdlib())
{
compiler.context.test_func = NULL;
error_exit("No test function could be found.");
return;
}
const char *testfn = compiler.build.testfn ? compiler.build.testfn : "std::core::runtime::default_test_runner";
@@ -374,12 +388,17 @@ static void assign_testfn(void)
{
error_exit("'%s::%s' is not a function.", path->module, ident);
}
if (!type_func_match(type_get_func_ptr(decl->type->canonical), type_bool, 0))
if (!type_func_match(type_get_func_ptr(decl->type->canonical), type_bool, 1, type_get_slice(type_string)))
{
error_exit("Expected test runner to have the signature fn void().");
error_exit("Expected test runner to have the signature fn bool(String[]).");
}
compiler.context.test_func = decl;
decl->no_strip = true;
if (compiler.build.type != TARGET_TYPE_TEST) return;
if (!setup_main_runner(decl))
{
error_exit("Failed to set up test runner.");
}
}
static void assign_benchfn(void)
@@ -387,7 +406,6 @@ static void assign_benchfn(void)
if (!compiler.build.benchmarking) return;
if (!compiler.build.benchfn && no_stdlib())
{
compiler.context.benchmark_func = NULL;
return;
}
const char *testfn = compiler.build.benchfn ? compiler.build.benchfn : "std::core::runtime::default_benchmark_runner";
@@ -407,12 +425,15 @@ static void assign_benchfn(void)
{
error_exit("'%s::%s' is not a function.", path->module, ident);
}
if (!type_func_match(type_get_func_ptr(decl->type->canonical), type_bool, 0))
if (!type_func_match(type_get_func_ptr(decl->type->canonical), type_bool, 1, type_get_slice(type_string)))
{
error_exit("Expected benchmark function to have the signature fn void().");
error_exit("Expected benchmark function to have the signature fn bool(String[] args).");
}
compiler.context.benchmark_func = decl;
decl->no_strip = true;
if (!setup_main_runner(decl))
{
error_exit("Failed to set up benchmark runner.");
}
}
/**