Optimise numeric multiplication for short inputs.

When either input has a small number of digits, and the exact product
is requested, the speed of numeric multiplication can be increased
significantly by using a faster direct multiplication algorithm. This
works by fully computing each result digit in turn, starting with the
least significant, and propagating the carry up. This save cycles by
not requiring a temporary buffer to store digit products, not making
multiple passes over the digits of the longer input, and not requiring
separate carry-propagation passes.

For now, this is used when the shorter input has 1-4 NBASE digits (up
to 13-16 decimal digits), and the longer input is of any size, which
covers a lot of common real-world cases. Also, the relative benefit
increases as the size of the longer input increases.

Possible future work would be to try extending the technique to larger
numbers of digits in the shorter input.

Joel Jacobson and Dean Rasheed.

Discussion: https://postgr.es/m/44d2ffca-d560-4919-b85a-4d07060946aa@app.fastmail.com
This commit is contained in:
Dean Rasheed 2024-07-09 10:00:42 +01:00
parent 42de72fa7b
commit ca481d3c9a
1 changed files with 219 additions and 1 deletions

View File

@ -558,6 +558,8 @@ static void sub_var(const NumericVar *var1, const NumericVar *var2,
static void mul_var(const NumericVar *var1, const NumericVar *var2,
NumericVar *result,
int rscale);
static void mul_var_short(const NumericVar *var1, const NumericVar *var2,
NumericVar *result);
static void div_var(const NumericVar *var1, const NumericVar *var2,
NumericVar *result,
int rscale, bool round);
@ -8722,7 +8724,7 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
var1digits = var1->digits;
var2digits = var2->digits;
if (var1ndigits == 0 || var2ndigits == 0)
if (var1ndigits == 0)
{
/* one or both inputs is zero; so is result */
zero_var(result);
@ -8730,6 +8732,16 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
return;
}
/*
* If var1 has 1-4 digits and the exact result was requested, delegate to
* mul_var_short() which uses a faster direct multiplication algorithm.
*/
if (var1ndigits <= 4 && rscale == var1->dscale + var2->dscale)
{
mul_var_short(var1, var2, result);
return;
}
/* Determine result sign and (maximum possible) weight */
if (var1->sign == var2->sign)
res_sign = NUMERIC_POS;
@ -8880,6 +8892,212 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
}
/*
* mul_var_short() -
*
* Special-case multiplication function used when var1 has 1-4 digits, var2
* has at least as many digits as var1, and the exact product var1 * var2 is
* requested.
*/
static void
mul_var_short(const NumericVar *var1, const NumericVar *var2,
NumericVar *result)
{
int var1ndigits = var1->ndigits;
int var2ndigits = var2->ndigits;
NumericDigit *var1digits = var1->digits;
NumericDigit *var2digits = var2->digits;
int res_sign;
int res_weight;
int res_ndigits;
NumericDigit *res_buf;
NumericDigit *res_digits;
uint32 carry;
uint32 term;
/* Check preconditions */
Assert(var1ndigits >= 1);
Assert(var1ndigits <= 4);
Assert(var2ndigits >= var1ndigits);
/*
* Determine the result sign, weight, and number of digits to calculate.
* The weight figured here is correct if the product has no leading zero
* digits; otherwise strip_var() will fix things up. Note that, unlike
* mul_var(), we do not need to allocate an extra output digit, because we
* are not rounding here.
*/
if (var1->sign == var2->sign)
res_sign = NUMERIC_POS;
else
res_sign = NUMERIC_NEG;
res_weight = var1->weight + var2->weight + 1;
res_ndigits = var1ndigits + var2ndigits;
/* Allocate result digit array */
res_buf = digitbuf_alloc(res_ndigits + 1);
res_buf[0] = 0; /* spare digit for later rounding */
res_digits = res_buf + 1;
/*
* Compute the result digits in reverse, in one pass, propagating the
* carry up as we go. The i'th result digit consists of the sum of the
* products var1digits[i1] * var2digits[i2] for which i = i1 + i2 + 1.
*/
switch (var1ndigits)
{
case 1:
/* ---------
* 1-digit case:
* var1ndigits = 1
* var2ndigits >= 1
* res_ndigits = var2ndigits + 1
* ----------
*/
carry = 0;
for (int i = res_ndigits - 2; i >= 0; i--)
{
term = (uint32) var1digits[0] * var2digits[i] + carry;
res_digits[i + 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
}
res_digits[0] = (NumericDigit) carry;
break;
case 2:
/* ---------
* 2-digit case:
* var1ndigits = 2
* var2ndigits >= 2
* res_ndigits = var2ndigits + 2
* ----------
*/
/* last result digit and carry */
term = (uint32) var1digits[1] * var2digits[res_ndigits - 3];
res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
/* remaining digits, except for the first two */
for (int i = res_ndigits - 3; i >= 1; i--)
{
term = (uint32) var1digits[0] * var2digits[i] +
(uint32) var1digits[1] * var2digits[i - 1] + carry;
res_digits[i + 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
}
/* first two digits */
term = (uint32) var1digits[0] * var2digits[0] + carry;
res_digits[1] = (NumericDigit) (term % NBASE);
res_digits[0] = (NumericDigit) (term / NBASE);
break;
case 3:
/* ---------
* 3-digit case:
* var1ndigits = 3
* var2ndigits >= 3
* res_ndigits = var2ndigits + 3
* ----------
*/
/* last two result digits */
term = (uint32) var1digits[2] * var2digits[res_ndigits - 4];
res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = (uint32) var1digits[1] * var2digits[res_ndigits - 4] +
(uint32) var1digits[2] * var2digits[res_ndigits - 5] + carry;
res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
/* remaining digits, except for the first three */
for (int i = res_ndigits - 4; i >= 2; i--)
{
term = (uint32) var1digits[0] * var2digits[i] +
(uint32) var1digits[1] * var2digits[i - 1] +
(uint32) var1digits[2] * var2digits[i - 2] + carry;
res_digits[i + 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
}
/* first three digits */
term = (uint32) var1digits[0] * var2digits[1] +
(uint32) var1digits[1] * var2digits[0] + carry;
res_digits[2] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = (uint32) var1digits[0] * var2digits[0] + carry;
res_digits[1] = (NumericDigit) (term % NBASE);
res_digits[0] = (NumericDigit) (term / NBASE);
break;
case 4:
/* ---------
* 4-digit case:
* var1ndigits = 4
* var2ndigits >= 4
* res_ndigits = var2ndigits + 4
* ----------
*/
/* last three result digits */
term = (uint32) var1digits[3] * var2digits[res_ndigits - 5];
res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = (uint32) var1digits[2] * var2digits[res_ndigits - 5] +
(uint32) var1digits[3] * var2digits[res_ndigits - 6] + carry;
res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = (uint32) var1digits[1] * var2digits[res_ndigits - 5] +
(uint32) var1digits[2] * var2digits[res_ndigits - 6] +
(uint32) var1digits[3] * var2digits[res_ndigits - 7] + carry;
res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
/* remaining digits, except for the first four */
for (int i = res_ndigits - 5; i >= 3; i--)
{
term = (uint32) var1digits[0] * var2digits[i] +
(uint32) var1digits[1] * var2digits[i - 1] +
(uint32) var1digits[2] * var2digits[i - 2] +
(uint32) var1digits[3] * var2digits[i - 3] + carry;
res_digits[i + 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
}
/* first four digits */
term = (uint32) var1digits[0] * var2digits[2] +
(uint32) var1digits[1] * var2digits[1] +
(uint32) var1digits[2] * var2digits[0] + carry;
res_digits[3] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = (uint32) var1digits[0] * var2digits[1] +
(uint32) var1digits[1] * var2digits[0] + carry;
res_digits[2] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = (uint32) var1digits[0] * var2digits[0] + carry;
res_digits[1] = (NumericDigit) (term % NBASE);
res_digits[0] = (NumericDigit) (term / NBASE);
break;
}
/* Store the product in result */
digitbuf_free(result->buf);
result->ndigits = res_ndigits;
result->buf = res_buf;
result->digits = res_digits;
result->weight = res_weight;
result->sign = res_sign;
result->dscale = var1->dscale + var2->dscale;
/* Strip leading and trailing zeroes */
strip_var(result);
}
/*
* div_var() -
*