Merge pull request from SparkiDev/pss_var_salt_len

Add support in PSS for salt lengths up to hash length
This commit is contained in:
John Safranek 2018-01-10 11:00:47 -08:00 committed by GitHub
commit 32a345e2f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 447 additions and 53 deletions
src
wolfcrypt
wolfssl/wolfcrypt

@ -4145,7 +4145,7 @@ static int CreateECCEncodedSig(byte* sigData, int sigDataSz, int hashAlgo)
* based on the digest of the signature data.
*
* ssl The SSL/TLS object.
* hashAlgo The signature algorithm used to generate signature.
* sigAlgo The signature algorithm used to generate signature.
* hashAlgo The hash algorithm used to generate signature.
* decSig The decrypted signature.
* decSigSz The size of the decrypted signature.
@ -4170,7 +4170,7 @@ static int CheckRSASignature(WOLFSSL* ssl, int sigAlgo, int hashAlgo,
if (ret < 0)
return ret;
/* PSS signature can be done in-pace */
/* PSS signature can be done in-place */
ret = CreateRSAEncodedSig(sigData, sigData, sigDataSz,
sigAlgo, hashAlgo);
if (ret < 0)

@ -440,6 +440,9 @@ const char* wc_GetErrorString(int error)
case WC_HW_WAIT_E:
return "Hardware waiting on resource";
case PSS_SALTLEN_E:
return "PSS - Length of salt is too big for hash algorithm";
default:
return "unknown error number";

@ -702,10 +702,23 @@ static int RsaPad_OAEP(const byte* input, word32 inputLen, byte* pkcsBlock,
/* 0x00 .. 0x00 0x01 | Salt | Gen Hash | 0xbc
* XOR MGF over all bytes down to end of Salt
* Gen Hash = HASH(8 * 0x00 | Message Hash | Salt)
*
* input Digest of the message.
* inputLen Length of digest.
* pkcsBlock Buffer to write to.
* pkcsBlockLen Length of buffer to write to.
* rng Random number generator (for salt).
* htype Hash function to use.
* mgf Mask generation function.
* saltLen Length of salt to put in padding.
* bits Length of key in bits.
* heap Used for dynamic memory allocation.
* returns 0 on success, PSS_SALTLEN_E when the salt length is invalid
* and other negative values on error.
*/
static int RsaPad_PSS(const byte* input, word32 inputLen, byte* pkcsBlock,
word32 pkcsBlockLen, WC_RNG* rng, enum wc_HashType hType, int mgf,
int bits, void* heap)
int saltLen, int bits, void* heap)
{
int ret;
int hLen, i;
@ -718,15 +731,22 @@ static int RsaPad_PSS(const byte* input, word32 inputLen, byte* pkcsBlock,
if (hLen < 0)
return hLen;
if (saltLen == -1)
saltLen = hLen;
else if (saltLen > hLen || saltLen < -1)
return PSS_SALTLEN_E;
if ((int)pkcsBlockLen - hLen - 1 < saltLen + 2)
return PSS_SALTLEN_E;
s = m = pkcsBlock;
XMEMSET(m, 0, 8);
m += 8;
XMEMSET(m, 0, RSA_PSS_PAD_SZ);
m += RSA_PSS_PAD_SZ;
XMEMCPY(m, input, inputLen);
m += inputLen;
if ((ret = wc_RNG_GenerateBlock(rng, salt, hLen)) != 0)
if ((ret = wc_RNG_GenerateBlock(rng, salt, saltLen)) != 0)
return ret;
XMEMCPY(m, salt, hLen);
m += hLen;
XMEMCPY(m, salt, saltLen);
m += saltLen;
h = pkcsBlock + pkcsBlockLen - 1 - hLen;
if ((ret = wc_Hash(hType, s, (word32)(m - s), h, hLen)) != 0)
@ -738,9 +758,9 @@ static int RsaPad_PSS(const byte* input, word32 inputLen, byte* pkcsBlock,
return ret;
pkcsBlock[0] &= (1 << ((bits - 1) & 0x7)) - 1;
m = pkcsBlock + pkcsBlockLen - 1 - hLen - hLen - 1;
m = pkcsBlock + pkcsBlockLen - 1 - saltLen - hLen - 1;
*(m++) ^= 0x01;
for (i = 0; i < hLen; i++)
for (i = 0; i < saltLen; i++)
m[i] ^= salt[i];
return 0;
@ -799,8 +819,8 @@ static int RsaPad(const byte* input, word32 inputLen, byte* pkcsBlock,
/* helper function to direct which padding is used */
static int wc_RsaPad_ex(const byte* input, word32 inputLen, byte* pkcsBlock,
word32 pkcsBlockLen, byte padValue, WC_RNG* rng, int padType,
enum wc_HashType hType, int mgf, byte* optLabel, word32 labelLen, int bits,
void* heap)
enum wc_HashType hType, int mgf, byte* optLabel, word32 labelLen,
int saltLen, int bits, void* heap)
{
int ret;
@ -824,7 +844,7 @@ static int wc_RsaPad_ex(const byte* input, word32 inputLen, byte* pkcsBlock,
case WC_RSA_PSS_PAD:
WOLFSSL_MSG("wolfSSL Using RSA PSS padding");
ret = RsaPad_PSS(input, inputLen, pkcsBlock, pkcsBlockLen, rng,
hType, mgf, bits, heap);
hType, mgf, saltLen, bits, heap);
break;
#endif
@ -838,6 +858,7 @@ static int wc_RsaPad_ex(const byte* input, word32 inputLen, byte* pkcsBlock,
(void)mgf;
(void)optLabel;
(void)labelLen;
(void)saltLen;
(void)bits;
(void)heap;
@ -934,9 +955,23 @@ static int RsaUnPad_OAEP(byte *pkcsBlock, unsigned int pkcsBlockLen,
#endif /* WC_NO_RSA_OAEP */
#ifdef WC_RSA_PSS
/* 0x00 .. 0x00 0x01 | Salt | Gen Hash | 0xbc
* MGF over all bytes down to end of Salt
*
* pkcsBlock Buffer holding decrypted data.
* pkcsBlockLen Length of buffer.
* htype Hash function to use.
* mgf Mask generation function.
* saltLen Length of salt to put in padding.
* bits Length of key in bits.
* heap Used for dynamic memory allocation.
* returns 0 on success, PSS_SALTLEN_E when the salt length is invalid,
* BAD_PADDING_E when the padding is not valid, MEMORY_E when allocation fails
* and other negative values on error.
*/
static int RsaUnPad_PSS(byte *pkcsBlock, unsigned int pkcsBlockLen,
byte **output, enum wc_HashType hType, int mgf,
int bits, void* heap)
int saltLen, int bits, void* heap)
{
int ret;
byte* tmp;
@ -946,15 +981,21 @@ static int RsaUnPad_PSS(byte *pkcsBlock, unsigned int pkcsBlockLen,
if (hLen < 0)
return hLen;
if (saltLen == -1)
saltLen = hLen;
else if (saltLen > hLen || saltLen < -1)
return PSS_SALTLEN_E;
if ((int)pkcsBlockLen - hLen - 1 < saltLen + 2)
return PSS_SALTLEN_E;
if (pkcsBlock[pkcsBlockLen - 1] != 0xbc) {
WOLFSSL_MSG("RsaUnPad_PSS: Padding Error 0xBC");
return BAD_PADDING_E;
}
tmp = (byte*)XMALLOC(pkcsBlockLen, heap, DYNAMIC_TYPE_RSA_BUFFER);
if (tmp == NULL) {
if (tmp == NULL)
return MEMORY_E;
}
if ((ret = RsaMGF(mgf, pkcsBlock + pkcsBlockLen - 1 - hLen, hLen,
tmp, pkcsBlockLen - 1 - hLen, heap)) != 0) {
@ -963,7 +1004,7 @@ static int RsaUnPad_PSS(byte *pkcsBlock, unsigned int pkcsBlockLen,
}
tmp[0] &= (1 << ((bits - 1) & 0x7)) - 1;
for (i = 0; i < (int)(pkcsBlockLen - 1 - hLen - hLen - 1); i++) {
for (i = 0; i < (int)(pkcsBlockLen - 1 - saltLen - hLen - 1); i++) {
if (tmp[i] != pkcsBlock[i]) {
XFREE(tmp, heap, DYNAMIC_TYPE_RSA_BUFFER);
WOLFSSL_MSG("RsaUnPad_PSS: Padding Error Match");
@ -980,11 +1021,11 @@ static int RsaUnPad_PSS(byte *pkcsBlock, unsigned int pkcsBlockLen,
XFREE(tmp, heap, DYNAMIC_TYPE_RSA_BUFFER);
i = pkcsBlockLen - (RSA_PSS_PAD_SZ + 3 * hLen + 1);
i = pkcsBlockLen - (RSA_PSS_PAD_SZ + saltLen + 2 * hLen + 1);
XMEMSET(pkcsBlock + i, 0, RSA_PSS_PAD_SZ);
*output = pkcsBlock + i;
return RSA_PSS_PAD_SZ + 3 * hLen;
return RSA_PSS_PAD_SZ + saltLen + 2 * hLen;
}
#endif
@ -1038,8 +1079,8 @@ static int RsaUnPad(const byte *pkcsBlock, unsigned int pkcsBlockLen,
/* helper function to direct unpadding */
static int wc_RsaUnPad_ex(byte* pkcsBlock, word32 pkcsBlockLen, byte** out,
byte padValue, int padType, enum wc_HashType hType,
int mgf, byte* optLabel, word32 labelLen, int bits,
void* heap)
int mgf, byte* optLabel, word32 labelLen, int saltLen,
int bits, void* heap)
{
int ret;
@ -1061,7 +1102,7 @@ static int wc_RsaUnPad_ex(byte* pkcsBlock, word32 pkcsBlockLen, byte** out,
case WC_RSA_PSS_PAD:
WOLFSSL_MSG("wolfSSL Using RSA PSS un-padding");
ret = RsaUnPad_PSS((byte*)pkcsBlock, pkcsBlockLen, out, hType, mgf,
bits, heap);
saltLen, bits, heap);
break;
#endif
@ -1075,6 +1116,7 @@ static int wc_RsaUnPad_ex(byte* pkcsBlock, word32 pkcsBlockLen, byte** out,
(void)mgf;
(void)optLabel;
(void)labelLen;
(void)saltLen;
(void)bits;
(void)heap;
@ -1451,12 +1493,15 @@ int wc_RsaFunction(const byte* in, word32 inLen, byte* out,
hash : type of hash algorithm to use found in wolfssl/wolfcrypt/hash.h
mgf : type of mask generation function to use
label : optional label
labelSz : size of optional label buffer */
labelSz : size of optional label buffer
saltLen : Length of salt used in PSS
rng : random number generator */
static int RsaPublicEncryptEx(const byte* in, word32 inLen, byte* out,
word32 outLen, RsaKey* key, int rsa_type,
byte pad_value, int pad_type,
enum wc_HashType hash, int mgf,
byte* label, word32 labelSz, WC_RNG* rng)
byte* label, word32 labelSz, int saltLen,
WC_RNG* rng)
{
int ret, sz;
@ -1502,7 +1547,7 @@ static int RsaPublicEncryptEx(const byte* in, word32 inLen, byte* out,
#endif
ret = wc_RsaPad_ex(in, inLen, out, sz, pad_value, rng, pad_type, hash,
mgf, label, labelSz, mp_count_bits(&key->n),
mgf, label, labelSz, saltLen, mp_count_bits(&key->n),
key->heap);
if (ret < 0) {
break;
@ -1561,12 +1606,15 @@ static int RsaPublicEncryptEx(const byte* in, word32 inLen, byte* out,
hash : type of hash algorithm to use found in wolfssl/wolfcrypt/hash.h
mgf : type of mask generation function to use
label : optional label
labelSz : size of optional label buffer */
labelSz : size of optional label buffer
saltLen : Length of salt used in PSS
rng : random number generator */
static int RsaPrivateDecryptEx(byte* in, word32 inLen, byte* out,
word32 outLen, byte** outPtr, RsaKey* key,
int rsa_type, byte pad_value, int pad_type,
enum wc_HashType hash, int mgf,
byte* label, word32 labelSz, WC_RNG* rng)
byte* label, word32 labelSz, int saltLen,
WC_RNG* rng)
{
int ret = RSA_WRONG_TYPE_E;
@ -1636,8 +1684,8 @@ static int RsaPrivateDecryptEx(byte* in, word32 inLen, byte* out,
{
byte* pad = NULL;
ret = wc_RsaUnPad_ex(key->data, key->dataLen, &pad, pad_value, pad_type,
hash, mgf, label, labelSz, mp_count_bits(&key->n),
key->heap);
hash, mgf, label, labelSz, saltLen,
mp_count_bits(&key->n), key->heap);
if (ret > 0 && ret <= (int)outLen && pad != NULL) {
/* only copy output if not inline */
if (outPtr == NULL) {
@ -1696,7 +1744,7 @@ int wc_RsaPublicEncrypt(const byte* in, word32 inLen, byte* out, word32 outLen,
{
return RsaPublicEncryptEx(in, inLen, out, outLen, key,
RSA_PUBLIC_ENCRYPT, RSA_BLOCK_TYPE_2, WC_RSA_PKCSV15_PAD,
WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, rng);
WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, 0, rng);
}
@ -1707,7 +1755,7 @@ int wc_RsaPublicEncrypt_ex(const byte* in, word32 inLen, byte* out,
word32 labelSz)
{
return RsaPublicEncryptEx(in, inLen, out, outLen, key, RSA_PUBLIC_ENCRYPT,
RSA_BLOCK_TYPE_2, type, hash, mgf, label, labelSz, rng);
RSA_BLOCK_TYPE_2, type, hash, mgf, label, labelSz, 0, rng);
}
#endif /* WC_NO_RSA_OAEP */
@ -1720,7 +1768,7 @@ int wc_RsaPrivateDecryptInline(byte* in, word32 inLen, byte** out, RsaKey* key)
#endif
return RsaPrivateDecryptEx(in, inLen, in, inLen, out, key,
RSA_PRIVATE_DECRYPT, RSA_BLOCK_TYPE_2, WC_RSA_PKCSV15_PAD,
WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, rng);
WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, 0, rng);
}
@ -1735,7 +1783,7 @@ int wc_RsaPrivateDecryptInline_ex(byte* in, word32 inLen, byte** out,
#endif
return RsaPrivateDecryptEx(in, inLen, in, inLen, out, key,
RSA_PRIVATE_DECRYPT, RSA_BLOCK_TYPE_2, type, hash,
mgf, label, labelSz, rng);
mgf, label, labelSz, 0, rng);
}
#endif /* WC_NO_RSA_OAEP */
@ -1749,7 +1797,7 @@ int wc_RsaPrivateDecrypt(const byte* in, word32 inLen, byte* out,
#endif
return RsaPrivateDecryptEx((byte*)in, inLen, out, outLen, NULL, key,
RSA_PRIVATE_DECRYPT, RSA_BLOCK_TYPE_2, WC_RSA_PKCSV15_PAD,
WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, rng);
WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, 0, rng);
}
#ifndef WC_NO_RSA_OAEP
@ -1764,7 +1812,7 @@ int wc_RsaPrivateDecrypt_ex(const byte* in, word32 inLen, byte* out,
#endif
return RsaPrivateDecryptEx((byte*)in, inLen, out, outLen, NULL, key,
RSA_PRIVATE_DECRYPT, RSA_BLOCK_TYPE_2, type, hash, mgf, label,
labelSz, rng);
labelSz, 0, rng);
}
#endif /* WC_NO_RSA_OAEP */
@ -1777,7 +1825,7 @@ int wc_RsaSSL_VerifyInline(byte* in, word32 inLen, byte** out, RsaKey* key)
#endif
return RsaPrivateDecryptEx(in, inLen, in, inLen, out, key,
RSA_PUBLIC_DECRYPT, RSA_BLOCK_TYPE_1, WC_RSA_PKCSV15_PAD,
WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, rng);
WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, 0, rng);
}
int wc_RsaSSL_Verify(const byte* in, word32 inLen, byte* out, word32 outLen,
@ -1795,12 +1843,44 @@ int wc_RsaSSL_Verify(const byte* in, word32 inLen, byte* out, word32 outLen,
#endif
return RsaPrivateDecryptEx((byte*)in, inLen, out, outLen, NULL, key,
RSA_PUBLIC_DECRYPT, RSA_BLOCK_TYPE_1, WC_RSA_PKCSV15_PAD,
WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, rng);
WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, 0, rng);
}
#ifdef WC_RSA_PSS
/* Verify the message signed with RSA-PSS.
* The input buffer is reused for the ouput buffer.
* Salt length is equal to hash length.
*
* in Buffer holding encrypted data.
* inLen Length of data in buffer.
* out Pointer to address containing the PSS data.
* hash Hash algorithm.
* mgf Mask generation function.
* key Public RSA key.
* returns the length of the PSS data on success and negative indicates failure.
*/
int wc_RsaPSS_VerifyInline(byte* in, word32 inLen, byte** out,
enum wc_HashType hash, int mgf, RsaKey* key)
{
return wc_RsaPSS_VerifyInline_ex(in, inLen, out, hash, mgf, -1, key);
}
/* Verify the message signed with RSA-PSS.
* The input buffer is reused for the ouput buffer.
*
* in Buffer holding encrypted data.
* inLen Length of data in buffer.
* out Pointer to address containing the PSS data.
* hash Hash algorithm.
* mgf Mask generation function.
* key Public RSA key.
* saltLen Length of salt used. -1 indicates salt length is the same as the
* hash length.
* returns the length of the PSS data on success and negative indicates failure.
*/
int wc_RsaPSS_VerifyInline_ex(byte* in, word32 inLen, byte** out,
enum wc_HashType hash, int mgf, int saltLen,
RsaKey* key)
{
WC_RNG* rng = NULL;
#ifdef WC_RSA_BLINDING
@ -1808,32 +1888,115 @@ int wc_RsaPSS_VerifyInline(byte* in, word32 inLen, byte** out,
#endif
return RsaPrivateDecryptEx(in, inLen, in, inLen, out, key,
RSA_PUBLIC_DECRYPT, RSA_BLOCK_TYPE_1, WC_RSA_PSS_PAD,
hash, mgf, NULL, 0, rng);
hash, mgf, NULL, 0, saltLen, rng);
}
/* Sig = 8 * 0x00 | Space for Message Hash | Salt | Exp Hash
* Exp Hash = HASH(8 * 0x00 | Message Hash | Salt)
/* Verify the message signed with RSA-PSS.
* Salt length is equal to hash length.
*
* in Buffer holding encrypted data.
* inLen Length of data in buffer.
* out Pointer to address containing the PSS data.
* hash Hash algorithm.
* mgf Mask generation function.
* key Public RSA key.
* returns the length of the PSS data on success and negative indicates failure.
*/
int wc_RsaPSS_Verify(byte* in, word32 inLen, byte* out, word32 outLen,
enum wc_HashType hash, int mgf, RsaKey* key)
{
return wc_RsaPSS_Verify_ex(in, inLen, out, outLen, hash, mgf, -1, key);
}
/* Verify the message signed with RSA-PSS.
*
* in Buffer holding encrypted data.
* inLen Length of data in buffer.
* out Pointer to address containing the PSS data.
* hash Hash algorithm.
* mgf Mask generation function.
* key Public RSA key.
* saltLen Length of salt used. -1 indicates salt length is the same as the
* hash length.
* returns the length of the PSS data on success and negative indicates failure.
*/
int wc_RsaPSS_Verify_ex(byte* in, word32 inLen, byte* out, word32 outLen,
enum wc_HashType hash, int mgf, int saltLen,
RsaKey* key)
{
WC_RNG* rng = NULL;
#ifdef WC_RSA_BLINDING
rng = key->rng;
#endif
return RsaPrivateDecryptEx(in, inLen, out, outLen, NULL, key,
RSA_PUBLIC_DECRYPT, RSA_BLOCK_TYPE_1, WC_RSA_PSS_PAD,
hash, mgf, NULL, 0, saltLen, rng);
}
/* Checks the PSS data to ensure that the signature matches.
* Salt length is equal to hash length.
*
* in Hash of the data that is being verified.
* inSz Length of hash.
* sig Buffer holding PSS data.
* sigSz Size of PSS data.
* hashType Hash algorithm.
* returns BAD_PADDING_E when the PSS data is invalid, BAD_FUNC_ARG when
* NULL is passed in to in or sig or inSz is not the same as the hash
* algorithm length and 0 on success.
*/
int wc_RsaPSS_CheckPadding(const byte* in, word32 inSz, byte* sig,
word32 sigSz, enum wc_HashType hashType)
{
int ret;
return wc_RsaPSS_CheckPadding_ex(in, inSz, sig, sigSz, hashType, inSz);
}
/* Checks the PSS data to ensure that the signature matches.
*
* in Hash of the data that is being verified.
* inSz Length of hash.
* sig Buffer holding PSS data.
* sigSz Size of PSS data.
* hashType Hash algorithm.
* saltLen Length of salt used. -1 indicates salt length is the same as the
* hash length.
* returns BAD_PADDING_E when the PSS data is invalid, BAD_FUNC_ARG when
* NULL is passed in to in or sig or inSz is not the same as the hash
* algorithm length and 0 on success.
*/
int wc_RsaPSS_CheckPadding_ex(const byte* in, word32 inSz, byte* sig,
word32 sigSz, enum wc_HashType hashType,
int saltLen)
{
int ret = 0;
if (in == NULL || sig == NULL ||
inSz != (word32)wc_HashGetDigestSize(hashType) ||
sigSz != RSA_PSS_PAD_SZ + inSz * 3)
inSz != (word32)wc_HashGetDigestSize(hashType))
ret = BAD_FUNC_ARG;
else {
if (ret == 0) {
if (saltLen == -1)
saltLen = inSz;
else if (saltLen < -1 || (word32)saltLen > inSz)
ret = PSS_SALTLEN_E;
}
/* Sig = 8 * 0x00 | Space for Message Hash | Salt | Exp Hash */
if (ret == 0) {
if (sigSz != RSA_PSS_PAD_SZ + inSz + (word32)saltLen + inSz)
ret = BAD_PADDING_E;
}
/* Exp Hash = HASH(8 * 0x00 | Message Hash | Salt) */
if (ret == 0) {
XMEMCPY(sig + RSA_PSS_PAD_SZ, in, inSz);
ret = wc_Hash(hashType, sig, RSA_PSS_PAD_SZ + inSz * 2, sig, inSz);
if (ret != 0)
return ret;
if (XMEMCMP(sig, sig + RSA_PSS_PAD_SZ + inSz * 2, inSz) != 0) {
ret = wc_Hash(hashType, sig, RSA_PSS_PAD_SZ + inSz + saltLen, sig,
inSz);
}
if (ret == 0) {
if (XMEMCMP(sig, sig + RSA_PSS_PAD_SZ + inSz + saltLen, inSz) != 0) {
WOLFSSL_MSG("RsaPSS_CheckPadding: Padding Error");
ret = BAD_PADDING_E;
}
else
ret = 0;
}
return ret;
@ -1845,16 +2008,52 @@ int wc_RsaSSL_Sign(const byte* in, word32 inLen, byte* out, word32 outLen,
{
return RsaPublicEncryptEx(in, inLen, out, outLen, key,
RSA_PRIVATE_ENCRYPT, RSA_BLOCK_TYPE_1, WC_RSA_PKCSV15_PAD,
WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, rng);
WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, 0, rng);
}
#ifdef WC_RSA_PSS
/* Sign the hash of a message using RSA-PSS.
* Salt length is equal to hash length.
*
* in Buffer holding hash of message.
* inLen Length of data in buffer (hash length).
* out Buffer to write encrypted signature into.
* outLen Size of buffer to write to.
* hash Hash algorithm.
* mgf Mask generation function.
* key Public RSA key.
* rng Random number generator.
* returns the length of the encrypted signature on success, a negative value
* indicates failure.
*/
int wc_RsaPSS_Sign(const byte* in, word32 inLen, byte* out, word32 outLen,
enum wc_HashType hash, int mgf, RsaKey* key, WC_RNG* rng)
{
return wc_RsaPSS_Sign_ex(in, inLen, out, outLen, hash, mgf, -1, key, rng);
}
/* Sign the hash of a message using RSA-PSS.
*
* in Buffer holding hash of message.
* inLen Length of data in buffer (hash length).
* out Buffer to write encrypted signature into.
* outLen Size of buffer to write to.
* hash Hash algorithm.
* mgf Mask generation function.
* saltLen Length of salt used. -1 indicates salt length is the same as the
* hash length.
* key Public RSA key.
* rng Random number generator.
* returns the length of the encrypted signature on success, a negative value
* indicates failure.
*/
int wc_RsaPSS_Sign_ex(const byte* in, word32 inLen, byte* out, word32 outLen,
enum wc_HashType hash, int mgf, int saltLen, RsaKey* key,
WC_RNG* rng)
{
return RsaPublicEncryptEx(in, inLen, out, outLen, key,
RSA_PRIVATE_ENCRYPT, RSA_BLOCK_TYPE_1, WC_RSA_PSS_PAD,
hash, mgf, NULL, 0, rng);
hash, mgf, NULL, 0, saltLen, rng);
}
#endif

@ -7653,6 +7653,175 @@ done:
#endif
#define RSA_TEST_BYTES 256
#ifdef WC_RSA_PSS
static int rsa_pss_test(WC_RNG* rng, RsaKey* key)
{
byte digest[WC_MAX_DIGEST_SIZE];
int ret = 0;
const char* inStr = "Everyone gets Friday off.";
word32 inLen = (word32)XSTRLEN((char*)inStr);
word32 outSz;
word32 plainSz;
word32 digestSz;
int i, j;
#ifdef RSA_PSS_TEST_WRONG_PARAMS
int k, l;
#endif
byte* plain;
int mgf[] = {
#ifndef NO_SHA
WC_MGF1SHA1,
#endif
#ifdef WOLFSSL_SHA224
WC_MGF1SHA224,
#endif
WC_MGF1SHA256,
#ifdef WOLFSSL_SHA384
WC_MGF1SHA384,
#endif
#ifdef WOLFSSL_SHA512
WC_MGF1SHA512
#endif
};
enum wc_HashType hash[] = {
#ifndef NO_SHA
WC_HASH_TYPE_SHA,
#endif
#ifdef WOLFSSL_SHA224
WC_HASH_TYPE_SHA224,
#endif
WC_HASH_TYPE_SHA256,
#ifdef WOLFSSL_SHA384
WC_HASH_TYPE_SHA384,
#endif
#ifdef WOLFSSL_SHA512
WC_HASH_TYPE_SHA512,
#endif
};
DECLARE_VAR_INIT(in, byte, inLen, inStr, HEAP_HINT);
DECLARE_VAR(out, byte, RSA_TEST_BYTES, HEAP_HINT);
DECLARE_VAR(sig, byte, RSA_TEST_BYTES, HEAP_HINT);
/* Test all combinations of hash and MGF. */
for (j = 0; j < (int)(sizeof(hash)/sizeof(*hash)); j++) {
/* Calculate hash of message. */
ret = wc_Hash(hash[j], in, inLen, digest, sizeof(digest));
if (ret != 0)
ERROR_OUT(-5450, exit_rsa_pss);
digestSz = wc_HashGetDigestSize(hash[j]);
for (i = 0; i < (int)(sizeof(mgf)/sizeof(*mgf)); i++) {
outSz = RSA_TEST_BYTES;
ret = wc_RsaPSS_Sign_ex(digest, digestSz, out, outSz, hash[j],
mgf[i], -1, key, rng);
if (ret <= 0)
ERROR_OUT(-5451, exit_rsa_pss);
outSz = ret;
XMEMCPY(sig, out, outSz);
plain = NULL;
ret = wc_RsaPSS_VerifyInline_ex(sig, outSz, &plain, hash[j],
mgf[i], -1, key);
if (ret <= 0)
ERROR_OUT(-5452, exit_rsa_pss);
plainSz = ret;
ret = wc_RsaPSS_CheckPadding(digest, digestSz, plain, plainSz,
hash[j]);
if (ret != 0)
ERROR_OUT(-5453, exit_rsa_pss);
#ifdef RSA_PSS_TEST_WRONG_PARAMS
for (k = 0; k < (int)(sizeof(mgf)/sizeof(*mgf)); k++) {
for (l = 0; l < (int)(sizeof(hash)/sizeof(*hash)); l++) {
if (i == k && j == l)
continue;
XMEMCPY(sig, out, outSz);
ret = wc_RsaPSS_VerifyInline_ex(sig, outSz, (byte**)&plain,
hash[l], mgf[k], -1, key);
if (ret >= 0)
ERROR_OUT(-5454, exit_rsa_pss);
}
}
#endif
}
}
/* Test that a salt length of zero works. */
digestSz = wc_HashGetDigestSize(hash[0]);
outSz = RSA_TEST_BYTES;
ret = wc_RsaPSS_Sign_ex(digest, digestSz, out, outSz, hash[0], mgf[0], 0,
key, rng);
if (ret <= 0)
ERROR_OUT(-5460, exit_rsa_pss);
outSz = ret;
ret = wc_RsaPSS_Verify_ex(out, outSz, sig, outSz, hash[0], mgf[0], 0,
key);
if (ret <= 0)
ERROR_OUT(-5461, exit_rsa_pss);
plainSz = ret;
ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, sig, plainSz, hash[0],
0);
if (ret != 0)
ERROR_OUT(-5462, exit_rsa_pss);
XMEMCPY(sig, out, outSz);
plain = NULL;
ret = wc_RsaPSS_VerifyInline_ex(sig, outSz, &plain, hash[0], mgf[0], 0,
key);
if (ret <= 0)
ERROR_OUT(-5463, exit_rsa_pss);
plainSz = ret;
ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, hash[0],
0);
if (ret != 0)
ERROR_OUT(-5464, exit_rsa_pss);
/* Test bad salt lengths in various APIs. */
digestSz = wc_HashGetDigestSize(hash[0]);
outSz = RSA_TEST_BYTES;
ret = wc_RsaPSS_Sign_ex(digest, digestSz, out, outSz, hash[0], mgf[0], -2,
key, rng);
if (ret != PSS_SALTLEN_E)
ERROR_OUT(-5470, exit_rsa_pss);
ret = wc_RsaPSS_Sign_ex(digest, digestSz, out, outSz, hash[0], mgf[0],
digestSz + 1, key, rng);
if (ret != PSS_SALTLEN_E)
ERROR_OUT(-5471, exit_rsa_pss);
ret = wc_RsaPSS_VerifyInline_ex(sig, outSz, &plain, hash[0], mgf[0], -2,
key);
if (ret != PSS_SALTLEN_E)
ERROR_OUT(-5472, exit_rsa_pss);
ret = wc_RsaPSS_VerifyInline_ex(sig, outSz, &plain, hash[0], mgf[0],
digestSz + 1, key);
if (ret != PSS_SALTLEN_E)
ERROR_OUT(-5473, exit_rsa_pss);
ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, hash[0],
-2);
if (ret != PSS_SALTLEN_E)
ERROR_OUT(-5474, exit_rsa_pss);
ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, hash[0],
digestSz + 1);
if (ret != PSS_SALTLEN_E)
ERROR_OUT(-5475, exit_rsa_pss);
ret = 0;
exit_rsa_pss:
FREE_VAR(in, HEAP_HINT);
FREE_VAR(out, HEAP_HINT);
return ret;
}
#endif
int rsa_test(void)
{
int ret;
@ -8967,6 +9136,10 @@ int rsa_test(void)
#endif /* WOLFSSL_CERT_REQ */
#endif /* WOLFSSL_CERT_GEN */
#ifdef WC_RSA_PSS
ret = rsa_pss_test(&rng, &key);
#endif
exit_rsa:
wc_FreeRsaKey(&key);
#ifdef WOLFSSL_CERT_EXT

@ -194,7 +194,9 @@ enum {
WC_HW_E = -248, /* Error with hardware crypto use */
WC_HW_WAIT_E = -249, /* Hardware waiting on resource */
WC_LAST_E = -249, /* Update this to indicate last error */
PSS_SALTLEN_E = -250, /* PSS length of salt is to long for hash */
WC_LAST_E = -250, /* Update this to indicate last error */
MIN_CODE_E = -300 /* errors -101 - -299 */
/* add new companion error id strings for any new error codes

@ -150,6 +150,10 @@ WOLFSSL_API int wc_RsaSSL_Sign(const byte* in, word32 inLen, byte* out,
WOLFSSL_API int wc_RsaPSS_Sign(const byte* in, word32 inLen, byte* out,
word32 outLen, enum wc_HashType hash, int mgf,
RsaKey* key, WC_RNG* rng);
WOLFSSL_API int wc_RsaPSS_Sign_ex(const byte* in, word32 inLen, byte* out,
word32 outLen, enum wc_HashType hash,
int mgf, int saltLen, RsaKey* key,
WC_RNG* rng);
WOLFSSL_API int wc_RsaSSL_VerifyInline(byte* in, word32 inLen, byte** out,
RsaKey* key);
WOLFSSL_API int wc_RsaSSL_Verify(const byte* in, word32 inLen, byte* out,
@ -157,9 +161,22 @@ WOLFSSL_API int wc_RsaSSL_Verify(const byte* in, word32 inLen, byte* out,
WOLFSSL_API int wc_RsaPSS_VerifyInline(byte* in, word32 inLen, byte** out,
enum wc_HashType hash, int mgf,
RsaKey* key);
WOLFSSL_API int wc_RsaPSS_VerifyInline_ex(byte* in, word32 inLen, byte** out,
enum wc_HashType hash, int mgf,
int saltLen, RsaKey* key);
WOLFSSL_API int wc_RsaPSS_Verify(byte* in, word32 inLen, byte* out,
word32 outLen, enum wc_HashType hash, int mgf,
RsaKey* key);
WOLFSSL_API int wc_RsaPSS_Verify_ex(byte* in, word32 inLen, byte* out,
word32 outLen, enum wc_HashType hash,
int mgf, int saltLen, RsaKey* key);
WOLFSSL_API int wc_RsaPSS_CheckPadding(const byte* in, word32 inLen, byte* sig,
word32 sigSz,
enum wc_HashType hashType);
WOLFSSL_API int wc_RsaPSS_CheckPadding_ex(const byte* in, word32 inLen,
byte* sig, word32 sigSz,
enum wc_HashType hashType,
int saltLen);
WOLFSSL_API int wc_RsaEncryptSize(RsaKey* key);