fix ecc sign/hash truncation with odd bit sizes when hash length is longer than key size

This commit is contained in:
toddouska 2013-07-25 15:59:09 -07:00
parent 55401c13dd
commit 505b1a8a67
5 changed files with 105 additions and 63 deletions

View File

@ -1144,10 +1144,17 @@ int ecc_sign_hash(const byte* in, word32 inlen, byte* out, word32 *outlen,
err = mp_read_radix(&p, (char *)key->dp->order, 16);
if (err == MP_OKAY) {
int truncLen = (int)inlen;
if (truncLen > ecc_size(key))
truncLen = ecc_size(key);
err = mp_read_unsigned_bin(&e, (byte*)in, truncLen);
/* we may need to truncate if hash is longer than key size */
word32 orderBits = mp_count_bits(&p);
/* truncate down to byte size, may be all that's needed */
if ( (CYASSL_BIT_SIZE * inlen) > orderBits)
inlen = (orderBits + CYASSL_BIT_SIZE - 1)/CYASSL_BIT_SIZE;
err = mp_read_unsigned_bin(&e, (byte*)in, inlen);
/* may still need bit truncation too */
if (err == MP_OKAY && (CYASSL_BIT_SIZE * inlen) > orderBits)
mp_rshb(&e, CYASSL_BIT_SIZE - (orderBits & 0x7));
}
/* make up a key and export the public copy */
@ -1311,10 +1318,17 @@ int ecc_verify_hash(const byte* sig, word32 siglen, byte* hash, word32 hashlen,
}
/* read hash */
if (err == MP_OKAY) {
int truncLen = (int)hashlen;
if (truncLen > ecc_size(key))
truncLen = ecc_size(key);
err = mp_read_unsigned_bin(&e, (byte*)hash, truncLen);
/* we may need to truncate if hash is longer than key size */
unsigned int orderBits = mp_count_bits(&p);
/* truncate down to byte size, may be all that's needed */
if ( (CYASSL_BIT_SIZE * hashlen) > orderBits)
hashlen = (orderBits + CYASSL_BIT_SIZE - 1)/CYASSL_BIT_SIZE;
err = mp_read_unsigned_bin(&e, (byte*)hash, hashlen);
/* may still need bit truncation too */
if (err == MP_OKAY && (CYASSL_BIT_SIZE * hashlen) > orderBits)
mp_rshb(&e, CYASSL_BIT_SIZE - (orderBits & 0x7));
}
/* w = s^-1 mod n */

View File

@ -328,8 +328,7 @@ bn_reverse (unsigned char *s, int len)
remainder in d) */
int mp_div_2d (mp_int * a, int b, mp_int * c, mp_int * d)
{
mp_digit D, r, rr;
int x, res;
int D, res;
mp_int t;
@ -366,33 +365,9 @@ int mp_div_2d (mp_int * a, int b, mp_int * c, mp_int * d)
}
/* shift any bit count < DIGIT_BIT */
D = (mp_digit) (b % DIGIT_BIT);
D = (b % DIGIT_BIT);
if (D != 0) {
register mp_digit *tmpc, mask, shift;
/* mask */
mask = (((mp_digit)1) << D) - 1;
/* shift for lsb */
shift = DIGIT_BIT - D;
/* alias */
tmpc = c->dp + (c->used - 1);
/* carry */
r = 0;
for (x = c->used - 1; x >= 0; x--) {
/* get the lower bits of this word in a temp */
rr = *tmpc & mask;
/* shift the current word and mix in the carry bits from the previous
word */
*tmpc = (*tmpc >> D) | (r << shift);
--tmpc;
/* set the carry to the carry bits of the current word found above */
r = rr;
}
mp_rshb(c, D);
}
mp_clamp (c);
if (d != NULL) {
@ -457,6 +432,38 @@ mp_exch (mp_int * a, mp_int * b)
}
/* shift right a certain number of bits */
void mp_rshb (mp_int *c, int x)
{
register mp_digit *tmpc, mask, shift;
mp_digit r, rr;
mp_digit D = x;
/* mask */
mask = (((mp_digit)1) << D) - 1;
/* shift for lsb */
shift = DIGIT_BIT - D;
/* alias */
tmpc = c->dp + (c->used - 1);
/* carry */
r = 0;
for (x = c->used - 1; x >= 0; x--) {
/* get the lower bits of this word in a temp */
rr = *tmpc & mask;
/* shift the current word and mix in the carry bits from previous word */
*tmpc = (*tmpc >> D) | (r << shift);
--tmpc;
/* set the carry to the carry bits of the current word found above */
r = rr;
}
}
/* shift right a certain amount of digits */
void mp_rshd (mp_int * a, int b)
{

View File

@ -641,8 +641,7 @@ void fp_div_2(fp_int * a, fp_int * b)
/* c = a / 2**b */
void fp_div_2d(fp_int *a, int b, fp_int *c, fp_int *d)
{
fp_digit D, r, rr;
int x;
int D;
fp_int t;
/* if the shift count is <= 0 then we do no work */
@ -670,32 +669,9 @@ void fp_div_2d(fp_int *a, int b, fp_int *c, fp_int *d)
}
/* shift any bit count < DIGIT_BIT */
D = (fp_digit) (b % DIGIT_BIT);
D = (b % DIGIT_BIT);
if (D != 0) {
register fp_digit *tmpc, mask, shift;
/* mask */
mask = (((fp_digit)1) << D) - 1;
/* shift for lsb */
shift = DIGIT_BIT - D;
/* alias */
tmpc = c->dp + (c->used - 1);
/* carry */
r = 0;
for (x = c->used - 1; x >= 0; x--) {
/* get the lower bits of this word in a temp */
rr = *tmpc & mask;
/* shift the current word and mix in the carry bits from the previous word */
*tmpc = (*tmpc >> D) | (r << shift);
--tmpc;
/* set the carry to the carry bits of the current word found above */
r = rr;
}
fp_rshb(c, D);
}
fp_clamp (c);
if (d != NULL) {
@ -1754,6 +1730,39 @@ void fp_lshd(fp_int *a, int x)
fp_clamp(a);
}
/* right shift by bit count */
void fp_rshb(fp_int *c, int x)
{
register fp_digit *tmpc, mask, shift;
fp_digit r, rr;
fp_digit D = x;
/* mask */
mask = (((fp_digit)1) << D) - 1;
/* shift for lsb */
shift = DIGIT_BIT - D;
/* alias */
tmpc = c->dp + (c->used - 1);
/* carry */
r = 0;
for (x = c->used - 1; x >= 0; x--) {
/* get the lower bits of this word in a temp */
rr = *tmpc & mask;
/* shift the current word and mix in the carry bits from previous word */
*tmpc = (*tmpc >> D) | (r << shift);
--tmpc;
/* set the carry to the carry bits of the current word found above */
r = rr;
}
}
void fp_rshd(fp_int *a, int x)
{
int y;
@ -1959,6 +1968,13 @@ int mp_count_bits (mp_int* a)
}
/* fast math conversion */
void mp_rshb (mp_int* a, int x)
{
fp_rshb(a, x);
}
/* fast math wrappers */
int mp_set_int(fp_int *a, fp_digit b)
{

View File

@ -233,6 +233,7 @@ void mp_zero (mp_int * a);
void mp_clamp (mp_int * a);
void mp_exch (mp_int * a, mp_int * b);
void mp_rshd (mp_int * a, int b);
void mp_rshb (mp_int * a, int b);
int mp_mod_2d (mp_int * a, int b, mp_int * c);
int mp_mul_2d (mp_int * a, int b, mp_int * c);
int mp_lshd (mp_int * a, int b);

View File

@ -372,6 +372,9 @@ void fp_set(fp_int *a, fp_digit b);
/* right shift x digits */
void fp_rshd(fp_int *a, int x);
/* right shift x bits */
void fp_rshb(fp_int *a, int x);
/* left shift x digits */
void fp_lshd(fp_int *a, int x);
@ -653,6 +656,7 @@ int mp_isodd(mp_int* a);
int mp_iszero(mp_int* a);
int mp_count_bits(mp_int *a);
int mp_set_int(fp_int *a, fp_digit b);
void mp_rshb(mp_int *a, int x);
#ifdef HAVE_ECC
int mp_read_radix(mp_int* a, const char* str, int radix);