diff --git a/decimal/decimal.nim b/decimal/decimal.nim index 914fa86..c4a75fc 100644 --- a/decimal/decimal.nim +++ b/decimal/decimal.nim @@ -6,3 +6,448 @@ # at your option. This file may not be copied, modified, or distributed except according to those terms. import decimal_lowlevel + +type + DecimalType* = ref[ptr mpd_t] + DecimalError* = object of Exception + +const + DEFAULT_PREC = MPD_RDIGITS * 2 + DEFAULT_EMAX = when (sizeof(int) == 8): 999999999999999999 else: 425000000 + DEFAULT_EMIN = when (sizeof(int) == 8): -999999999999999999 else: -425000000 + +var CTX: mpd_context_t +var CTX_ADDR = addr CTX +mpd_defaultcontext(CTX_ADDR) + +proc setPrec*(prec: mpd_ssize_t) = + ## Sets the precision (number of decimals) in the Context + if 0 < prec: + let success = mpd_qsetprec(CTX_ADDR, prec) + if success == 0: + raise newException(DecimalError, "Couldn't set precision") + +proc `$`*(s: DecimalType): string = + ## Convert DecimalType to string + $mpd_to_sci(s[], 0) + +proc newDecimal*(): DecimalType = + ## Initialize a empty DecimalType + new result + result[] = mpd_qnew() + +proc newDecimal*(s: string): DecimalType = + ## Create a new DecimalType from a string + new result + result[] = mpd_qnew() + mpd_set_string(result[], s, CTX_ADDR) + +proc newDecimal*(s: int): DecimalType = + ## Create a new DecimalType from a int64 + new result + result[] = mpd_qnew() + when (sizeof(int) == 8): + mpd_set_i64(result[], s, CTX_ADDR) + else: + mpd_set_i32(result[], s, CTX_ADDR) + +proc clone*(b: DecimalType): DecimalType = + ## Clone a DecimalType and returns a new independent one + var status: uint32 + result = newDecimal() + let success = mpd_qcopy(result[], b[], addr status) + if success == 0: + raise newException(DecimalError, "Decimal failed to copy") + +# Operators + +proc `+`*(a, b: DecimalType): DecimalType = + var status: uint32 + result = newDecimal() + mpd_qadd(result[], a[], b[], CTX_ADDR, addr status) + +template `+`*[T: SomeNumber](a: DecimalType, b: T): untyped = + a + newDecimal($b) + +template `+`*[T: SomeNumber](a: T, b: DecimalType): untyped = + newDecimal($a) + b + +proc `+=`*(a, b: DecimalType) = + ## Inplace addition + var status: uint32 + mpd_qadd(a[], a[], b[], CTX_ADDR, addr status) + +template `+=`*[T: SomeNumber](a: DecimalType, b: T): untyped = + a += newDecimal($b) + + + +proc `-`*(a, b: DecimalType): DecimalType = + var status: uint32 + result = newDecimal() + mpd_qsub(result[], a[], b[], CTX_ADDR, addr status) + +template `-`*[T: SomeNumber](a: DecimalType, b: T): untyped = + a - newDecimal($b) + +template `-`*[T: SomeNumber](a: T, b: DecimalType): untyped = + newDecimal($a) - b + +proc `-=`*(a, b: DecimalType) = + ## Inplace subtraction + var status: uint32 + mpd_qsub(a[], a[], b[], CTX_ADDR, addr status) + +template `-=`*[T: SomeNumber](a: DecimalType, b: T) = + a -= newDecimal($b) + + +proc `*`*(a, b: DecimalType): DecimalType = + var status: uint32 + result = newDecimal() + mpd_qmul(result[], a[], b[], CTX_ADDR, addr status) + +template `*`*[T: SomeNumber](a: T, b: DecimalType): untyped = + newDecimal($a) * b + +template `*`*[T: SomeNumber](a: DecimalType, b: T): untyped = + newDecimal($b) * a + +proc `*=`*(a, b: DecimalType) = + ## Inplace multiplication + var status: uint32 + mpd_qmul(a[], a[], b[], CTX_ADDR, addr status) + +template `*=`*[T: SomeNumber](a: DecimalType, b: T) = + a *= newDecimal($b) + + + +proc `/`*(a, b: DecimalType): DecimalType = + var status: uint32 + result = newDecimal() + mpd_qdiv(result[], a[], b[], CTX_ADDR, addr status) + +template `/`*[T: SomeNumber](a: DecimalType, b: T): untyped = + a / newDecimal($b) + +template `/`*[T: SomeNumber](a: T, b: DecimalType): untyped = + newDecimal($a) / b + +proc `/=`*(a, b: DecimalType) = + ## Inplace division + var status: uint32 + mpd_qdiv(a[], a[], b[], CTX_ADDR, addr status) + +template `/=`*[T: SomeNumber](a: DecimalType, b: T): untyped = + a /= newDecimal($b) + + + +proc `//`*(a, b: DecimalType): DecimalType = + ## Integer division, same as divint + var status: uint32 + result = newDecimal() + mpd_qdivint(result[], a[], b[], CTX_ADDR, addr status) + +template `//`*[T: SomeNumber](a: DecimalType, b: T): untyped = + a // newDecimal($b) + +proc `^`*(a, b: DecimalType): DecimalType = + ## Power operator + var status: uint32 + result = newDecimal() + mpd_qpow(result[], a[], b[], CTX_ADDR, addr status) + +template `^`*[T: SomeNumber](a: DecimalType, b: T): untyped = + a ^ newDecimal($b) + +proc `==`*(a, b: DecimalType): bool = + var status: uint32 + let cmp = mpd_qcmp(a[], b[], addr status) + if cmp == 0: + return true + else: + return false + +template `==`*[T: SomeNumber](a: DecimalType, b: T): untyped = + a == newDecimal($b) + +template `==`*[T: SomeNumber](a: T, b: DecimalType): untyped = + newDecimal($a) == b + +proc `<`*(a, b: DecimalType): bool = + var status: uint32 + let cmp = mpd_qcmp(a[], b[], addr status) + if cmp == -1: + return true + else: + return false + +template `<`*[T: SomeNumber](a: DecimalType, b: T): untyped = + a < newDecimal($b) +template `<`*[T: SomeNumber](a: T, b: DecimalType): untyped = + newDecimal($a) < b + +proc `<=`*(a, b: DecimalType): bool = + let less_cmp = a < b + if less_cmp: return true + let equal_cmp = a == b + if equal_cmp: return true + return false +template `<=`*[T: SomeNumber](a: DecimalType, b: T): untyped = + a <= newDecimal($b) +template `<=`*[T: SomeNumber](a: T, b: DecimalType): untyped = + newDecimal($a) <= b + + + +proc divint*(a, b: DecimalType): DecimalType = + ## Integer division, same ass // + var status: uint32 + result = newDecimal() + mpd_qdivint(result[], a[], b[], CTX_ADDR, addr status) + +proc rem*(a, b: DecimalType): DecimalType = + ## Returns the remainder of the division a/b + var status: uint32 + result = newDecimal() + mpd_qrem(result[], a[], b[], CTX_ADDR, addr status) + +proc rem_near*(a, b: DecimalType): DecimalType = + ## Return a - b * n, where n is the integer nearest the exact value of a / b. If two integers are equally near then the even one is chosen. + var status: uint32 + result = newDecimal() + mpd_qrem_near(result[], a[], b[], CTX_ADDR, addr status) + +proc divmod*(a, b: DecimalType): (DecimalType, DecimalType) = + ## Return both the integer part and remainder of the division a/b, same as (a // b, rem(a, b)) + var status: uint32 + var q = newDecimal() + var r = newDecimal() + mpd_qdivmod(q[], r[], a[], b[], CTX_ADDR, addr status) + result = (q, r) + +proc fma*(a, b, c: DecimalType): DecimalType = + ## Fused multiplication-addition, returns a * b + c + var status: uint32 + result = newDecimal() + mpd_qfma(result[], a[], b[], c[], CTX_ADDR, addr status) + + +# Math functions + +proc exp*(a: DecimalType): DecimalType = + ## The exponential function + var status: uint32 + result = newDecimal() + mpd_qexp(result[], a[], CTX_ADDR, addr status) + +proc ln*(a: DecimalType): DecimalType = + ## The natural logarithm + var status: uint32 + result = newDecimal() + mpd_qln(result[], a[], CTX_ADDR, addr status) + +proc log10*(a: DecimalType): DecimalType = + ## Logarithm base 10 + var status: uint32 + result = newDecimal() + mpd_qlog10(result[], a[], CTX_ADDR, addr status) + +proc sqrt*(a: DecimalType): DecimalType = + ## Square root + var status: uint32 + result = newDecimal() + mpd_qsqrt(result[], a[], CTX_ADDR, addr status) + +proc invroot*(a: DecimalType): DecimalType = + ## Inverse square root, same as 1/sqrt(a) + var status: uint32 + result = newDecimal() + mpd_qinvroot(result[], a[], CTX_ADDR, addr status) + + + +proc `-`*(a: DecimalType): DecimalType = + ## Negation operator + var status: uint32 + result = newDecimal() + mpd_qminus(result[], a[], CTX_ADDR, addr status) + +proc plus*(a: DecimalType): DecimalType = + var status: uint32 + result = newDecimal() + mpd_qplus(result[], a[], CTX_ADDR, addr status) + +proc abs*(a: DecimalType): DecimalType = + ## Absolute value + var status: uint32 + result = newDecimal() + mpd_qabs(result[], a[], CTX_ADDR, addr status) + + + +proc max*(a,b: DecimalType): DecimalType = + ## Returns the most positive of a and b. + var status: uint32 + result = newDecimal() + mpd_qmax(result[], a[], b[], CTX_ADDR, addr status) + +proc max_mag*(a,b: DecimalType): DecimalType = + ## Returns the largest by magnitude of a and b + var status: uint32 + result = newDecimal() + mpd_qmax_mag(result[], a[], b[], CTX_ADDR, addr status) + +proc min*(a,b: DecimalType): DecimalType = + ## Returns the most negative of a and b + var status: uint32 + result = newDecimal() + mpd_qmin(result[], a[], b[], CTX_ADDR, addr status) + +proc min_mag*(a,b: DecimalType): DecimalType = + ## Returns the smallest by magnitude of a and b + var status: uint32 + result = newDecimal() + mpd_qmin_mag(result[], a[], b[], CTX_ADDR, addr status) + +proc next_plus*(a: DecimalType): DecimalType = + ## The closest representable number that is larger than a + var status: uint32 + result = newDecimal() + mpd_qnext_plus(result[], a[], CTX_ADDR, addr status) + +proc next_minus*(a: DecimalType): DecimalType = + ## The closest representable number that is smaller than a + var status: uint32 + result = newDecimal() + mpd_qnext_minus(result[], a[], CTX_ADDR, addr status) + +proc next_toward*(a, b: DecimalType): DecimalType = + ## Representable number closest to a that is in the direction towards b + var status: uint32 + result = newDecimal() + mpd_qnext_toward(result[], a[], b[], CTX_ADDR, addr status) + +proc quantize*(a, b: DecimalType): DecimalType = + ## Return the number that is equal in value to a, but has the exponent of b + var status: uint32 + result = newDecimal() + mpd_qquantize(result[], a[], b[], CTX_ADDR, addr status) + +proc rescale*(a: DecimalType, b: mpd_ssize_t): DecimalType = + ## Return the number that is equal in value to a, but has the exponent exp + var status: uint32 + result = newDecimal() + mpd_qrescale(result[], a[], b, CTX_ADDR, addr status) + +proc same_quantum*(a, b: DecimalType): bool = + ## Return true if a and b have the same exponent, false otherwise + let cmp = mpd_same_quantum(a[], b[]) + if cmp == 1: + return true + else: + return false + +proc reduce*(a: DecimalType): DecimalType = + ## If a is finite after applying rounding and overflow/underflow checks, result is set to the simplest form of a with all trailing zeros removed + var status: uint32 + result = newDecimal() + mpd_qreduce(result[], a[], CTX_ADDR, addr status) + +proc round_to_intx*(a: DecimalType): DecimalType = + ## Round to an integer, using the rounding mode of the context + var status: uint32 + result = newDecimal() + mpd_qround_to_intx(result[], a[], CTX_ADDR, addr status) + +proc round_to_int*(a: DecimalType): DecimalType = + ## Same as mpd_qround_to_intx, but the MPD_Inexact and MPD_Rounded flags are never set + var status: uint32 + result = newDecimal() + mpd_qround_to_int(result[], a[], CTX_ADDR, addr status) + +proc floor*(a: DecimalType): DecimalType = + ## Return the nearest integer towards -infinity + var status: uint32 + result = newDecimal() + mpd_qfloor(result[], a[], CTX_ADDR, addr status) + +proc ceil*(a: DecimalType): DecimalType = + ## Return the nearest integer towards +infinity + var status: uint32 + result = newDecimal() + mpd_qceil(result[], a[], CTX_ADDR, addr status) + +proc truncate*(a: DecimalType): DecimalType = + ## Return the truncated value of a + var status: uint32 + result = newDecimal() + mpd_qtrunc(result[], a[], CTX_ADDR, addr status) + +proc logb*(a: DecimalType): DecimalType = + ## Return the adjusted exponent of a. Same as floor(log10(a)) + var status: uint32 + result = newDecimal() + mpd_qlogb(result[], a[], CTX_ADDR, addr status) + +proc scaleb*(a, b: DecimalType): DecimalType = + ## b must be an integer with exponent 0. If a is infinite, result is set to a. Otherwise, result is a with the value of b added to the exponent. + var status: uint32 + result = newDecimal() + mpd_qscaleb(result[], a[], b[], CTX_ADDR, addr status) + +proc powmod*(base, exp, modulus: DecimalType): DecimalType = + ## Return (base ^ exp) % mod. All operands must be integers. The function fails if result does not fit in the current prec. + var status: uint32 + result = newDecimal() + mpd_qpowmod(result[], base[], exp[], modulus[], CTX_ADDR, addr status) + +proc finalize*(a: DecimalType) = + ## Apply the current context to a + var status: uint32 + mpd_qfinalize(a[], CTX_ADDR, addr status) + +proc shift*(a, b: DecimalType): DecimalType = + ## Return a shifted by b places. b must be in the range [-prec, prec]. A negative b indicates a right shift, a positive b a left shift. Digits that do not fit are discarded. + var status: uint32 + result = newDecimal() + mpd_qshift(result[], a[], b[], CTX_ADDR, addr status) + +proc shift*(a: DecimalType, b: mpd_ssize_t): DecimalType = + ## Like shift, only that the number of places is specified by a integer type rather than a DecimalType + var status: uint32 + result = newDecimal() + mpd_qshiftn(result[], a[], b, CTX_ADDR, addr status) + +proc rotate*(a, b: DecimalType): DecimalType = + ## Return a rotated by b places. b must be in the range [-prec, prec]. A negative b indicates a right rotation, a positive b a left rotation. + var status: uint32 + result = newDecimal() + mpd_qrotate(result[], a[], b[], CTX_ADDR, addr status) + +proc elementwiseAnd*(a, b: DecimalType): DecimalType = + ## Return the digit-wise logical and of a and b + var status: uint32 + result = newDecimal() + mpd_qand(result[], a[], b[], CTX_ADDR, addr status) + +proc elementwiseOr*(a, b: DecimalType): DecimalType = + ## Return the digit-wise logical or of a and b + var status: uint32 + result = newDecimal() + mpd_qor(result[], a[], b[], CTX_ADDR, addr status) + +proc elementwiseXor*(a, b: DecimalType): DecimalType = + ## Return the digit-wise logical xor of a and b + var status: uint32 + result = newDecimal() + mpd_qxor(result[], a[], b[], CTX_ADDR, addr status) + +proc elementwiseInvert*(a: DecimalType): DecimalType = + ## Return the digit-wise logical inversion of a + var status: uint32 + result = newDecimal() + mpd_qinvert(result[], a[], CTX_ADDR, addr status) + diff --git a/tests/all_tests.nim b/tests/all_tests.nim index 4fd69ac..2603887 100644 --- a/tests/all_tests.nim +++ b/tests/all_tests.nim @@ -8,6 +8,292 @@ import unittest, ../decimal/decimal -suite "Mock compile test": - test "Nim-decimal and its wrapped library compiles": - discard +suite "Basic Arithmetic": + test "init Decimal": + var d = newDecimal() + test "Set Decimal from string": + let s = "1.23456" + var d = newDecimal(s) + check $d == s + test "Set Decimal from int": + let s = 123456 + var d = newDecimal(s) + let correct = "123456" + check $d == correct + + test "Decimal Addition": + var a = newDecimal("1.2") + var b = newDecimal("3.5") + var c1 = a + b + var c2 = b + a + let correct = "4.7" + check $c1 == correct + check $c2 == correct + test "Decimal inplace Addition": + var a = newDecimal("1.2") + var b = newDecimal("3.6") + a += b + let correct = "4.8" + check $a == correct + test "Decimal-Int Addition": + var a = newDecimal("1.2") + var b = 5 + var c1 = a + b + var c2 = b + a + let correct = "6.2" + check $c1 == correct + check $c2 == correct + test "Decimal-Int inplace Addition": + var a = newDecimal("1.2") + var b = 4 + a += b + let correct = "5.2" + check $a == correct + + test "Decimal Subtraction": + var a = newDecimal("1.2") + var b = newDecimal("3.5") + var c = a - b + let correct = "-2.3" + check $c == correct + test "Decimal Multiplication": + var a = newDecimal("1.2") + var b = newDecimal("3.5") + var c = a * b + let correct = "4.20" + check $c == correct + + test "Decimal Division": + var a = newDecimal("6.25") + var b = newDecimal("2.5") + var c = a / b + let correct = "2.5" + check $c == correct + test "Decimal-Int Division": + var a = newDecimal("10") + var b = 5 + var c = a / b + var d = b / a + let correctC = "2" + let correctD = "0.5" + check $c == correctC + check $d == correctD + + test "Decimal ==": + var a = newDecimal("6.25") + var b = newDecimal("2.5") + check a == a + check (a == b) == false + test "Decimal <": + var a = newDecimal("6.25") + var b = newDecimal("2.5") + check b < a + check (a < b) == false + test "Decimal >": + var a = newDecimal("6.25") + var b = newDecimal("2.5") + check a > b + check (b > a) == false + test "Decimal Power 1": + var a = newDecimal("2.5") + var b = newDecimal("2") + var c = a ^ b + check $c == "6.25" + test "Decimal Power 2": + var a = newDecimal("81") + var b = newDecimal("0.5") + var c = a ^ b + check $c == "9.0000000000000000000000000000000000000" + test "Decimal divint": + let a = newDecimal("11") + let b = newDecimal("3") + let c = a // b + check $c == "3" + test "Decimal rem": + let a = newDecimal("11") + let b = newDecimal("3") + let c = rem(a, b) + check $c == "2" + test "Decimal divmod": + let a = newDecimal("11") + let b = newDecimal("3") + let (q, r) = divmod(a, b) + check $q == "3" + check $r == "2" + test "Decimal exp": + let a = newDecimal("2") + let c = exp(a) + check $c == "7.3890560989306502272304274605750078132" + test "Decimal rem_near": + let a = newDecimal("11") + let b = newDecimal("3") + let c = rem_near(a, b) + check $c == "-1" + test "Decimal fma": + let a = newDecimal("11") + let b = newDecimal("3") + let c = newDecimal("2") + let d = fma(a, b, c) + check $d == "35" + test "Decimal ln": + let a = newDecimal("1") + let b = exp(newDecimal("1")) + let ln1 = ln(a) + let ln2 = ln(b) + check $ln1 == "0" + check $ln2 == "1.0000000000000000000000000000000000000" + test "Decimal log10": + let a = newDecimal("1") + let b = newDecimal("10") + let c = newDecimal("20") + let log1 = log10(a) + let log2 = log10(b) + let log3 = log10(c) + check $log1 == "0" + check $log2 == "1" + check $log3 == "1.3010299956639811952137388947244930268" + test "Decimal sqrt": + let a = newDecimal("6.25") + let b = sqrt(a) + check $b == "2.5" + test "Decimal invroot": + let a = newDecimal("10") + let b = invroot(a) + check $b == "0.31622776601683793319988935444327185337" + test "Decimal negate": + let a = newDecimal("1.23") + let b = newDecimal("-4.56") + let a2 = -a + let b2 = -b + check $a2 == "-1.23" + check $b2 == "4.56" + test "Decimal abs": + let a = newDecimal("7") + let b = newDecimal("-8") + let c = abs(a) + let d = abs(b) + check $c == "7" + check $d == "8" + test "Decimal quantize": + let a = newDecimal("17.89843759") + let b = newDecimal("1e-5") + check $quantize(a,b) == "17.89844" + test "Decimal max": + let a = newDecimal("5") + let b = newDecimal("-5") + let c = newDecimal("2") + check max(a, b) == a + check max(b, c) == c + test "Decimal max_mag": + let a = newDecimal("5") + let b = newDecimal("-6") + let c = newDecimal("7") + check max_mag(a, b) == b + check max_mag(b, c) == c + test "Decimal min": + let a = newDecimal("5") + let b = newDecimal("-5") + let c = newDecimal("-2") + check min(a, b) == b + check min(b, c) == b + test "Decimal min_mag": + let a = newDecimal("5") + let b = newDecimal("-6") + let c = newDecimal("7") + check min_mag(a, b) == a + check min_mag(b, c) == b + test "Decimal next_plus": + let a = newDecimal("1.01") + let b = next_plus(a) + let correct = "1.0100000000000000000000000000000000001" + check $b == correct + test "Decimal next_minus": + let a = newDecimal("1.01") + let b = next_minus(a) + let correct = "1.0099999999999999999999999999999999999" + check $b == correct + test "Decimal next_toward": + let a = newDecimal("1.01") + let b = newDecimal("2") + let c = newDecimal("1") + let correct1 = "1.0100000000000000000000000000000000001" + let correct2 = "1.0099999999999999999999999999999999999" + check $next_toward(a, b) == correct1 + check $next_toward(a, c) == correct2 + test "Decimal rescale": + let a = newDecimal("2000") + let b = rescale(a, 3) + check $b == "2e+3" + test "Decimal same_quantum": + let a = newDecimal("2e-5") + let b = newDecimal("20e-6") + let c = newDecimal("20e-5") + check not same_quantum(a, b) + check same_quantum(a, c) + test "Decimal reduce": + let a = newDecimal("1.2345000000000000000") + let b = newDecimal("1.2345") + let c = reduce(a) + check $c == $b + test "Decimal round_to_intx": + let a = newDecimal("1.49") + let b = round_to_intx(a) + check $b == "1" + test "Decimal floor": + let a = newDecimal("1.49") + let b = floor(a) + check $b == "1" + test "Decimal ceil": + let a = newDecimal("1.49") + let b = ceil(a) + check $b == "2" + test "Decimal truncate": + let a = newDecimal("10.12345678") + let b = truncate(a) + check $b == "10" + test "Decimal logb": + let a = newDecimal("76543") + let b = logb(a) + let correct = floor(log10(a)) + check b == correct + test "Decimal scaleb": + let a = newDecimal("23e4") + let b = newDecimal("3") + let c = scaleb(a, b) + check $c == "2.3e+8" + test "Decimal powmod": + let base = newDecimal("2") + let exp = newDecimal("5") + let m = newDecimal("7") + let correct = rem(base ^ exp, m) + check powmod(base, exp, m) == correct + test "Decimal shift": + let a = newDecimal("1.23e7") + let b = 2 + let c = shift(a, b) + check $c == "1.2300e+9" + test "Decimal rotate": + let a = newDecimal("1.23e7") + let b = newDecimal(2) + let c = rotate(a, b) + check $c == "1.2300e+9" + test "Decimal elementwiseOr": + let a = newDecimal("1110") + let b = newDecimal("1010") + check $elementwiseOr(a,b) == "1110" + test "Decimal elementwiseAnd": + let a = newDecimal("1110") + let b = newDecimal("1010") + check $elementwiseAnd(a,b) == "1010" + test "Decimal elementwiseXor": + let a = newDecimal("01110") + let b = newDecimal("11010") + check $elementwiseXor(a,b) == "10100" + test "Decimal elementwiseInvert": + let a = newDecimal("111010") + check $elementwiseInvert(a) == "11111111111111111111111111111111000101" + + + + +