1. Added a couple missing checks for NULL pointers in DTLS code.

2. Fixed compiler warning under Windows.
3. DTLS sliding window packet filter.
This commit is contained in:
John Safranek 2013-10-08 14:59:59 -07:00
parent fc97174fb8
commit 9fe165e8f8
3 changed files with 129 additions and 37 deletions

View File

@ -1363,6 +1363,30 @@ enum ClientCertificateType {
enum CipherType { stream, block, aead };
#ifdef CYASSL_DTLS
#ifdef WORD64_AVAILABLE
typedef word64 DtlsSeq;
#else
typedef word32 DtlsSeq;
#endif
#define DTLS_SEQ_BITS (sizeof(DtlsSeq) * CHAR_BIT)
typedef struct DtlsState {
DtlsSeq window; /* Sliding window for current epoch */
word16 nextEpoch; /* Expected epoch in next record */
word32 nextSeq; /* Expected sequence in next record */
word16 curEpoch; /* Received epoch in current record */
word32 curSeq; /* Received sequence in current record */
DtlsSeq prevWindow; /* Sliding window for old epoch */
word32 prevSeq; /* Next sequence in allowed old epoch */
} DtlsState;
#endif /* CYASSL_DTLS */
/* keys and secrets */
typedef struct Keys {
byte client_write_MAC_secret[MAX_DIGEST_SIZE]; /* max sizes */
@ -1381,15 +1405,13 @@ typedef struct Keys {
word32 sequence_number;
#ifdef CYASSL_DTLS
word32 dtls_sequence_number;
word32 dtls_peer_sequence_number;
word32 dtls_expected_peer_sequence_number;
word16 dtls_handshake_number;
DtlsState dtls_state; /* Peer's state */
word16 dtls_peer_handshake_number;
word16 dtls_expected_peer_handshake_number;
word16 dtls_epoch;
word16 dtls_peer_epoch;
word16 dtls_expected_peer_epoch;
word16 dtls_epoch; /* Current tx epoch */
word32 dtls_sequence_number; /* Current tx sequence */
word16 dtls_handshake_number; /* Current tx handshake seq */
#endif
word32 encryptSz; /* last size of encrypted data */

View File

@ -87,6 +87,13 @@ CYASSL_CALLBACKS needs LARGE_STATIC_BUFFERS, please add LARGE_STATIC_BUFFERS
#endif
#endif
#ifdef CYASSL_DTLS
static int DtlsCheckWindow(DtlsState* state);
static int DtlsUpdateWindow(DtlsState* state);
#endif
typedef enum {
doProcessInit = 0,
#ifndef NO_CYASSL_SERVER
@ -1421,6 +1428,9 @@ int InitSSL(CYASSL* ssl, CYASSL_CTX* ctx)
#ifdef CYASSL_DTLS
ssl->IOCB_CookieCtx = NULL; /* we don't use for default cb */
ssl->dtls_expected_rx = MAX_MTU;
ssl->keys.dtls_state.window = 0;
ssl->keys.dtls_state.nextEpoch = 0;
ssl->keys.dtls_state.nextSeq = 0;
#endif
#ifndef NO_OLD_TLS
@ -1478,13 +1488,13 @@ int InitSSL(CYASSL* ssl, CYASSL_CTX* ctx)
#ifdef CYASSL_DTLS
ssl->keys.dtls_sequence_number = 0;
ssl->keys.dtls_peer_sequence_number = 0;
ssl->keys.dtls_expected_peer_sequence_number = 0;
ssl->keys.dtls_state.curSeq = 0;
ssl->keys.dtls_state.nextSeq = 0;
ssl->keys.dtls_handshake_number = 0;
ssl->keys.dtls_expected_peer_handshake_number = 0;
ssl->keys.dtls_epoch = 0;
ssl->keys.dtls_peer_epoch = 0;
ssl->keys.dtls_expected_peer_epoch = 0;
ssl->keys.dtls_state.curEpoch = 0;
ssl->keys.dtls_state.nextEpoch = 0;
ssl->dtls_timeout_init = DTLS_TIMEOUT_INIT;
ssl->dtls_timeout_max = DTLS_TIMEOUT_MAX;
ssl->dtls_timeout = ssl->dtls_timeout_init;
@ -2762,9 +2772,9 @@ static int GetRecordHeader(CYASSL* ssl, const byte* input, word32* inOutIdx,
/* type and version in same sport */
XMEMCPY(rh, input + *inOutIdx, ENUM_LEN + VERSION_SZ);
*inOutIdx += ENUM_LEN + VERSION_SZ;
ato16(input + *inOutIdx, &ssl->keys.dtls_peer_epoch);
ato16(input + *inOutIdx, &ssl->keys.dtls_state.curEpoch);
*inOutIdx += 4; /* advance past epoch, skip first 2 seq bytes for now */
ato32(input + *inOutIdx, &ssl->keys.dtls_peer_sequence_number);
ato32(input + *inOutIdx, &ssl->keys.dtls_state.curSeq);
*inOutIdx += 4; /* advance past rest of seq */
ato16(input + *inOutIdx, size);
*inOutIdx += LENGTH_SZ;
@ -2785,27 +2795,14 @@ static int GetRecordHeader(CYASSL* ssl, const byte* input, word32* inOutIdx,
return VERSION_ERROR; /* only use requested version */
}
}
#if 0
/* Instead of this, check the datagram against the sliding window of
* received datagram goodness. */
#ifdef CYASSL_DTLS
/* If DTLS, check the sequence number against expected. If out of
* order, drop the record. Allows newer records in and resets the
* expected to the next record. */
if (ssl->options.dtls) {
if ((ssl->keys.dtls_expected_peer_epoch ==
ssl->keys.dtls_peer_epoch) &&
(ssl->keys.dtls_peer_sequence_number >=
ssl->keys.dtls_expected_peer_sequence_number)) {
ssl->keys.dtls_expected_peer_sequence_number =
ssl->keys.dtls_peer_sequence_number + 1;
}
else {
if (DtlsCheckWindow(&ssl->keys.dtls_state) != 1)
return SEQUENCE_ERROR;
}
}
#endif
#endif
/* record layer length check */
#ifdef HAVE_MAX_FRAGMENT
if (*size > (ssl->max_fragment + MAX_COMP_EXTRA + MAX_MSG_EXTRA))
@ -3868,6 +3865,68 @@ static int DoHandShakeMsg(CYASSL* ssl, byte* input, word32* inOutIdx,
#ifdef CYASSL_DTLS
static INLINE int DtlsCheckWindow(DtlsState* state)
{
word32 cur;
word32 next;
DtlsSeq window;
if (state->curEpoch == state->nextEpoch) {
next = state->nextSeq;
window = state->window;
}
else if (state->curEpoch < state->nextEpoch) {
next = state->prevSeq;
window = state->prevWindow;
}
else {
return 0;
}
cur = state->curSeq;
if ((next > DTLS_SEQ_BITS) && (cur < next - DTLS_SEQ_BITS)) {
return 0;
}
else if ((cur < next) && (window & (1 << (next - cur - 1)))) {
return 0;
}
return 1;
}
static INLINE int DtlsUpdateWindow(DtlsState* state)
{
word32 cur;
word32* next;
DtlsSeq* window;
if (state->curEpoch == state->nextEpoch) {
next = &state->nextSeq;
window = &state->window;
}
else {
next = &state->prevSeq;
window = &state->prevWindow;
}
cur = state->curSeq;
if (cur < *next) {
*window |= (1 << (*next - cur - 1));
}
else {
*window <<= (1 + cur - *next);
*window |= 1;
*next = cur + 1;
}
return 1;
}
static int DtlsMsgDrain(CYASSL* ssl)
{
DtlsMsg* item = ssl->dtls_msg_list;
@ -4888,8 +4947,6 @@ int ProcessReply(CYASSL* ssl)
&ssl->curRL, &ssl->curSize);
#ifdef CYASSL_DTLS
if (ssl->options.dtls && ret == SEQUENCE_ERROR) {
/* This message is out of order. If we are handshaking, save
*it for later. Otherwise go ahead and process it. */
ssl->options.processReply = doProcessInit;
ssl->buffers.inputBuffer.length = 0;
ssl->buffers.inputBuffer.idx = 0;
@ -4925,7 +4982,14 @@ int ProcessReply(CYASSL* ssl)
/* the record layer is here */
case runProcessingOneMessage:
if (ssl->keys.encryptionOn && ssl->keys.decryptedCur == 0) {
#ifdef CYASSL_DTLS
if (ssl->options.dtls &&
ssl->keys.dtls_state.curEpoch < ssl->keys.dtls_state.nextEpoch)
ssl->keys.decryptedCur = 1;
#endif
if (ssl->keys.encryptionOn && ssl->keys.decryptedCur == 0)
{
ret = SanityCheckCipherText(ssl, ssl->curSize);
if (ret < 0)
return ret;
@ -4975,6 +5039,12 @@ int ProcessReply(CYASSL* ssl)
ssl->keys.decryptedCur = 1;
}
if (ssl->options.dtls) {
#ifdef CYASSL_DTLS
DtlsUpdateWindow(&ssl->keys.dtls_state);
#endif /* CYASSL_DTLS */
}
CYASSL_MSG("received record layer msg");
switch (ssl->curRL.type) {
@ -5034,8 +5104,8 @@ int ProcessReply(CYASSL* ssl)
#ifdef CYASSL_DTLS
if (ssl->options.dtls) {
DtlsPoolReset(ssl);
ssl->keys.dtls_expected_peer_epoch++;
ssl->keys.dtls_expected_peer_sequence_number = 0;
ssl->keys.dtls_state.nextEpoch++;
ssl->keys.dtls_state.nextSeq = 0;
}
#endif

View File

@ -401,7 +401,7 @@ static INLINE word32 GetSEQIncrement(CYASSL* ssl, int verify)
#ifdef CYASSL_DTLS
if (ssl->options.dtls) {
if (verify)
return ssl->keys.dtls_peer_sequence_number; /* explicit from peer */
return ssl->keys.dtls_state.curSeq; /* explicit from peer */
else
return ssl->keys.dtls_sequence_number - 1; /* already incremented */
}
@ -418,9 +418,9 @@ static INLINE word32 GetSEQIncrement(CYASSL* ssl, int verify)
static INLINE word32 GetEpoch(CYASSL* ssl, int verify)
{
if (verify)
return ssl->keys.dtls_peer_epoch;
return ssl->keys.dtls_state.curEpoch;
else
return ssl->keys.dtls_epoch;
return ssl->keys.dtls_epoch;
}
#endif /* CYASSL_DTLS */