linuxkm/lkcapi_glue.c: refactor AES-CBC, AES-CFB, and AES-GCM glue around struct km_AesCtx with separate aes_encrypt and aes_decrypt Aes pointers, and no cached key, to avoid AesSetKey operations at encrypt/decrypt time.

This commit is contained in:
Daniel Pouzzner 2024-01-27 23:16:02 -06:00
parent 8ae031a5ed
commit 957fc7460c

View File

@ -96,35 +96,57 @@ static int linuxkm_test_aesxts(void);
#include <wolfssl/wolfcrypt/aes.h>
struct km_AesCtx {
Aes *aes; /* must be pointer to control alignment, needed for AESNI. */
u8 key[AES_MAX_KEY_SIZE / 8];
unsigned int keylen;
Aes *aes_encrypt; /* must be pointer to control alignment, needed for AESNI. */
Aes *aes_decrypt; /* same. */
};
static inline void km_ForceZero(struct km_AesCtx * ctx)
{
memzero_explicit(ctx->key, sizeof(ctx->key));
ctx->keylen = 0;
}
#if defined(LINUXKM_LKCAPI_REGISTER_ALL) || \
defined(LINUXKM_LKCAPI_REGISTER_AESCBC) || \
defined(LINUXKM_LKCAPI_REGISTER_AESCFB) || \
defined(LINUXKM_LKCAPI_REGISTER_AESGCM)
static int km_AesInitCommon(struct km_AesCtx * ctx, const char * name)
static void km_AesExitCommon(struct km_AesCtx * ctx);
static int km_AesInitCommon(struct km_AesCtx * ctx, const char * name, int need_decryption)
{
int err;
ctx->aes = (Aes *)malloc(sizeof(*ctx->aes));
ctx->aes_encrypt = (Aes *)malloc(sizeof(*ctx->aes_encrypt));
if (! ctx->aes)
if (! ctx->aes_encrypt) {
pr_err("error: km_AesInitCommon %s failed: %d\n", name, MEMORY_E);
return MEMORY_E;
}
err = wc_AesInit(ctx->aes, NULL, INVALID_DEVID);
err = wc_AesInit(ctx->aes_encrypt, NULL, INVALID_DEVID);
if (unlikely(err)) {
pr_err("error: km_AesInitCommon %s failed: %d\n", name, err);
free(ctx->aes_encrypt);
ctx->aes_encrypt = NULL;
return err;
}
if (! need_decryption) {
ctx->aes_decrypt = NULL;
return 0;
}
ctx->aes_decrypt = (Aes *)malloc(sizeof(*ctx->aes_decrypt));
if (! ctx->aes_encrypt) {
pr_err("error: km_AesInitCommon %s failed: %d\n", name, MEMORY_E);
km_AesExitCommon(ctx);
return MEMORY_E;
}
err = wc_AesInit(ctx->aes_decrypt, NULL, INVALID_DEVID);
if (unlikely(err)) {
pr_err("error: km_AesInitCommon %s failed: %d\n", name, err);
free(ctx->aes_decrypt);
ctx->aes_decrypt = NULL;
km_AesExitCommon(ctx);
return err;
}
@ -133,10 +155,16 @@ static int km_AesInitCommon(struct km_AesCtx * ctx, const char * name)
static void km_AesExitCommon(struct km_AesCtx * ctx)
{
wc_AesFree(ctx->aes);
free(ctx->aes);
ctx->aes = NULL;
km_ForceZero(ctx);
if (ctx->aes_encrypt) {
wc_AesFree(ctx->aes_encrypt);
free(ctx->aes_encrypt);
ctx->aes_encrypt = NULL;
}
if (ctx->aes_decrypt) {
wc_AesFree(ctx->aes_decrypt);
free(ctx->aes_decrypt);
ctx->aes_decrypt = NULL;
}
}
static int km_AesSetKeyCommon(struct km_AesCtx * ctx, const u8 *in_key,
@ -144,15 +172,21 @@ static int km_AesSetKeyCommon(struct km_AesCtx * ctx, const u8 *in_key,
{
int err;
err = wc_AesSetKey(ctx->aes, in_key, key_len, NULL, 0);
err = wc_AesSetKey(ctx->aes_encrypt, in_key, key_len, NULL, AES_ENCRYPTION);
if (unlikely(err)) {
pr_err("error: km_AesSetKeyCommon %s failed: %d\n", name, err);
return err;
}
XMEMCPY(ctx->key, in_key, key_len);
ctx->keylen = key_len;
if (ctx->aes_decrypt) {
err = wc_AesSetKey(ctx->aes_decrypt, in_key, key_len, NULL, AES_DECRYPTION);
if (unlikely(err)) {
pr_err("error: km_AesSetKeyCommon %s failed: %d\n", name, err);
return err;
}
}
return 0;
}
@ -161,25 +195,12 @@ static int km_AesSetKeyCommon(struct km_AesCtx * ctx, const u8 *in_key,
defined(LINUXKM_LKCAPI_REGISTER_AESCBC) || \
defined(LINUXKM_LKCAPI_REGISTER_AESCFB)
static int km_AesInit(struct crypto_skcipher *tfm)
{
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
return km_AesInitCommon(ctx, WOLFKM_AESCBC_DRIVER);
}
static void km_AesExit(struct crypto_skcipher *tfm)
{
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
km_AesExitCommon(ctx);
}
static int km_AesSetKey(struct crypto_skcipher *tfm, const u8 *in_key,
unsigned int key_len)
{
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESCBC_DRIVER);
}
#endif /* LINUXKM_LKCAPI_REGISTER_ALL ||
* LINUXKM_LKCAPI_REGISTER_AESCBC ||
* LINUXKM_LKCAPI_REGISTER_AESCFB
@ -192,6 +213,19 @@ static int km_AesSetKey(struct crypto_skcipher *tfm, const u8 *in_key,
#if defined(HAVE_AES_CBC) && \
(defined(LINUXKM_LKCAPI_REGISTER_ALL) || defined(LINUXKM_LKCAPI_REGISTER_AESCBC))
static int km_AesCbcInit(struct crypto_skcipher *tfm)
{
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
return km_AesInitCommon(ctx, WOLFKM_AESCBC_DRIVER, 1);
}
static int km_AesCbcSetKey(struct crypto_skcipher *tfm, const u8 *in_key,
unsigned int key_len)
{
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESCBC_DRIVER);
}
static int km_AesCbcEncrypt(struct skcipher_request *req)
{
struct crypto_skcipher * tfm = NULL;
@ -206,15 +240,14 @@ static int km_AesCbcEncrypt(struct skcipher_request *req)
err = skcipher_walk_virt(&walk, req, false);
while ((nbytes = walk.nbytes)) {
err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv,
AES_ENCRYPTION);
err = wc_AesSetIV(ctx->aes_encrypt, walk.iv);
if (unlikely(err)) {
pr_err("wc_AesSetKey failed: %d\n", err);
pr_err("wc_AesSetIV failed: %d\n", err);
return err;
}
err = wc_AesCbcEncrypt(ctx->aes, walk.dst.virt.addr,
err = wc_AesCbcEncrypt(ctx->aes_encrypt, walk.dst.virt.addr,
walk.src.virt.addr, nbytes);
if (unlikely(err)) {
@ -242,15 +275,14 @@ static int km_AesCbcDecrypt(struct skcipher_request *req)
err = skcipher_walk_virt(&walk, req, false);
while ((nbytes = walk.nbytes)) {
err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv,
AES_DECRYPTION);
err = wc_AesSetIV(ctx->aes_decrypt, walk.iv);
if (unlikely(err)) {
pr_err("wc_AesSetKey failed");
return err;
}
err = wc_AesCbcDecrypt(ctx->aes, walk.dst.virt.addr,
err = wc_AesCbcDecrypt(ctx->aes_decrypt, walk.dst.virt.addr,
walk.src.virt.addr, nbytes);
if (unlikely(err)) {
@ -271,12 +303,12 @@ static struct skcipher_alg cbcAesAlg = {
.base.cra_blocksize = AES_BLOCK_SIZE,
.base.cra_ctxsize = sizeof(struct km_AesCtx),
.base.cra_module = THIS_MODULE,
.init = km_AesInit,
.init = km_AesCbcInit,
.exit = km_AesExit,
.min_keysize = AES_128_KEY_SIZE,
.max_keysize = AES_256_KEY_SIZE,
.ivsize = AES_BLOCK_SIZE,
.setkey = km_AesSetKey,
.setkey = km_AesCbcSetKey,
.encrypt = km_AesCbcEncrypt,
.decrypt = km_AesCbcDecrypt,
};
@ -289,6 +321,19 @@ static int cbcAesAlg_loaded = 0;
#if defined(WOLFSSL_AES_CFB) && \
(defined(LINUXKM_LKCAPI_REGISTER_ALL) || defined(LINUXKM_LKCAPI_REGISTER_AESCFB))
static int km_AesCfbInit(struct crypto_skcipher *tfm)
{
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
return km_AesInitCommon(ctx, WOLFKM_AESCFB_DRIVER, 0);
}
static int km_AesCfbSetKey(struct crypto_skcipher *tfm, const u8 *in_key,
unsigned int key_len)
{
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESCFB_DRIVER);
}
static int km_AesCfbEncrypt(struct skcipher_request *req)
{
struct crypto_skcipher * tfm = NULL;
@ -303,15 +348,14 @@ static int km_AesCfbEncrypt(struct skcipher_request *req)
err = skcipher_walk_virt(&walk, req, false);
while ((nbytes = walk.nbytes)) {
err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv,
AES_ENCRYPTION);
err = wc_AesSetIV(ctx->aes_encrypt, walk.iv);
if (unlikely(err)) {
pr_err("wc_AesSetKey failed: %d\n", err);
return err;
}
err = wc_AesCfbEncrypt(ctx->aes, walk.dst.virt.addr,
err = wc_AesCfbEncrypt(ctx->aes_encrypt, walk.dst.virt.addr,
walk.src.virt.addr, nbytes);
if (unlikely(err)) {
@ -339,15 +383,14 @@ static int km_AesCfbDecrypt(struct skcipher_request *req)
err = skcipher_walk_virt(&walk, req, false);
while ((nbytes = walk.nbytes)) {
err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv,
AES_ENCRYPTION);
err = wc_AesSetIV(ctx->aes_encrypt, walk.iv);
if (unlikely(err)) {
pr_err("wc_AesSetKey failed");
return err;
}
err = wc_AesCfbDecrypt(ctx->aes, walk.dst.virt.addr,
err = wc_AesCfbDecrypt(ctx->aes_encrypt, walk.dst.virt.addr,
walk.src.virt.addr, nbytes);
if (unlikely(err)) {
@ -368,12 +411,12 @@ static struct skcipher_alg cfbAesAlg = {
.base.cra_blocksize = AES_BLOCK_SIZE,
.base.cra_ctxsize = sizeof(struct km_AesCtx),
.base.cra_module = THIS_MODULE,
.init = km_AesInit,
.init = km_AesCfbInit,
.exit = km_AesExit,
.min_keysize = AES_128_KEY_SIZE,
.max_keysize = AES_256_KEY_SIZE,
.ivsize = AES_BLOCK_SIZE,
.setkey = km_AesSetKey,
.setkey = km_AesCfbSetKey,
.encrypt = km_AesCfbEncrypt,
.decrypt = km_AesCfbDecrypt,
};
@ -390,8 +433,7 @@ static int cfbAesAlg_loaded = 0;
static int km_AesGcmInit(struct crypto_aead * tfm)
{
struct km_AesCtx * ctx = crypto_aead_ctx(tfm);
km_ForceZero(ctx);
return km_AesInitCommon(ctx, WOLFKM_AESGCM_DRIVER);
return km_AesInitCommon(ctx, WOLFKM_AESGCM_DRIVER, 0);
}
static void km_AesGcmExit(struct crypto_aead * tfm)
@ -403,8 +445,16 @@ static void km_AesGcmExit(struct crypto_aead * tfm)
static int km_AesGcmSetKey(struct crypto_aead *tfm, const u8 *in_key,
unsigned int key_len)
{
int err;
struct km_AesCtx * ctx = crypto_aead_ctx(tfm);
return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESGCM_DRIVER);
err = wc_AesGcmSetKey(ctx->aes_encrypt, in_key, key_len);
if (err) {
pr_err("error: km_AesGcmSetKey %s failed: %d\n", WOLFKM_AESGCM_DRIVER, err);
}
return err;
}
static int km_AesGcmSetAuthsize(struct crypto_aead *tfm, unsigned int authsize)
@ -454,7 +504,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
return -1;
}
err = wc_AesGcmInit(ctx->aes, ctx->key, ctx->keylen, walk.iv,
err = wc_AesGcmInit(ctx->aes_encrypt, NULL /* key */, 0 /* keylen */, walk.iv,
AES_BLOCK_SIZE);
if (unlikely(err)) {
pr_err("error: wc_AesGcmInit failed with return code %d.\n", err);
@ -467,7 +517,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
return err;
}
err = wc_AesGcmEncryptUpdate(ctx->aes, NULL, NULL, 0, assoc, assocLeft);
err = wc_AesGcmEncryptUpdate(ctx->aes_encrypt, NULL, NULL, 0, assoc, assocLeft);
assocLeft -= assocLeft;
scatterwalk_unmap(assoc);
assoc = NULL;
@ -483,7 +533,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
if (likely(cryptLeft && nbytes)) {
n = cryptLeft < nbytes ? cryptLeft : nbytes;
err = wc_AesGcmEncryptUpdate(ctx->aes, walk.dst.virt.addr,
err = wc_AesGcmEncryptUpdate(ctx->aes_encrypt, walk.dst.virt.addr,
walk.src.virt.addr, cryptLeft, NULL, 0);
nbytes -= n;
cryptLeft -= n;
@ -497,7 +547,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
err = skcipher_walk_done(&walk, nbytes);
}
err = wc_AesGcmEncryptFinal(ctx->aes, authTag, tfm->authsize);
err = wc_AesGcmEncryptFinal(ctx->aes_encrypt, authTag, tfm->authsize);
if (unlikely(err)) {
pr_err("error: wc_AesGcmEncryptFinal failed with return code %d\n", err);
return err;
@ -542,7 +592,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
return -1;
}
err = wc_AesGcmInit(ctx->aes, ctx->key, ctx->keylen, walk.iv,
err = wc_AesGcmInit(ctx->aes_encrypt, NULL /* key */, 0 /* keylen */, walk.iv,
AES_BLOCK_SIZE);
if (unlikely(err)) {
pr_err("error: wc_AesGcmInit failed with return code %d.\n", err);
@ -555,7 +605,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
return err;
}
err = wc_AesGcmDecryptUpdate(ctx->aes, NULL, NULL, 0, assoc, assocLeft);
err = wc_AesGcmDecryptUpdate(ctx->aes_encrypt, NULL, NULL, 0, assoc, assocLeft);
assocLeft -= assocLeft;
scatterwalk_unmap(assoc);
assoc = NULL;
@ -571,7 +621,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
if (likely(cryptLeft && nbytes)) {
n = cryptLeft < nbytes ? cryptLeft : nbytes;
err = wc_AesGcmDecryptUpdate(ctx->aes, walk.dst.virt.addr,
err = wc_AesGcmDecryptUpdate(ctx->aes_encrypt, walk.dst.virt.addr,
walk.src.virt.addr, cryptLeft, NULL, 0);
nbytes -= n;
cryptLeft -= n;
@ -585,7 +635,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
err = skcipher_walk_done(&walk, nbytes);
}
err = wc_AesGcmDecryptFinal(ctx->aes, origAuthTag, tfm->authsize);
err = wc_AesGcmDecryptFinal(ctx->aes_encrypt, origAuthTag, tfm->authsize);
if (unlikely(err)) {
pr_err("error: wc_AesGcmDecryptFinal failed with return code %d\n", err);