Merge pull request #4416 from SparkiDev/mp_submod_addmod_ct

SP math, TFM: constant time addmod, submod
This commit is contained in:
David Garske 2021-09-20 11:37:45 -07:00 committed by GitHub
commit 753a931196
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 132 additions and 106 deletions

View File

@ -4178,47 +4178,13 @@ int sp_submod(sp_int* a, sp_int* b, sp_int* m, sp_int* r)
}
#endif /* WOLFSSL_SP_MATH_ALL */
#if defined(WOLFSSL_SP_MATH_ALL) && defined(HAVE_ECC)
/* Compare two multi-precision numbers.
*
* Constant time implementation.
*
* @param [in] a SP integer to compare.
* @param [in] b SP integer to compare.
* @param [in] len Number of digits to compare.
*
* @return MP_GT when a is greater than b.
* @return MP_LT when a is less than b.
* @return MP_EQ when a is equals b.
*/
static int sp_cmp_mag_ct(sp_int* a, sp_int* b, int len)
{
int i;
sp_sint_digit r = MP_EQ;
sp_int_digit mask = SP_MASK;
for (i = len - 1; i >= 0; i--) {
sp_int_digit am = 0 - (i < a->used);
sp_int_digit bm = 0 - (i < b->used);
sp_int_digit ad = a->dp[i] & am;
sp_int_digit bd = b->dp[i] & bm;
r |= mask & (ad > bd);
mask &= (ad > bd) - 1;
r |= mask & (-(ad < bd));
mask &= (ad < bd) - 1;
}
return (int)r;
}
#endif /* WOLFSSL_SP_MATH_ALL && HAVE_ECC */
#if defined(WOLFSSL_SP_MATH_ALL) && defined(HAVE_ECC)
/* Add two value and reduce: r = (a + b) % m
*
* r = a + b (mod m) - constant time (a < m and b < m, a, b and m are positive)
*
* Assumes a, b, m and r are not NULL.
* m and r must not be the same pointer.
*
* @param [in] a SP integer to add.
* @param [in] b SP integer to add with.
@ -4230,11 +4196,15 @@ static int sp_cmp_mag_ct(sp_int* a, sp_int* b, int len)
int sp_addmod_ct(sp_int* a, sp_int* b, sp_int* m, sp_int* r)
{
int err = MP_OKAY;
sp_int_word w = 0;
sp_int_sword w;
sp_int_sword s;
sp_int_digit mask;
int i;
if ((r->size < m->used + 1) || (m->used == m->size)) {
if (r->size < m->used) {
err = MP_VAL;
}
if ((err == MP_OKAY) && (r == m)) {
err = MP_VAL;
}
@ -4245,19 +4215,43 @@ int sp_addmod_ct(sp_int* a, sp_int* b, sp_int* m, sp_int* r)
sp_print(m, "m");
}
_sp_add_off(a, b, r, 0);
mask = 0 - (sp_cmp_mag_ct(r, m, m->used + 1) != MP_LT);
/* Add a to b into r. Do the subtract of modulus but don't store result.
* When subtract result is negative, the overflow will be negative.
* Only need to subtract mod when result is positive - overflow is
* positive.
*/
w = 0;
s = 0;
for (i = 0; i < m->used; i++) {
sp_int_digit mask_r = 0 - (i < r->used);
w += m->dp[i] & mask;
w = (r->dp[i] & mask_r) - w;
r->dp[i] = (sp_int_digit)w;
w = (w >> DIGIT_BIT) & 1;
/* Values past 'used' are not initialized. */
sp_int_digit mask_a = (sp_int_digit)0 - (i < a->used);
sp_int_digit mask_b = (sp_int_digit)0 - (i < b->used);
w += a->dp[i] & mask_a;
w += b->dp[i] & mask_b;
r->dp[i] = (sp_int_digit)w;
s += (sp_int_digit)w;
s -= m->dp[i];
s >>= DIGIT_BIT;
w >>= DIGIT_BIT;
}
r->dp[i] = 0;
s += (sp_int_digit)w;
/* s will be positive when subtracting modulus is needed. */
mask = (sp_int_digit)0 - (s >= 0);
/* Constant time, conditionally, subtract modulus from sum. */
w = 0;
for (i = 0; i < m->used; i++) {
w += r->dp[i];
w -= m->dp[i] & mask;
r->dp[i] = (sp_int_digit)w;
w >>= DIGIT_BIT;
}
/* Result will always have digits equal to or less than those in
* modulus. */
r->used = i;
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = a->sign;
r->sign = MP_ZPOS;
#endif /* WOLFSSL_SP_INT_NEGATIVE */
sp_clamp(r);
@ -4277,6 +4271,7 @@ int sp_addmod_ct(sp_int* a, sp_int* b, sp_int* m, sp_int* r)
* r = a - b (mod m) - constant time (a < m and b < m, a, b and m are positive)
*
* Assumes a, b, m and r are not NULL.
* m and r must not be the same pointer.
*
* @param [in] a SP integer to subtract from
* @param [in] b SP integer to subtract.
@ -4288,13 +4283,16 @@ int sp_addmod_ct(sp_int* a, sp_int* b, sp_int* m, sp_int* r)
int sp_submod_ct(sp_int* a, sp_int* b, sp_int* m, sp_int* r)
{
int err = MP_OKAY;
sp_int_word w = 0;
sp_int_sword w;
sp_int_digit mask;
int i;
if (r->size < m->used + 1) {
err = MP_VAL;
}
if ((err == MP_OKAY) && (r == m)) {
err = MP_VAL;
}
if (err == MP_OKAY) {
if (0) {
@ -4303,23 +4301,34 @@ int sp_submod_ct(sp_int* a, sp_int* b, sp_int* m, sp_int* r)
sp_print(m, "m");
}
mask = 0 - (sp_cmp_mag_ct(a, b, m->used) == MP_LT);
/* In constant time, subtract b from a putting result in r. */
w = 0;
for (i = 0; i < m->used; i++) {
sp_int_digit mask_a = 0 - (i < a->used);
sp_int_digit mask_m = 0 - (i < m->used);
/* Values past 'used' are not initialized. */
sp_int_digit mask_a = (sp_int_digit)0 - (i < a->used);
sp_int_digit mask_b = (sp_int_digit)0 - (i < b->used);
w += m->dp[i] & mask_m & mask;
w += a->dp[i] & mask_a;
w -= b->dp[i] & mask_b;
r->dp[i] = (sp_int_digit)w;
w >>= DIGIT_BIT;
}
r->dp[i] = (sp_int_digit)w;
r->used = i + 1;
/* When w is negative then we need to add modulus to make result
* positive. */
mask = (sp_int_digit)0 - (w < 0);
/* Constant time, conditionally, add modulus to difference. */
w = 0;
for (i = 0; i < m->used; i++) {
w += r->dp[i];
w += m->dp[i] & mask;
r->dp[i] = (sp_int_digit)w;
w >>= DIGIT_BIT;
}
r->used = i;
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = MP_ZPOS;
#endif /* WOLFSSL_SP_INT_NEGATIVE */
sp_clamp(r);
_sp_sub_off(r, b, r, 0);
if (0) {
sp_print(r, "rms");

View File

@ -96,26 +96,6 @@ word32 CheckRunTimeFastMath(void)
/* Functions */
static fp_digit fp_cmp_mag_ct(fp_int *a, fp_int *b, int len)
{
int i;
fp_digit r = FP_EQ;
fp_digit mask = (fp_digit)-1;
for (i = len - 1; i >= 0; i--) {
/* 0 is placed into unused digits. */
fp_digit ad = a->dp[i];
fp_digit bd = b->dp[i];
r |= mask & (ad > bd);
mask &= (ad > bd) - 1;
r |= mask & (-(ad < bd));
mask &= (ad < bd) - 1;
}
return r;
}
int fp_add(fp_int *a, fp_int *b, fp_int *c)
{
int sa, sb;
@ -1619,62 +1599,93 @@ int fp_addmod(fp_int *a, fp_int *b, fp_int *c, fp_int *d)
return err;
}
/* d = a - b (mod c) - constant time (a < c and b < c and all positive) */
/* d = a - b (mod c) - constant time (a < c and b < c and all positive)
* c and d must not be the same pointers.
*/
int fp_submod_ct(fp_int *a, fp_int *b, fp_int *c, fp_int *d)
{
fp_word w = 0;
fp_sword w;
fp_digit mask;
int i;
if (c->used + 1 > FP_SIZE) {
return FP_VAL;
return FP_VAL;
}
if (c == d) {
return FP_VAL;
}
/* Check whether b is greater than a. mask has all bits set when true. */
mask = 0 - (fp_cmp_mag_ct(a, b, c->used + 1) == (fp_digit)FP_LT);
/* Constant time, conditionally, add modulus to a into result. */
/* In constant time, subtract b from a putting result in d. */
w = 0;
for (i = 0; i < c->used; i++) {
fp_digit mask_a = 0 - (i < a->used);
w += c->dp[i] & mask;
w += a->dp[i] & mask_a;
d->dp[i] = (fp_digit)w;
w >>= DIGIT_BIT;
w += a->dp[i];
w -= b->dp[i];
d->dp[i] = (fp_digit)w;
w >>= DIGIT_BIT;
}
/* Handle overflow */
d->dp[i] = (fp_digit)w;
d->used = i + 1;
w += a->dp[i];
w -= b->dp[i];
w >>= DIGIT_BIT;
/* When w is negative then we need to add modulus to make result positive. */
mask = (fp_digit)0 - (w < 0);
/* Constant time, conditionally, add modulus to difference. */
w = 0;
for (i = 0; i < c->used; i++) {
w += d->dp[i];
w += c->dp[i] & mask;
d->dp[i] = (fp_digit)w;
w >>= DIGIT_BIT;
}
/* Result will always have digits equal to or less than those in modulus. */
d->used = i;
d->sign = FP_ZPOS;
fp_clamp(d);
/* Subtract b from a (that my have had modulus added to it). */
s_fp_sub(d, b, d);
return FP_OKAY;
}
/* d = a + b (mod c) - constant time (a < c and b < c and all positive) */
/* d = a + b (mod c) - constant time (a < c and b < c and all positive)
* c and d must not be the same pointers.
*/
int fp_addmod_ct(fp_int *a, fp_int *b, fp_int *c, fp_int *d)
{
fp_word w = 0;
fp_word w;
fp_sword s;
fp_digit mask;
int i;
if (c->used + 1 > FP_SIZE) {
return FP_VAL;
if (c == d) {
return FP_VAL;
}
s_fp_add(a, b, d);
/* Check whether sum is bigger than modulus.
* mask has all bits set when true. */
mask = 0 - (fp_cmp_mag_ct(d, c, c->used + 1) != (fp_digit)FP_LT);
/* Constant time, conditionally, subtract modulus from sum. */
/* Add a to b into d. Do the subtract of modulus but don't store result.
* When subtract result is negative, the overflow will be negative.
* Only need to subtract mod when result is positive - overflow is positive.
*/
w = 0;
s = 0;
for (i = 0; i < c->used; i++) {
w += c->dp[i] & mask;
w = d->dp[i] - w;
d->dp[i] = (fp_digit)w;
w = (w >> DIGIT_BIT)&1;
w += a->dp[i];
w += b->dp[i];
d->dp[i] = (fp_digit)w;
s += (fp_digit)w;
s -= c->dp[i];
w >>= DIGIT_BIT;
s >>= DIGIT_BIT;
}
d->dp[i] = 0;
s += (fp_digit)w;
/* s will be positive when subtracting modulus is needed. */
mask = (fp_digit)0 - (s >= 0);
/* Constant time, conditionally, subtract modulus from sum. */
w = 0;
for (i = 0; i < c->used; i++) {
w += c->dp[i] & mask;
w = d->dp[i] - w;
d->dp[i] = (fp_digit)w;
w = (w >> DIGIT_BIT)&1;
}
/* Result will always have digits equal to or less than those in modulus. */
d->used = i;
d->sign = FP_ZPOS;
fp_clamp(d);

View File

@ -221,22 +221,27 @@
typedef unsigned int fp_digit;
#define SIZEOF_FP_DIGIT 2
typedef unsigned long fp_word;
typedef signed long fp_sword;
#elif defined(FP_64BIT)
/* for GCC only on supported platforms */
typedef unsigned long long fp_digit; /* 64bit, 128 uses mode(TI) below */
#define SIZEOF_FP_DIGIT 8
typedef unsigned long fp_word __attribute__ ((mode(TI)));
typedef unsigned long fp_word __attribute__ ((mode(TI)));
typedef signed long fp_sword __attribute__ ((mode(TI)));
#else
#ifndef NO_TFM_64BIT
#if defined(_MSC_VER) || defined(__BORLANDC__)
typedef unsigned __int64 ulong64;
typedef signed __int64 long64;
#else
typedef unsigned long long ulong64;
typedef signed long long long64;
#endif
typedef unsigned int fp_digit;
#define SIZEOF_FP_DIGIT 4
typedef ulong64 fp_word;
typedef long64 fp_sword;
#define FP_32BIT
#else
/* some procs like coldfire prefer not to place multiply into 64bit type
@ -244,6 +249,7 @@
typedef unsigned short fp_digit;
#define SIZEOF_FP_DIGIT 2
typedef unsigned int fp_word;
typedef signed int fp_sword;
#endif
#endif