From bc52147fbb0a2ee20f3dedd2a48246cade09a230 Mon Sep 17 00:00:00 2001 From: Michael Saxl Date: Mon, 25 Jan 2021 16:20:18 +0100 Subject: [PATCH] rdg websocket support --- libfreerdp/core/gateway/http.c | 116 ++++- libfreerdp/core/gateway/http.h | 3 + libfreerdp/core/gateway/rdg.c | 758 ++++++++++++++++++++++++++++----- 3 files changed, 772 insertions(+), 105 deletions(-) diff --git a/libfreerdp/core/gateway/http.c b/libfreerdp/core/gateway/http.c index 85f0908ce..91d19620f 100644 --- a/libfreerdp/core/gateway/http.c +++ b/libfreerdp/core/gateway/http.c @@ -30,6 +30,9 @@ #include +/* websocket need sha1 for Sec-Websocket-Accept */ +#include + #ifdef HAVE_VALGRIND_MEMCHECK_H #include #endif @@ -40,6 +43,8 @@ #define RESPONSE_SIZE_LIMIT 64 * 1024 * 1024 +#define WEBSOCKET_MAGIC_GUID "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + struct _http_context { char* Method; @@ -52,6 +57,8 @@ struct _http_context char* Pragma; char* RdgConnectionId; char* RdgAuthScheme; + BOOL websocketUpgrade; + char SecWebsocketKey[16]; }; struct _http_request @@ -77,6 +84,8 @@ struct _http_response size_t ContentLength; const char* ContentType; TRANSFER_ENCODING TransferEncoding; + const char* SecWebsocketVersion; + const char* SecWebsocketAccept; size_t BodyLength; BYTE* BodyContent; @@ -259,6 +268,30 @@ BOOL http_context_set_rdg_connection_id(HttpContext* context, const char* RdgCon return TRUE; } +BOOL http_context_enable_websocket_upgrade(HttpContext* context, BOOL enable) +{ + if (!context) + return FALSE; + + if (enable) + { + int i; + winpr_RAND((BYTE*)context->SecWebsocketKey, 15); + for (i = 0; i < 16; i++) + context->SecWebsocketKey[i] = (context->SecWebsocketKey[i] | 0x40) & 0x5f; + context->SecWebsocketKey[15] = '\0'; + } + else + context->SecWebsocketKey[0] = '\0'; + context->websocketUpgrade = enable; + return TRUE; +} + +BOOL http_context_is_websocket_upgrade_enabled(HttpContext* context) +{ + return context->websocketUpgrade; +} + BOOL http_context_set_rdg_auth_scheme(HttpContext* context, const char* RdgAuthScheme) { if (!context || !RdgAuthScheme) @@ -426,13 +459,26 @@ wStream* http_request_write(HttpContext* context, HttpRequest* request) if (!http_encode_header_line(s, request->Method, request->URI) || !http_encode_body_line(s, "Cache-Control", context->CacheControl) || - !http_encode_body_line(s, "Connection", context->Connection) || !http_encode_body_line(s, "Pragma", context->Pragma) || !http_encode_body_line(s, "Accept", context->Accept) || !http_encode_body_line(s, "User-Agent", context->UserAgent) || !http_encode_body_line(s, "Host", context->Host)) goto fail; + if (!context->websocketUpgrade) + { + if (!http_encode_body_line(s, "Connection", context->Connection)) + goto fail; + } + else + { + if (!http_encode_body_line(s, "Connection", "Upgrade") || + !http_encode_body_line(s, "Upgrade", "websocket") || + !http_encode_body_line(s, "Sec-Websocket-Version", "13") || + !http_encode_body_line(s, "Sec-Websocket-Key", context->SecWebsocketKey)) + goto fail; + } + if (context->RdgConnectionId) { if (!http_encode_body_line(s, "RDG-Connection-Id", context->RdgConnectionId)) @@ -556,7 +602,6 @@ static BOOL http_response_parse_header_field(HttpResponse* response, const char* const char* value) { BOOL status = TRUE; - if (!response || !name) return FALSE; @@ -587,6 +632,20 @@ static BOOL http_response_parse_header_field(HttpResponse* response, const char* else response->TransferEncoding = TransferEncodingUnknown; } + else if (_stricmp(name, "Sec-WebSocket-Version") == 0) + { + response->SecWebsocketVersion = value; + + if (!response->SecWebsocketVersion) + return FALSE; + } + else if (_stricmp(name, "Sec-WebSocket-Accept") == 0) + { + response->SecWebsocketAccept = value; + + if (!response->SecWebsocketAccept) + return FALSE; + } else if (_stricmp(name, "WWW-Authenticate") == 0) { char* separator = NULL; @@ -1041,3 +1100,56 @@ TRANSFER_ENCODING http_response_get_transfer_encoding(HttpResponse* response) return response->TransferEncoding; } + +BOOL http_response_is_websocket(HttpContext* http, HttpResponse* response) +{ + BOOL isWebsocket = FALSE; + WINPR_DIGEST_CTX* sha1 = NULL; + char* base64accept = NULL; + BYTE sha1_digest[WINPR_SHA1_DIGEST_LENGTH]; + + if (!http || !response) + return FALSE; + + if (!http->websocketUpgrade || response->StatusCode != HTTP_STATUS_SWITCH_PROTOCOLS) + return FALSE; + + if (response->SecWebsocketVersion && _stricmp(response->SecWebsocketVersion, "13") != 0) + return FALSE; + + if (!response->SecWebsocketAccept) + return FALSE; + + /* now check if Sec-Websocket-Accept is correct */ + + sha1 = winpr_Digest_New(); + if (!sha1) + goto out; + + if (!winpr_Digest_Init(sha1, WINPR_MD_SHA1)) + goto out; + + if (!winpr_Digest_Update(sha1, (const BYTE*)http->SecWebsocketKey, + strlen(http->SecWebsocketKey))) + goto out; + if (!winpr_Digest_Update(sha1, (const BYTE*)WEBSOCKET_MAGIC_GUID, strlen(WEBSOCKET_MAGIC_GUID))) + goto out; + + if (!winpr_Digest_Final(sha1, sha1_digest, sizeof(sha1_digest))) + goto out; + + base64accept = crypto_base64_encode(sha1_digest, WINPR_SHA1_DIGEST_LENGTH); + if (!base64accept) + goto out; + + if (_stricmp(response->SecWebsocketAccept, base64accept) != 0) + { + WLog_WARN(TAG, "Webserver gave Websocket Upgrade response but sanity check failed"); + goto out; + } + isWebsocket = TRUE; +out: + winpr_Digest_Free(sha1); + free(base64accept); + return isWebsocket; +} diff --git a/libfreerdp/core/gateway/http.h b/libfreerdp/core/gateway/http.h index af1be0fdb..00bfb03e7 100644 --- a/libfreerdp/core/gateway/http.h +++ b/libfreerdp/core/gateway/http.h @@ -52,6 +52,8 @@ FREERDP_LOCAL BOOL http_context_set_rdg_connection_id(HttpContext* context, const char* RdgConnectionId); FREERDP_LOCAL BOOL http_context_set_rdg_auth_scheme(HttpContext* context, const char* RdgAuthScheme); +FREERDP_LOCAL BOOL http_context_enable_websocket_upgrade(HttpContext* context, BOOL enable); +FREERDP_LOCAL BOOL http_context_is_websocket_upgrade_enabled(HttpContext* context); /* HTTP request */ typedef struct _http_request HttpRequest; @@ -85,5 +87,6 @@ FREERDP_LOCAL long http_response_get_status_code(HttpResponse* response); FREERDP_LOCAL SSIZE_T http_response_get_body_length(HttpResponse* response); FREERDP_LOCAL const char* http_response_get_auth_token(HttpResponse* response, const char* method); FREERDP_LOCAL TRANSFER_ENCODING http_response_get_transfer_encoding(HttpResponse* response); +FREERDP_LOCAL BOOL http_response_is_websocket(HttpContext* http, HttpResponse* response); #endif /* FREERDP_LIB_CORE_GATEWAY_HTTP_H */ diff --git a/libfreerdp/core/gateway/rdg.c b/libfreerdp/core/gateway/rdg.c index 73327f24b..a87f890e4 100644 --- a/libfreerdp/core/gateway/rdg.c +++ b/libfreerdp/core/gateway/rdg.c @@ -104,6 +104,42 @@ #define HTTP_CAPABILITY_REAUTH 0x10 #define HTTP_CAPABILITY_UDP_TRANSPORT 0x20 +#define WEBSOCKET_MASK_BIT 0x80 +#define WEBSOCKET_FIN_BIT 0x80 + +typedef enum _WEBSOCKET_OPCODE +{ + WebsocketContinuationOpcode = 0x0, + WebsocketTextOpcode = 0x1, + WebsocketBinaryOpcode = 0x2, + WebsocketCloseOpcode = 0x8, + WebsocketPingOpcode = 0x9, + WebsocketPongOpcode = 0xa, +} WEBSOCKET_OPCODE; + +typedef enum _WEBSOCKET_STATE +{ + WebsocketStateOpcodeAndFin, + WebsocketStateLengthAndMasking, + WebsocketStateShortLength, + WebsocketStateLongLength, + WebSocketStateMaskingKey, + WebSocketStatePayload, +} WEBSOCKET_STATE; + +typedef struct +{ + size_t payloadLength; + uint32_t maskingKey; + BOOL masking; + BOOL closeSent; + BYTE opcode; + BYTE fragmentOriginalOpcode; + BYTE lengthAndMaskPosition; + WEBSOCKET_STATE state; + wStream* responseStreamBuffer; +} rdg_http_websocket_context; + typedef enum _CHUNK_STATE { ChunkStateLenghHeader, @@ -122,9 +158,11 @@ typedef struct typedef struct { TRANSFER_ENCODING httpTransferEncoding; + BOOL isWebsocketTransport; union _context { rdg_http_encoding_chunked_context chunked; + rdg_http_websocket_context websocket; } context; } rdg_http_encoding_context; @@ -293,9 +331,9 @@ static BOOL rdg_read_http_unicode_string(wStream* s, const WCHAR** string, UINT1 return TRUE; } -static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket) +static BOOL rdg_write_chunked(BIO* bio, wStream* sPacket) { - size_t s; + size_t len; int status; wStream* sChunk; char chunkSize[11]; @@ -309,44 +347,436 @@ static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket) Stream_Write(sChunk, Stream_Buffer(sPacket), Stream_Length(sPacket)); Stream_Write(sChunk, "\r\n", 2); Stream_SealLength(sChunk); - s = Stream_Length(sChunk); + len = Stream_Length(sChunk); - if (s > INT_MAX) + if (len > INT_MAX) return FALSE; - status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)s); + status = BIO_write(bio, Stream_Buffer(sChunk), (int)len); Stream_Free(sChunk, TRUE); - if (status < 0) + if (status != len) return FALSE; return TRUE; } -static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size, - rdg_http_encoding_context* encodingContext) +static BOOL rdg_write_websocket(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode) +{ + size_t len; + size_t fullLen; + int status; + wStream* sWS; + + uint32_t maskingKey; + + size_t streamPos; + + len = Stream_Length(sPacket); + Stream_SetPosition(sPacket, 0); + + if (len > INT_MAX) + return FALSE; + + if (len < 126) + fullLen = len + 6; /* 2 byte "mini header" + 4 byte masking key */ + else if (len < 0x10000) + fullLen = len + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */ + else + fullLen = len + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */ + + sWS = Stream_New(NULL, fullLen); + if (!sWS) + return FALSE; + + winpr_RAND((BYTE*)&maskingKey, 4); + + Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | opcode); + if (len < 126) + Stream_Write_UINT8(sWS, len | WEBSOCKET_MASK_BIT); + else if (len < 0x10000) + { + Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT); + Stream_Write_UINT16_BE(sWS, len); + } + else + { + Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT); + Stream_Write_UINT32_BE(sWS, 0); /* payload is limited to INT_MAX */ + Stream_Write_UINT32_BE(sWS, len); + } + Stream_Write_UINT32(sWS, maskingKey); + + /* mask as much as possible with 32bit access */ + for (streamPos = 0; streamPos + 4 <= len; streamPos += 4) + { + uint32_t data; + Stream_Read_UINT32(sPacket, data); + Stream_Write_UINT32(sWS, data ^ maskingKey); + } + + /* mask the rest byte by byte */ + for (; streamPos < len; streamPos++) + { + BYTE data; + BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4); + Stream_Read_UINT8(sPacket, data); + Stream_Write_UINT8(sWS, data ^ *partialMask); + } + + Stream_SealLength(sWS); + + status = BIO_write(bio, Stream_Buffer(sWS), Stream_Length(sWS)); + Stream_Free(sWS, TRUE); + + if (status != fullLen) + return FALSE; + + return TRUE; +} + +static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket) +{ + if (rdg->transferEncoding.isWebsocketTransport) + { + if (rdg->transferEncoding.context.websocket.closeSent) + return FALSE; + return rdg_write_websocket(rdg->tlsOut->bio, sPacket, WebsocketBinaryOpcode); + } + + return rdg_write_chunked(rdg->tlsIn->bio, sPacket); +} + +static int rdg_websocket_read_data(BIO* bio, BYTE* pBuffer, size_t size, + rdg_http_websocket_context* encodingContext) +{ + int status; + + if (encodingContext->payloadLength == 0) + { + encodingContext->state = WebsocketStateOpcodeAndFin; + return 0; + } + + status = + BIO_read(bio, pBuffer, + (encodingContext->payloadLength < size ? encodingContext->payloadLength : size)); + if (status <= 0) + return status; + + encodingContext->payloadLength -= status; + + if (encodingContext->payloadLength == 0) + encodingContext->state = WebsocketStateOpcodeAndFin; + + return status; +} + +static int rdg_websocket_read_discard(BIO* bio, rdg_http_websocket_context* encodingContext) +{ + char _dummy[256]; + int status; + + if (encodingContext->payloadLength == 0) + { + encodingContext->state = WebsocketStateOpcodeAndFin; + return 0; + } + + status = BIO_read(bio, _dummy, sizeof(_dummy)); + if (status <= 0) + return status; + + encodingContext->payloadLength -= status; + + if (encodingContext->payloadLength == 0) + encodingContext->state = WebsocketStateOpcodeAndFin; + + return status; +} + +static int rdg_websocket_read_wstream(BIO* bio, wStream* s, + rdg_http_websocket_context* encodingContext) +{ + int status; + + if (encodingContext->payloadLength == 0) + { + encodingContext->state = WebsocketStateOpcodeAndFin; + return 0; + } + if (s == NULL || Stream_GetRemainingCapacity(s) != encodingContext->payloadLength) + return -1; + + status = BIO_read(bio, Stream_Pointer(s), encodingContext->payloadLength); + if (status <= 0) + return status; + + Stream_Seek(s, status); + + encodingContext->payloadLength -= status; + + if (encodingContext->payloadLength == 0) + { + encodingContext->state = WebsocketStateOpcodeAndFin; + Stream_SealLength(s); + Stream_SetPosition(s, 0); + } + + return status; +} + +static BOOL rdg_websocket_reply_close(BIO* bio, wStream* s) +{ + /* write back close */ + wStream* closeFrame; + uint16_t maskingKey1; + uint16_t maskingKey2; + int status; + size_t closeDataLen; + + closeDataLen = 0; + if (s != NULL && Stream_Length(s) >= 2) + closeDataLen = 2; + + closeFrame = Stream_New(NULL, 6 + closeDataLen); + Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketPongOpcode); + Stream_Write_UINT8(closeFrame, closeDataLen | WEBSOCKET_MASK_BIT); /* no payload */ + winpr_RAND((BYTE*)&maskingKey1, 2); + winpr_RAND((BYTE*)&maskingKey2, 2); + Stream_Write_UINT16(closeFrame, maskingKey1); + Stream_Write_UINT16(closeFrame, maskingKey2); /* unused half, max 2 bytes of data */ + + if (closeDataLen == 2) + { + uint16_t data; + Stream_Read_UINT16(s, data); + Stream_Write_UINT16(s, data ^ maskingKey1); + } + Stream_SealLength(closeFrame); + + status = BIO_write(bio, Stream_Buffer(closeFrame), Stream_Length(closeFrame)); + /* server MUST close socket now. The server is not allowed anymore to + * send frames but if he does, nothing bad would happen */ + if (status < 0) + return FALSE; + return TRUE; +} + +static BOOL rdg_websocket_reply_pong(BIO* bio, wStream* s) +{ + wStream* closeFrame; + uint32_t maskingKey; + int status; + + if (s != NULL) + return rdg_write_websocket(bio, s, WebsocketPongOpcode); + + closeFrame = Stream_New(NULL, 6); + Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketPongOpcode); + Stream_Write_UINT8(closeFrame, 0 | WEBSOCKET_MASK_BIT); /* no payload */ + winpr_RAND((BYTE*)&maskingKey, 4); + Stream_Write_UINT32(closeFrame, maskingKey); /* dummy masking key. */ + Stream_SealLength(closeFrame); + + status = BIO_write(bio, Stream_Buffer(closeFrame), Stream_Length(closeFrame)); + + if (status < 0) + return FALSE; + return TRUE; +} + +static int rdg_websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size, + rdg_http_websocket_context* encodingContext) +{ + int status; + BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode + ? encodingContext->fragmentOriginalOpcode & 0xf + : encodingContext->opcode & 0xf); + + switch (effectiveOpcode) + { + case WebsocketBinaryOpcode: + { + status = rdg_websocket_read_data(bio, pBuffer, size, encodingContext); + if (status < 0) + return status; + + return status; + } + break; + case WebsocketPingOpcode: + { + if (encodingContext->responseStreamBuffer == NULL) + encodingContext->responseStreamBuffer = + Stream_New(NULL, encodingContext->payloadLength); + + status = rdg_websocket_read_wstream(bio, encodingContext->responseStreamBuffer, + encodingContext); + if (status < 0) + return status; + + if (encodingContext->payloadLength == 0) + { + if (!encodingContext->closeSent) + rdg_websocket_reply_pong(bio, encodingContext->responseStreamBuffer); + + if (encodingContext->responseStreamBuffer) + Stream_Free(encodingContext->responseStreamBuffer, TRUE); + encodingContext->responseStreamBuffer = NULL; + } + } + break; + case WebsocketCloseOpcode: + { + if (encodingContext->responseStreamBuffer == NULL) + encodingContext->responseStreamBuffer = + Stream_New(NULL, encodingContext->payloadLength); + + status = rdg_websocket_read_wstream(bio, encodingContext->responseStreamBuffer, + encodingContext); + if (status < 0) + return status; + + if (encodingContext->payloadLength == 0) + { + rdg_websocket_reply_close(bio, encodingContext->responseStreamBuffer); + encodingContext->closeSent = TRUE; + + if (encodingContext->responseStreamBuffer) + Stream_Free(encodingContext->responseStreamBuffer, TRUE); + encodingContext->responseStreamBuffer = NULL; + } + } + break; + default: + WLog_WARN(TAG, "Unimplemented websocket opcode %x. Dropping", effectiveOpcode & 0xf); + + status = rdg_websocket_read_discard(bio, encodingContext); + if (status < 0) + return status; + } + /* return how many bytes have been written to pBuffer. + * Only WebsocketBinaryOpcode writes into it and it returns directly */ + return 0; +} + +static int rdg_websocket_read(BIO* bio, BYTE* pBuffer, size_t size, + rdg_http_websocket_context* encodingContext) { int status; int effectiveDataLen = 0; assert(encodingContext != NULL); while (TRUE) { - switch (encodingContext->context.chunked.state) + switch (encodingContext->state) { - case ChunkStateData: + case WebsocketStateOpcodeAndFin: { - status = BIO_read(bio, pBuffer, - (size > encodingContext->context.chunked.nextOffset - ? encodingContext->context.chunked.nextOffset - : size)); + BYTE buffer[1]; + status = BIO_read(bio, (char*)buffer, 1); if (status <= 0) return (effectiveDataLen > 0 ? effectiveDataLen : status); - encodingContext->context.chunked.nextOffset -= status; - if (encodingContext->context.chunked.nextOffset == 0) + encodingContext->opcode = buffer[0]; + if (((encodingContext->opcode & 0xf) != WebsocketContinuationOpcode) && + (encodingContext->opcode & 0xf) < 0x08) + encodingContext->fragmentOriginalOpcode = encodingContext->opcode; + encodingContext->state = WebsocketStateLengthAndMasking; + } + break; + case WebsocketStateLengthAndMasking: + { + BYTE buffer[1]; + BYTE len; + status = BIO_read(bio, (char*)buffer, 1); + if (status <= 0) + return (effectiveDataLen > 0 ? effectiveDataLen : status); + + encodingContext->masking = ((buffer[0] & WEBSOCKET_MASK_BIT) == WEBSOCKET_MASK_BIT); + encodingContext->lengthAndMaskPosition = 0; + encodingContext->payloadLength = 0; + len = buffer[0] & 0x7f; + if (len < 126) { - encodingContext->context.chunked.state = ChunkStateFooter; - encodingContext->context.chunked.headerFooterPos = 0; + encodingContext->payloadLength = len; + encodingContext->state = (encodingContext->masking ? WebSocketStateMaskingKey + : WebSocketStatePayload); + } + else if (len == 126) + encodingContext->state = WebsocketStateShortLength; + else + encodingContext->state = WebsocketStateLongLength; + } + break; + case WebsocketStateShortLength: + case WebsocketStateLongLength: + { + BYTE buffer[1]; + BYTE lenLength = (encodingContext->state == WebsocketStateShortLength ? 2 : 8); + while (encodingContext->lengthAndMaskPosition < lenLength) + { + status = BIO_read(bio, (char*)buffer, 1); + if (status <= 0) + return (effectiveDataLen > 0 ? effectiveDataLen : status); + + encodingContext->payloadLength = + (encodingContext->payloadLength) << 8 | buffer[0]; + encodingContext->lengthAndMaskPosition += status; + } + encodingContext->state = + (encodingContext->masking ? WebSocketStateMaskingKey : WebSocketStatePayload); + } + break; + case WebSocketStateMaskingKey: + { + WLog_WARN( + TAG, "Websocket Server sends data with masking key. This is against RFC 6455."); + return -1; + } + break; + case WebSocketStatePayload: + { + status = rdg_websocket_handle_payload(bio, pBuffer, size, encodingContext); + if (status < 0) + return (effectiveDataLen > 0 ? effectiveDataLen : status); + + effectiveDataLen += status; + + if ((size_t)status == size) + return effectiveDataLen; + pBuffer += status; + size -= status; + } + } + } + /* should be unreachable */ + return -1; +} + +static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size, + rdg_http_encoding_chunked_context* encodingContext) +{ + int status; + int effectiveDataLen = 0; + assert(encodingContext != NULL); + while (TRUE) + { + switch (encodingContext->state) + { + case ChunkStateData: + { + status = BIO_read( + bio, pBuffer, + (size > encodingContext->nextOffset ? encodingContext->nextOffset : size)); + if (status <= 0) + return (effectiveDataLen > 0 ? effectiveDataLen : status); + + encodingContext->nextOffset -= status; + if (encodingContext->nextOffset == 0) + { + encodingContext->state = ChunkStateFooter; + encodingContext->headerFooterPos = 0; } effectiveDataLen += status; @@ -360,17 +790,16 @@ static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size, case ChunkStateFooter: { char _dummy[2]; - assert(encodingContext->context.chunked.nextOffset == 0); - assert(encodingContext->context.chunked.headerFooterPos < 2); - status = - BIO_read(bio, _dummy, 2 - encodingContext->context.chunked.headerFooterPos); + assert(encodingContext->nextOffset == 0); + assert(encodingContext->headerFooterPos < 2); + status = BIO_read(bio, _dummy, 2 - encodingContext->headerFooterPos); if (status >= 0) { - encodingContext->context.chunked.headerFooterPos += status; - if (encodingContext->context.chunked.headerFooterPos == 2) + encodingContext->headerFooterPos += status; + if (encodingContext->headerFooterPos == 2) { - encodingContext->context.chunked.state = ChunkStateLenghHeader; - encodingContext->context.chunked.headerFooterPos = 0; + encodingContext->state = ChunkStateLenghHeader; + encodingContext->headerFooterPos = 0; } } else @@ -381,43 +810,40 @@ static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size, { BOOL _haveNewLine = FALSE; size_t tmp; - char* dst = &encodingContext->context.chunked - .lenBuffer[encodingContext->context.chunked.headerFooterPos]; - assert(encodingContext->context.chunked.nextOffset == 0); - while (encodingContext->context.chunked.headerFooterPos < 10 && !_haveNewLine) + char* dst = &encodingContext->lenBuffer[encodingContext->headerFooterPos]; + assert(encodingContext->nextOffset == 0); + while (encodingContext->headerFooterPos < 10 && !_haveNewLine) { status = BIO_read(bio, dst, 1); if (status >= 0) { if (*dst == '\n') _haveNewLine = TRUE; - encodingContext->context.chunked.headerFooterPos += status; + encodingContext->headerFooterPos += status; dst += status; } else return (effectiveDataLen > 0 ? effectiveDataLen : status); } - *dst = '\0'; - /* strtoul is tricky, error are reported via errno, we also need * to ensure the result does not overflow */ errno = 0; - tmp = strtoul(encodingContext->context.chunked.lenBuffer, NULL, 16); + tmp = strtoul(encodingContext->lenBuffer, NULL, 16); if ((errno != 0) || (tmp > SIZE_MAX)) return -1; - encodingContext->context.chunked.nextOffset = tmp; - encodingContext->context.chunked.state = ChunkStateData; + encodingContext->nextOffset = tmp; + encodingContext->state = ChunkStateData; - if (encodingContext->context.chunked.nextOffset == 0) - { // end of stream + if (encodingContext->nextOffset == 0) + { /* end of stream */ int fd = BIO_get_fd(bio, NULL); if (fd >= 0) - close(fd); + closesocket((SOCKET)fd); WLog_WARN(TAG, "cunked encoding end of stream received"); - encodingContext->context.chunked.headerFooterPos = 0; - encodingContext->context.chunked.state = ChunkStateFooter; + encodingContext->headerFooterPos = 0; + encodingContext->state = ChunkStateFooter; } } break; @@ -433,12 +859,18 @@ static int rdg_socket_read(BIO* bio, BYTE* pBuffer, size_t size, rdg_http_encoding_context* encodingContext) { assert(encodingContext != NULL); + + if (encodingContext->isWebsocketTransport) + { + return rdg_websocket_read(bio, pBuffer, size, &encodingContext->context.websocket); + } + switch (encodingContext->httpTransferEncoding) { case TransferEncodingIdentity: return BIO_read(bio, pBuffer, size); case TransferEncodingChunked: - return rdg_chuncked_read(bio, pBuffer, size, encodingContext); + return rdg_chuncked_read(bio, pBuffer, size, &encodingContext->context.chunked); default: return -1; } @@ -454,7 +886,6 @@ static BOOL rdg_read_all(rdpTls* tls, BYTE* buffer, size_t size, while (readCount < size) { int status = rdg_socket_read(tls->bio, pBuffer, size - readCount, transferEncoding); - if (status <= 0) { if (!BIO_should_retry(tls->bio)) @@ -730,10 +1161,7 @@ static wStream* rdg_build_http_request(rdpRdg* rdg, const char* method, goto out; } - if (transferEncoding) - { - http_request_set_transfer_encoding(request, transferEncoding); - } + http_request_set_transfer_encoding(request, transferEncoding); s = http_request_write(rdg->http, request); out: @@ -1320,6 +1748,7 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char* SSIZE_T bodyLength; long StatusCode; TRANSFER_ENCODING encoding; + BOOL isWebsocket; if (!rdg_tls_connect(rdg, tls, peerAddress, timeout)) return FALSE; @@ -1344,6 +1773,7 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char* case HTTP_STATUS_NOT_FOUND: { WLog_INFO(TAG, "RD Gateway does not support HTTP transport."); + http_context_enable_websocket_upgrade(rdg->http, FALSE); if (rpcFallback) *rpcFallback = TRUE; @@ -1377,16 +1807,44 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char* statusCode = http_response_get_status_code(response); bodyLength = http_response_get_body_length(response); encoding = http_response_get_transfer_encoding(response); + isWebsocket = http_response_is_websocket(rdg->http, response); http_response_free(response); WLog_DBG(TAG, "%s authorization result: %d", method, statusCode); switch (statusCode) { case HTTP_STATUS_OK: + /* old rdg endpoint without websocket support, don't request websocket for RDG_IN_DATA + */ + http_context_enable_websocket_upgrade(rdg->http, FALSE); break; case HTTP_STATUS_DENIED: freerdp_set_last_error_log(rdg->context, FREERDP_ERROR_CONNECT_ACCESS_DENIED); return FALSE; + case HTTP_STATUS_SWITCH_PROTOCOLS: + if (!isWebsocket) + { + /* + * webserver is broken, a fallback may be possible here + * but only if already tested with oppurtonistic upgrade + */ + if (http_context_is_websocket_upgrade_enabled(rdg->http)) + { + int fd = BIO_get_fd(tls->bio, NULL); + if (fd >= 0) + closesocket((SOCKET)fd); + http_context_enable_websocket_upgrade(rdg->http, FALSE); + return rdg_establish_data_connection(rdg, tls, method, peerAddress, timeout, + rpcFallback); + } + return FALSE; + } + rdg->transferEncoding.isWebsocketTransport = TRUE; + rdg->transferEncoding.context.websocket.state = WebsocketStateOpcodeAndFin; + rdg->transferEncoding.context.websocket.responseStreamBuffer = NULL; + + return TRUE; + break; default: return FALSE; } @@ -1452,14 +1910,21 @@ BOOL rdg_connect(rdpRdg* rdg, DWORD timeout, BOOL* rpcFallback) if (status) { - /* Establish IN connection with the same peer/server as OUT connection, - * even when server hostname resolves to different IP addresses. - */ - BIO_get_socket(rdg->tlsOut->underlying, &outConnSocket); - peerAddress = freerdp_tcp_get_peer_address(outConnSocket); - status = rdg_establish_data_connection(rdg, rdg->tlsIn, "RDG_IN_DATA", peerAddress, timeout, - NULL); - free(peerAddress); + if (rdg->transferEncoding.isWebsocketTransport) + { + WLog_DBG(TAG, "Upgraded to websocket. RDG_IN_DATA not required"); + } + else + { + /* Establish IN connection with the same peer/server as OUT connection, + * even when server hostname resolves to different IP addresses. + */ + BIO_get_socket(rdg->tlsOut->underlying, &outConnSocket); + peerAddress = freerdp_tcp_get_peer_address(outConnSocket); + status = rdg_establish_data_connection(rdg, rdg->tlsIn, "RDG_IN_DATA", peerAddress, + timeout, NULL); + free(peerAddress); + } } if (!status) @@ -1476,10 +1941,97 @@ BOOL rdg_connect(rdpRdg* rdg, DWORD timeout, BOOL* rpcFallback) return TRUE; } -static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize) +static int rdg_write_websocket_data_packet(rdpRdg* rdg, const BYTE* buf, int isize) +{ + size_t payloadSize; + size_t fullLen; + int status; + wStream* sWS; + + uint32_t maskingKey; + BYTE* maskingKeyByte1 = (BYTE*)&maskingKey; + BYTE* maskingKeyByte2 = maskingKeyByte1 + 1; + BYTE* maskingKeyByte3 = maskingKeyByte1 + 2; + BYTE* maskingKeyByte4 = maskingKeyByte1 + 3; + + int streamPos; + + winpr_RAND((BYTE*)&maskingKey, 4); + + payloadSize = isize + 10; + if ((isize < 0) || (isize > UINT16_MAX)) + return -1; + + if (payloadSize < 1) + return 0; + + if (payloadSize < 126) + fullLen = payloadSize + 6; /* 2 byte "mini header" + 4 byte masking key */ + else if (payloadSize < 0x10000) + fullLen = payloadSize + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */ + else + fullLen = payloadSize + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */ + + sWS = Stream_New(NULL, fullLen); + if (!sWS) + return FALSE; + + Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | WebsocketBinaryOpcode); + if (payloadSize < 126) + Stream_Write_UINT8(sWS, payloadSize | WEBSOCKET_MASK_BIT); + else if (payloadSize < 0x10000) + { + Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT); + Stream_Write_UINT16_BE(sWS, payloadSize); + } + else + { + Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT); + /* biggest packet possible is 0xffff + 0xa, so 32bit is always enough */ + Stream_Write_UINT32_BE(sWS, 0); + Stream_Write_UINT32_BE(sWS, payloadSize); + } + Stream_Write_UINT32(sWS, maskingKey); + + Stream_Write_UINT16(sWS, PKT_TYPE_DATA ^ (*maskingKeyByte1 | *maskingKeyByte2 << 8)); /* Type */ + Stream_Write_UINT16(sWS, 0 ^ (*maskingKeyByte3 | *maskingKeyByte4 << 8)); /* Reserved */ + Stream_Write_UINT32(sWS, (UINT32)payloadSize ^ maskingKey); /* Packet length */ + Stream_Write_UINT16(sWS, + (UINT16)isize ^ (*maskingKeyByte1 | *maskingKeyByte2 << 8)); /* Data size */ + + /* masking key is now off by 2 bytes. fix that */ + maskingKey = (maskingKey & 0xffff) << 16 | (maskingKey >> 16); + + /* mask as much as possible with 32bit access */ + for (streamPos = 0; streamPos + 4 <= isize; streamPos += 4) + { + uint32_t masked = *((uint32_t*)((BYTE*)buf + streamPos)) ^ maskingKey; + Stream_Write_UINT32(sWS, masked); + } + + /* mask the rest byte by byte */ + for (; streamPos < isize; streamPos++) + { + BYTE* partialMask = (BYTE*)(&maskingKey) + streamPos % 4; + BYTE masked = *((BYTE*)((BYTE*)buf + streamPos)) ^ *partialMask; + Stream_Write_UINT8(sWS, masked); + } + + Stream_SealLength(sWS); + + status = tls_write_all(rdg->tlsOut, Stream_Buffer(sWS), Stream_Length(sWS)); + Stream_Free(sWS, TRUE); + + if (status < 0) + return status; + + return isize; +} + +static int rdg_write_chunked_data_packet(rdpRdg* rdg, const BYTE* buf, int isize) { int status; - size_t s; + size_t len; wStream* sChunk; size_t size = (size_t)isize; size_t packetSize = size + 10; @@ -1505,12 +2057,12 @@ static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize) Stream_Write(sChunk, buf, size); /* Data */ Stream_Write(sChunk, "\r\n", 2); Stream_SealLength(sChunk); - s = Stream_Length(sChunk); + len = Stream_Length(sChunk); - if (s > INT_MAX) + if (len > INT_MAX) return -1; - status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)s); + status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)len); Stream_Free(sChunk, TRUE); if (status < 0) @@ -1519,18 +2071,26 @@ static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize) return (int)size; } +static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize) +{ + if (rdg->transferEncoding.isWebsocketTransport) + { + if (rdg->transferEncoding.context.websocket.closeSent == TRUE) + return -1; + return rdg_write_websocket_data_packet(rdg, buf, isize); + } + else + return rdg_write_chunked_data_packet(rdg, buf, isize); + + return -1; +} + static BOOL rdg_process_close_packet(rdpRdg* rdg, wStream* s) { int status = -1; - size_t len; - wStream* sChunk; + wStream* sClose; UINT32 errorCode; UINT32 packetSize = 12; - char chunkSize[11]; - int chunkLen = sprintf_s(chunkSize, sizeof(chunkSize), "%" PRIx32 "\r\n", packetSize); - - if (chunkLen < 0) - return FALSE; /* Read error code */ if (Stream_GetRemainingLength(s) < 4) @@ -1540,55 +2100,39 @@ static BOOL rdg_process_close_packet(rdpRdg* rdg, wStream* s) if (errorCode != 0) freerdp_set_last_error_log(rdg->context, errorCode); - sChunk = Stream_New(NULL, (size_t)chunkLen + packetSize + 2); - if (!sChunk) + sClose = Stream_New(NULL, packetSize); + if (!sClose) return FALSE; - Stream_Write(sChunk, chunkSize, (size_t)chunkLen); - Stream_Write_UINT16(sChunk, PKT_TYPE_CLOSE_CHANNEL_RESPONSE); /* Type */ - Stream_Write_UINT16(sChunk, 0); /* Reserved */ - Stream_Write_UINT32(sChunk, packetSize); /* Packet length */ - Stream_Write_UINT32(sChunk, 0); /* Status code */ - Stream_Write(sChunk, "\r\n", 2); - Stream_SealLength(sChunk); - len = Stream_Length(sChunk); + Stream_Write_UINT16(sClose, PKT_TYPE_CLOSE_CHANNEL_RESPONSE); /* Type */ + Stream_Write_UINT16(sClose, 0); /* Reserved */ + Stream_Write_UINT32(sClose, packetSize); /* Packet length */ + Stream_Write_UINT32(sClose, 0); /* Status code */ + Stream_SealLength(sClose); + status = rdg_write_packet(rdg, sClose); + Stream_Free(sClose, TRUE); - if (len <= INT_MAX) - status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)len); - - Stream_Free(sChunk, TRUE); return (status < 0 ? FALSE : TRUE); } static BOOL rdg_process_keep_alive_packet(rdpRdg* rdg) { int status = -1; - size_t s; - wStream* sChunk; + wStream* sKeepAlive; size_t packetSize = 8; - char chunkSize[11]; - int chunkLen = sprintf_s(chunkSize, sizeof(chunkSize), "%" PRIxz "\r\n", packetSize); - if ((chunkLen < 0) || (packetSize > UINT32_MAX)) + sKeepAlive = Stream_New(NULL, packetSize); + + if (!sKeepAlive) return FALSE; - sChunk = Stream_New(NULL, (size_t)chunkLen + packetSize + 2); + Stream_Write_UINT16(sKeepAlive, PKT_TYPE_KEEPALIVE); /* Type */ + Stream_Write_UINT16(sKeepAlive, 0); /* Reserved */ + Stream_Write_UINT32(sKeepAlive, (UINT32)packetSize); /* Packet length */ + Stream_SealLength(sKeepAlive); + status = rdg_write_packet(rdg, sKeepAlive); + Stream_Free(sKeepAlive, TRUE); - if (!sChunk) - return FALSE; - - Stream_Write(sChunk, chunkSize, (size_t)chunkLen); - Stream_Write_UINT16(sChunk, PKT_TYPE_KEEPALIVE); /* Type */ - Stream_Write_UINT16(sChunk, 0); /* Reserved */ - Stream_Write_UINT32(sChunk, (UINT32)packetSize); /* Packet length */ - Stream_Write(sChunk, "\r\n", 2); - Stream_SealLength(sChunk); - s = Stream_Length(sChunk); - - if (s <= INT_MAX) - status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)s); - - Stream_Free(sChunk, TRUE); return (status < 0 ? FALSE : TRUE); } @@ -2003,7 +2547,8 @@ rdpRdg* rdg_new(rdpContext* context) !http_context_set_connection(rdg->http, "Keep-Alive") || !http_context_set_user_agent(rdg->http, "MS-RDGateway/1.0") || !http_context_set_host(rdg->http, rdg->settings->GatewayHostname) || - !http_context_set_rdg_connection_id(rdg->http, bracedUuid)) + !http_context_set_rdg_connection_id(rdg->http, bracedUuid) || + !http_context_enable_websocket_upgrade(rdg->http, TRUE)) { goto rdg_alloc_error; } @@ -2033,6 +2578,7 @@ rdpRdg* rdg_new(rdpContext* context) InitializeCriticalSection(&rdg->writeSection); rdg->transferEncoding.httpTransferEncoding = TransferEncodingIdentity; + rdg->transferEncoding.isWebsocketTransport = FALSE; } return rdg; @@ -2056,6 +2602,12 @@ void rdg_free(rdpRdg* rdg) DeleteCriticalSection(&rdg->writeSection); + if (rdg->transferEncoding.isWebsocketTransport) + { + if (rdg->transferEncoding.context.websocket.responseStreamBuffer != NULL) + Stream_Free(rdg->transferEncoding.context.websocket.responseStreamBuffer, TRUE); + } + free(rdg); }