diff --git a/Makefile.in b/Makefile.in index d563849851..081d29482b 100644 --- a/Makefile.in +++ b/Makefile.in @@ -443,6 +443,7 @@ TESTSRC += \ $(TOP)/ext/misc/carray.c \ $(TOP)/ext/misc/closure.c \ $(TOP)/ext/misc/csv.c \ + $(TOP)/ext/misc/decimal.c \ $(TOP)/ext/misc/eval.c \ $(TOP)/ext/misc/explain.c \ $(TOP)/ext/misc/fileio.c \ @@ -1082,6 +1083,7 @@ SHELL_SRC = \ $(TOP)/src/shell.c.in \ $(TOP)/ext/misc/appendvfs.c \ $(TOP)/ext/misc/shathree.c \ + $(TOP)/ext/misc/decimal.c \ $(TOP)/ext/misc/fileio.c \ $(TOP)/ext/misc/completion.c \ $(TOP)/ext/misc/sqlar.c \ diff --git a/Makefile.msc b/Makefile.msc index 241961e2d5..ca643d0842 100644 --- a/Makefile.msc +++ b/Makefile.msc @@ -1560,6 +1560,7 @@ TESTEXT = \ $(TOP)\ext\misc\carray.c \ $(TOP)\ext\misc\closure.c \ $(TOP)\ext\misc\csv.c \ + $(TOP)\ext\misc\decimal.c \ $(TOP)\ext\misc\eval.c \ $(TOP)\ext\misc\explain.c \ $(TOP)\ext\misc\fileio.c \ @@ -2204,6 +2205,7 @@ SHELL_SRC = \ $(TOP)\src\shell.c.in \ $(TOP)\ext\misc\appendvfs.c \ $(TOP)\ext\misc\shathree.c \ + $(TOP)\ext\misc\decimal.c \ $(TOP)\ext\misc\fileio.c \ $(TOP)\ext\misc\completion.c \ $(TOP)\ext\misc\uint.c \ diff --git a/ext/misc/decimal.c b/ext/misc/decimal.c new file mode 100644 index 0000000000..078e91dc27 --- /dev/null +++ b/ext/misc/decimal.c @@ -0,0 +1,567 @@ +/* +** 2020-06-22 +** +** The author disclaims copyright to this source code. In place of +** a legal notice, here is a blessing: +** +** May you do good and not evil. +** May you find forgiveness for yourself and forgive others. +** May you share freely, never taking more than you give. +** +****************************************************************************** +** +** Routines to implement arbitrary-precision decimal math. +** +** The focus here is on simplicity and correctness, not performance. +*/ +#include "sqlite3ext.h" +SQLITE_EXTENSION_INIT1 +#include +#include +#include +#include + + +/* A decimal object */ +typedef struct Decimal Decimal; +struct Decimal { + char sign; /* 0 for positive, 1 for negative */ + char oom; /* True if an OOM is encountered */ + char isNull; /* True if holds a NULL rather than a number */ + char isInit; /* True upon initialization */ + int nDigit; /* Total number of digits */ + int nFrac; /* Number of digits to the right of the decimal point */ + signed char *a; /* Array of digits. Most significant first. */ +}; + +/* +** Release memory held by a Decimal, but do not free the object itself. +*/ +static void decimal_clear(Decimal *p){ + sqlite3_free(p->a); +} + +/* +** Destroy a Decimal object +*/ +static void decimal_free(Decimal *p){ + if( p ){ + decimal_clear(p); + sqlite3_free(p); + } +} + +/* +** Allocate a new Decimal object. Initialize it to the number given +** by the input string. +*/ +static Decimal *decimal_new( + sqlite3_context *pCtx, + sqlite3_value *pIn, + int nAlt, + const unsigned char *zAlt +){ + Decimal *p; + int n, i; + const unsigned char *zIn; + int iExp = 0; + p = sqlite3_malloc( sizeof(*p) ); + if( p==0 ) goto new_no_mem; + p->sign = 0; + p->oom = 0; + p->isInit = 1; + p->isNull = 0; + p->nDigit = 0; + p->nFrac = 0; + if( zAlt ){ + n = nAlt, + zIn = zAlt; + }else{ + if( sqlite3_value_type(pIn)==SQLITE_NULL ){ + p->a = 0; + p->isNull = 1; + return p; + } + n = sqlite3_value_bytes(pIn); + zIn = sqlite3_value_text(pIn); + } + p->a = sqlite3_malloc64( n+1 ); + if( p->a==0 ) goto new_no_mem; + for(i=0; isspace(zIn[i]); i++){} + if( zIn[i]=='-' ){ + p->sign = 1; + i++; + }else if( zIn[i]=='+' ){ + i++; + } + while( i='0' && c<='9' ){ + p->a[p->nDigit++] = c - '0'; + }else if( c=='.' ){ + p->nFrac = p->nDigit + 1; + }else if( c=='e' || c=='E' ){ + int j = i+1; + int neg = 0; + if( j>=n ) break; + if( zIn[j]=='-' ){ + neg = 1; + j++; + }else if( zIn[j]=='+' ){ + j++; + } + while( j='0' && zIn[j]<='9' ){ + iExp = iExp*10 + zIn[j] - '0'; + } + j++; + } + if( neg ) iExp = -iExp; + break; + } + i++; + } + if( p->nFrac ){ + p->nFrac = p->nDigit - (p->nFrac - 1); + } + if( iExp>0 ){ + if( p->nFrac>0 ){ + if( iExp<=p->nFrac ){ + p->nFrac -= iExp; + iExp = 0; + }else{ + iExp -= p->nFrac; + p->nFrac = 0; + } + } + if( iExp>0 ){ + p->a = sqlite3_realloc64(p->a, p->nDigit + iExp + 1 ); + if( p->a==0 ) goto new_no_mem; + memset(p->a+p->nDigit, 0, iExp); + p->nDigit += iExp; + } + }else if( iExp<0 ){ + int nExtra; + iExp = -iExp; + nExtra = p->nDigit - p->nFrac - 1; + if( nExtra ){ + if( nExtra>=iExp ){ + p->nFrac += iExp; + iExp = 0; + }else{ + iExp -= nExtra; + p->nFrac = p->nDigit - 1; + } + } + if( iExp>0 ){ + p->a = sqlite3_realloc64(p->a, p->nDigit + iExp + 1 ); + if( p->a==0 ) goto new_no_mem; + memmove(p->a+iExp, p->a, p->nDigit); + memset(p->a, 0, iExp); + p->nDigit += iExp; + p->nFrac += iExp; + } + } + return p; + +new_no_mem: + if( pCtx ) sqlite3_result_error_nomem(pCtx); + sqlite3_free(p); + return 0; +} + +/* +** Make the given Decimal the result. +*/ +static void decimal_result(sqlite3_context *pCtx, Decimal *p){ + char *z; + int i, j; + int n; + if( p==0 || p->oom ){ + sqlite3_result_error_nomem(pCtx); + return; + } + if( p->isNull ){ + sqlite3_result_null(pCtx); + return; + } + z = sqlite3_malloc( p->nDigit+4 ); + if( z==0 ){ + sqlite3_result_error_nomem(pCtx); + return; + } + i = 0; + if( p->nDigit==0 || (p->nDigit==1 && p->a[0]==0) ){ + p->sign = 0; + } + if( p->sign ){ + z[0] = '-'; + i = 1; + } + n = p->nDigit - p->nFrac; + if( n<=0 ){ + z[i++] = '0'; + } + j = 0; + while( n>0 ){ + z[i++] = p->a[j] + '0'; + j++; + n--; + } + if( p->nFrac ){ + z[i++] = '.'; + do{ + z[i++] = p->a[j] + '0'; + j++; + }while( jnDigit ); + } + z[i] = 0; + sqlite3_result_text(pCtx, z, i, sqlite3_free); +} + +/* +** SQL Function: decimal(X) +** +** Convert input X into decimal and then back into text +*/ +static void decimalFunc( + sqlite3_context *context, + int argc, + sqlite3_value **argv +){ + Decimal *p = decimal_new(context, argv[0], 0, 0); + decimal_result(context, p); + decimal_free(p); +} + +/* +** Compare to Decimal objects. Return negative, 0, or positive if the +** first object is less than, equal to, or greater than the second. +** +** Preconditions for this routine: +** +** pA!=0 +** pA->isNull==0 +** pB!=0 +** pB->isNull==0 +*/ +static int decimal_cmp(const Decimal *pA, const Decimal *pB){ + int nASig, nBSig, rc, n; + if( pA->sign!=pB->sign ){ + return pA->sign ? -1 : +1; + } + if( pA->sign ){ + const Decimal *pTemp = pA; + pA = pB; + pB = pTemp; + } + nASig = pA->nDigit - pA->nFrac; + nBSig = pB->nDigit - pB->nFrac; + if( nASig!=nBSig ){ + return nASig - nBSig; + } + n = pA->nDigit; + if( n>pB->nDigit ) n = pB->nDigit; + rc = memcmp(pA->a, pB->a, n); + if( rc==0 ){ + rc = pA->nDigit - pB->nDigit; + } + return rc; +} + +/* +** SQL Function: decimal_cmp(X, Y) +** +** Return negative, zero, or positive if X is less then, equal to, or +** greater than Y. +*/ +static void decimalCmpFunc( + sqlite3_context *context, + int argc, + sqlite3_value **argv +){ + Decimal *pA = 0, *pB = 0; + int rc; + + pA = decimal_new(context, argv[0], 0, 0); + if( pA==0 || pA->isNull ) goto cmp_done; + pB = decimal_new(context, argv[1], 0, 0); + if( pB==0 || pB->isNull ) goto cmp_done; + rc = decimal_cmp(pA, pB); + if( rc<0 ) rc = -1; + else if( rc>0 ) rc = +1; + sqlite3_result_int(context, rc); +cmp_done: + decimal_free(pA); + decimal_free(pB); +} + +/* +** Remove leading zeros from the Decimal +*/ +static void decimal_normalize(Decimal *p){ + int i; + int nSig; + if( p==0 ) return; + nSig = p->nDigit - p->nFrac; + for(i=0; ia[i]==0; i++){} + if( i ){ + memmove(p->a, p->a+i, p->nDigit - i); + p->nDigit -= i; + } +} + +/* +** Expand the Decimal so that it has a least nDigit digits and nFrac +** digits to the right of the decimal point. +*/ +static void decimal_expand(Decimal *p, int nDigit, int nFrac){ + int nAddSig; + int nAddFrac; + if( p==0 ) return; + nAddFrac = nFrac - p->nFrac; + nAddSig = (nDigit - p->nDigit) - nAddFrac; + if( nAddFrac==0 && nAddSig==0 ) return; + p->a = sqlite3_realloc64(p->a, nDigit+1); + if( p->a==0 ){ + p->oom = 1; + return; + } + if( nAddSig ){ + memmove(p->a+nAddSig, p->a, p->nDigit); + memset(p->a, 0, nAddSig); + p->nDigit += nAddSig; + } + if( nAddFrac ){ + memset(p->a+p->nDigit, 0, nAddFrac); + p->nDigit += nAddFrac; + p->nFrac += nAddFrac; + } +} + +/* +** Add the value pB into pA. +** +** pB might become denormalized by this routine. +*/ +static void decimal_add(Decimal *pA, Decimal *pB){ + int nSig, nFrac, nDigit; + int i, rc; + if( pA==0 ){ + return; + } + if( pA->oom || pB==0 || pB->oom ){ + pA->oom = 1; + return; + } + if( pA->isNull || pB->isNull ){ + pA->isNull = 1; + return; + } + nSig = pA->nDigit - pA->nFrac; + if( nSignDigit-pB->nFrac ) nSig = pB->nDigit - pB->nFrac; + nFrac = pA->nFrac; + if( nFracnFrac ) nFrac = pB->nFrac; + nDigit = nSig + nFrac + 1; + decimal_expand(pA, nDigit, nFrac); + decimal_expand(pB, nDigit, nFrac); + if( pA->oom || pB->oom ){ + pA->oom = 1; + }else{ + if( pA->sign==pB->sign ){ + int carry = 0; + for(i=nDigit-1; i>=0; i--){ + int x = pA->a[i] + pB->a[i] + carry; + if( x>=10 ){ + carry = 1; + pA->a[i] = x - 10; + }else{ + carry = 0; + pA->a[i] = x; + } + } + }else{ + signed char *aA, *aB; + int borrow = 0; + rc = memcmp(pA->a, pB->a, nDigit); + if( rc<0 ){ + aA = pB->a; + aB = pA->a; + pA->sign = !pA->sign; + }else{ + aA = pA->a; + aB = pB->a; + } + for(i=nDigit-1; i>=0; i--){ + int x = aA[i] - aB[i] - borrow; + if( x<0 ){ + pA->a[i] = x+10; + borrow = 1; + }else{ + pA->a[i] = x; + borrow = 0; + } + } + } + } + decimal_normalize(pA); +} + +/* +** Compare text in decimal order. +*/ +static int decimalCollFunc( + void *notUsed, + int nKey1, const void *pKey1, + int nKey2, const void *pKey2 +){ + const unsigned char *zA = (const unsigned char*)pKey1; + const unsigned char *zB = (const unsigned char*)pKey2; + Decimal *pA = decimal_new(0, 0, nKey1, zA); + Decimal *pB = decimal_new(0, 0, nKey2, zB); + int rc; + if( pA==0 || pB==0 ){ + rc = 0; + }else{ + rc = decimal_cmp(pA, pB); + } + decimal_free(pA); + decimal_free(pB); + return rc; +} + + +/* +** SQL Function: decimal_add(X, Y) +** decimal_sub(X, Y) +** +** Return the sum or difference of X and Y. +*/ +static void decimalAddFunc( + sqlite3_context *context, + int argc, + sqlite3_value **argv +){ + Decimal *pA = decimal_new(context, argv[0], 0, 0); + Decimal *pB = decimal_new(context, argv[1], 0, 0); + decimal_add(pA, pB); + decimal_result(context, pA); + decimal_free(pA); + decimal_free(pB); +} +static void decimalSubFunc( + sqlite3_context *context, + int argc, + sqlite3_value **argv +){ + Decimal *pA = decimal_new(context, argv[0], 0, 0); + Decimal *pB = decimal_new(context, argv[1], 0, 0); + if( pB==0 ) return; + pB->sign = !pB->sign; + decimal_add(pA, pB); + decimal_result(context, pA); + decimal_free(pA); + decimal_free(pB); +} + +/* Aggregate funcion: decimal_sum(X) +** +** Works like sum() except that it uses decimal arithmetic for unlimited +** precision. +*/ +static void decimalSumStep( + sqlite3_context *context, + int argc, + sqlite3_value **argv +){ + Decimal *p; + Decimal *pArg; + p = sqlite3_aggregate_context(context, sizeof(*p)); + if( p==0 ) return; + if( !p->isInit ){ + p->isInit = 1; + p->a = sqlite3_malloc(2); + if( p->a==0 ){ + p->oom = 1; + }else{ + p->a[0] = 0; + } + p->nDigit = 1; + p->nFrac = 0; + } + if( sqlite3_value_type(argv[0])==SQLITE_NULL ) return; + pArg = decimal_new(context, argv[0], 0, 0); + decimal_add(p, pArg); + decimal_free(pArg); +} +static void decimalSumInverse( + sqlite3_context *context, + int argc, + sqlite3_value **argv +){ + Decimal *p; + Decimal *pArg; + p = sqlite3_aggregate_context(context, sizeof(*p)); + if( p==0 ) return; + if( sqlite3_value_type(argv[0])==SQLITE_NULL ) return; + pArg = decimal_new(context, argv[0], 0, 0); + if( pArg ) pArg->sign = !pArg->sign; + decimal_add(p, pArg); + decimal_free(pArg); +} +static void decimalSumValue(sqlite3_context *context){ + Decimal *p = sqlite3_aggregate_context(context, 0); + if( p==0 ) return; + decimal_normalize(p); + decimal_result(context, p); +} +static void decimalSumFinalize(sqlite3_context *context){ + Decimal *p = sqlite3_aggregate_context(context, 0); + if( p==0 ) return; + decimal_normalize(p); + decimal_result(context, p); + decimal_clear(p); +} + + +#ifdef _WIN32 +__declspec(dllexport) +#endif +int sqlite3_decimal_init( + sqlite3 *db, + char **pzErrMsg, + const sqlite3_api_routines *pApi +){ + int rc = SQLITE_OK; + SQLITE_EXTENSION_INIT2(pApi); + static const struct { + const char *zFuncName; + int nArg; + void (*xFunc)(sqlite3_context*,int,sqlite3_value**); + } aFunc[] = { + { "decimal", 1, decimalFunc }, + { "decimal_cmp", 2, decimalCmpFunc }, + { "decimal_add", 2, decimalAddFunc }, + { "decimal_sub", 2, decimalSubFunc }, + }; + int i; + (void)pzErrMsg; /* Unused parameter */ + + for(i=0; idb, 0, 0); sqlite3_completion_init(p->db, 0, 0); sqlite3_uint_init(p->db, 0, 0); + sqlite3_decimal_init(p->db, 0, 0); #if !defined(SQLITE_OMIT_VIRTUALTABLE) && defined(SQLITE_ENABLE_DBPAGE_VTAB) sqlite3_dbdata_init(p->db, 0, 0); #endif diff --git a/src/test1.c b/src/test1.c index 8c8853737e..f94adf2d60 100644 --- a/src/test1.c +++ b/src/test1.c @@ -7257,6 +7257,7 @@ static int SQLITE_TCLAPI tclLoadStaticExtensionCmd( extern int sqlite3_eval_init(sqlite3*,char**,const sqlite3_api_routines*); extern int sqlite3_explain_init(sqlite3*,char**,const sqlite3_api_routines*); extern int sqlite3_fileio_init(sqlite3*,char**,const sqlite3_api_routines*); + extern int sqlite3_decimal_init(sqlite3*,char**,const sqlite3_api_routines*); extern int sqlite3_fuzzer_init(sqlite3*,char**,const sqlite3_api_routines*); extern int sqlite3_ieee_init(sqlite3*,char**,const sqlite3_api_routines*); extern int sqlite3_nextchar_init(sqlite3*,char**,const sqlite3_api_routines*); @@ -7282,6 +7283,7 @@ static int SQLITE_TCLAPI tclLoadStaticExtensionCmd( { "carray", sqlite3_carray_init }, { "closure", sqlite3_closure_init }, { "csv", sqlite3_csv_init }, + { "decimal", sqlite3_decimal_init }, { "eval", sqlite3_eval_init }, { "explain", sqlite3_explain_init }, { "fileio", sqlite3_fileio_init }, diff --git a/test/decimal.test b/test/decimal.test new file mode 100644 index 0000000000..1097d14db6 --- /dev/null +++ b/test/decimal.test @@ -0,0 +1,118 @@ +# 2017 December 9 +# +# The author disclaims copyright to this source code. In place of +# a legal notice, here is a blessing: +# +# May you do good and not evil. +# May you find forgiveness for yourself and forgive others. +# May you share freely, never taking more than you give. +# +#*********************************************************************** +# + +set testdir [file dirname $argv0] +source $testdir/tester.tcl +set testprefix decimal + +if {[catch {load_static_extension db decimal} error]} { + puts "Skipping zipfile tests, hit load error: $error" + finish_test; return +} + +do_execsql_test 1000 { + SELECT decimal(1); +} {1} +do_execsql_test 1010 { + SELECT decimal(1.0); +} {1.0} +do_execsql_test 1020 { + SELECT decimal(0001.0); +} {1.0} +do_execsql_test 1030 { + SELECT decimal(+0001.0); +} {1.0} +do_execsql_test 1040 { + SELECT decimal(-0001.0); +} {-1.0} +do_execsql_test 1050 { + SELECT decimal(1.0e72); +} {1000000000000000000000000000000000000000000000000000000000000000000000000} +# 123456789 123456789 123456789 123456789 123456789 123456789 123456789 123 +do_execsql_test 1060 { + SELECT decimal(1.0e-72); +} {0.0000000000000000000000000000000000000000000000000000000000000000000000010} +# 123456789 123456789 123456789 123456789 123456789 123456789 123456789 123 +do_execsql_test 1070 { + SELECT decimal(-123e-4); +} {-0.0123} +do_execsql_test 1080 { + SELECT decimal(+123e+4); +} {1230000.0} + + +do_execsql_test 2000 { + CREATE TABLE t1(seq INTEGER PRIMARY KEY, val TEXT); + INSERT INTO t1 VALUES + (1, '-9999e99'), + (2, '-9998.000e+99'), + (3, '-9999.0'), + (4, '-1'), + (5, '-9999e-20'), + (6, '0'), + (7, '1e-30'), + (8, '1e-29'), + (9, '1'), + (10,'1.00000000000000001'), + (11,'+1.00001'), + (12,'99e+99'); + SELECT *, '|' + FROM t1 AS a, t1 AS b + WHERE a.seq=0; +} {} +do_execsql_test 2010 { + SELECT *, '|' + FROM t1 AS a, t1 AS b + WHERE a.seq<>b.seq + AND decimal_cmp(a.val,b.val)==0; +} {} +do_execsql_test 2020 { + SELECT *, '|' + FROM t1 AS a, t1 AS b + WHERE a.seq>b.seq + AND decimal_cmp(a.val,b.val)<=0; +} {} +do_execsql_test 2030 { + SELECT seq FROM t1 ORDER BY val COLLATE decimal; +} {1 2 3 4 5 6 7 8 9 10 11 12} +do_execsql_test 2040 { + SELECT seq FROM t1 ORDER BY val COLLATE decimal DESC; +} {12 11 10 9 8 7 6 5 4 3 2 1} + +do_execsql_test 3000 { + CREATE TABLE t3(seq INTEGER PRIMARY KEY, val TEXT); + WITH RECURSIVE c(x) AS (VALUES(1) UNION SELECT x+1 FROM c WHERE x<10) + INSERT INTO t3(seq, val) SELECT x, x FROM c; + WITH RECURSIVE c(x) AS (VALUES(1) UNION SELECT x+1 FROM c WHERE x<5) + INSERT INTO t3(seq, val) SELECT x+10, x*1000 FROM c; + SELECT decimal(val) FROM t3 ORDER BY seq; +} {1 2 3 4 5 6 7 8 9 10 1000 2000 3000 4000 5000} +do_execsql_test 3020 { + SELECT decimal_add(val,'0.5') FROM t3 WHERE seq>5 ORDER BY seq +} {6.5 7.5 8.5 9.5 10.5 1000.5 2000.5 3000.5 4000.5 5000.5} +do_execsql_test 3030 { + SELECT decimal_add(val,'-10') FROM t3 ORDER BY seq; +} {-9 -8 -7 -6 -5 -4 -3 -2 -1 0 990 1990 2990 3990 4990} + +do_execsql_test 4000 { + SELECT decimal_sum(val) FROM t3; +} {15055} +do_execsql_test 4010 { + SELECT decimal_sum(decimal_add(val,val||'e+10')) FROM t3; +} {150550000015055} +do_execsql_test 4010 { + SELECT decimal_sum(decimal_add(val||'e+20',decimal_add(val,val||'e-20'))) + FROM t3; +} {1505500000000000000015055.00000000000000015055} + +finish_test