diff --git a/src/ssl.c b/src/ssl.c index faa75a82e..64c97576d 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -10445,44 +10445,57 @@ void CyaSSL_DH_free(CYASSL_DH* dh) static int SetDhInternal(CYASSL_DH* dh) { - unsigned char p[1024]; - unsigned char g[1024]; - int pSz = sizeof(p); - int gSz = sizeof(g); + int ret = SSL_FATAL_ERROR; + int pSz = 1024; + int gSz = 1024; +#ifdef CYASSL_SMALL_STACK + unsigned char* p = NULL; + unsigned char* g = NULL; +#else + unsigned char p[1024]; + unsigned char g[1024]; +#endif CYASSL_ENTER("SetDhInternal"); - if (dh == NULL || dh->p == NULL || dh->g == NULL) { + if (dh == NULL || dh->p == NULL || dh->g == NULL) CYASSL_MSG("Bad function arguments"); - return SSL_FATAL_ERROR; - } - - if (CyaSSL_BN_bn2bin(dh->p, NULL) > pSz) { + else if (CyaSSL_BN_bn2bin(dh->p, NULL) > pSz) CYASSL_MSG("Bad p internal size"); - return SSL_FATAL_ERROR; - } - - if (CyaSSL_BN_bn2bin(dh->g, NULL) > gSz) { + else if (CyaSSL_BN_bn2bin(dh->g, NULL) > gSz) CYASSL_MSG("Bad g internal size"); - return SSL_FATAL_ERROR; + else { + #ifdef CYASSL_SMALL_STACK + p = (unsigned char*)XMALLOC(pSz, NULL, DYNAMIC_TYPE_TMP_BUFFER); + g = (unsigned char*)XMALLOC(gSz, NULL, DYNAMIC_TYPE_TMP_BUFFER); + + if (p == NULL || g == NULL) { + XFREE(p, NULL, DYNAMIC_TYPE_TMP_BUFFER); + XFREE(g, NULL, DYNAMIC_TYPE_TMP_BUFFER); + return ret; + } + #endif + + pSz = CyaSSL_BN_bn2bin(dh->p, p); + gSz = CyaSSL_BN_bn2bin(dh->g, g); + + if (pSz <= 0 || gSz <= 0) + CYASSL_MSG("Bad BN2bin set"); + else if (DhSetKey((DhKey*)dh->internal, p, pSz, g, gSz) < 0) + CYASSL_MSG("Bad DH SetKey"); + else { + dh->inSet = 1; + ret = 0; + } + + #ifdef CYASSL_SMALL_STACK + XFREE(p, NULL, DYNAMIC_TYPE_TMP_BUFFER); + XFREE(g, NULL, DYNAMIC_TYPE_TMP_BUFFER); + #endif } - pSz = CyaSSL_BN_bn2bin(dh->p, p); - gSz = CyaSSL_BN_bn2bin(dh->g, g); - if (pSz <= 0 || gSz <= 0) { - CYASSL_MSG("Bad BN2bin set"); - return SSL_FATAL_ERROR; - } - - if (DhSetKey((DhKey*)dh->internal, p, pSz, g, gSz) < 0) { - CYASSL_MSG("Bad DH SetKey"); - return SSL_FATAL_ERROR; - } - - dh->inSet = 1; - - return 0; + return ret; }