Extend mul_var_short() to 5 and 6-digit inputs.

Commit ca481d3c9a introduced mul_var_short(), which is used by
mul_var() whenever the shorter input has 1-4 NBASE digits and the
exact product is requested. As speculated on in that commit, it can be
extended to work for more digits in the shorter input. This commit
extends it up to 6 NBASE digits (up to 24 decimal digits), for which
it also gives a significant speedup. This covers more cases likely to
occur in real-world queries, for which using base-NBASE^2 arithmetic
provides little benefit.

To avoid code bloat and duplication, refactor it a bit using macros
and exploiting the fact that some portions of the code are shared
between the different cases.

Dean Rasheed, reviewed by Joel Jacobson.

Discussion: https://postgr.es/m/9d8a4a42-c354-41f3-bbf3-199e1957db97%40app.fastmail.com
This commit is contained in:
Dean Rasheed 2024-08-15 10:33:12 +01:00
parent fce7cb6da0
commit c4e44224cf
1 changed files with 123 additions and 52 deletions

View File

@ -8714,10 +8714,10 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
}
/*
* If var1 has 1-4 digits and the exact result was requested, delegate to
* If var1 has 1-6 digits and the exact result was requested, delegate to
* mul_var_short() which uses a faster direct multiplication algorithm.
*/
if (var1ndigits <= 4 && rscale == var1->dscale + var2->dscale)
if (var1ndigits <= 6 && rscale == var1->dscale + var2->dscale)
{
mul_var_short(var1, var2, result);
return;
@ -8876,7 +8876,7 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
/*
* mul_var_short() -
*
* Special-case multiplication function used when var1 has 1-4 digits, var2
* Special-case multiplication function used when var1 has 1-6 digits, var2
* has at least as many digits as var1, and the exact product var1 * var2 is
* requested.
*/
@ -8898,7 +8898,7 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2,
/* Check preconditions */
Assert(var1ndigits >= 1);
Assert(var1ndigits <= 4);
Assert(var1ndigits <= 6);
Assert(var2ndigits >= var1ndigits);
/*
@ -8925,6 +8925,13 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2,
* carry up as we go. The i'th result digit consists of the sum of the
* products var1digits[i1] * var2digits[i2] for which i = i1 + i2 + 1.
*/
#define PRODSUM1(v1,i1,v2,i2) ((v1)[(i1)] * (v2)[(i2)])
#define PRODSUM2(v1,i1,v2,i2) (PRODSUM1(v1,i1,v2,i2) + (v1)[(i1)+1] * (v2)[(i2)-1])
#define PRODSUM3(v1,i1,v2,i2) (PRODSUM2(v1,i1,v2,i2) + (v1)[(i1)+2] * (v2)[(i2)-2])
#define PRODSUM4(v1,i1,v2,i2) (PRODSUM3(v1,i1,v2,i2) + (v1)[(i1)+3] * (v2)[(i2)-3])
#define PRODSUM5(v1,i1,v2,i2) (PRODSUM4(v1,i1,v2,i2) + (v1)[(i1)+4] * (v2)[(i2)-4])
#define PRODSUM6(v1,i1,v2,i2) (PRODSUM5(v1,i1,v2,i2) + (v1)[(i1)+5] * (v2)[(i2)-5])
switch (var1ndigits)
{
case 1:
@ -8936,9 +8943,9 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2,
* ----------
*/
carry = 0;
for (int i = res_ndigits - 2; i >= 0; i--)
for (int i = var2ndigits - 1; i >= 0; i--)
{
term = (uint32) var1digits[0] * var2digits[i] + carry;
term = PRODSUM1(var1digits, 0, var2digits, i) + carry;
res_digits[i + 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
}
@ -8954,23 +8961,17 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2,
* ----------
*/
/* last result digit and carry */
term = (uint32) var1digits[1] * var2digits[res_ndigits - 3];
term = PRODSUM1(var1digits, 1, var2digits, var2ndigits - 1);
res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
/* remaining digits, except for the first two */
for (int i = res_ndigits - 3; i >= 1; i--)
for (int i = var2ndigits - 1; i >= 1; i--)
{
term = (uint32) var1digits[0] * var2digits[i] +
(uint32) var1digits[1] * var2digits[i - 1] + carry;
term = PRODSUM2(var1digits, 0, var2digits, i) + carry;
res_digits[i + 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
}
/* first two digits */
term = (uint32) var1digits[0] * var2digits[0] + carry;
res_digits[1] = (NumericDigit) (term % NBASE);
res_digits[0] = (NumericDigit) (term / NBASE);
break;
case 3:
@ -8982,34 +8983,21 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2,
* ----------
*/
/* last two result digits */
term = (uint32) var1digits[2] * var2digits[res_ndigits - 4];
term = PRODSUM1(var1digits, 2, var2digits, var2ndigits - 1);
res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = (uint32) var1digits[1] * var2digits[res_ndigits - 4] +
(uint32) var1digits[2] * var2digits[res_ndigits - 5] + carry;
term = PRODSUM2(var1digits, 1, var2digits, var2ndigits - 1) + carry;
res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
/* remaining digits, except for the first three */
for (int i = res_ndigits - 4; i >= 2; i--)
for (int i = var2ndigits - 1; i >= 2; i--)
{
term = (uint32) var1digits[0] * var2digits[i] +
(uint32) var1digits[1] * var2digits[i - 1] +
(uint32) var1digits[2] * var2digits[i - 2] + carry;
term = PRODSUM3(var1digits, 0, var2digits, i) + carry;
res_digits[i + 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
}
/* first three digits */
term = (uint32) var1digits[0] * var2digits[1] +
(uint32) var1digits[1] * var2digits[0] + carry;
res_digits[2] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = (uint32) var1digits[0] * var2digits[0] + carry;
res_digits[1] = (NumericDigit) (term % NBASE);
res_digits[0] = (NumericDigit) (term / NBASE);
break;
case 4:
@ -9021,45 +9009,128 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2,
* ----------
*/
/* last three result digits */
term = (uint32) var1digits[3] * var2digits[res_ndigits - 5];
term = PRODSUM1(var1digits, 3, var2digits, var2ndigits - 1);
res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = (uint32) var1digits[2] * var2digits[res_ndigits - 5] +
(uint32) var1digits[3] * var2digits[res_ndigits - 6] + carry;
term = PRODSUM2(var1digits, 2, var2digits, var2ndigits - 1) + carry;
res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = (uint32) var1digits[1] * var2digits[res_ndigits - 5] +
(uint32) var1digits[2] * var2digits[res_ndigits - 6] +
(uint32) var1digits[3] * var2digits[res_ndigits - 7] + carry;
term = PRODSUM3(var1digits, 1, var2digits, var2ndigits - 1) + carry;
res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
/* remaining digits, except for the first four */
for (int i = res_ndigits - 5; i >= 3; i--)
for (int i = var2ndigits - 1; i >= 3; i--)
{
term = (uint32) var1digits[0] * var2digits[i] +
(uint32) var1digits[1] * var2digits[i - 1] +
(uint32) var1digits[2] * var2digits[i - 2] +
(uint32) var1digits[3] * var2digits[i - 3] + carry;
term = PRODSUM4(var1digits, 0, var2digits, i) + carry;
res_digits[i + 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
}
break;
/* first four digits */
term = (uint32) var1digits[0] * var2digits[2] +
(uint32) var1digits[1] * var2digits[1] +
(uint32) var1digits[2] * var2digits[0] + carry;
case 5:
/* ---------
* 5-digit case:
* var1ndigits = 5
* var2ndigits >= 5
* res_ndigits = var2ndigits + 5
* ----------
*/
/* last four result digits */
term = PRODSUM1(var1digits, 4, var2digits, var2ndigits - 1);
res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = PRODSUM2(var1digits, 3, var2digits, var2ndigits - 1) + carry;
res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = PRODSUM3(var1digits, 2, var2digits, var2ndigits - 1) + carry;
res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = PRODSUM4(var1digits, 1, var2digits, var2ndigits - 1) + carry;
res_digits[res_ndigits - 4] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
/* remaining digits, except for the first five */
for (int i = var2ndigits - 1; i >= 4; i--)
{
term = PRODSUM5(var1digits, 0, var2digits, i) + carry;
res_digits[i + 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
}
break;
case 6:
/* ---------
* 6-digit case:
* var1ndigits = 6
* var2ndigits >= 6
* res_ndigits = var2ndigits + 6
* ----------
*/
/* last five result digits */
term = PRODSUM1(var1digits, 5, var2digits, var2ndigits - 1);
res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = PRODSUM2(var1digits, 4, var2digits, var2ndigits - 1) + carry;
res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = PRODSUM3(var1digits, 3, var2digits, var2ndigits - 1) + carry;
res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = PRODSUM4(var1digits, 2, var2digits, var2ndigits - 1) + carry;
res_digits[res_ndigits - 4] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = PRODSUM5(var1digits, 1, var2digits, var2ndigits - 1) + carry;
res_digits[res_ndigits - 5] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
/* remaining digits, except for the first six */
for (int i = var2ndigits - 1; i >= 5; i--)
{
term = PRODSUM6(var1digits, 0, var2digits, i) + carry;
res_digits[i + 1] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
}
break;
}
/*
* Finally, for var1ndigits > 1, compute the remaining var1ndigits most
* significant result digits.
*/
switch (var1ndigits)
{
case 6:
term = PRODSUM5(var1digits, 0, var2digits, 4) + carry;
res_digits[5] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
/* FALLTHROUGH */
case 5:
term = PRODSUM4(var1digits, 0, var2digits, 3) + carry;
res_digits[4] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
/* FALLTHROUGH */
case 4:
term = PRODSUM3(var1digits, 0, var2digits, 2) + carry;
res_digits[3] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = (uint32) var1digits[0] * var2digits[1] +
(uint32) var1digits[1] * var2digits[0] + carry;
/* FALLTHROUGH */
case 3:
term = PRODSUM2(var1digits, 0, var2digits, 1) + carry;
res_digits[2] = (NumericDigit) (term % NBASE);
carry = term / NBASE;
term = (uint32) var1digits[0] * var2digits[0] + carry;
/* FALLTHROUGH */
case 2:
term = PRODSUM1(var1digits, 0, var2digits, 0) + carry;
res_digits[1] = (NumericDigit) (term % NBASE);
res_digits[0] = (NumericDigit) (term / NBASE);
break;