diff --git a/aes.c b/aes.c index b39d111..2d002e3 100644 --- a/aes.c +++ b/aes.c @@ -638,14 +638,24 @@ void AES_CCM_decrypt(AES_CCM_ctx* ctx, uint8_t buf[], uint32_t len) CCM_Xcrypt(ctx, buf, len, 0); } -void AES_CCM_generate_tag(AES_CCM_ctx* ctx, uint8_t tag[AES_BLOCKLEN], uint8_t tag_len) +void CCM_generate_tag(AES_CCM_ctx* ctx) { memset(ctx->counter + AES_BLOCKLEN - ctx->ctr_len, 0, ctx->ctr_len); memcpy(ctx->ctr_buf, ctx->counter, AES_BLOCKLEN); Cipher((state_t*)ctx->ctr_buf, ctx->round_key); Cipher((state_t*)ctx->cbc_buf, ctx->round_key); XorWithIv(ctx->cbc_buf, ctx->ctr_buf); +} + +void AES_CCM_generate_tag(AES_CCM_ctx* ctx, uint8_t tag[AES_BLOCKLEN], uint8_t tag_len) +{ + CCM_generate_tag(ctx); memcpy(tag, ctx->cbc_buf, tag_len); } +bool AES_CCM_verify_tag(AES_CCM_ctx* ctx, uint8_t tag[], uint8_t tag_len) +{ + return memcmp(tag, ctx->cbc_buf, tag_len) == 0; +} + #endif // #if defined(CTR) && (CTR == 1) diff --git a/aes.h b/aes.h index ceaed01..29c91b1 100644 --- a/aes.h +++ b/aes.h @@ -2,6 +2,7 @@ #define _AES_H_ #include +#include // #define the macros below to 1/0 to enable/disable the mode of operation. // @@ -105,13 +106,14 @@ typedef struct { // AES_CCM_Encrypt and AES_CCM_Decrypt could be called multiple times. // The sum of len MUST be data_len given in AES_CCM_Init. // The tag_len in AES_CCM_Init and AES_CCM_GenerateTag MUST be the same. -// After calling AES_CCM_GenerateTag, ctx MUST be initialized before reusing it. +// After calling AES_CCM_GenerateTag or AES_CCM_verify_tag, ctx MUST be initialized before reused. void AES_CCM_init(AES_CCM_ctx* ctx, const uint8_t key[AES_KEYLEN], const uint8_t nonce[], uint8_t nonce_len, uint32_t data_len, uint8_t tag_len); void AES_CCM_encrypt(AES_CCM_ctx* ctx, uint8_t buf[], uint32_t len); void AES_CCM_decrypt(AES_CCM_ctx* ctx, uint8_t buf[], uint32_t len); void AES_CCM_generate_tag(AES_CCM_ctx* ctx, uint8_t tag[], uint8_t tag_len); +bool AES_CCM_verify_tag(AES_CCM_ctx* ctx, uint8_t tag[], uint8_t tag_len); #endif // #if defined(CTR) && (CTR == 1) diff --git a/test.c b/test.c index 7979ae2..a90c898 100644 --- a/test.c +++ b/test.c @@ -411,7 +411,7 @@ static int test_decrypt_ccm(void) printf("CCM decrypt: "); if ((0 == memcmp((char *)out, (char *)in, sizeof(out))) - && (0 == memcmp((char *)tag, (char *)tag_out, sizeof(tag)))) + && (AES_CCM_verify_tag(&ctx, tag, 16))) { printf("SUCCESS!\n"); return (0);