Merge pull request #3381 from SparkiDev/ecc_ct_fix

ECC mulmod: some curves can't do order-1
This commit is contained in:
David Garske 2020-10-15 14:46:46 -07:00 committed by GitHub
commit 9793414d78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 215 additions and 244 deletions

View File

@ -2074,7 +2074,20 @@ int ecc_projective_dbl_point(ecc_point *P, ecc_point *R, mp_int* a,
/* Use a and prime to determine if a == 3 */
err = mp_submod(modulus, a, modulus, t2);
}
if (err == MP_OKAY && mp_cmp_d(t2, 3) != MP_EQ) {
if (err == MP_OKAY && mp_iszero(t2)) {
/* T2 = X * X */
if (err == MP_OKAY)
err = mp_sqr(x, t2);
if (err == MP_OKAY)
err = mp_montgomery_reduce(t2, modulus, mp);
/* T1 = T2 + T1 */
if (err == MP_OKAY)
err = mp_addmod_ct(t2, t2, modulus, t1);
/* T1 = T2 + T1 */
if (err == MP_OKAY)
err = mp_addmod_ct(t1, t2, modulus, t1);
}
else if (err == MP_OKAY && mp_cmp_d(t2, 3) != MP_EQ) {
/* use "a" in calc */
/* T2 = T1 * T1 */
@ -2635,182 +2648,158 @@ static int wc_ecc_gen_z(WC_RNG* rng, int size, ecc_point* p,
return err;
}
#if defined(WC_NO_CACHE_RESISTANT)
#define M_POINTS 4
#else
#define M_POINTS 5
#endif
#define M_POINTS 3
static int ecc_mulmod(mp_int* k, ecc_point* tG, ecc_point* R, ecc_point** M,
/* Joye double-add ladder.
* "Highly Regular Right-to-Left Algorithms for Scalar Multiplication"
* by Marc Joye (2007)
*
* Algorithm 1':
* Input: P element of curve, k = (k[t-1],..., k[0]) base 2
* Output: Q = kP
* 1: R[0] = P; R[1] = P
* 2: for j = 1 to t-1 do
* 3: b = 1 - k[j]; R[b] = 2*R[b] + R[k[j]]
* 4: end for
* 5: b = k[0]; R[b] = R[b] - P
* 6: return R[0]
*
* Assumes: k < order.
*/
static int ecc_mulmod(mp_int* k, ecc_point* P, ecc_point* Q, ecc_point** R,
mp_int* a, mp_int* modulus, mp_digit mp, WC_RNG* rng)
{
int err = MP_OKAY;
int i;
int bitcnt = 0, mode = 0, digidx = 0;
mp_digit buf;
/* calc the M tab */
/* M[0] == G */
if (err == MP_OKAY)
err = mp_copy(tG->x, M[0]->x);
if (err == MP_OKAY)
err = mp_copy(tG->y, M[0]->y);
if (err == MP_OKAY)
err = mp_copy(tG->z, M[0]->z);
/* M[1] == 2G */
if (err == MP_OKAY)
err = ecc_projective_dbl_point(tG, M[1], a, modulus, mp);
#ifdef WC_NO_CACHE_RESISTANT
if (err == MP_OKAY)
err = wc_ecc_copy_point(M[0], M[2]);
if (rng != NULL) {
if (err == MP_OKAY) {
err = wc_ecc_gen_z(rng, (mp_count_bits(modulus) + 7) / 8, M[0],
modulus, mp, M[3]->x, M[3]->y);
}
if (err == MP_OKAY) {
err = wc_ecc_gen_z(rng, (mp_count_bits(modulus) + 7) / 8, M[1],
modulus, mp, M[3]->x, M[3]->y);
}
}
#else
if (err == MP_OKAY)
err = wc_ecc_copy_point(M[0], M[3]);
if (err == MP_OKAY)
err = wc_ecc_copy_point(M[1], M[4]);
if (rng != NULL) {
if (err == MP_OKAY) {
err = wc_ecc_gen_z(rng, (mp_count_bits(modulus) + 7) / 8, M[3],
modulus, mp, M[2]->x, M[2]->y);
}
if (err == MP_OKAY) {
err = wc_ecc_gen_z(rng, (mp_count_bits(modulus) + 7) / 8, M[4],
modulus, mp, M[2]->x, M[2]->y);
}
}
int err = MP_OKAY;
int bytes = (mp_count_bits(modulus) + 7) / 8;
int i;
int j = 1;
int cnt = DIGIT_BIT;
int t = 0;
mp_digit b;
mp_digit v = 0;
#ifndef WC_NO_CACHE_RESISTANT
/* First bit always 1 (fix at end) and swap equals first bit */
int swap = 1;
#endif
/* setup sliding window */
mode = 0;
digidx = get_digit_count(modulus) - 1;
/* The order MAY be 1 bit longer than the modulus.
* k MAY be 1 bit longer than the order.
*/
bitcnt = (mp_count_bits(modulus) + 2) % DIGIT_BIT;
digidx += (bitcnt <= 3);
buf = get_digit(k, digidx) << (DIGIT_BIT - bitcnt);
bitcnt = (bitcnt + 1) % DIGIT_BIT;
digidx -= bitcnt != 1;
/* Step 1: R[0] = P; R[1] = P */
/* R[0] = P */
if (err == MP_OKAY)
err = mp_copy(P->x, R[0]->x);
if (err == MP_OKAY)
err = mp_copy(P->y, R[0]->y);
if (err == MP_OKAY)
err = mp_copy(P->z, R[0]->z);
/* perform ops */
if (err == MP_OKAY) {
for (;;) {
/* grab next digit as required */
if (--bitcnt == 0) {
if (digidx == -1) {
break;
}
buf = get_digit(k, digidx);
bitcnt = (int)DIGIT_BIT;
--digidx;
}
/* R[1] = P */
if (err == MP_OKAY)
err = mp_copy(P->x, R[1]->x);
if (err == MP_OKAY)
err = mp_copy(P->y, R[1]->y);
if (err == MP_OKAY)
err = mp_copy(P->z, R[1]->z);
/* grab the next msb from the multiplicand */
i = (buf >> (DIGIT_BIT - 1)) & 1;
buf <<= 1;
/* Randomize z ordinates to obfuscate timing. */
if ((err == MP_OKAY) && (rng != NULL))
err = wc_ecc_gen_z(rng, bytes, R[0], modulus, mp, R[2]->x, R[2]->y);
if ((err == MP_OKAY) && (rng != NULL))
err = wc_ecc_gen_z(rng, bytes, R[1], modulus, mp, R[2]->x, R[2]->y);
if (err == MP_OKAY) {
/* Order could be one greater than the size of the modulus. */
t = mp_count_bits(modulus) + 1;
v = k->dp[0] >> 1;
if (cnt > t) {
cnt = t;
}
err = mp_grow(k, modulus->used + 1);
}
/* Step 2: for j = 1 to t-1 do */
for (i = 1; (err == MP_OKAY) && (i < t); i++) {
if (--cnt == 0) {
v = k->dp[j++];
cnt = DIGIT_BIT;
}
/* Step 3: b = 1 - k[j]; R[b] = 2*R[b] + R[k[j]] */
b = v & 1;
v >>= 1;
#ifdef WC_NO_CACHE_RESISTANT
if (mode == 0) {
/* timing resistant - dummy operations */
if (err == MP_OKAY)
err = ecc_projective_add_point(M[1], M[2], M[3], a, modulus,
mp);
if (err == MP_OKAY)
err = ecc_projective_dbl_point(M[3], M[2], a, modulus, mp);
}
else {
if (err == MP_OKAY)
err = ecc_projective_add_point(M[0], M[1], M[i^1], a,
modulus, mp);
if (err == MP_OKAY)
err = ecc_projective_dbl_point(M[i], M[i], a, modulus, mp);
}
err = ecc_projective_dbl_point(R[b^1], R[b^1], a, modulus, mp);
if (err == MP_OKAY) {
err = ecc_projective_add_point(R[b^1], R[b], R[b^1], a, modulus,
mp);
}
#else
if (err == MP_OKAY)
err = ecc_projective_add_point(M[0], M[1], M[2], a, modulus, mp);
if (err == MP_OKAY)
err = mp_cond_copy(M[2]->x, i, M[0]->x);
if (err == MP_OKAY)
err = mp_cond_copy(M[2]->y, i, M[0]->y);
if (err == MP_OKAY)
err = mp_cond_copy(M[2]->z, i, M[0]->z);
if (err == MP_OKAY)
err = mp_cond_copy(M[2]->x, i ^ 1, M[1]->x);
if (err == MP_OKAY)
err = mp_cond_copy(M[2]->y, i ^ 1, M[1]->y);
if (err == MP_OKAY)
err = mp_cond_copy(M[2]->z, i ^ 1, M[1]->z);
/* Swap R[0] and R[1] if other index is needed. */
swap ^= b;
if (err == MP_OKAY)
err = mp_cond_swap_ct(R[0]->x, R[1]->x, modulus->used, swap);
if (err == MP_OKAY)
err = mp_cond_swap_ct(R[0]->y, R[1]->y, modulus->used, swap);
if (err == MP_OKAY)
err = mp_cond_swap_ct(R[0]->z, R[1]->z, modulus->used, swap);
swap = (int)b;
if (err == MP_OKAY)
err = mp_cond_copy(M[0]->x, i ^ 1, M[2]->x);
if (err == MP_OKAY)
err = mp_cond_copy(M[0]->y, i ^ 1, M[2]->y);
if (err == MP_OKAY)
err = mp_cond_copy(M[0]->z, i ^ 1, M[2]->z);
if (err == MP_OKAY)
err = mp_cond_copy(M[1]->x, i, M[2]->x);
if (err == MP_OKAY)
err = mp_cond_copy(M[1]->y, i, M[2]->y);
if (err == MP_OKAY)
err = mp_cond_copy(M[1]->z, i, M[2]->z);
if (err == MP_OKAY)
err = ecc_projective_dbl_point(M[2], M[2], a, modulus, mp);
if (err == MP_OKAY)
err = mp_cond_copy(M[2]->x, i ^ 1, M[0]->x);
if (err == MP_OKAY)
err = mp_cond_copy(M[2]->y, i ^ 1, M[0]->y);
if (err == MP_OKAY)
err = mp_cond_copy(M[2]->z, i ^ 1, M[0]->z);
if (err == MP_OKAY)
err = mp_cond_copy(M[2]->x, i, M[1]->x);
if (err == MP_OKAY)
err = mp_cond_copy(M[2]->y, i, M[1]->y);
if (err == MP_OKAY)
err = mp_cond_copy(M[2]->z, i, M[1]->z);
if (err == MP_OKAY)
err = mp_cond_copy(M[3]->x, (mode ^ 1) & i, M[0]->x);
if (err == MP_OKAY)
err = mp_cond_copy(M[3]->y, (mode ^ 1) & i, M[0]->y);
if (err == MP_OKAY)
err = mp_cond_copy(M[3]->z, (mode ^ 1) & i, M[0]->z);
if (err == MP_OKAY)
err = mp_cond_copy(M[4]->x, (mode ^ 1) & i, M[1]->x);
if (err == MP_OKAY)
err = mp_cond_copy(M[4]->y, (mode ^ 1) & i, M[1]->y);
if (err == MP_OKAY)
err = mp_cond_copy(M[4]->z, (mode ^ 1) & i, M[1]->z);
if (err == MP_OKAY)
err = ecc_projective_dbl_point(R[0], R[0], a, modulus, mp);
if (err == MP_OKAY)
err = ecc_projective_add_point(R[0], R[1], R[0], a, modulus, mp);
#endif /* WC_NO_CACHE_RESISTANT */
}
/* Step 4: end for */
#ifndef WC_NO_CACHE_RESISTANT
/* Swap back if last bit is 0. */
swap ^= 1;
if (err == MP_OKAY)
err = mp_cond_swap_ct(R[0]->x, R[1]->x, modulus->used, swap);
if (err == MP_OKAY)
err = mp_cond_swap_ct(R[0]->y, R[1]->y, modulus->used, swap);
if (err == MP_OKAY)
err = mp_cond_swap_ct(R[0]->z, R[1]->z, modulus->used, swap);
#endif
if (err != MP_OKAY)
break;
/* Step 5: b = k[0]; R[b] = R[b] - P */
/* R[2] = -P */
if (err == MP_OKAY)
err = mp_copy(P->x, R[2]->x);
if (err == MP_OKAY)
err = mp_sub(modulus, P->y, R[2]->y);
if (err == MP_OKAY)
err = mp_copy(P->z, R[2]->z);
/* Subtract point by adding negative. */
if (err == MP_OKAY) {
b = k->dp[0] & 1;
#ifdef WC_NO_CACHE_RESISTANT
err = ecc_projective_add_point(R[b], R[2], R[b], a, modulus, mp);
#else
/* Swap R[0] and R[1], if necessary, to operate on the one we want. */
err = mp_cond_swap_ct(R[0]->x, R[1]->x, modulus->used, (int)b);
if (err == MP_OKAY)
err = mp_cond_swap_ct(R[0]->y, R[1]->y, modulus->used, (int)b);
if (err == MP_OKAY)
err = mp_cond_swap_ct(R[0]->z, R[1]->z, modulus->used, (int)b);
if (err == MP_OKAY)
err = ecc_projective_add_point(R[0], R[2], R[0], a, modulus, mp);
/* Swap back if necessary. */
if (err == MP_OKAY)
err = mp_cond_swap_ct(R[0]->x, R[1]->x, modulus->used, (int)b);
if (err == MP_OKAY)
err = mp_cond_swap_ct(R[0]->y, R[1]->y, modulus->used, (int)b);
if (err == MP_OKAY)
err = mp_cond_swap_ct(R[0]->z, R[1]->z, modulus->used, (int)b);
#endif
}
mode |= i;
} /* end for */
}
/* Step 6: return R[0] */
if (err == MP_OKAY)
err = mp_copy(R[0]->x, Q->x);
if (err == MP_OKAY)
err = mp_copy(R[0]->y, Q->y);
if (err == MP_OKAY)
err = mp_copy(R[0]->z, Q->z);
/* copy result out */
if (err == MP_OKAY)
err = mp_copy(M[0]->x, R->x);
if (err == MP_OKAY)
err = mp_copy(M[0]->y, R->y);
if (err == MP_OKAY)
err = mp_copy(M[0]->z, R->z);
return err;
return err;
}
#endif
@ -3058,8 +3047,6 @@ int wc_ecc_mulmod_ex2(mp_int* k, ecc_point *G, ecc_point *R, mp_int* a,
mp_digit mp;
#ifdef ECC_TIMING_RESISTANT
mp_int t;
mp_int o;
mp_digit mask;
#endif
if (k == NULL || G == NULL || R == NULL || modulus == NULL) {
@ -3114,49 +3101,30 @@ int wc_ecc_mulmod_ex2(mp_int* k, ecc_point *G, ecc_point *R, mp_int* a,
#ifdef ECC_TIMING_RESISTANT
if ((err = mp_init(&t)) != MP_OKAY)
goto exit;
if ((err = mp_init(&o)) != MP_OKAY) {
mp_free(&t);
goto exit;
}
/* Make k at 1 bit longer than order. */
if (err == MP_OKAY) {
err = mp_add(k, order, &t);
}
if (err == MP_OKAY) {
err = mp_copy(order, &o);
}
if (err == MP_OKAY) {
/* Only add if order + k has same number of bits as order */
mask = (mp_digit)0 - (mp_count_bits(&t) == mp_count_bits(order));
for (i = 0; i < o.used; i++) {
o.dp[i] &= mask;
}
err = mp_add(&t, &o, &t);
}
mp_free(&o);
if (err == MP_OKAY)
err = ecc_mulmod(&t, tG, R, M, a, modulus, mp, rng);
err = ecc_mulmod(k, tG, R, M, a, modulus, mp, rng);
/* Check for k == 1 or k == order+1. Result will be 0 point which is not
* correct. Calculates 2 * order and get 0 point then adds base point
* which results in 0 point with constant time implementation)
/* Check for k == order - 1. Result will be 0 point which is not correct
* Calculates order / 2 and adds order / 2 + 1 and gets infinity.
* (with constant time implementation)
*/
if (err == MP_OKAY)
err = mp_add_d(order, 1, &t);
err = mp_sub_d(order, 1, &t);
if (err == MP_OKAY) {
int kIsOne = (mp_cmp_d(k, 1) == MP_EQ) | (mp_cmp(k, &t) == MP_EQ);
err = mp_cond_copy(tG->x, kIsOne, R->x);
int kIsMinusOne = (mp_cmp(k, &t) == MP_EQ);
err = mp_cond_copy(tG->x, kIsMinusOne, R->x);
if (err == 0) {
err = mp_cond_copy(tG->y, kIsOne, R->y);
err = mp_sub(modulus, tG->y, &t);
}
if (err == 0) {
err = mp_cond_copy(tG->z, kIsOne, R->z);
err = mp_cond_copy(&t, kIsMinusOne, R->y);
}
if (err == 0) {
err = mp_cond_copy(tG->z, kIsMinusOne, R->z);
}
}
mp_forcezero(&t);
mp_free(&t);
#else
err = ecc_mulmod(k, tG, R, M, a, modulus, mp, rng);
@ -5672,7 +5640,7 @@ int ecc_projective_add_point_safe(ecc_point* A, ecc_point* B, ecc_point* R,
err = ecc_projective_add_point(A, B, R, a, modulus, mp);
if ((err == MP_OKAY) && mp_iszero(R->z)) {
/* When all zero then should have done an add */
/* When all zero then should have done a double */
if (mp_iszero(R->x) && mp_iszero(R->y)) {
err = ecc_projective_dbl_point(B, R, a, modulus, mp);
}
@ -9890,52 +9858,6 @@ int wc_ecc_mulmod_ex(mp_int* k, ecc_point *G, ecc_point *R, mp_int* a,
#endif
}
#ifndef WOLFSSL_SP_MATH
static int normal_ecc_mulmod_ex(mp_int* k, ecc_point *G, ecc_point *R,
mp_int* a, mp_int* modulus, mp_int* order,
WC_RNG* rng, int map, void* heap)
{
int err;
mp_int t;
mp_int o;
mp_digit mask;
int i;
if ((err = mp_init(&t)) != MP_OKAY)
return err;
if ((err = mp_init(&o)) != MP_OKAY) {
mp_free(&t);
return err;
}
/* Make k at 1 bit longer than order. */
if (err == MP_OKAY) {
err = mp_add(k, order, &t);
}
if (err == MP_OKAY) {
err = mp_copy(order, &o);
}
if (err == MP_OKAY) {
/* Only add if order + k has same number of bits as order */
mask = (mp_digit)0 - (mp_count_bits(&t) == mp_count_bits(order));
for (i = 0; i < o.used; i++) {
o.dp[i] &= mask;
}
err = mp_add(&t, &o, &t);
}
if (err == MP_OKAY) {
err = normal_ecc_mulmod(&t, G, R, a, modulus, rng, map, heap);
}
mp_forcezero(&t);
mp_free(&o);
mp_free(&t);
return err;
}
#endif /* !WOLFSSL_SP_MATH */
/** ECC Fixed Point mulmod global
k The multiplicand
G Base point to multiply
@ -10022,8 +9944,7 @@ int wc_ecc_mulmod_ex2(mp_int* k, ecc_point *G, ecc_point *R, mp_int* a,
if (err == MP_OKAY)
err = accel_fp_mul(idx, k, R, a, modulus, mp, map);
} else {
err = normal_ecc_mulmod_ex(k, G, R, a, modulus, order, rng, map,
heap);
err = normal_ecc_mulmod(k, G, R, a, modulus, rng, map, heap);
}
}

View File

@ -546,6 +546,14 @@ void mp_exch (mp_int * a, mp_int * b)
*b = t;
}
int mp_cond_swap_ct (mp_int * a, mp_int * b, int c, int m)
{
(void)c;
if (m == 1)
mp_exch(a, b);
return MP_OKAY;
}
/* shift right a certain number of bits */
void mp_rshb (mp_int *c, int x)

View File

@ -4440,6 +4440,41 @@ int mp_montgomery_calc_normalization(mp_int *a, mp_int *b)
#endif /* WOLFSSL_KEYGEN || HAVE_ECC */
static int fp_cond_swap_ct (mp_int * a, mp_int * b, int c, int m)
{
int i;
mp_digit mask = (mp_digit)0 - m;
#ifndef WOLFSSL_SMALL_STACK
fp_int t[1];
#else
fp_int* t;
#endif
#ifdef WOLFSSL_SMALL_STACK
t = (fp_int*)XMALLOC(sizeof(fp_int), NULL, DYNAMIC_TYPE_BIGINT);
if (t == NULL)
return FP_MEM;
#endif
t->used = (a->used ^ b->used) & mask;
for (i = 0; i < c; i++) {
t->dp[i] = (a->dp[i] ^ b->dp[i]) & mask;
}
a->used ^= t->used;
for (i = 0; i < c; i++) {
a->dp[i] ^= t->dp[i];
}
b->used ^= t->used;
for (i = 0; i < c; i++) {
b->dp[i] ^= t->dp[i];
}
#ifdef WOLFSSL_SMALL_STACK
XFREE(t, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return FP_OKAY;
}
#if defined(WC_MP_TO_RADIX) || !defined(NO_DH) || !defined(NO_DSA) || \
!defined(NO_RSA)
@ -4996,6 +5031,11 @@ int mp_prime_is_prime_ex(mp_int* a, int t, int* result, WC_RNG* rng)
#endif /* !NO_RSA || !NO_DSA || !NO_DH || WOLFSSL_KEY_GEN */
int mp_cond_swap_ct(mp_int * a, mp_int * b, int c, int m)
{
return fp_cond_swap_ct(a, b, c, m);
}
#ifdef WOLFSSL_KEY_GEN
static int fp_gcd(fp_int *a, fp_int *b, fp_int *c);

View File

@ -305,6 +305,7 @@ MP_API int mp_div_2d (mp_int * a, int b, mp_int * c, mp_int * d);
MP_API void mp_zero (mp_int * a);
MP_API void mp_clamp (mp_int * a);
MP_API void mp_exch (mp_int * a, mp_int * b);
MP_API int mp_cond_swap_ct (mp_int * a, mp_int * b, int c, int m);
MP_API void mp_rshd (mp_int * a, int b);
MP_API void mp_rshb (mp_int * a, int b);
MP_API int mp_mod_2d (mp_int * a, int b, mp_int * c);

View File

@ -832,6 +832,7 @@ MP_API int mp_lcm(fp_int *a, fp_int *b, fp_int *c);
MP_API int mp_rand_prime(mp_int* N, int len, WC_RNG* rng, void* heap);
MP_API int mp_exch(mp_int *a, mp_int *b);
#endif /* WOLFSSL_KEY_GEN */
MP_API int mp_cond_swap_ct (mp_int * a, mp_int * b, int c, int m);
MP_API int mp_cnt_lsb(fp_int *a);
MP_API int mp_div_2d(fp_int *a, int b, fp_int *c, fp_int *d);