Fixes DTLS 1.3 client use-after-free error

This commit is contained in:
jordan 2022-09-20 09:17:08 -05:00
parent 43715d1bb5
commit 8336dbf366

View File

@ -10132,8 +10132,8 @@ int CheckAvailableSize(WOLFSSL *ssl, int size)
#ifdef WOLFSSL_DTLS13
static int GetInputData(WOLFSSL *ssl, word32 size);
static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
word32* inOutIdx, RecordLayerHeader* rh, word16* size)
static int GetDtls13RecordHeader(WOLFSSL* ssl, word32* inOutIdx,
RecordLayerHeader* rh, word16* size)
{
Dtls13UnifiedHdrInfo hdrInfo;
@ -10147,7 +10147,7 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
if (readSize < DTLS_UNIFIED_HEADER_MIN_SZ)
return BUFFER_ERROR;
epochBits = *(input + *inOutIdx) & EE_MASK;
epochBits = *(ssl->buffers.inputBuffer.buffer + *inOutIdx) & EE_MASK;
ret = Dtls13ReconstructEpochNumber(ssl, epochBits, &epochNumber);
if (ret != 0)
return ret;
@ -10179,7 +10179,7 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
}
ret = Dtls13GetUnifiedHeaderSize(ssl,
*(input+*inOutIdx), &ssl->dtls13CurRlLength);
*(ssl->buffers.inputBuffer.buffer+*inOutIdx), &ssl->dtls13CurRlLength);
if (ret != 0)
return ret;
@ -10192,7 +10192,8 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
return ret;
}
ret = Dtls13ParseUnifiedRecordLayer(ssl, input + *inOutIdx, (word16)readSize,
ret = Dtls13ParseUnifiedRecordLayer(ssl,
ssl->buffers.inputBuffer.buffer + *inOutIdx, (word16)readSize,
&hdrInfo);
if (ret != 0)
@ -10219,7 +10220,8 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
ssl->keys.curSeq);
#endif /* WOLFSSL_DEBUG_TLS */
XMEMCPY(ssl->dtls13CurRL, input + *inOutIdx, ssl->dtls13CurRlLength);
XMEMCPY(ssl->dtls13CurRL, ssl->buffers.inputBuffer.buffer + *inOutIdx,
ssl->dtls13CurRlLength);
*inOutIdx += ssl->dtls13CurRlLength;
return 0;
@ -10228,14 +10230,14 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
#endif /* WOLFSSL_DTLS13 */
#ifdef WOLFSSL_DTLS
static int GetDtlsRecordHeader(WOLFSSL* ssl, const byte* input,
word32* inOutIdx, RecordLayerHeader* rh, word16* size)
static int GetDtlsRecordHeader(WOLFSSL* ssl, word32* inOutIdx,
RecordLayerHeader* rh, word16* size)
{
#ifdef HAVE_FUZZER
if (ssl->fuzzerCb)
ssl->fuzzerCb(ssl, input + *inOutIdx, DTLS_RECORD_HEADER_SZ,
FUZZ_HEAD, ssl->fuzzerCtx);
ssl->fuzzerCb(ssl, ssl->buffers.inputBuffer.buffer + *inOutIdx,
DTLS_RECORD_HEADER_SZ, FUZZ_HEAD, ssl->fuzzerCtx);
#endif
#ifdef WOLFSSL_DTLS13
@ -10244,11 +10246,11 @@ static int GetDtlsRecordHeader(WOLFSSL* ssl, const byte* input,
read_size = ssl->buffers.inputBuffer.length - *inOutIdx;
if (Dtls13IsUnifiedHeader(*(input + *inOutIdx))) {
if (Dtls13IsUnifiedHeader(*(ssl->buffers.inputBuffer.buffer + *inOutIdx))) {
/* version 1.3 already negotiated */
if (ssl->options.tls1_3) {
ret = GetDtls13RecordHeader(ssl, input, inOutIdx, rh, size);
ret = GetDtls13RecordHeader(ssl, inOutIdx, rh, size);
if (ret == 0 || ret != SEQUENCE_ERROR || ret != DTLS_CID_ERROR)
return ret;
}
@ -10276,9 +10278,10 @@ static int GetDtlsRecordHeader(WOLFSSL* ssl, const byte* input,
#endif /* WOLFSSL_DTLS13 */
/* type and version in same spot */
XMEMCPY(rh, input + *inOutIdx, ENUM_LEN + VERSION_SZ);
XMEMCPY(rh, ssl->buffers.inputBuffer.buffer + *inOutIdx,
ENUM_LEN + VERSION_SZ);
*inOutIdx += ENUM_LEN + VERSION_SZ;
ato16(input + *inOutIdx, &ssl->keys.curEpoch);
ato16(ssl->buffers.inputBuffer.buffer + *inOutIdx, &ssl->keys.curEpoch);
#ifdef WOLFSSL_DTLS13
/* only non protected message can use the DTLSPlaintext record header */
if (ssl->options.tls1_3 && ssl->keys.curEpoch != 0)
@ -10292,14 +10295,14 @@ static int GetDtlsRecordHeader(WOLFSSL* ssl, const byte* input,
*inOutIdx += OPAQUE16_LEN;
if (ssl->options.haveMcast) {
#ifdef WOLFSSL_MULTICAST
ssl->keys.curPeerId = input[*inOutIdx];
ssl->keys.curSeq_hi = input[*inOutIdx+1];
ssl->keys.curPeerId = ssl->buffers.inputBuffer.buffer[*inOutIdx];
ssl->keys.curSeq_hi = ssl->buffers.inputBuffer.buffer[*inOutIdx+1];
#endif
}
else
ato16(input + *inOutIdx, &ssl->keys.curSeq_hi);
ato16(ssl->buffers.inputBuffer.buffer + *inOutIdx, &ssl->keys.curSeq_hi);
*inOutIdx += OPAQUE16_LEN;
ato32(input + *inOutIdx, &ssl->keys.curSeq_lo);
ato32(ssl->buffers.inputBuffer.buffer + *inOutIdx, &ssl->keys.curSeq_lo);
*inOutIdx += OPAQUE32_LEN; /* advance past rest of seq */
#ifdef WOLFSSL_DTLS13
@ -10308,7 +10311,7 @@ static int GetDtlsRecordHeader(WOLFSSL* ssl, const byte* input,
ssl->keys.curSeq = w64From32(ssl->keys.curSeq_hi, ssl->keys.curSeq_lo);
#endif /* WOLFSSL_DTLS13 */
ato16(input + *inOutIdx, size);
ato16(ssl->buffers.inputBuffer.buffer + *inOutIdx, size);
*inOutIdx += LENGTH_SZ;
return 0;
@ -10316,7 +10319,7 @@ static int GetDtlsRecordHeader(WOLFSSL* ssl, const byte* input,
#endif /* WOLFSSL_DTLS */
/* do all verify and sanity checks on record header */
static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx,
static int GetRecordHeader(WOLFSSL* ssl, word32* inOutIdx,
RecordLayerHeader* rh, word16 *size)
{
byte tls12minor;
@ -10333,16 +10336,16 @@ static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx,
if (!ssl->options.dtls) {
#ifdef HAVE_FUZZER
if (ssl->fuzzerCb)
ssl->fuzzerCb(ssl, input + *inOutIdx, RECORD_HEADER_SZ, FUZZ_HEAD,
ssl->fuzzerCtx);
ssl->fuzzerCb(ssl, ssl->buffers.inputBuffer.buffer + *inOutIdx,
RECORD_HEADER_SZ, FUZZ_HEAD, ssl->fuzzerCtx);
#endif
XMEMCPY(rh, input + *inOutIdx, RECORD_HEADER_SZ);
XMEMCPY(rh, ssl->buffers.inputBuffer.buffer + *inOutIdx, RECORD_HEADER_SZ);
*inOutIdx += RECORD_HEADER_SZ;
ato16(rh->length, size);
}
else {
#ifdef WOLFSSL_DTLS
ret = GetDtlsRecordHeader(ssl, input, inOutIdx, rh, size);
ret = GetDtlsRecordHeader(ssl, inOutIdx, rh, size);
if (ret != 0)
return ret;
#endif
@ -10392,7 +10395,7 @@ static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx,
else if (ssl->options.dtls && !ssl->options.handShakeDone) {
/* we may have lost the ServerHello and this is a unified record
before version been negotiated */
if (Dtls13IsUnifiedHeader(*input)) {
if (Dtls13IsUnifiedHeader(*ssl->buffers.inputBuffer.buffer)) {
return SEQUENCE_ERROR;
}
}
@ -10446,7 +10449,7 @@ static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx,
default:
#ifdef OPENSSL_ALL
{
char *method = (char*)input + start;
char *method = (char*)ssl->buffers.inputBuffer.buffer + start;
/* Attempt to identify if this is a plain HTTP request.
* No size checks because this function assumes at least
* RECORD_HEADER_SZ size of data has been read which is
@ -19063,8 +19066,7 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr)
* header, decrypting the numbers inside
* DtlsParseUnifiedRecordLayer(). This violates the const attribute
* of the buffer parameter of GetRecordHeader() used here. */
ret = GetRecordHeader(ssl, ssl->buffers.inputBuffer.buffer,
&ssl->buffers.inputBuffer.idx,
ret = GetRecordHeader(ssl, &ssl->buffers.inputBuffer.idx,
&ssl->curRL, &ssl->curSize);
#ifdef WOLFSSL_DTLS