Fix performance of RSA public key ops with TFM

Have a constant and non-constant time modular exponentation available in
tfm.c.
Call the non-constant time version explicitly when performing RSA public
key mod exp.
This commit is contained in:
Sean Parkinson 2020-03-26 16:24:41 +10:00
parent 93fd1b1eeb
commit c82531a41a
5 changed files with 108 additions and 52 deletions

View File

@ -2256,7 +2256,7 @@ static int wc_RsaFunctionSync(const byte* in, word32 inLen, byte* out,
#ifdef WOLFSSL_XILINX_CRYPT
ret = wc_RsaFunctionXil(in, inLen, out, outLen, type, key, rng);
#else
if (mp_exptmod(tmp, &key->e, &key->n, tmp) != MP_OKAY)
if (mp_exptmod_nct(tmp, &key->e, &key->n, tmp) != MP_OKAY)
ret = MP_EXPTMOD_E;
#endif
break;

View File

@ -1564,7 +1564,8 @@ int fp_exptmod_nb(exptModNb_t* nb, fp_int* G, fp_int* X, fp_int* P, fp_int* Y)
Based on work by Marc Joye, Sung-Ming Yen, "The Montgomery Powering Ladder",
Cryptographic Hardware and Embedded Systems, CHES 2002
*/
static int _fp_exptmod(fp_int * G, fp_int * X, int digits, fp_int * P, fp_int * Y)
static int _fp_exptmod_ct(fp_int * G, fp_int * X, int digits, fp_int * P,
fp_int * Y)
{
#ifndef WOLFSSL_SMALL_STACK
#ifdef WC_NO_CACHE_RESISTANT
@ -1701,25 +1702,17 @@ static int _fp_exptmod(fp_int * G, fp_int * X, int digits, fp_int * P, fp_int *
return err;
}
#else /* TFM_TIMING_RESISTANT */
#endif /* TFM_TIMING_RESISTANT */
/* y = g**x (mod b)
* Some restrictions... x must be positive and < b
*/
static int _fp_exptmod(fp_int * G, fp_int * X, int digits, fp_int * P,
fp_int * Y)
static int _fp_exptmod_nct(fp_int * G, fp_int * X, fp_int * P, fp_int * Y)
{
fp_digit buf, mp;
int err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
#ifdef WOLFSSL_SMALL_STACK
fp_int *res;
fp_int *M;
#else
fp_int res[1];
fp_int M[64];
#endif
(void)digits;
fp_digit buf, mp;
int err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
/* find window size */
x = fp_count_bits (X);
@ -1740,14 +1733,13 @@ static int _fp_exptmod(fp_int * G, fp_int * X, int digits, fp_int * P,
return err;
}
#ifdef WOLFSSL_SMALL_STACK
/* only allocate space for what's needed for window plus res */
M = (fp_int*)XMALLOC(sizeof(fp_int)*((1 << winsize) + 1), NULL, DYNAMIC_TYPE_BIGINT);
M = (fp_int*)XMALLOC(sizeof(fp_int)*((1 << winsize) + 1), NULL,
DYNAMIC_TYPE_BIGINT);
if (M == NULL) {
return FP_MEM;
}
res = &M[1 << winsize];
#endif
/* init M array */
for(x = 0; x < (1 << winsize); x++)
@ -1782,9 +1774,7 @@ static int _fp_exptmod(fp_int * G, fp_int * X, int digits, fp_int * P,
fp_sqr (&M[1 << (winsize - 1)], &M[1 << (winsize - 1)]);
err = fp_montgomery_reduce (&M[1 << (winsize - 1)], P, mp);
if (err != FP_OKAY) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return err;
}
}
@ -1793,23 +1783,19 @@ static int _fp_exptmod(fp_int * G, fp_int * X, int digits, fp_int * P,
for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
err = fp_mul(&M[x - 1], &M[1], &M[x]);
if (err != FP_OKAY) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return err;
}
err = fp_montgomery_reduce(&M[x], P, mp);
if (err != FP_OKAY) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return err;
}
}
/* set initial mode and bit cnt */
mode = 0;
bitcnt = 1;
bitcnt = (x % DIGIT_BIT) + 1;
buf = 0;
digidx = X->used - 1;
bitcpy = 0;
@ -1844,16 +1830,12 @@ static int _fp_exptmod(fp_int * G, fp_int * X, int digits, fp_int * P,
if (mode == 1 && y == 0) {
err = fp_sqr(res, res);
if (err != FP_OKAY) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return err;
}
fp_montgomery_reduce(res, P, mp);
if (err != FP_OKAY) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return err;
}
continue;
@ -1869,16 +1851,12 @@ static int _fp_exptmod(fp_int * G, fp_int * X, int digits, fp_int * P,
for (x = 0; x < winsize; x++) {
err = fp_sqr(res, res);
if (err != FP_OKAY) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return err;
}
err = fp_montgomery_reduce(res, P, mp);
if (err != FP_OKAY) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return err;
}
}
@ -1886,16 +1864,12 @@ static int _fp_exptmod(fp_int * G, fp_int * X, int digits, fp_int * P,
/* then multiply */
err = fp_mul(res, &M[bitbuf], res);
if (err != FP_OKAY) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return err;
}
err = fp_montgomery_reduce(res, P, mp);
if (err != FP_OKAY) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return err;
}
@ -1912,16 +1886,12 @@ static int _fp_exptmod(fp_int * G, fp_int * X, int digits, fp_int * P,
for (x = 0; x < bitcpy; x++) {
err = fp_sqr(res, res);
if (err != FP_OKAY) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return err;
}
err = fp_montgomery_reduce(res, P, mp);
if (err != FP_OKAY) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return err;
}
@ -1931,16 +1901,12 @@ static int _fp_exptmod(fp_int * G, fp_int * X, int digits, fp_int * P,
/* then multiply */
err = fp_mul(res, &M[1], res);
if (err != FP_OKAY) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return err;
}
err = fp_montgomery_reduce(res, P, mp);
if (err != FP_OKAY) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return err;
}
}
@ -1958,14 +1924,10 @@ static int _fp_exptmod(fp_int * G, fp_int * X, int digits, fp_int * P,
/* swap res with Y */
fp_copy (res, Y);
#ifdef WOLFSSL_SMALL_STACK
XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return err;
}
#endif /* TFM_TIMING_RESISTANT */
#ifdef TFM_TIMING_RESISTANT
#if DIGIT_BIT <= 16
@ -2331,7 +2293,11 @@ int fp_exptmod(fp_int * G, fp_int * X, fp_int * P, fp_int * Y)
if (err == FP_OKAY) {
fp_copy(X, &tmp[1]);
tmp[1].sign = FP_ZPOS;
err = _fp_exptmod(&tmp[0], &tmp[1], tmp[1].used, P, Y);
#ifdef TFM_TIMING_RESISTANT
err = _fp_exptmod_ct(&tmp[0], &tmp[1], tmp[1].used, P, Y);
#else
err = _fp_exptmod_nct(&tmp[0], &tmp[1], P, Y);
#endif
if (P->sign == FP_NEG) {
fp_add(Y, P, Y);
}
@ -2349,7 +2315,11 @@ int fp_exptmod(fp_int * G, fp_int * X, fp_int * P, fp_int * Y)
}
else {
/* Positive exponent so just exptmod */
return _fp_exptmod(G, X, X->used, P, Y);
#ifdef TFM_TIMING_RESISTANT
return _fp_exptmod_ct(G, X, X->used, P, Y);
#else
return _fp_exptmod_nct(G, X, P, Y);
#endif
}
}
@ -2400,7 +2370,12 @@ int fp_exptmod_ex(fp_int * G, fp_int * X, int digits, fp_int * P, fp_int * Y)
err = fp_invmod(&tmp[0], &tmp[1], &tmp[0]);
if (err == FP_OKAY) {
X->sign = FP_ZPOS;
err = _fp_exptmod(&tmp[0], X, digits, P, Y);
#ifdef TFM_TIMING_RESISTANT
err = _fp_exptmod_ct(&tmp[0], X, digits, P, Y);
#else
err = _fp_exptmod_nct(&tmp[0], X, P, Y);
(void)digits;
#endif
if (X != Y) {
X->sign = FP_NEG;
}
@ -2418,10 +2393,81 @@ int fp_exptmod_ex(fp_int * G, fp_int * X, int digits, fp_int * P, fp_int * Y)
}
else {
/* Positive exponent so just exptmod */
return _fp_exptmod(G, X, digits, P, Y);
#ifdef TFM_TIMING_RESISTANT
return _fp_exptmod_ct(G, X, digits, P, Y);
#else
return _fp_exptmod_nct(G, X, P, Y);
#endif
}
}
int fp_exptmod_nct(fp_int * G, fp_int * X, fp_int * P, fp_int * Y)
{
#if defined(WOLFSSL_ESP32WROOM32_CRYPT_RSA_PRI) && \
!defined(NO_WOLFSSL_ESP32WROOM32_CRYPT_RSA_PRI)
int x = fp_count_bits (X);
#endif
if (fp_iszero(G)) {
fp_set(G, 0);
return FP_OKAY;
}
/* prevent overflows */
if (P->used > (FP_SIZE/2)) {
return FP_VAL;
}
#if defined(WOLFSSL_ESP32WROOM32_CRYPT_RSA_PRI) && \
!defined(NO_WOLFSSL_ESP32WROOM32_CRYPT_RSA_PRI)
if(x > EPS_RSA_EXPT_XBTIS) {
return esp_mp_exptmod(G, X, x, P, Y);
}
#endif
if (X->sign == FP_NEG) {
#ifndef POSITIVE_EXP_ONLY /* reduce stack if assume no negatives */
int err;
#ifndef WOLFSSL_SMALL_STACK
fp_int tmp[2];
#else
fp_int *tmp;
#endif
#ifdef WOLFSSL_SMALL_STACK
tmp = (fp_int*)XMALLOC(sizeof(fp_int) * 2, NULL, DYNAMIC_TYPE_TMP_BUFFER);
if (tmp == NULL)
return FP_MEM;
#endif
/* yes, copy G and invmod it */
fp_init_copy(&tmp[0], G);
fp_init_copy(&tmp[1], P);
tmp[1].sign = FP_ZPOS;
err = fp_invmod(&tmp[0], &tmp[1], &tmp[0]);
if (err == FP_OKAY) {
X->sign = FP_ZPOS;
err = _fp_exptmod_nct(&tmp[0], X, P, Y);
if (X != Y) {
X->sign = FP_NEG;
}
if (P->sign == FP_NEG) {
fp_add(Y, P, Y);
}
}
#ifdef WOLFSSL_SMALL_STACK
XFREE(tmp, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return err;
#else
return FP_VAL;
#endif
}
else {
/* Positive exponent so just exptmod */
return _fp_exptmod_nct(G, X, P, Y);
}
}
/* computes a = 2**b */
void fp_2expt(fp_int *a, int b)
@ -3639,6 +3685,12 @@ int mp_exptmod_ex (mp_int * G, mp_int * X, int digits, mp_int * P, mp_int * Y)
return fp_exptmod_ex(G, X, digits, P, Y);
}
int mp_exptmod_nct (mp_int * G, mp_int * X, mp_int * P, mp_int * Y)
{
return fp_exptmod_nct(G, X, P, Y);
}
/* compare two ints (signed)*/
int mp_cmp (mp_int * a, mp_int * b)
{

View File

@ -328,6 +328,7 @@ MP_API int mp_dr_is_modulus(mp_int *a);
MP_API int mp_exptmod_fast (mp_int * G, mp_int * X, mp_int * P, mp_int * Y,
int);
MP_API int mp_exptmod_base_2 (mp_int * X, mp_int * P, mp_int * Y);
#define mp_exptmod_nct(G,X,P,Y) mp_exptmod_fast(G,X,P,Y,0)
MP_API int mp_montgomery_setup (mp_int * n, mp_digit * rho);
int fast_mp_montgomery_reduce (mp_int * x, mp_int * n, mp_digit rho);
MP_API int mp_montgomery_reduce (mp_int * x, mp_int * n, mp_digit rho);

View File

@ -272,6 +272,7 @@ MP_API int sp_mul_d(sp_int* a, sp_int_digit n, sp_int* r);
#define mp_invmod sp_invmod
#define mp_lcm sp_lcm
#define mp_exptmod sp_exptmod
#define mp_exptmod_nct sp_exptmod
#define mp_prime_is_prime sp_prime_is_prime
#define mp_prime_is_prime_ex sp_prime_is_prime_ex
#define mp_exch sp_exch

View File

@ -557,6 +557,7 @@ int fp_montgomery_reduce(fp_int *a, fp_int *m, fp_digit mp);
/* d = a**b (mod c) */
int fp_exptmod(fp_int *a, fp_int *b, fp_int *c, fp_int *d);
int fp_exptmod_ex(fp_int *a, fp_int *b, int minDigits, fp_int *c, fp_int *d);
int fp_exptmod_nct(fp_int *a, fp_int *b, fp_int *c, fp_int *d);
#ifdef WC_RSA_NONBLOCK
@ -748,6 +749,7 @@ MP_API int mp_invmod_mont_ct(mp_int *a, mp_int *b, mp_int *c, fp_digit mp);
MP_API int mp_exptmod (mp_int * g, mp_int * x, mp_int * p, mp_int * y);
MP_API int mp_exptmod_ex (mp_int * g, mp_int * x, int minDigits, mp_int * p,
mp_int * y);
MP_API int mp_exptmod_nct (mp_int * g, mp_int * x, mp_int * p, mp_int * y);
MP_API int mp_mul_2d(mp_int *a, int b, mp_int *c);
MP_API int mp_2expt(mp_int* a, int b);