diff --git a/wolfcrypt/src/rsa.c b/wolfcrypt/src/rsa.c index aea0ad7c0..d02802a20 100644 --- a/wolfcrypt/src/rsa.c +++ b/wolfcrypt/src/rsa.c @@ -1523,7 +1523,7 @@ static int wc_RsaFunctionAsync(const byte* in, word32 inLen, byte* out, * RSA_PUBLIC_DECRYPT, RSA_PRIVATE_ENCRYPT, RSA_PRIVATE_DECRYPT} * rng wolfSSL RNG to use if needed * - * returns 0 on success + * returns size of result on success */ int wc_RsaDirect(byte* in, word32 inLen, byte* out, word32* outSz, RsaKey* key, int type, WC_RNG* rng) @@ -1560,7 +1560,48 @@ int wc_RsaDirect(byte* in, word32 inLen, byte* out, word32* outSz, return LENGTH_ONLY_E; } - return wc_RsaFunction(in, inLen, out, outSz, type, key, rng); + switch (key->state) { + case RSA_STATE_NONE: + case RSA_STATE_ENCRYPT_PAD: + case RSA_STATE_ENCRYPT_EXPTMOD: + case RSA_STATE_DECRYPT_EXPTMOD: + case RSA_STATE_DECRYPT_UNPAD: + key->state = (type == RSA_PRIVATE_ENCRYPT || + type == RSA_PUBLIC_ENCRYPT) ? RSA_STATE_ENCRYPT_EXPTMOD: + RSA_STATE_DECRYPT_EXPTMOD; + + key->dataLen = *outSz; + + ret = wc_RsaFunction(in, inLen, out, &key->dataLen, type, key, rng); + if (ret >= 0 || ret == WC_PENDING_E) { + key->state = (type == RSA_PRIVATE_ENCRYPT || + type == RSA_PUBLIC_ENCRYPT) ? RSA_STATE_ENCRYPT_RES: + RSA_STATE_DECRYPT_RES; + } + if (ret < 0) { + break; + } + + FALL_THROUGH; + + case RSA_STATE_ENCRYPT_RES: + case RSA_STATE_DECRYPT_RES: + ret = key->dataLen; + break; + + default: + ret = BAD_STATE_E; + } + + /* if async pending then skip cleanup*/ + if (ret == WC_PENDING_E) { + return ret; + } + + key->state = RSA_STATE_NONE; + wc_RsaCleanup(key); + + return ret; } #endif /* WC_RSA_NO_PADDING */ diff --git a/wolfcrypt/test/test.c b/wolfcrypt/test/test.c index 9da753ef0..8bbbd77b7 100644 --- a/wolfcrypt/test/test.c +++ b/wolfcrypt/test/test.c @@ -8585,8 +8585,16 @@ int rsa_no_pad_test(void) inLen = wc_RsaEncryptSize(&key); XMEMSET(tmp, 7, inLen); - ret = wc_RsaDirect(tmp, inLen, out, &outSz, &key, RSA_PRIVATE_ENCRYPT, &rng); - if (ret != 0) { + do { + #if defined(WOLFSSL_ASYNC_CRYPT) + ret = wc_AsyncWait(ret, &key.asyncDev, WC_ASYNC_FLAG_CALL_AGAIN); + #endif + if (ret >= 0) { + ret = wc_RsaDirect(tmp, inLen, out, &outSz, &key, + RSA_PRIVATE_ENCRYPT, &rng); + } + } while (ret == WC_PENDING_E); + if (ret <= 0) { ERROR_OUT(-506, exit_rsa_nopadding); } @@ -8596,9 +8604,16 @@ int rsa_no_pad_test(void) } /* decrypt with public key and compare result */ - ret = wc_RsaDirect(out, outSz, plain, &plainSz, &key, RSA_PUBLIC_DECRYPT, - &rng); - if (ret != 0) { + do { + #if defined(WOLFSSL_ASYNC_CRYPT) + ret = wc_AsyncWait(ret, &key.asyncDev, WC_ASYNC_FLAG_CALL_AGAIN); + #endif + if (ret >= 0) { + ret = wc_RsaDirect(out, outSz, plain, &plainSz, &key, + RSA_PUBLIC_DECRYPT, &rng); + } + } while (ret == WC_PENDING_E); + if (ret <= 0) { ERROR_OUT(-508, exit_rsa_nopadding); } @@ -8614,15 +8629,28 @@ int rsa_no_pad_test(void) #endif /* test encrypt and decrypt using WC_RSA_NO_PAD */ - ret = wc_RsaPublicEncrypt_ex(tmp, inLen, out, (int)outSz, &key, &rng, - WC_RSA_NO_PAD, WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0); + do { + #if defined(WOLFSSL_ASYNC_CRYPT) + ret = wc_AsyncWait(ret, &key.asyncDev, WC_ASYNC_FLAG_CALL_AGAIN); + #endif + if (ret >= 0) { + ret = wc_RsaPublicEncrypt_ex(tmp, inLen, out, (int)outSz, &key, &rng, + WC_RSA_NO_PAD, WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0); + } + } while (ret == WC_PENDING_E); if (ret < 0) { ERROR_OUT(-511, exit_rsa_nopadding); } - printf("outSz = %d plainSz = %d\n", outSz, plainSz); - ret = wc_RsaPrivateDecrypt_ex(out, outSz, plain, (int)plainSz, &key, - WC_RSA_NO_PAD, WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0); + do { + #if defined(WOLFSSL_ASYNC_CRYPT) + ret = wc_AsyncWait(ret, &key.asyncDev, WC_ASYNC_FLAG_CALL_AGAIN); + #endif + if (ret >= 0) { + ret = wc_RsaPrivateDecrypt_ex(out, outSz, plain, (int)plainSz, &key, + WC_RSA_NO_PAD, WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0); + } + } while (ret == WC_PENDING_E); if (ret < 0) { ERROR_OUT(-512, exit_rsa_nopadding); }