diff --git a/src/tls.c b/src/tls.c index f2e5250bb..0b025ecef 100644 --- a/src/tls.c +++ b/src/tls.c @@ -863,32 +863,6 @@ void TLSX_SNI_SetOptions(TLSX* extensions, byte type, byte options) sni->options = options; } -#define BYTE_CHECK(buffer, offset, op, expected) do { \ - if (buffer[offset++] op expected) \ - return BUFFER_ERROR; \ -} while (0) - -#define SAFE_READ_16(buffer, offset, max, len) do { \ - ato16(buffer + offset, &len); offset += 2; \ - \ - if (offset + len > max) \ - return INCOMPLETE_DATA; \ -} while (0) - -#define SAFE_READ_32(buffer, offset, max, len) do { \ - c24to32(buffer + offset, &len); offset += 3; \ - \ - if (offset + len > max) \ - return INCOMPLETE_DATA; \ -} while (0) - -#define SKIP_LEN8(buffer, offset, max) do { \ - if (offset + buffer[offset] > max) \ - return INCOMPLETE_DATA; \ - \ - offset += ENUM_LEN + buffer[offset]; \ -} while (0) - int TLSX_SNI_GetFromBuffer(const byte* buffer, word32 bufferSz, byte type, byte* sni, word32* inOutSz) { @@ -900,37 +874,69 @@ int TLSX_SNI_GetFromBuffer(const byte* buffer, word32 bufferSz, return INCOMPLETE_DATA; /* TLS record header */ - BYTE_CHECK(buffer, offset, !=, handshake); - BYTE_CHECK(buffer, offset, !=, SSLv3_MAJOR); - BYTE_CHECK(buffer, offset, <, TLSv1_MINOR); - SAFE_READ_16(buffer, offset, bufferSz, len16); + if ((enum ContentType) buffer[offset++] != handshake) + return BUFFER_ERROR; - /* Handshake header */ - BYTE_CHECK(buffer, offset, !=, client_hello); - SAFE_READ_32(buffer, offset, bufferSz, len32); + if (buffer[offset++] != SSLv3_MAJOR) + return BUFFER_ERROR; - /* client hello */ - offset += VERSION_SZ + RAN_LEN; /* version, random */ - SKIP_LEN8(buffer, offset, bufferSz); /* session id */ + if (buffer[offset++] < TLSv1_MINOR) + return BUFFER_ERROR; - /* cypher suites */ - if (bufferSz < offset + 2) + ato16(buffer + offset, &len16); + offset += OPAQUE16_LEN; + + if (offset + len16 > bufferSz) return INCOMPLETE_DATA; - SAFE_READ_16(buffer, offset, bufferSz, len16); - offset += len16; + /* Handshake header */ + if ((enum HandShakeType) buffer[offset] != client_hello) + return BUFFER_ERROR; + + c24to32(buffer + offset + 1, &len32); + offset += HANDSHAKE_HEADER_SZ; + + if (offset + len32 > bufferSz) + return INCOMPLETE_DATA; + + /* client hello */ + offset += VERSION_SZ + RAN_LEN; /* version, random */ + + if (bufferSz < offset + buffer[offset]) + return INCOMPLETE_DATA; + + offset += ENUM_LEN + buffer[offset]; /* skip session id */ + + /* cypher suites */ + if (bufferSz < offset + OPAQUE16_LEN) + return INCOMPLETE_DATA; + + ato16(buffer + offset, &len16); + offset += OPAQUE16_LEN; + + if (bufferSz < offset + len16) + return INCOMPLETE_DATA; + + offset += len16; /* skip cypher suites */ /* compression methods */ if (bufferSz < offset + 1) return INCOMPLETE_DATA; - SKIP_LEN8(buffer, offset, bufferSz); + if (bufferSz < offset + buffer[offset]) + return INCOMPLETE_DATA; + + offset += ENUM_LEN + buffer[offset]; /* skip compression methods */ /* extensions */ - if (bufferSz < offset + 2) + if (bufferSz < offset + OPAQUE16_LEN) return 0; /* no extensions in client hello. */ - SAFE_READ_16(buffer, offset, bufferSz, len16); + ato16(buffer + offset, &len16); + offset += OPAQUE16_LEN; + + if (bufferSz < offset + len16) + return INCOMPLETE_DATA; while (len16 > OPAQUE16_LEN + OPAQUE16_LEN) { word16 extType; @@ -939,24 +945,36 @@ int TLSX_SNI_GetFromBuffer(const byte* buffer, word32 bufferSz, ato16(buffer + offset, &extType); offset += OPAQUE16_LEN; - SAFE_READ_16(buffer, offset, bufferSz, extLen); + ato16(buffer + offset, &extLen); + offset += OPAQUE16_LEN; + + if (bufferSz < offset + extLen) + return INCOMPLETE_DATA; if (extType != SERVER_NAME_INDICATION) { - offset += extLen; - continue; + offset += extLen; /* skip extension */ } else { word16 listLen; - SAFE_READ_16(buffer, offset, bufferSz, listLen); + ato16(buffer + offset, &listLen); + offset += OPAQUE16_LEN; + + if (bufferSz < offset + listLen) + return INCOMPLETE_DATA; while (listLen > ENUM_LEN + OPAQUE16_LEN) { byte sniType = buffer[offset++]; word16 sniLen; - SAFE_READ_16(buffer, offset, bufferSz, sniLen); + ato16(buffer + offset, &sniLen); + offset += OPAQUE16_LEN; + + if (bufferSz < offset + sniLen) + return INCOMPLETE_DATA; if (sniType != type) { - offset += sniLen; + offset += sniLen; + listLen -= MIN(ENUM_LEN + OPAQUE16_LEN + sniLen, listLen); continue; } @@ -966,16 +984,13 @@ int TLSX_SNI_GetFromBuffer(const byte* buffer, word32 bufferSz, return SSL_SUCCESS; } } + + len16 -= MIN(2 * OPAQUE16_LEN + extLen, len16); } return len16 ? BUFFER_ERROR : 0; } -#undef SAFE_READ_32 -#undef SAFE_READ_16 -#undef BYTE_CHECK -#undef SKIP_LEN8 - #endif #define SNI_FREE_ALL TLSX_SNI_FreeAll diff --git a/tests/api.c b/tests/api.c index dd211d54c..677cfd3f7 100644 --- a/tests/api.c +++ b/tests/api.c @@ -366,32 +366,43 @@ static void test_CyaSSL_SNI_GetFromBuffer(void) 0x0a, 0x05, 0x01, 0x04, 0x01, 0x02, 0x01, 0x04, 0x03, 0x02, 0x03 }; + byte buffer3[] = { /* no sni extension */ + 0x16, 0x03, 0x03, 0x00, 0x4d, 0x01, 0x00, 0x00, 0x49, 0x03, 0x03, 0xea, + 0xa1, 0x9f, 0x60, 0xdd, 0x52, 0x12, 0x13, 0xbd, 0x84, 0x34, 0xd5, 0x1c, + 0x38, 0x25, 0xa8, 0x97, 0xd2, 0xd5, 0xc6, 0x45, 0xaf, 0x1b, 0x08, 0xe4, + 0x1e, 0xbb, 0xdf, 0x9d, 0x39, 0xf0, 0x65, 0x00, 0x00, 0x16, 0x00, 0x6b, + 0x00, 0x67, 0x00, 0x39, 0x00, 0x33, 0x00, 0x3d, 0x00, 0x3c, 0x00, 0x35, + 0x00, 0x2f, 0x00, 0x05, 0x00, 0x04, 0x00, 0x0a, 0x01, 0x00, 0x00, 0x0a, + 0x00, 0x0d, 0x00, 0x06, 0x00, 0x04, 0x04, 0x01, 0x02, 0x01 + }; + byte result[32] = {0}; word32 length = 32; - AssertIntEQ(-228, CyaSSL_SNI_GetFromBuffer(buffer, sizeof(buffer), 0, + AssertIntEQ(0, CyaSSL_SNI_GetFromBuffer(buffer3, sizeof(buffer3), 0, result, &length)); + AssertIntEQ(0, CyaSSL_SNI_GetFromBuffer(buffer2, sizeof(buffer2), 1, + result, &length)); + + AssertIntEQ(-228, CyaSSL_SNI_GetFromBuffer(buffer, sizeof(buffer), 0, + result, &length)); buffer[0] = 0x16; AssertIntEQ(-228, CyaSSL_SNI_GetFromBuffer(buffer, sizeof(buffer), 0, - result, &length)); - + result, &length)); buffer[1] = 0x03; AssertIntEQ(-228, CyaSSL_SNI_GetFromBuffer(buffer, sizeof(buffer), 0, - result, &length)); - + result, &length)); buffer[2] = 0x03; AssertIntEQ(-210, CyaSSL_SNI_GetFromBuffer(buffer, sizeof(buffer), 0, - result, &length)); - + result, &length)); buffer[4] = 0x64; AssertIntEQ(1, CyaSSL_SNI_GetFromBuffer(buffer, sizeof(buffer), 0, result, &length)); - result[length] = 0; AssertStrEQ("www.paypal.com", (const char*) result);