diff --git a/ctaocrypt/test/test.c b/ctaocrypt/test/test.c index 10da769bb..0675d920e 100644 --- a/ctaocrypt/test/test.c +++ b/ctaocrypt/test/test.c @@ -1875,8 +1875,12 @@ int aes_test(void) if (ret != 0) return -1002; - AesCbcEncrypt(&enc, cipher, msg, AES_BLOCK_SIZE); - AesCbcDecrypt(&dec, plain, cipher, AES_BLOCK_SIZE); + ret = AesCbcEncrypt(&enc, cipher, msg, AES_BLOCK_SIZE); + if (ret != 0) + return -1005; + ret = AesCbcDecrypt(&dec, plain, cipher, AES_BLOCK_SIZE); + if (ret != 0) + return -1006; if (memcmp(plain, msg, AES_BLOCK_SIZE)) return -60; diff --git a/cyassl/ctaocrypt/aes.h b/cyassl/ctaocrypt/aes.h index bc1cd5913..371778b62 100644 --- a/cyassl/ctaocrypt/aes.h +++ b/cyassl/ctaocrypt/aes.h @@ -154,10 +154,15 @@ CYASSL_API int AesCcmDecrypt(Aes* aes, byte* out, const byte* in, word32 inSz, /* fips wrapper calls, user can call direct */ CYASSL_API int AesSetKey_fips(Aes* aes, const byte* key, word32 len, const byte* iv, int dir); - + CYASSL_API int AesCbcEncrypt_fips(Aes* aes, byte* out, const byte* in, + word32 sz); + CYASSL_API int AesCbcDecrypt_fips(Aes* aes, byte* out, const byte* in, + word32 sz); #ifndef FIPS_NO_WRAPPERS /* if not internal or fips.c consumer force fips calls if fips build */ - #define AesSetKey AesSetKey_fips + #define AesSetKey AesSetKey_fips + #define AesCbcEncrypt AesCbcEncrypt_fips + #define AesCbcDecrypt AesCbcDecrypt_fips #endif /* FIPS_NO_WRAPPERS */ #endif /* HAVE_FIPS */ diff --git a/cyassl/sniffer_error.h b/cyassl/sniffer_error.h index 586efcb7d..f8528668f 100644 --- a/cyassl/sniffer_error.h +++ b/cyassl/sniffer_error.h @@ -101,6 +101,7 @@ #define BAD_COMPRESSION_STR 67 #define BAD_DERIVE_STR 68 #define ACK_MISSED_STR 69 +#define BAD_DECRYPT 70 /* !!!! also add to msgTable in sniffer.c and .rc file !!!! */ diff --git a/cyassl/sniffer_error.rc b/cyassl/sniffer_error.rc index 6171f7849..516f7aa11 100644 --- a/cyassl/sniffer_error.rc +++ b/cyassl/sniffer_error.rc @@ -83,5 +83,6 @@ STRINGTABLE 67, "Bad Compression Type" 68, "Bad DeriveKeys Error" 69, "Saw ACK for Missing Packet Error" + 70, "Bad Decrypt Operation" } diff --git a/src/sniffer.c b/src/sniffer.c index 4f6d7c21d..85c00093e 100644 --- a/src/sniffer.c +++ b/src/sniffer.c @@ -224,7 +224,8 @@ static const char* const msgTable[] = "Bad Finished Message Processing", "Bad Compression Type", "Bad DeriveKeys Error", - "Saw ACK for Missing Packet Error" + "Saw ACK for Missing Packet Error", + "Bad Decrypt Operation" }; @@ -1557,9 +1558,11 @@ static int DoHandShake(const byte* input, int* sslBytes, } -/* Decrypt input into plain output */ -static void Decrypt(SSL* ssl, byte* output, const byte* input, word32 sz) +/* Decrypt input into plain output, 0 on success */ +static int Decrypt(SSL* ssl, byte* output, const byte* input, word32 sz) { + int ret = 0; + switch (ssl->specs.bulk_cipher_algorithm) { #ifdef BUILD_ARC4 case cyassl_rc4: @@ -1575,7 +1578,7 @@ static void Decrypt(SSL* ssl, byte* output, const byte* input, word32 sz) #ifdef BUILD_AES case cyassl_aes: - AesCbcDecrypt(ssl->decrypt.aes, output, input, sz); + ret = AesCbcDecrypt(ssl->decrypt.aes, output, input, sz); break; #endif @@ -1599,18 +1602,25 @@ static void Decrypt(SSL* ssl, byte* output, const byte* input, word32 sz) default: Trace(BAD_DECRYPT_TYPE); + ret = -1; break; } + + return ret; } /* Decrypt input message into output, adjust output steam if needed */ static const byte* DecryptMessage(SSL* ssl, const byte* input, word32 sz, - byte* output) + byte* output, int* error) { int ivExtra = 0; - Decrypt(ssl, output, input, sz); + int ret = Decrypt(ssl, output, input, sz); + if (ret != 0) { + *error = ret; + return NULL; + } ssl->keys.encryptSz = sz; if (ssl->options.tls1_1 && ssl->specs.cipher_type == block) { output += ssl->specs.block_size; /* go past TLSv1.1 IV */ @@ -2320,6 +2330,7 @@ static int ProcessMessage(const byte* sslFrame, SnifferSession* session, RecordLayerHeader rh; int rhSize = 0; int ret; + int errCode = 0; int decoded = 0; /* bytes stored for user in data */ int notEnough; /* notEnough bytes yet flag */ SSL* ssl = (session->flags.side == CYASSL_SERVER_END) ? @@ -2372,7 +2383,11 @@ doMessage: return -1; } sslFrame = DecryptMessage(ssl, sslFrame, rhSize, - ssl->buffers.outputBuffer.buffer); + ssl->buffers.outputBuffer.buffer, &errCode); + if (errCode != 0) { + SetError(BAD_DECRYPT, error, session, FATAL_ERROR_STATE); + return -1; + } } switch ((enum ContentType)rh.type) {