Merge pull request #6955 from SparkiDev/rsa_dec_inv_blind_mul_mont

RSA private exponentiation: multiply blinding invert in Mont
This commit is contained in:
JacobBarthelmeh 2023-11-28 11:08:57 -07:00 committed by GitHub
commit 61a2d2de3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 276 additions and 157 deletions

View File

@ -2495,6 +2495,7 @@ static int RsaFunctionPrivate(mp_int* tmp, RsaKey* key, WC_RNG* rng)
{
int ret = 0;
#if defined(WC_RSA_BLINDING) && !defined(WC_NO_RNG)
mp_digit mp;
DECL_MP_INT_SIZE_DYN(rnd, mp_bitsused(&key->n), RSA_MAX_SIZE);
DECL_MP_INT_SIZE_DYN(rndi, mp_bitsused(&key->n), RSA_MAX_SIZE);
#endif /* WC_RSA_BLINDING && !WC_NO_RNG */
@ -2627,9 +2628,31 @@ static int RsaFunctionPrivate(mp_int* tmp, RsaKey* key, WC_RNG* rng)
#endif /* RSA_LOW_MEM */
#if defined(WC_RSA_BLINDING) && !defined(WC_NO_RNG)
/* unblind */
if (ret == 0 && mp_mulmod(tmp, rndi, &key->n, tmp) != MP_OKAY)
/* Multiply result (tmp) by bliding invertor (rndi).
* Use Montogemery form to make operation more constant time.
*/
if ((ret == 0) && (mp_montgomery_setup(&key->n, &mp) != MP_OKAY)) {
ret = MP_MULMOD_E;
}
if ((ret == 0) && (mp_montgomery_calc_normalization(rnd, &key->n) !=
MP_OKAY)) {
ret = MP_MULMOD_E;
}
/* Convert blinding invert to Montogmery form. */
if ((ret == 0) && (mp_mul(rndi, rnd, rndi) != MP_OKAY)) {
ret = MP_MULMOD_E;
}
if ((ret == 0) && (mp_mod(rndi, &key->n, rndi) != MP_OKAY)) {
ret = MP_MULMOD_E;
}
/* Multiply result by blinding invert. */
if ((ret == 0) && (mp_mul(tmp, rndi, tmp) != MP_OKAY)) {
ret = MP_MULMOD_E;
}
/* Reduce result. */
if ((ret == 0) && (mp_montgomery_reduce_ct(tmp, &key->n, mp) != MP_OKAY)) {
ret = MP_MULMOD_E;
}
mp_forcezero(rndi);
mp_forcezero(rnd);
@ -3520,8 +3543,9 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out,
mgf, label, labelSz, saltLen,
mp_count_bits(&key->n), key->heap);
#endif
if (rsa_type == RSA_PUBLIC_DECRYPT && ret > (int)outLen)
if (rsa_type == RSA_PUBLIC_DECRYPT && ret > (int)outLen) {
ret = RSA_BUFFER_E;
}
else if (ret >= 0 && pad != NULL) {
/* only copy output if not inline */
if (outPtr == NULL) {
@ -3547,8 +3571,9 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out,
XMEMCPY(out, pad, (size_t)ret);
}
}
else
else {
*outPtr = pad;
}
#if !defined(WOLFSSL_RSA_VERIFY_ONLY)
ret = ctMaskSelInt(ctMaskLTE(ret, (int)outLen), ret, RSA_BUFFER_E);

View File

@ -4770,7 +4770,7 @@ WOLFSSL_LOCAL int sp_ModExp_4096(sp_int* base, sp_int* exp, sp_int* mod,
#if defined(WOLFSSL_SP_MATH_ALL) || defined(WOLFSSL_HAVE_SP_DH) || \
defined(OPENSSL_ALL)
static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp);
static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp, int ct);
#endif
#if defined(WOLFSSL_SP_MATH_ALL) || defined(WOLFSSL_HAVE_SP_DH) || \
defined(WOLFCRYPT_HAVE_ECCSI) || defined(WOLFCRYPT_HAVE_SAKKE) || \
@ -7673,6 +7673,28 @@ int sp_submod(const sp_int* a, const sp_int* b, const sp_int* m, sp_int* r)
}
#endif /* WOLFSSL_SP_MATH_ALL */
#if (defined(WOLFSSL_SP_MATH_ALL) && defined(HAVE_ECC)) || \
(defined(WOLFSSL_SP_MATH_ALL) || defined(WOLFSSL_HAVE_SP_DH) || \
defined(WOLFCRYPT_HAVE_ECCSI) || defined(WOLFCRYPT_HAVE_SAKKE) || \
defined(OPENSSL_ALL))
/* Constant time clamping/
*
* @param [in, out] a SP integer to clamp.
*/
static void sp_clamp_ct(sp_int* a)
{
int i;
unsigned int used = a->used;
unsigned int mask = (unsigned int)-1;
for (i = a->used-1; i >= 0; i--) {
used -= ((unsigned int)(a->dp[i] == 0)) & mask;
mask &= (unsigned int)0 - (a->dp[i] == 0);
}
a->used = used;
}
#endif
#if defined(WOLFSSL_SP_MATH_ALL) && defined(HAVE_ECC)
/* Add two value and reduce: r = (a + b) % m
*
@ -7826,7 +7848,7 @@ int sp_addmod_ct(const sp_int* a, const sp_int* b, const sp_int* m, sp_int* r)
r->sign = MP_ZPOS;
#endif /* WOLFSSL_SP_INT_NEGATIVE */
/* Remove leading zeros. */
sp_clamp(r);
sp_clamp_ct(r);
#if 0
sp_print(r, "rma");
@ -7837,8 +7859,121 @@ int sp_addmod_ct(const sp_int* a, const sp_int* b, const sp_int* m, sp_int* r)
}
#endif /* WOLFSSL_SP_MATH_ALL && HAVE_ECC */
#if (defined(WOLFSSL_SP_MATH_ALL) && defined(HAVE_ECC)) || \
(defined(WOLFSSL_SP_MATH_ALL) || defined(WOLFSSL_HAVE_SP_DH) || \
defined(WOLFCRYPT_HAVE_ECCSI) || defined(WOLFCRYPT_HAVE_SAKKE) || \
defined(OPENSSL_ALL))
/* Sub b from a modulo m: r = (a - b) % m
*
* Result is always positive.
*
* Assumes a, b, m and r are not NULL.
* m and r must not be the same pointer.
*
* @param [in] a SP integer to subtract from
* @param [in] b SP integer to subtract.
* @param [in] m SP integer that is the modulus.
* @param [out] r SP integer to hold result.
*
* @return MP_OKAY on success.
*/
static void _sp_submod_ct(const sp_int* a, const sp_int* b, const sp_int* m,
unsigned int max, sp_int* r)
{
#ifndef SQR_MUL_ASM
sp_int_sword w;
#else
sp_int_digit l;
sp_int_digit h;
sp_int_digit t;
#endif
sp_int_digit mask;
sp_int_digit mask_a = (sp_int_digit)-1;
sp_int_digit mask_b = (sp_int_digit)-1;
unsigned int i;
/* In constant time, subtract b from a putting result in r. */
#ifndef SQR_MUL_ASM
w = 0;
#else
l = 0;
h = 0;
#endif
for (i = 0; i < max; i++) {
/* Values past 'used' are not initialized. */
mask_a += (i == a->used);
mask_b += (i == b->used);
#ifndef SQR_MUL_ASM
/* Add a to and subtract b from current value. */
w += a->dp[i] & mask_a;
w -= b->dp[i] & mask_b;
/* Store low digit in result. */
r->dp[i] = (sp_int_digit)w;
/* Move high digit down. */
w >>= DIGIT_BIT;
#else
/* Add a and subtract b from current value. */
t = a->dp[i] & mask_a;
SP_ASM_ADDC_REG(l, h, t);
t = b->dp[i] & mask_b;
SP_ASM_SUBB_REG(l, h, t);
/* Store low digit in result. */
r->dp[i] = l;
/* Move high digit down. */
l = h;
/* High digit is 0 when positive or -1 on negative. */
h = (sp_int_digit)0 - (l >> (SP_WORD_SIZE - 1));
#endif
}
/* When w is negative then we need to add modulus to make result
* positive. */
#ifndef SQR_MUL_ASM
mask = (sp_int_digit)0 - (w < 0);
#else
mask = h;
#endif
/* Constant time, conditionally, add modulus to difference. */
#ifndef SQR_MUL_ASM
w = 0;
#else
l = 0;
#endif
for (i = 0; i < m->used; i++) {
#ifndef SQR_MUL_ASM
/* Add result and conditionally modulus to current value. */
w += r->dp[i];
w += m->dp[i] & mask;
/* Store low digit in result. */
r->dp[i] = (sp_int_digit)w;
/* Move high digit down. */
w >>= DIGIT_BIT;
#else
h = 0;
/* Add result and conditionally modulus to current value. */
SP_ASM_ADDC(l, h, r->dp[i]);
t = m->dp[i] & mask;
SP_ASM_ADDC_REG(l, h, t);
/* Store low digit in result. */
r->dp[i] = l;
/* Move high digit down. */
l = h;
#endif
}
/* Result will always have digits equal to or less than those in
* modulus. */
r->used = i;
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = MP_ZPOS;
#endif /* WOLFSSL_SP_INT_NEGATIVE */
/* Remove leading zeros. */
sp_clamp_ct(r);
}
#endif
#if defined(WOLFSSL_SP_MATH_ALL) && defined(HAVE_ECC)
/* Sub b from a and reduce: r = (a - b) % m
/* Sub b from a modulo m: r = (a - b) % m
* Result is always positive.
*
* r = a - b (mod m) - constant time (a < m and b < m, a, b and m are positive)
@ -7856,17 +7991,6 @@ int sp_addmod_ct(const sp_int* a, const sp_int* b, const sp_int* m, sp_int* r)
int sp_submod_ct(const sp_int* a, const sp_int* b, const sp_int* m, sp_int* r)
{
int err = MP_OKAY;
#ifndef SQR_MUL_ASM
sp_int_sword w;
#else
sp_int_digit l;
sp_int_digit h;
sp_int_digit t;
#endif
sp_int_digit mask;
sp_int_digit mask_a = (sp_int_digit)-1;
sp_int_digit mask_b = (sp_int_digit)-1;
unsigned int i;
/* Check result is as big as modulus plus one digit. */
if (m->used > r->size) {
@ -7884,82 +8008,7 @@ int sp_submod_ct(const sp_int* a, const sp_int* b, const sp_int* m, sp_int* r)
sp_print(m, "m");
#endif
/* In constant time, subtract b from a putting result in r. */
#ifndef SQR_MUL_ASM
w = 0;
#else
l = 0;
h = 0;
#endif
for (i = 0; i < m->used; i++) {
/* Values past 'used' are not initialized. */
mask_a += (i == a->used);
mask_b += (i == b->used);
#ifndef SQR_MUL_ASM
/* Add a to and subtract b from current value. */
w += a->dp[i] & mask_a;
w -= b->dp[i] & mask_b;
/* Store low digit in result. */
r->dp[i] = (sp_int_digit)w;
/* Move high digit down. */
w >>= DIGIT_BIT;
#else
/* Add a and subtract b from current value. */
t = a->dp[i] & mask_a;
SP_ASM_ADDC_REG(l, h, t);
t = b->dp[i] & mask_b;
SP_ASM_SUBB_REG(l, h, t);
/* Store low digit in result. */
r->dp[i] = l;
/* Move high digit down. */
l = h;
/* High digit is 0 when positive or -1 on negative. */
h = (sp_int_digit)0 - (l >> (SP_WORD_SIZE - 1));
#endif
}
/* When w is negative then we need to add modulus to make result
* positive. */
#ifndef SQR_MUL_ASM
mask = (sp_int_digit)0 - (w < 0);
#else
mask = h;
#endif
/* Constant time, conditionally, add modulus to difference. */
#ifndef SQR_MUL_ASM
w = 0;
#else
l = 0;
#endif
for (i = 0; i < m->used; i++) {
#ifndef SQR_MUL_ASM
/* Add result and conditionally modulus to current value. */
w += r->dp[i];
w += m->dp[i] & mask;
/* Store low digit in result. */
r->dp[i] = (sp_int_digit)w;
/* Move high digit down. */
w >>= DIGIT_BIT;
#else
h = 0;
/* Add result and conditionally modulus to current value. */
SP_ASM_ADDC(l, h, r->dp[i]);
t = m->dp[i] & mask;
SP_ASM_ADDC_REG(l, h, t);
/* Store low digit in result. */
r->dp[i] = l;
/* Move high digit down. */
l = h;
#endif
}
/* Result will always have digits equal to or less than those in
* modulus. */
r->used = i;
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = MP_ZPOS;
#endif /* WOLFSSL_SP_INT_NEGATIVE */
/* Remove leading zeros. */
sp_clamp(r);
_sp_submod_ct(a, b, m, m->used, r);
#if 0
sp_print(r, "rms");
@ -12377,14 +12426,14 @@ static int _sp_invmod_mont_ct(const sp_int* a, const sp_int* m, sp_int* r,
_sp_init_size(pre[i], m->used * 2 + 1);
err = sp_sqr(pre[i-1], pre[i]);
if (err == MP_OKAY) {
err = _sp_mont_red(pre[i], m, mp);
err = _sp_mont_red(pre[i], m, mp, 0);
}
/* ..10 -> ..11 */
if (err == MP_OKAY) {
err = sp_mul(pre[i], a, pre[i]);
}
if (err == MP_OKAY) {
err = _sp_mont_red(pre[i], m, mp);
err = _sp_mont_red(pre[i], m, mp, 0);
}
}
}
@ -12438,7 +12487,7 @@ static int _sp_invmod_mont_ct(const sp_int* a, const sp_int* m, sp_int* r,
/* 6.4.2.1. t = (t ^ 2) mod m */
err = sp_sqr(t, t);
if (err == MP_OKAY) {
err = _sp_mont_red(t, m, mp);
err = _sp_mont_red(t, m, mp, 0);
}
}
/* 6.4.3. s = 1 - bit */
@ -12449,7 +12498,7 @@ static int _sp_invmod_mont_ct(const sp_int* a, const sp_int* m, sp_int* r,
err = sp_mul(t, pre[j-1], t);
}
if (err == MP_OKAY) {
err = _sp_mont_red(t, m, mp);
err = _sp_mont_red(t, m, mp, 0);
}
/* 6.4.5. j = 0
* Reset number of 1 bits seen.
@ -12465,7 +12514,7 @@ static int _sp_invmod_mont_ct(const sp_int* a, const sp_int* m, sp_int* r,
/* 7.1. t = (t ^ 2) mod m */
err = sp_sqr(t, t);
if (err == MP_OKAY) {
err = _sp_mont_red(t, m, mp);
err = _sp_mont_red(t, m, mp, 0);
}
}
}
@ -12474,7 +12523,7 @@ static int _sp_invmod_mont_ct(const sp_int* a, const sp_int* m, sp_int* r,
if (j > 0) {
err = sp_mul(t, pre[j-1], r);
if (err == MP_OKAY) {
err = _sp_mont_red(r, m, mp);
err = _sp_mont_red(r, m, mp, 0);
}
}
/* 9. Else r = t */
@ -12887,7 +12936,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits,
t[3]);
err = sp_sqr(t[3], t[3]);
if (err == MP_OKAY) {
err = _sp_mont_red(t[3], m, mp);
err = _sp_mont_red(t[3], m, mp, 0);
}
_sp_copy(t[3],
(sp_int*)(((size_t)t[0] & sp_off_on_addr[s^1]) +
@ -12907,7 +12956,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits,
t[3]);
err = sp_mul(t[3], t[2], t[3]);
if (err == MP_OKAY) {
err = _sp_mont_red(t[3], m, mp);
err = _sp_mont_red(t[3], m, mp, 0);
}
_sp_copy(t[3],
(sp_int*)(((size_t)t[0] & sp_off_on_addr[j^1]) +
@ -12916,7 +12965,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits,
}
if (err == MP_OKAY) {
/* 7. t[1] = FromMont(t[1]) */
err = _sp_mont_red(t[1], m, mp);
err = _sp_mont_red(t[1], m, mp, 0);
/* Reduction implementation returns number to range: 0..m-1. */
}
}
@ -13017,7 +13066,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits,
/* 4.2. t[2] = t[0] * t[1] */
err = sp_mul(t[0], t[1], t[2]);
if (err == MP_OKAY) {
err = _sp_mont_red(t[2], m, mp);
err = _sp_mont_red(t[2], m, mp, 0);
}
/* 4.3. t[3] = t[y] ^ 2 */
if (err == MP_OKAY) {
@ -13027,7 +13076,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits,
err = sp_sqr(t[3], t[3]);
}
if (err == MP_OKAY) {
err = _sp_mont_red(t[3], m, mp);
err = _sp_mont_red(t[3], m, mp, 0);
}
/* 4.4. t[y] = t[3], t[y^1] = t[2] */
if (err == MP_OKAY) {
@ -13037,7 +13086,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits,
if (err == MP_OKAY) {
/* 5. t[0] = FromMont(t[0]) */
err = _sp_mont_red(t[0], m, mp);
err = _sp_mont_red(t[0], m, mp, 0);
/* Reduction implementation returns number to range: 0..m-1. */
}
}
@ -13189,7 +13238,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits,
}
/* Montgomery reduce square or multiplication result. */
if (err == MP_OKAY) {
err = _sp_mont_red(t[i], m, mp);
err = _sp_mont_red(t[i], m, mp, 0);
}
}
@ -13250,7 +13299,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits,
for (j = 0; (j < winBits) && (err == MP_OKAY); j++) {
err = sp_sqr(tr, tr);
if (err == MP_OKAY) {
err = _sp_mont_red(tr, m, mp);
err = _sp_mont_red(tr, m, mp, 0);
}
}
@ -13259,14 +13308,14 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits,
err = sp_mul(tr, t[y], tr);
}
if (err == MP_OKAY) {
err = _sp_mont_red(tr, m, mp);
err = _sp_mont_red(tr, m, mp, 0);
}
}
}
if (err == MP_OKAY) {
/* 7. tr = FromMont(tr) */
err = _sp_mont_red(tr, m, mp);
err = _sp_mont_red(tr, m, mp, 0);
/* Reduction implementation returns number to range: 0..m-1. */
}
}
@ -13475,7 +13524,7 @@ static int _sp_exptmod_base_2(const sp_int* e, int digits, const sp_int* m,
err = sp_sqr(tr, tr);
if (err == MP_OKAY) {
if (useMont) {
err = _sp_mont_red(tr, m, mp);
err = _sp_mont_red(tr, m, mp, 0);
}
else {
err = sp_mod(tr, m, tr);
@ -13501,7 +13550,7 @@ static int _sp_exptmod_base_2(const sp_int* e, int digits, const sp_int* m,
/* 7. if Words(m) > 1 then tr = FromMont(tr) */
if ((err == MP_OKAY) && useMont) {
err = _sp_mont_red(tr, m, mp);
err = _sp_mont_red(tr, m, mp, 0);
/* Reduction implementation returns number to range: 0..m-1. */
}
if (err == MP_OKAY) {
@ -13880,7 +13929,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m,
for (i = 1; (i < winBits) && (err == MP_OKAY); i++) {
err = sp_sqr(t[0], t[0]);
if (err == MP_OKAY) {
err = _sp_mont_red(t[0], m, mp);
err = _sp_mont_red(t[0], m, mp, 0);
}
}
/* For each table entry after first. */
@ -13888,7 +13937,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m,
/* Multiply previous entry by the base in Mont form into table. */
err = sp_mul(t[i-1], bm, t[i]);
if (err == MP_OKAY) {
err = _sp_mont_red(t[i], m, mp);
err = _sp_mont_red(t[i], m, mp, 0);
}
}
@ -13972,7 +14021,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m,
for (; (err == MP_OKAY) && (sqrs > 0); sqrs--) {
err = sp_sqr(tr, tr);
if (err == MP_OKAY) {
err = _sp_mont_red(tr, m, mp);
err = _sp_mont_red(tr, m, mp, 0);
}
}
@ -14013,7 +14062,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m,
err = sp_mul(tr, t[y], tr);
}
if (err == MP_OKAY) {
err = _sp_mont_red(tr, m, mp);
err = _sp_mont_red(tr, m, mp, 0);
}
}
@ -14027,7 +14076,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m,
/* 5.1. Montogmery square result */
err = sp_sqr(tr, tr);
if (err == MP_OKAY) {
err = _sp_mont_red(tr, m, mp);
err = _sp_mont_red(tr, m, mp, 0);
}
/* 5.2. If exponent bit set */
if ((err == MP_OKAY) && ((n >> c) & 1)) {
@ -14036,7 +14085,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m,
*/
err = sp_mul(tr, bm, tr);
if (err == MP_OKAY) {
err = _sp_mont_red(tr, m, mp);
err = _sp_mont_red(tr, m, mp, 0);
}
}
}
@ -14045,7 +14094,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m,
if (err == MP_OKAY) {
/* 6. Convert result back from Montgomery form. */
err = _sp_mont_red(tr, m, mp);
err = _sp_mont_red(tr, m, mp, 0);
/* Reduction implementation returns number to range: 0..m-1. */
}
}
@ -14141,7 +14190,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m,
/* 3.1. Montgomery square result. */
err = sp_sqr(t[0], t[0]);
if (err == MP_OKAY) {
err = _sp_mont_red(t[0], m, mp);
err = _sp_mont_red(t[0], m, mp, 0);
}
if (err == MP_OKAY) {
/* Get bit and index i. */
@ -14151,14 +14200,14 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m,
/* 3.2.1. Montgomery multiply result by Mont of base. */
err = sp_mul(t[0], t[1], t[0]);
if (err == MP_OKAY) {
err = _sp_mont_red(t[0], m, mp);
err = _sp_mont_red(t[0], m, mp, 0);
}
}
}
}
if (err == MP_OKAY) {
/* 4. Convert from Montgomery form. */
err = _sp_mont_red(t[0], m, mp);
err = _sp_mont_red(t[0], m, mp, 0);
/* Reduction implementation returns number of range 0..m-1. */
}
}
@ -16995,10 +17044,11 @@ int sp_sqrmod(const sp_int* a, const sp_int* m, sp_int* r)
* @param [in,out] a SP integer to Montgomery reduce.
* @param [in] m SP integer that is the modulus.
* @param [in] mp SP integer digit that is the bottom digit of inv(-m).
* @param [in] ct Indicates operation must be constant time.
*
* @return MP_OKAY on success.
*/
static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp)
static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp, int ct)
{
#if !defined(SQR_MUL_ASM)
unsigned int i;
@ -17015,8 +17065,15 @@ static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp)
bits = sp_count_bits(m);
/* Adding numbers into m->used * 2 digits - zero out unused digits. */
for (i = a->used; i < m->used * 2; i++) {
a->dp[i] = 0;
if (!ct) {
for (i = a->used; i < m->used * 2; i++) {
a->dp[i] = 0;
}
}
else {
for (i = 0; i < m->used * 2; i++) {
a->dp[i] &= (sp_int_digit)(sp_int_sdigit)ctMaskIntGTE(a->used-1, i);
}
}
/* Special case when modulus is 1 digit or less. */
@ -17087,15 +17144,28 @@ static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp)
a->used = m->used * 2 + 1;
}
/* Remove leading zeros. */
sp_clamp(a);
/* 3. a >>= NumBits(m) */
(void)sp_rshb(a, bits, a);
/* 4. a = a mod m */
if (_sp_cmp_abs(a, m) != MP_LT) {
_sp_sub_off(a, m, a, 0);
if (!ct) {
/* Remove leading zeros. */
sp_clamp(a);
/* 3. a >>= NumBits(m) */
(void)sp_rshb(a, bits, a);
/* 4. a = a mod m */
if (_sp_cmp_abs(a, m) != MP_LT) {
_sp_sub_off(a, m, a, 0);
}
}
else {
/* 3. a >>= NumBits(m) */
(void)sp_rshb(a, bits, a);
/* Constant time clamping. */
sp_clamp_ct(a);
/* 4. a = a mod m
* Always subtract but at a too high offset if a is less than m.
*/
_sp_submod_ct(a, m, m, m->used + 1, a);
}
#if 0
sp_print(a, "rr");
@ -17118,8 +17188,15 @@ static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp)
bits = sp_count_bits(m);
mask = ((sp_int_digit)1 << (bits & (SP_WORD_SIZE - 1))) - 1;
for (i = a->used; i < m->used * 2; i++) {
a->dp[i] = 0;
if (!ct) {
for (i = a->used; i < m->used * 2; i++) {
a->dp[i] = 0;
}
}
else {
for (i = 0; i < m->used * 2; i++) {
a->dp[i] &= (sp_int_digit)(sp_int_sdigit)ctMaskIntGTE(a->used-1, i);
}
}
if (m->used <= 1) {
@ -17398,13 +17475,21 @@ static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp)
a->used = m->used * 2 + 1;
}
/* Remove leading zeros. */
sp_clamp(a);
(void)sp_rshb(a, bits, a);
if (!ct) {
/* Remove leading zeros. */
sp_clamp(a);
(void)sp_rshb(a, bits, a);
/* a = a mod m */
if (_sp_cmp_abs(a, m) != MP_LT) {
_sp_sub_off(a, m, a, 0);
}
}
else {
(void)sp_rshb(a, bits, a);
/* Constant time clamping. */
sp_clamp_ct(a);
/* a = a mod m */
if (_sp_cmp_abs(a, m) != MP_LT) {
_sp_sub_off(a, m, a, 0);
_sp_submod_ct(a, m, m, m->used + 1, a);
}
#if 0
@ -17422,11 +17507,12 @@ static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp)
* @param [in,out] a SP integer to Montgomery reduce.
* @param [in] m SP integer that is the modulus.
* @param [in] mp SP integer digit that is the bottom digit of inv(-m).
* @param [in] ct Indicates operation must be constant time.
*
* @return MP_OKAY on success.
* @return MP_VAL when a or m is NULL or m is zero.
*/
int sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp)
int sp_mont_red_ex(sp_int* a, const sp_int* m, sp_int_digit mp, int ct)
{
int err;
@ -17440,7 +17526,7 @@ int sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp)
}
else {
/* Perform Montogomery Reduction. */
err = _sp_mont_red(a, m, mp);
err = _sp_mont_red(a, m, mp, ct);
}
return err;

View File

@ -6049,15 +6049,8 @@ int mp_read_radix(mp_int *a, const char *str, int radix)
#endif /* !defined(NO_DSA) || defined(HAVE_ECC) */
#ifdef HAVE_ECC
#if defined(HAVE_ECC) || (!defined(NO_RSA) && defined(WC_RSA_BLINDING))
/* fast math conversion */
int mp_sqr(fp_int *A, fp_int *B)
{
return fp_sqr(A, B);
}
/* fast math conversion */
int mp_montgomery_reduce(fp_int *a, fp_int *m, fp_digit mp)
{
return fp_montgomery_reduce(a, m, mp);
@ -6075,6 +6068,17 @@ int mp_montgomery_setup(fp_int *a, fp_digit *rho)
return fp_montgomery_setup(a, rho);
}
#endif /* HAVE_ECC || (!NO_RSA && WC_RSA_BLINDING) */
#ifdef HAVE_ECC
/* fast math conversion */
int mp_sqr(fp_int *A, fp_int *B)
{
return fp_sqr(A, B);
}
/* fast math conversion */
int mp_div_2(fp_int * a, fp_int * b)
{
fp_div_2(a, b);

View File

@ -366,6 +366,7 @@ MP_API int mp_montgomery_setup (mp_int * n, mp_digit * rho);
int fast_mp_montgomery_reduce (mp_int * x, mp_int * n, mp_digit rho);
MP_API int mp_montgomery_reduce (mp_int * x, mp_int * n, mp_digit rho);
#define mp_montgomery_reduce_ex(x, n, rho, ct) mp_montgomery_reduce (x, n, rho)
#define mp_montgomery_reduce_ct(x, n, rho) mp_montgomery_reduce (x, n, rho)
MP_API void mp_dr_setup(mp_int *a, mp_digit *d);
MP_API int mp_dr_reduce (mp_int * x, mp_int * n, mp_digit k);
MP_API int mp_reduce_2k(mp_int *a, mp_int *n, mp_digit d);

View File

@ -1037,7 +1037,8 @@ MP_API int sp_mul_2d(const sp_int* a, int e, sp_int* r);
MP_API int sp_sqr(const sp_int* a, sp_int* r);
MP_API int sp_sqrmod(const sp_int* a, const sp_int* m, sp_int* r);
MP_API int sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp);
MP_API int sp_mont_red_ex(sp_int* a, const sp_int* m, sp_int_digit mp, int ct);
#define sp_mont_red(a, m, mp) sp_mont_red_ex(a, m, mp, 0)
MP_API int sp_mont_setup(const sp_int* m, sp_int_digit* rho);
MP_API int sp_mont_norm(sp_int* norm, const sp_int* m);
@ -1085,7 +1086,8 @@ WOLFSSL_LOCAL void sp_memzero_check(sp_int* sp);
#define mp_div_3(a, r, rem) sp_div_d(a, 3, r, rem)
#define mp_rshb(A,x) sp_rshb(A,x,A)
#define mp_is_bit_set(a,b) sp_is_bit_set(a,(unsigned int)(b))
#define mp_montgomery_reduce sp_mont_red
#define mp_montgomery_reduce(a, m, mp) sp_mont_red_ex(a, m, mp, 0)
#define mp_montgomery_reduce_ct(a, m, mp) sp_mont_red_ex(a, m, mp, 1)
#define mp_montgomery_setup sp_mont_setup
#define mp_montgomery_calc_normalization sp_mont_norm

View File

@ -871,12 +871,13 @@ MP_API int mp_radix_size (mp_int * a, int radix, int *size);
MP_API int mp_read_radix(mp_int* a, const char* str, int radix);
#endif
#define mp_montgomery_reduce_ct(a, m, mp) \
mp_montgomery_reduce_ex(a, m, mp, 1)
MP_API int mp_montgomery_reduce(fp_int *a, fp_int *m, fp_digit mp);
MP_API int mp_montgomery_reduce_ex(fp_int *a, fp_int *m, fp_digit mp, int ct);
MP_API int mp_montgomery_setup(fp_int *a, fp_digit *rho);
#ifdef HAVE_ECC
MP_API int mp_sqr(fp_int *a, fp_int *b);
MP_API int mp_montgomery_reduce(fp_int *a, fp_int *m, fp_digit mp);
MP_API int mp_montgomery_reduce_ex(fp_int *a, fp_int *m, fp_digit mp,
int ct);
MP_API int mp_montgomery_setup(fp_int *a, fp_digit *rho);
MP_API int mp_div_2(fp_int * a, fp_int * b);
MP_API int mp_div_2_mod_ct(mp_int *a, mp_int *b, mp_int *c);
#endif