From 5ff726d8d197030b01f912fe8e23ad46ac3ea676 Mon Sep 17 00:00:00 2001 From: Dmitry Atamanov Date: Mon, 14 Nov 2022 17:07:32 +0500 Subject: [PATCH] Added `$$get_rounding_mode` and `$$set_rounding_mode` builtins. (#655) --- lib/std/math.c3 | 43 +++++++++++++++++++++++++ src/compiler/enums.h | 2 ++ src/compiler/llvm_codegen.c | 2 ++ src/compiler/llvm_codegen_builtins.c | 11 +++++++ src/compiler/llvm_codegen_internal.h | 2 ++ src/compiler/sema_builtins.c | 14 +++++++- src/compiler/symtab.c | 2 ++ test/test_suite/stdlib/rounding_mode.c3 | 30 +++++++++++++++++ 8 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 test/test_suite/stdlib/rounding_mode.c3 diff --git a/lib/std/math.c3 b/lib/std/math.c3 index c2ac698f6..e25341ca8 100644 --- a/lib/std/math.c3 +++ b/lib/std/math.c3 @@ -70,6 +70,14 @@ const QUAD_MIN_EXP = -16481; const QUAD_EPSILON = 1.92592994438723585305597794258492732e-34; */ +enum RoundingMode : int +{ + TOWARD_ZERO, + TO_NEAREST, + TOWARD_INFINITY, + TOWARD_NEG_INFINITY +} + define Complex32 = Complex; define Complex64 = Complex; @@ -78,6 +86,11 @@ define Complex64 = Complex; **/ macro abs(x) = $$abs(x); +/** + * @require values::@is_floatlike(x) `The input must be a floating point value or float vector` + **/ +macro ceil(x) = $$ceil(x); + /** * @require types::is_numerical($typeof(x)) `The input must be a numerical value or numerical vector` * @require types::@has_same(x, lower, upper) `The input types must be equal` @@ -118,6 +131,11 @@ macro exp(x) = $$exp(x); **/ macro exp2(x) = $$exp2(x); +/** + * @require values::@is_floatlike(x) `The input must be a floating point value or float vector` + **/ +macro floor(x) = $$floor(x); + /** * @require values::@is_floatlike(a) `The input must be a floating point value or float vector` * @require types::@has_same(a, b, c) `The input types must be equal` @@ -163,6 +181,11 @@ macro min(x, y) = $$min(x, y); **/ macro muladd(a, b, c) = $$fmuladd(a, b, c); +/** + * @require values::@is_floatlike(x) `The input must be a floating point value or float vector` + **/ +macro nearbyint(x) = $$nearbyint(x); + /** * @require values::@is_floatlike(x) `The input must be a floating point value or float vector` * @require values::@convertable_to(exp, x) || values::@is_int(exp) `The input must be an integer, castable to the type of x` @@ -176,6 +199,21 @@ macro pow(x, exp) $endif; } +/** + * @require values::@is_floatlike(x) `The input must be a floating point value or float vector` + **/ +macro rint(x) = $$rint(x); + +/** + * @require values::@is_floatlike(x) `The input must be a floating point value or float vector` + **/ +macro round(x) = $$round(x); + +/** + * @require values::@is_floatlike(x) `The input must be a floating point value or float vector` + **/ +macro roundeven(x) = $$roundeven(x); + macro sec(x) = 1 / cos(x); macro sech(x) = 2 / (exp(x) + exp(-x)); @@ -207,6 +245,11 @@ macro tan(x) = sin(x) / cos(x); **/ macro tanh(x) = (exp(2.0 * x) - 1.0) / (exp(2.0 * x) + 1.0); +/** + * @require values::@is_floatlike(x) `The input must be a floating point value or float vector` + **/ +macro trunc(x) = $$trunc(x); + macro float float.ceil(float x) = $$ceil(x); macro float float.clamp(float x, float lower, float upper) = $$max(lower, $$min(x, upper)); macro float float.copysign(float mag, float sgn) = $$copysign(mag, sgn); diff --git a/src/compiler/enums.h b/src/compiler/enums.h index 9f805f988..15b6ba922 100644 --- a/src/compiler/enums.h +++ b/src/compiler/enums.h @@ -838,6 +838,7 @@ typedef enum BUILTIN_FMULADD, BUILTIN_FSHL, BUILTIN_FSHR, + BUILTIN_GET_ROUNDING_MODE, BUILTIN_LOG, BUILTIN_LOG10, BUILTIN_LOG2, @@ -870,6 +871,7 @@ typedef enum BUILTIN_SAT_ADD, BUILTIN_SAT_SHL, BUILTIN_SAT_SUB, + BUILTIN_SET_ROUNDING_MODE, BUILTIN_SHUFFLEVECTOR, BUILTIN_SIN, BUILTIN_SQRT, diff --git a/src/compiler/llvm_codegen.c b/src/compiler/llvm_codegen.c index caad6ba0c..34de6175e 100644 --- a/src/compiler/llvm_codegen.c +++ b/src/compiler/llvm_codegen.c @@ -635,6 +635,7 @@ static void llvm_codegen_setup() intrinsic_id.exp2 = lookup_intrinsic("llvm.exp2"); intrinsic_id.fabs = lookup_intrinsic("llvm.fabs"); intrinsic_id.floor = lookup_intrinsic("llvm.floor"); + intrinsic_id.flt_rounds = lookup_intrinsic("llvm.flt.rounds"); intrinsic_id.fma = lookup_intrinsic("llvm.fma"); intrinsic_id.fshl = lookup_intrinsic("llvm.fshl"); intrinsic_id.fshr = lookup_intrinsic("llvm.fshr"); @@ -665,6 +666,7 @@ static void llvm_codegen_setup() intrinsic_id.roundeven = lookup_intrinsic("llvm.roundeven"); intrinsic_id.sadd_overflow = lookup_intrinsic("llvm.sadd.with.overflow"); intrinsic_id.sadd_sat = lookup_intrinsic("llvm.sadd.sat"); + intrinsic_id.set_rounding = lookup_intrinsic("llvm.set.rounding"); intrinsic_id.sin = lookup_intrinsic("llvm.sin"); intrinsic_id.sshl_sat = lookup_intrinsic("llvm.sshl.sat"); intrinsic_id.smax = lookup_intrinsic("llvm.smax"); diff --git a/src/compiler/llvm_codegen_builtins.c b/src/compiler/llvm_codegen_builtins.c index b4d6a7b08..24dd44bfd 100644 --- a/src/compiler/llvm_codegen_builtins.c +++ b/src/compiler/llvm_codegen_builtins.c @@ -589,6 +589,9 @@ void llvm_emit_builtin_call(GenContext *c, BEValue *result_value, Expr *expr) case BUILTIN_FSHR: llvm_emit_simple_builtin(c, result_value, expr, intrinsic_id.fshr); return; + case BUILTIN_GET_ROUNDING_MODE: + llvm_value_set(result_value, llvm_emit_call_intrinsic(c, intrinsic_id.flt_rounds, NULL, 0, NULL, 0), expr->type); + return; case BUILTIN_LOG: llvm_emit_simple_builtin(c, result_value, expr, intrinsic_id.log); return; @@ -616,6 +619,14 @@ void llvm_emit_builtin_call(GenContext *c, BEValue *result_value, Expr *expr) case BUILTIN_ROUNDEVEN: llvm_emit_simple_builtin(c, result_value, expr, intrinsic_id.roundeven); return; + case BUILTIN_SET_ROUNDING_MODE: + { + Expr **args = expr->call_expr.arguments; + LLVMValueRef arg_slots[1]; + llvm_emit_intrinsic_args(c, args, arg_slots, 1); + llvm_value_set(result_value, llvm_emit_call_intrinsic(c, intrinsic_id.set_rounding, NULL, 0, arg_slots, 1), type_void); + } + return; case BUILTIN_SIN: llvm_emit_simple_builtin(c, result_value, expr, intrinsic_id.sin); return; diff --git a/src/compiler/llvm_codegen_internal.h b/src/compiler/llvm_codegen_internal.h index 1df6ca883..d90766a59 100644 --- a/src/compiler/llvm_codegen_internal.h +++ b/src/compiler/llvm_codegen_internal.h @@ -132,6 +132,7 @@ typedef struct unsigned exp2; unsigned fabs; unsigned floor; + unsigned flt_rounds; unsigned fma; unsigned fshl; unsigned fshr; @@ -162,6 +163,7 @@ typedef struct unsigned roundeven; unsigned sadd_overflow; unsigned sadd_sat; + unsigned set_rounding; unsigned sin; unsigned smax; unsigned smin; diff --git a/src/compiler/sema_builtins.c b/src/compiler/sema_builtins.c index fb9192c3e..fac69bc08 100644 --- a/src/compiler/sema_builtins.c +++ b/src/compiler/sema_builtins.c @@ -281,6 +281,16 @@ bool sema_expr_analyse_builtin_call(SemaContext *context, Expr *expr) case BUILTIN_SYSCLOCK: rtype = type_ulong; break; + case BUILTIN_GET_ROUNDING_MODE: + rtype = type_int; + break; + case BUILTIN_SET_ROUNDING_MODE: + if (!sema_check_builtin_args(args, + (BuiltinArg[]) { BA_INTEGER }, + arg_count)) return false; + if (!sema_check_builtin_args_match(args, 1)) return false; + rtype = type_void; + break; case BUILTIN_SYSCALL: if (arg_count > 7) { @@ -414,7 +424,7 @@ bool sema_expr_analyse_builtin_call(SemaContext *context, Expr *expr) } if (!expr_in_int_range(args[1], 0, 1)) { - SEMA_ERROR(args[2], "Expected a value between 0 and 3."); + SEMA_ERROR(args[1], "Expected a value between 0 and 1."); return false; } if (!expr_in_int_range(args[2], 0, 3)) @@ -536,6 +546,7 @@ static inline unsigned builtin_expected_args(BuiltinFunction func) { switch (func) { + case BUILTIN_GET_ROUNDING_MODE: case BUILTIN_STACKTRACE: case BUILTIN_SYSCLOCK: case BUILTIN_TRAP: @@ -577,6 +588,7 @@ static inline unsigned builtin_expected_args(BuiltinFunction func) case BUILTIN_REDUCE_XOR: case BUILTIN_REDUCE_MAX: case BUILTIN_REDUCE_MIN: + case BUILTIN_SET_ROUNDING_MODE: return 1; case BUILTIN_COPYSIGN: case BUILTIN_EXACT_ADD: diff --git a/src/compiler/symtab.c b/src/compiler/symtab.c index a25192bf5..06516f450 100644 --- a/src/compiler/symtab.c +++ b/src/compiler/symtab.c @@ -205,6 +205,7 @@ void symtab_init(uint32_t capacity) builtin_list[BUILTIN_FMULADD] = KW_DEF("fmuladd"); builtin_list[BUILTIN_FSHL] = KW_DEF("fshl"); builtin_list[BUILTIN_FSHR] = KW_DEF("fshr"); + builtin_list[BUILTIN_GET_ROUNDING_MODE] = KW_DEF("get_rounding_mode"); builtin_list[BUILTIN_LOG] = KW_DEF("log"); builtin_list[BUILTIN_LOG2] = KW_DEF("log2"); builtin_list[BUILTIN_LOG10] = KW_DEF("log10"); @@ -235,6 +236,7 @@ void symtab_init(uint32_t capacity) builtin_list[BUILTIN_SAT_ADD] = KW_DEF("sat_add"); builtin_list[BUILTIN_SAT_SHL] = KW_DEF("sat_shl"); builtin_list[BUILTIN_SAT_SUB] = KW_DEF("sat_sub"); + builtin_list[BUILTIN_SET_ROUNDING_MODE] = KW_DEF("set_rounding_mode"); builtin_list[BUILTIN_SIN] = KW_DEF("sin"); builtin_list[BUILTIN_SHUFFLEVECTOR] = KW_DEF("shufflevector"); builtin_list[BUILTIN_SQRT] = KW_DEF("sqrt"); diff --git a/test/test_suite/stdlib/rounding_mode.c3 b/test/test_suite/stdlib/rounding_mode.c3 new file mode 100644 index 000000000..0027c6595 --- /dev/null +++ b/test/test_suite/stdlib/rounding_mode.c3 @@ -0,0 +1,30 @@ +import std::io; +import std::math; + +fn void main() +{ + io::printfln("Current rounding mode: %s", $$get_rounding_mode()); + float f1 = 11.5; + float f2 = -11.5; + + foreach (int mode : RoundingMode.values) + { + $$set_rounding_mode(mode); + io::printfln("Rounding mode: %s", $$get_rounding_mode()); + + io::printfln(" ceil(%s) == %s", f1, math::ceil(f1)); + io::printfln(" ceil(%s) == %s", f2, math::ceil(f2)); + io::printfln(" floor(%s) == %s", f1, math::floor(f1)); + io::printfln(" floor(%s) == %s", f2, math::floor(f2)); + io::printfln(" nearbyint(%s) == %s", f1, math::nearbyint(f1)); + io::printfln(" nearbyint(%s) == %s", f2, math::nearbyint(f2)); + io::printfln(" rint(%s) == %s", f1, math::rint(f1)); + io::printfln(" rint(%s) == %s", f2, math::rint(f2)); + io::printfln(" round(%s) == %s", f1, math::round(f1)); + io::printfln(" round(%s) == %s", f2, math::round(f2)); + io::printfln(" roundeven(%s) == %s", f1, math::roundeven(f1)); + io::printfln(" roundeven(%s) == %s", f2, math::roundeven(f2)); + io::printfln(" trunc(%s) == %s", f1, math::trunc(f1)); + io::printfln(" trunc(%s) == %s", f2, math::trunc(f2)); + } +}