rdg websocket support

This commit is contained in:
Michael Saxl 2021-01-25 16:20:18 +01:00 committed by akallabeth
parent b05fea134e
commit bc52147fbb
3 changed files with 772 additions and 105 deletions

View File

@ -30,6 +30,9 @@
#include <freerdp/log.h> #include <freerdp/log.h>
/* websocket need sha1 for Sec-Websocket-Accept */
#include <winpr/crypto.h>
#ifdef HAVE_VALGRIND_MEMCHECK_H #ifdef HAVE_VALGRIND_MEMCHECK_H
#include <valgrind/memcheck.h> #include <valgrind/memcheck.h>
#endif #endif
@ -40,6 +43,8 @@
#define RESPONSE_SIZE_LIMIT 64 * 1024 * 1024 #define RESPONSE_SIZE_LIMIT 64 * 1024 * 1024
#define WEBSOCKET_MAGIC_GUID "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
struct _http_context struct _http_context
{ {
char* Method; char* Method;
@ -52,6 +57,8 @@ struct _http_context
char* Pragma; char* Pragma;
char* RdgConnectionId; char* RdgConnectionId;
char* RdgAuthScheme; char* RdgAuthScheme;
BOOL websocketUpgrade;
char SecWebsocketKey[16];
}; };
struct _http_request struct _http_request
@ -77,6 +84,8 @@ struct _http_response
size_t ContentLength; size_t ContentLength;
const char* ContentType; const char* ContentType;
TRANSFER_ENCODING TransferEncoding; TRANSFER_ENCODING TransferEncoding;
const char* SecWebsocketVersion;
const char* SecWebsocketAccept;
size_t BodyLength; size_t BodyLength;
BYTE* BodyContent; BYTE* BodyContent;
@ -259,6 +268,30 @@ BOOL http_context_set_rdg_connection_id(HttpContext* context, const char* RdgCon
return TRUE; return TRUE;
} }
BOOL http_context_enable_websocket_upgrade(HttpContext* context, BOOL enable)
{
if (!context)
return FALSE;
if (enable)
{
int i;
winpr_RAND((BYTE*)context->SecWebsocketKey, 15);
for (i = 0; i < 16; i++)
context->SecWebsocketKey[i] = (context->SecWebsocketKey[i] | 0x40) & 0x5f;
context->SecWebsocketKey[15] = '\0';
}
else
context->SecWebsocketKey[0] = '\0';
context->websocketUpgrade = enable;
return TRUE;
}
BOOL http_context_is_websocket_upgrade_enabled(HttpContext* context)
{
return context->websocketUpgrade;
}
BOOL http_context_set_rdg_auth_scheme(HttpContext* context, const char* RdgAuthScheme) BOOL http_context_set_rdg_auth_scheme(HttpContext* context, const char* RdgAuthScheme)
{ {
if (!context || !RdgAuthScheme) if (!context || !RdgAuthScheme)
@ -426,13 +459,26 @@ wStream* http_request_write(HttpContext* context, HttpRequest* request)
if (!http_encode_header_line(s, request->Method, request->URI) || if (!http_encode_header_line(s, request->Method, request->URI) ||
!http_encode_body_line(s, "Cache-Control", context->CacheControl) || !http_encode_body_line(s, "Cache-Control", context->CacheControl) ||
!http_encode_body_line(s, "Connection", context->Connection) ||
!http_encode_body_line(s, "Pragma", context->Pragma) || !http_encode_body_line(s, "Pragma", context->Pragma) ||
!http_encode_body_line(s, "Accept", context->Accept) || !http_encode_body_line(s, "Accept", context->Accept) ||
!http_encode_body_line(s, "User-Agent", context->UserAgent) || !http_encode_body_line(s, "User-Agent", context->UserAgent) ||
!http_encode_body_line(s, "Host", context->Host)) !http_encode_body_line(s, "Host", context->Host))
goto fail; goto fail;
if (!context->websocketUpgrade)
{
if (!http_encode_body_line(s, "Connection", context->Connection))
goto fail;
}
else
{
if (!http_encode_body_line(s, "Connection", "Upgrade") ||
!http_encode_body_line(s, "Upgrade", "websocket") ||
!http_encode_body_line(s, "Sec-Websocket-Version", "13") ||
!http_encode_body_line(s, "Sec-Websocket-Key", context->SecWebsocketKey))
goto fail;
}
if (context->RdgConnectionId) if (context->RdgConnectionId)
{ {
if (!http_encode_body_line(s, "RDG-Connection-Id", context->RdgConnectionId)) if (!http_encode_body_line(s, "RDG-Connection-Id", context->RdgConnectionId))
@ -556,7 +602,6 @@ static BOOL http_response_parse_header_field(HttpResponse* response, const char*
const char* value) const char* value)
{ {
BOOL status = TRUE; BOOL status = TRUE;
if (!response || !name) if (!response || !name)
return FALSE; return FALSE;
@ -587,6 +632,20 @@ static BOOL http_response_parse_header_field(HttpResponse* response, const char*
else else
response->TransferEncoding = TransferEncodingUnknown; response->TransferEncoding = TransferEncodingUnknown;
} }
else if (_stricmp(name, "Sec-WebSocket-Version") == 0)
{
response->SecWebsocketVersion = value;
if (!response->SecWebsocketVersion)
return FALSE;
}
else if (_stricmp(name, "Sec-WebSocket-Accept") == 0)
{
response->SecWebsocketAccept = value;
if (!response->SecWebsocketAccept)
return FALSE;
}
else if (_stricmp(name, "WWW-Authenticate") == 0) else if (_stricmp(name, "WWW-Authenticate") == 0)
{ {
char* separator = NULL; char* separator = NULL;
@ -1041,3 +1100,56 @@ TRANSFER_ENCODING http_response_get_transfer_encoding(HttpResponse* response)
return response->TransferEncoding; return response->TransferEncoding;
} }
BOOL http_response_is_websocket(HttpContext* http, HttpResponse* response)
{
BOOL isWebsocket = FALSE;
WINPR_DIGEST_CTX* sha1 = NULL;
char* base64accept = NULL;
BYTE sha1_digest[WINPR_SHA1_DIGEST_LENGTH];
if (!http || !response)
return FALSE;
if (!http->websocketUpgrade || response->StatusCode != HTTP_STATUS_SWITCH_PROTOCOLS)
return FALSE;
if (response->SecWebsocketVersion && _stricmp(response->SecWebsocketVersion, "13") != 0)
return FALSE;
if (!response->SecWebsocketAccept)
return FALSE;
/* now check if Sec-Websocket-Accept is correct */
sha1 = winpr_Digest_New();
if (!sha1)
goto out;
if (!winpr_Digest_Init(sha1, WINPR_MD_SHA1))
goto out;
if (!winpr_Digest_Update(sha1, (const BYTE*)http->SecWebsocketKey,
strlen(http->SecWebsocketKey)))
goto out;
if (!winpr_Digest_Update(sha1, (const BYTE*)WEBSOCKET_MAGIC_GUID, strlen(WEBSOCKET_MAGIC_GUID)))
goto out;
if (!winpr_Digest_Final(sha1, sha1_digest, sizeof(sha1_digest)))
goto out;
base64accept = crypto_base64_encode(sha1_digest, WINPR_SHA1_DIGEST_LENGTH);
if (!base64accept)
goto out;
if (_stricmp(response->SecWebsocketAccept, base64accept) != 0)
{
WLog_WARN(TAG, "Webserver gave Websocket Upgrade response but sanity check failed");
goto out;
}
isWebsocket = TRUE;
out:
winpr_Digest_Free(sha1);
free(base64accept);
return isWebsocket;
}

View File

@ -52,6 +52,8 @@ FREERDP_LOCAL BOOL http_context_set_rdg_connection_id(HttpContext* context,
const char* RdgConnectionId); const char* RdgConnectionId);
FREERDP_LOCAL BOOL http_context_set_rdg_auth_scheme(HttpContext* context, FREERDP_LOCAL BOOL http_context_set_rdg_auth_scheme(HttpContext* context,
const char* RdgAuthScheme); const char* RdgAuthScheme);
FREERDP_LOCAL BOOL http_context_enable_websocket_upgrade(HttpContext* context, BOOL enable);
FREERDP_LOCAL BOOL http_context_is_websocket_upgrade_enabled(HttpContext* context);
/* HTTP request */ /* HTTP request */
typedef struct _http_request HttpRequest; typedef struct _http_request HttpRequest;
@ -85,5 +87,6 @@ FREERDP_LOCAL long http_response_get_status_code(HttpResponse* response);
FREERDP_LOCAL SSIZE_T http_response_get_body_length(HttpResponse* response); FREERDP_LOCAL SSIZE_T http_response_get_body_length(HttpResponse* response);
FREERDP_LOCAL const char* http_response_get_auth_token(HttpResponse* response, const char* method); FREERDP_LOCAL const char* http_response_get_auth_token(HttpResponse* response, const char* method);
FREERDP_LOCAL TRANSFER_ENCODING http_response_get_transfer_encoding(HttpResponse* response); FREERDP_LOCAL TRANSFER_ENCODING http_response_get_transfer_encoding(HttpResponse* response);
FREERDP_LOCAL BOOL http_response_is_websocket(HttpContext* http, HttpResponse* response);
#endif /* FREERDP_LIB_CORE_GATEWAY_HTTP_H */ #endif /* FREERDP_LIB_CORE_GATEWAY_HTTP_H */

View File

@ -104,6 +104,42 @@
#define HTTP_CAPABILITY_REAUTH 0x10 #define HTTP_CAPABILITY_REAUTH 0x10
#define HTTP_CAPABILITY_UDP_TRANSPORT 0x20 #define HTTP_CAPABILITY_UDP_TRANSPORT 0x20
#define WEBSOCKET_MASK_BIT 0x80
#define WEBSOCKET_FIN_BIT 0x80
typedef enum _WEBSOCKET_OPCODE
{
WebsocketContinuationOpcode = 0x0,
WebsocketTextOpcode = 0x1,
WebsocketBinaryOpcode = 0x2,
WebsocketCloseOpcode = 0x8,
WebsocketPingOpcode = 0x9,
WebsocketPongOpcode = 0xa,
} WEBSOCKET_OPCODE;
typedef enum _WEBSOCKET_STATE
{
WebsocketStateOpcodeAndFin,
WebsocketStateLengthAndMasking,
WebsocketStateShortLength,
WebsocketStateLongLength,
WebSocketStateMaskingKey,
WebSocketStatePayload,
} WEBSOCKET_STATE;
typedef struct
{
size_t payloadLength;
uint32_t maskingKey;
BOOL masking;
BOOL closeSent;
BYTE opcode;
BYTE fragmentOriginalOpcode;
BYTE lengthAndMaskPosition;
WEBSOCKET_STATE state;
wStream* responseStreamBuffer;
} rdg_http_websocket_context;
typedef enum _CHUNK_STATE typedef enum _CHUNK_STATE
{ {
ChunkStateLenghHeader, ChunkStateLenghHeader,
@ -122,9 +158,11 @@ typedef struct
typedef struct typedef struct
{ {
TRANSFER_ENCODING httpTransferEncoding; TRANSFER_ENCODING httpTransferEncoding;
BOOL isWebsocketTransport;
union _context union _context
{ {
rdg_http_encoding_chunked_context chunked; rdg_http_encoding_chunked_context chunked;
rdg_http_websocket_context websocket;
} context; } context;
} rdg_http_encoding_context; } rdg_http_encoding_context;
@ -293,9 +331,9 @@ static BOOL rdg_read_http_unicode_string(wStream* s, const WCHAR** string, UINT1
return TRUE; return TRUE;
} }
static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket) static BOOL rdg_write_chunked(BIO* bio, wStream* sPacket)
{ {
size_t s; size_t len;
int status; int status;
wStream* sChunk; wStream* sChunk;
char chunkSize[11]; char chunkSize[11];
@ -309,44 +347,436 @@ static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket)
Stream_Write(sChunk, Stream_Buffer(sPacket), Stream_Length(sPacket)); Stream_Write(sChunk, Stream_Buffer(sPacket), Stream_Length(sPacket));
Stream_Write(sChunk, "\r\n", 2); Stream_Write(sChunk, "\r\n", 2);
Stream_SealLength(sChunk); Stream_SealLength(sChunk);
s = Stream_Length(sChunk); len = Stream_Length(sChunk);
if (s > INT_MAX) if (len > INT_MAX)
return FALSE; return FALSE;
status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)s); status = BIO_write(bio, Stream_Buffer(sChunk), (int)len);
Stream_Free(sChunk, TRUE); Stream_Free(sChunk, TRUE);
if (status < 0) if (status != len)
return FALSE; return FALSE;
return TRUE; return TRUE;
} }
static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size, static BOOL rdg_write_websocket(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode)
rdg_http_encoding_context* encodingContext) {
size_t len;
size_t fullLen;
int status;
wStream* sWS;
uint32_t maskingKey;
size_t streamPos;
len = Stream_Length(sPacket);
Stream_SetPosition(sPacket, 0);
if (len > INT_MAX)
return FALSE;
if (len < 126)
fullLen = len + 6; /* 2 byte "mini header" + 4 byte masking key */
else if (len < 0x10000)
fullLen = len + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */
else
fullLen = len + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */
sWS = Stream_New(NULL, fullLen);
if (!sWS)
return FALSE;
winpr_RAND((BYTE*)&maskingKey, 4);
Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | opcode);
if (len < 126)
Stream_Write_UINT8(sWS, len | WEBSOCKET_MASK_BIT);
else if (len < 0x10000)
{
Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
Stream_Write_UINT16_BE(sWS, len);
}
else
{
Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
Stream_Write_UINT32_BE(sWS, 0); /* payload is limited to INT_MAX */
Stream_Write_UINT32_BE(sWS, len);
}
Stream_Write_UINT32(sWS, maskingKey);
/* mask as much as possible with 32bit access */
for (streamPos = 0; streamPos + 4 <= len; streamPos += 4)
{
uint32_t data;
Stream_Read_UINT32(sPacket, data);
Stream_Write_UINT32(sWS, data ^ maskingKey);
}
/* mask the rest byte by byte */
for (; streamPos < len; streamPos++)
{
BYTE data;
BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4);
Stream_Read_UINT8(sPacket, data);
Stream_Write_UINT8(sWS, data ^ *partialMask);
}
Stream_SealLength(sWS);
status = BIO_write(bio, Stream_Buffer(sWS), Stream_Length(sWS));
Stream_Free(sWS, TRUE);
if (status != fullLen)
return FALSE;
return TRUE;
}
static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket)
{
if (rdg->transferEncoding.isWebsocketTransport)
{
if (rdg->transferEncoding.context.websocket.closeSent)
return FALSE;
return rdg_write_websocket(rdg->tlsOut->bio, sPacket, WebsocketBinaryOpcode);
}
return rdg_write_chunked(rdg->tlsIn->bio, sPacket);
}
static int rdg_websocket_read_data(BIO* bio, BYTE* pBuffer, size_t size,
rdg_http_websocket_context* encodingContext)
{
int status;
if (encodingContext->payloadLength == 0)
{
encodingContext->state = WebsocketStateOpcodeAndFin;
return 0;
}
status =
BIO_read(bio, pBuffer,
(encodingContext->payloadLength < size ? encodingContext->payloadLength : size));
if (status <= 0)
return status;
encodingContext->payloadLength -= status;
if (encodingContext->payloadLength == 0)
encodingContext->state = WebsocketStateOpcodeAndFin;
return status;
}
static int rdg_websocket_read_discard(BIO* bio, rdg_http_websocket_context* encodingContext)
{
char _dummy[256];
int status;
if (encodingContext->payloadLength == 0)
{
encodingContext->state = WebsocketStateOpcodeAndFin;
return 0;
}
status = BIO_read(bio, _dummy, sizeof(_dummy));
if (status <= 0)
return status;
encodingContext->payloadLength -= status;
if (encodingContext->payloadLength == 0)
encodingContext->state = WebsocketStateOpcodeAndFin;
return status;
}
static int rdg_websocket_read_wstream(BIO* bio, wStream* s,
rdg_http_websocket_context* encodingContext)
{
int status;
if (encodingContext->payloadLength == 0)
{
encodingContext->state = WebsocketStateOpcodeAndFin;
return 0;
}
if (s == NULL || Stream_GetRemainingCapacity(s) != encodingContext->payloadLength)
return -1;
status = BIO_read(bio, Stream_Pointer(s), encodingContext->payloadLength);
if (status <= 0)
return status;
Stream_Seek(s, status);
encodingContext->payloadLength -= status;
if (encodingContext->payloadLength == 0)
{
encodingContext->state = WebsocketStateOpcodeAndFin;
Stream_SealLength(s);
Stream_SetPosition(s, 0);
}
return status;
}
static BOOL rdg_websocket_reply_close(BIO* bio, wStream* s)
{
/* write back close */
wStream* closeFrame;
uint16_t maskingKey1;
uint16_t maskingKey2;
int status;
size_t closeDataLen;
closeDataLen = 0;
if (s != NULL && Stream_Length(s) >= 2)
closeDataLen = 2;
closeFrame = Stream_New(NULL, 6 + closeDataLen);
Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketPongOpcode);
Stream_Write_UINT8(closeFrame, closeDataLen | WEBSOCKET_MASK_BIT); /* no payload */
winpr_RAND((BYTE*)&maskingKey1, 2);
winpr_RAND((BYTE*)&maskingKey2, 2);
Stream_Write_UINT16(closeFrame, maskingKey1);
Stream_Write_UINT16(closeFrame, maskingKey2); /* unused half, max 2 bytes of data */
if (closeDataLen == 2)
{
uint16_t data;
Stream_Read_UINT16(s, data);
Stream_Write_UINT16(s, data ^ maskingKey1);
}
Stream_SealLength(closeFrame);
status = BIO_write(bio, Stream_Buffer(closeFrame), Stream_Length(closeFrame));
/* server MUST close socket now. The server is not allowed anymore to
* send frames but if he does, nothing bad would happen */
if (status < 0)
return FALSE;
return TRUE;
}
static BOOL rdg_websocket_reply_pong(BIO* bio, wStream* s)
{
wStream* closeFrame;
uint32_t maskingKey;
int status;
if (s != NULL)
return rdg_write_websocket(bio, s, WebsocketPongOpcode);
closeFrame = Stream_New(NULL, 6);
Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketPongOpcode);
Stream_Write_UINT8(closeFrame, 0 | WEBSOCKET_MASK_BIT); /* no payload */
winpr_RAND((BYTE*)&maskingKey, 4);
Stream_Write_UINT32(closeFrame, maskingKey); /* dummy masking key. */
Stream_SealLength(closeFrame);
status = BIO_write(bio, Stream_Buffer(closeFrame), Stream_Length(closeFrame));
if (status < 0)
return FALSE;
return TRUE;
}
static int rdg_websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size,
rdg_http_websocket_context* encodingContext)
{
int status;
BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode
? encodingContext->fragmentOriginalOpcode & 0xf
: encodingContext->opcode & 0xf);
switch (effectiveOpcode)
{
case WebsocketBinaryOpcode:
{
status = rdg_websocket_read_data(bio, pBuffer, size, encodingContext);
if (status < 0)
return status;
return status;
}
break;
case WebsocketPingOpcode:
{
if (encodingContext->responseStreamBuffer == NULL)
encodingContext->responseStreamBuffer =
Stream_New(NULL, encodingContext->payloadLength);
status = rdg_websocket_read_wstream(bio, encodingContext->responseStreamBuffer,
encodingContext);
if (status < 0)
return status;
if (encodingContext->payloadLength == 0)
{
if (!encodingContext->closeSent)
rdg_websocket_reply_pong(bio, encodingContext->responseStreamBuffer);
if (encodingContext->responseStreamBuffer)
Stream_Free(encodingContext->responseStreamBuffer, TRUE);
encodingContext->responseStreamBuffer = NULL;
}
}
break;
case WebsocketCloseOpcode:
{
if (encodingContext->responseStreamBuffer == NULL)
encodingContext->responseStreamBuffer =
Stream_New(NULL, encodingContext->payloadLength);
status = rdg_websocket_read_wstream(bio, encodingContext->responseStreamBuffer,
encodingContext);
if (status < 0)
return status;
if (encodingContext->payloadLength == 0)
{
rdg_websocket_reply_close(bio, encodingContext->responseStreamBuffer);
encodingContext->closeSent = TRUE;
if (encodingContext->responseStreamBuffer)
Stream_Free(encodingContext->responseStreamBuffer, TRUE);
encodingContext->responseStreamBuffer = NULL;
}
}
break;
default:
WLog_WARN(TAG, "Unimplemented websocket opcode %x. Dropping", effectiveOpcode & 0xf);
status = rdg_websocket_read_discard(bio, encodingContext);
if (status < 0)
return status;
}
/* return how many bytes have been written to pBuffer.
* Only WebsocketBinaryOpcode writes into it and it returns directly */
return 0;
}
static int rdg_websocket_read(BIO* bio, BYTE* pBuffer, size_t size,
rdg_http_websocket_context* encodingContext)
{ {
int status; int status;
int effectiveDataLen = 0; int effectiveDataLen = 0;
assert(encodingContext != NULL); assert(encodingContext != NULL);
while (TRUE) while (TRUE)
{ {
switch (encodingContext->context.chunked.state) switch (encodingContext->state)
{ {
case ChunkStateData: case WebsocketStateOpcodeAndFin:
{ {
status = BIO_read(bio, pBuffer, BYTE buffer[1];
(size > encodingContext->context.chunked.nextOffset status = BIO_read(bio, (char*)buffer, 1);
? encodingContext->context.chunked.nextOffset
: size));
if (status <= 0) if (status <= 0)
return (effectiveDataLen > 0 ? effectiveDataLen : status); return (effectiveDataLen > 0 ? effectiveDataLen : status);
encodingContext->context.chunked.nextOffset -= status; encodingContext->opcode = buffer[0];
if (encodingContext->context.chunked.nextOffset == 0) if (((encodingContext->opcode & 0xf) != WebsocketContinuationOpcode) &&
(encodingContext->opcode & 0xf) < 0x08)
encodingContext->fragmentOriginalOpcode = encodingContext->opcode;
encodingContext->state = WebsocketStateLengthAndMasking;
}
break;
case WebsocketStateLengthAndMasking:
{ {
encodingContext->context.chunked.state = ChunkStateFooter; BYTE buffer[1];
encodingContext->context.chunked.headerFooterPos = 0; BYTE len;
status = BIO_read(bio, (char*)buffer, 1);
if (status <= 0)
return (effectiveDataLen > 0 ? effectiveDataLen : status);
encodingContext->masking = ((buffer[0] & WEBSOCKET_MASK_BIT) == WEBSOCKET_MASK_BIT);
encodingContext->lengthAndMaskPosition = 0;
encodingContext->payloadLength = 0;
len = buffer[0] & 0x7f;
if (len < 126)
{
encodingContext->payloadLength = len;
encodingContext->state = (encodingContext->masking ? WebSocketStateMaskingKey
: WebSocketStatePayload);
}
else if (len == 126)
encodingContext->state = WebsocketStateShortLength;
else
encodingContext->state = WebsocketStateLongLength;
}
break;
case WebsocketStateShortLength:
case WebsocketStateLongLength:
{
BYTE buffer[1];
BYTE lenLength = (encodingContext->state == WebsocketStateShortLength ? 2 : 8);
while (encodingContext->lengthAndMaskPosition < lenLength)
{
status = BIO_read(bio, (char*)buffer, 1);
if (status <= 0)
return (effectiveDataLen > 0 ? effectiveDataLen : status);
encodingContext->payloadLength =
(encodingContext->payloadLength) << 8 | buffer[0];
encodingContext->lengthAndMaskPosition += status;
}
encodingContext->state =
(encodingContext->masking ? WebSocketStateMaskingKey : WebSocketStatePayload);
}
break;
case WebSocketStateMaskingKey:
{
WLog_WARN(
TAG, "Websocket Server sends data with masking key. This is against RFC 6455.");
return -1;
}
break;
case WebSocketStatePayload:
{
status = rdg_websocket_handle_payload(bio, pBuffer, size, encodingContext);
if (status < 0)
return (effectiveDataLen > 0 ? effectiveDataLen : status);
effectiveDataLen += status;
if ((size_t)status == size)
return effectiveDataLen;
pBuffer += status;
size -= status;
}
}
}
/* should be unreachable */
return -1;
}
static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size,
rdg_http_encoding_chunked_context* encodingContext)
{
int status;
int effectiveDataLen = 0;
assert(encodingContext != NULL);
while (TRUE)
{
switch (encodingContext->state)
{
case ChunkStateData:
{
status = BIO_read(
bio, pBuffer,
(size > encodingContext->nextOffset ? encodingContext->nextOffset : size));
if (status <= 0)
return (effectiveDataLen > 0 ? effectiveDataLen : status);
encodingContext->nextOffset -= status;
if (encodingContext->nextOffset == 0)
{
encodingContext->state = ChunkStateFooter;
encodingContext->headerFooterPos = 0;
} }
effectiveDataLen += status; effectiveDataLen += status;
@ -360,17 +790,16 @@ static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size,
case ChunkStateFooter: case ChunkStateFooter:
{ {
char _dummy[2]; char _dummy[2];
assert(encodingContext->context.chunked.nextOffset == 0); assert(encodingContext->nextOffset == 0);
assert(encodingContext->context.chunked.headerFooterPos < 2); assert(encodingContext->headerFooterPos < 2);
status = status = BIO_read(bio, _dummy, 2 - encodingContext->headerFooterPos);
BIO_read(bio, _dummy, 2 - encodingContext->context.chunked.headerFooterPos);
if (status >= 0) if (status >= 0)
{ {
encodingContext->context.chunked.headerFooterPos += status; encodingContext->headerFooterPos += status;
if (encodingContext->context.chunked.headerFooterPos == 2) if (encodingContext->headerFooterPos == 2)
{ {
encodingContext->context.chunked.state = ChunkStateLenghHeader; encodingContext->state = ChunkStateLenghHeader;
encodingContext->context.chunked.headerFooterPos = 0; encodingContext->headerFooterPos = 0;
} }
} }
else else
@ -381,43 +810,40 @@ static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size,
{ {
BOOL _haveNewLine = FALSE; BOOL _haveNewLine = FALSE;
size_t tmp; size_t tmp;
char* dst = &encodingContext->context.chunked char* dst = &encodingContext->lenBuffer[encodingContext->headerFooterPos];
.lenBuffer[encodingContext->context.chunked.headerFooterPos]; assert(encodingContext->nextOffset == 0);
assert(encodingContext->context.chunked.nextOffset == 0); while (encodingContext->headerFooterPos < 10 && !_haveNewLine)
while (encodingContext->context.chunked.headerFooterPos < 10 && !_haveNewLine)
{ {
status = BIO_read(bio, dst, 1); status = BIO_read(bio, dst, 1);
if (status >= 0) if (status >= 0)
{ {
if (*dst == '\n') if (*dst == '\n')
_haveNewLine = TRUE; _haveNewLine = TRUE;
encodingContext->context.chunked.headerFooterPos += status; encodingContext->headerFooterPos += status;
dst += status; dst += status;
} }
else else
return (effectiveDataLen > 0 ? effectiveDataLen : status); return (effectiveDataLen > 0 ? effectiveDataLen : status);
} }
*dst = '\0'; *dst = '\0';
/* strtoul is tricky, error are reported via errno, we also need /* strtoul is tricky, error are reported via errno, we also need
* to ensure the result does not overflow */ * to ensure the result does not overflow */
errno = 0; errno = 0;
tmp = strtoul(encodingContext->context.chunked.lenBuffer, NULL, 16); tmp = strtoul(encodingContext->lenBuffer, NULL, 16);
if ((errno != 0) || (tmp > SIZE_MAX)) if ((errno != 0) || (tmp > SIZE_MAX))
return -1; return -1;
encodingContext->context.chunked.nextOffset = tmp; encodingContext->nextOffset = tmp;
encodingContext->context.chunked.state = ChunkStateData; encodingContext->state = ChunkStateData;
if (encodingContext->context.chunked.nextOffset == 0) if (encodingContext->nextOffset == 0)
{ // end of stream { /* end of stream */
int fd = BIO_get_fd(bio, NULL); int fd = BIO_get_fd(bio, NULL);
if (fd >= 0) if (fd >= 0)
close(fd); closesocket((SOCKET)fd);
WLog_WARN(TAG, "cunked encoding end of stream received"); WLog_WARN(TAG, "cunked encoding end of stream received");
encodingContext->context.chunked.headerFooterPos = 0; encodingContext->headerFooterPos = 0;
encodingContext->context.chunked.state = ChunkStateFooter; encodingContext->state = ChunkStateFooter;
} }
} }
break; break;
@ -433,12 +859,18 @@ static int rdg_socket_read(BIO* bio, BYTE* pBuffer, size_t size,
rdg_http_encoding_context* encodingContext) rdg_http_encoding_context* encodingContext)
{ {
assert(encodingContext != NULL); assert(encodingContext != NULL);
if (encodingContext->isWebsocketTransport)
{
return rdg_websocket_read(bio, pBuffer, size, &encodingContext->context.websocket);
}
switch (encodingContext->httpTransferEncoding) switch (encodingContext->httpTransferEncoding)
{ {
case TransferEncodingIdentity: case TransferEncodingIdentity:
return BIO_read(bio, pBuffer, size); return BIO_read(bio, pBuffer, size);
case TransferEncodingChunked: case TransferEncodingChunked:
return rdg_chuncked_read(bio, pBuffer, size, encodingContext); return rdg_chuncked_read(bio, pBuffer, size, &encodingContext->context.chunked);
default: default:
return -1; return -1;
} }
@ -454,7 +886,6 @@ static BOOL rdg_read_all(rdpTls* tls, BYTE* buffer, size_t size,
while (readCount < size) while (readCount < size)
{ {
int status = rdg_socket_read(tls->bio, pBuffer, size - readCount, transferEncoding); int status = rdg_socket_read(tls->bio, pBuffer, size - readCount, transferEncoding);
if (status <= 0) if (status <= 0)
{ {
if (!BIO_should_retry(tls->bio)) if (!BIO_should_retry(tls->bio))
@ -730,10 +1161,7 @@ static wStream* rdg_build_http_request(rdpRdg* rdg, const char* method,
goto out; goto out;
} }
if (transferEncoding)
{
http_request_set_transfer_encoding(request, transferEncoding); http_request_set_transfer_encoding(request, transferEncoding);
}
s = http_request_write(rdg->http, request); s = http_request_write(rdg->http, request);
out: out:
@ -1320,6 +1748,7 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char*
SSIZE_T bodyLength; SSIZE_T bodyLength;
long StatusCode; long StatusCode;
TRANSFER_ENCODING encoding; TRANSFER_ENCODING encoding;
BOOL isWebsocket;
if (!rdg_tls_connect(rdg, tls, peerAddress, timeout)) if (!rdg_tls_connect(rdg, tls, peerAddress, timeout))
return FALSE; return FALSE;
@ -1344,6 +1773,7 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char*
case HTTP_STATUS_NOT_FOUND: case HTTP_STATUS_NOT_FOUND:
{ {
WLog_INFO(TAG, "RD Gateway does not support HTTP transport."); WLog_INFO(TAG, "RD Gateway does not support HTTP transport.");
http_context_enable_websocket_upgrade(rdg->http, FALSE);
if (rpcFallback) if (rpcFallback)
*rpcFallback = TRUE; *rpcFallback = TRUE;
@ -1377,16 +1807,44 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char*
statusCode = http_response_get_status_code(response); statusCode = http_response_get_status_code(response);
bodyLength = http_response_get_body_length(response); bodyLength = http_response_get_body_length(response);
encoding = http_response_get_transfer_encoding(response); encoding = http_response_get_transfer_encoding(response);
isWebsocket = http_response_is_websocket(rdg->http, response);
http_response_free(response); http_response_free(response);
WLog_DBG(TAG, "%s authorization result: %d", method, statusCode); WLog_DBG(TAG, "%s authorization result: %d", method, statusCode);
switch (statusCode) switch (statusCode)
{ {
case HTTP_STATUS_OK: case HTTP_STATUS_OK:
/* old rdg endpoint without websocket support, don't request websocket for RDG_IN_DATA
*/
http_context_enable_websocket_upgrade(rdg->http, FALSE);
break; break;
case HTTP_STATUS_DENIED: case HTTP_STATUS_DENIED:
freerdp_set_last_error_log(rdg->context, FREERDP_ERROR_CONNECT_ACCESS_DENIED); freerdp_set_last_error_log(rdg->context, FREERDP_ERROR_CONNECT_ACCESS_DENIED);
return FALSE; return FALSE;
case HTTP_STATUS_SWITCH_PROTOCOLS:
if (!isWebsocket)
{
/*
* webserver is broken, a fallback may be possible here
* but only if already tested with oppurtonistic upgrade
*/
if (http_context_is_websocket_upgrade_enabled(rdg->http))
{
int fd = BIO_get_fd(tls->bio, NULL);
if (fd >= 0)
closesocket((SOCKET)fd);
http_context_enable_websocket_upgrade(rdg->http, FALSE);
return rdg_establish_data_connection(rdg, tls, method, peerAddress, timeout,
rpcFallback);
}
return FALSE;
}
rdg->transferEncoding.isWebsocketTransport = TRUE;
rdg->transferEncoding.context.websocket.state = WebsocketStateOpcodeAndFin;
rdg->transferEncoding.context.websocket.responseStreamBuffer = NULL;
return TRUE;
break;
default: default:
return FALSE; return FALSE;
} }
@ -1451,16 +1909,23 @@ BOOL rdg_connect(rdpRdg* rdg, DWORD timeout, BOOL* rpcFallback)
rdg_establish_data_connection(rdg, rdg->tlsOut, "RDG_OUT_DATA", NULL, timeout, rpcFallback); rdg_establish_data_connection(rdg, rdg->tlsOut, "RDG_OUT_DATA", NULL, timeout, rpcFallback);
if (status) if (status)
{
if (rdg->transferEncoding.isWebsocketTransport)
{
WLog_DBG(TAG, "Upgraded to websocket. RDG_IN_DATA not required");
}
else
{ {
/* Establish IN connection with the same peer/server as OUT connection, /* Establish IN connection with the same peer/server as OUT connection,
* even when server hostname resolves to different IP addresses. * even when server hostname resolves to different IP addresses.
*/ */
BIO_get_socket(rdg->tlsOut->underlying, &outConnSocket); BIO_get_socket(rdg->tlsOut->underlying, &outConnSocket);
peerAddress = freerdp_tcp_get_peer_address(outConnSocket); peerAddress = freerdp_tcp_get_peer_address(outConnSocket);
status = rdg_establish_data_connection(rdg, rdg->tlsIn, "RDG_IN_DATA", peerAddress, timeout, status = rdg_establish_data_connection(rdg, rdg->tlsIn, "RDG_IN_DATA", peerAddress,
NULL); timeout, NULL);
free(peerAddress); free(peerAddress);
} }
}
if (!status) if (!status)
{ {
@ -1476,10 +1941,97 @@ BOOL rdg_connect(rdpRdg* rdg, DWORD timeout, BOOL* rpcFallback)
return TRUE; return TRUE;
} }
static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize) static int rdg_write_websocket_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
{
size_t payloadSize;
size_t fullLen;
int status;
wStream* sWS;
uint32_t maskingKey;
BYTE* maskingKeyByte1 = (BYTE*)&maskingKey;
BYTE* maskingKeyByte2 = maskingKeyByte1 + 1;
BYTE* maskingKeyByte3 = maskingKeyByte1 + 2;
BYTE* maskingKeyByte4 = maskingKeyByte1 + 3;
int streamPos;
winpr_RAND((BYTE*)&maskingKey, 4);
payloadSize = isize + 10;
if ((isize < 0) || (isize > UINT16_MAX))
return -1;
if (payloadSize < 1)
return 0;
if (payloadSize < 126)
fullLen = payloadSize + 6; /* 2 byte "mini header" + 4 byte masking key */
else if (payloadSize < 0x10000)
fullLen = payloadSize + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */
else
fullLen = payloadSize + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */
sWS = Stream_New(NULL, fullLen);
if (!sWS)
return FALSE;
Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | WebsocketBinaryOpcode);
if (payloadSize < 126)
Stream_Write_UINT8(sWS, payloadSize | WEBSOCKET_MASK_BIT);
else if (payloadSize < 0x10000)
{
Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
Stream_Write_UINT16_BE(sWS, payloadSize);
}
else
{
Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
/* biggest packet possible is 0xffff + 0xa, so 32bit is always enough */
Stream_Write_UINT32_BE(sWS, 0);
Stream_Write_UINT32_BE(sWS, payloadSize);
}
Stream_Write_UINT32(sWS, maskingKey);
Stream_Write_UINT16(sWS, PKT_TYPE_DATA ^ (*maskingKeyByte1 | *maskingKeyByte2 << 8)); /* Type */
Stream_Write_UINT16(sWS, 0 ^ (*maskingKeyByte3 | *maskingKeyByte4 << 8)); /* Reserved */
Stream_Write_UINT32(sWS, (UINT32)payloadSize ^ maskingKey); /* Packet length */
Stream_Write_UINT16(sWS,
(UINT16)isize ^ (*maskingKeyByte1 | *maskingKeyByte2 << 8)); /* Data size */
/* masking key is now off by 2 bytes. fix that */
maskingKey = (maskingKey & 0xffff) << 16 | (maskingKey >> 16);
/* mask as much as possible with 32bit access */
for (streamPos = 0; streamPos + 4 <= isize; streamPos += 4)
{
uint32_t masked = *((uint32_t*)((BYTE*)buf + streamPos)) ^ maskingKey;
Stream_Write_UINT32(sWS, masked);
}
/* mask the rest byte by byte */
for (; streamPos < isize; streamPos++)
{
BYTE* partialMask = (BYTE*)(&maskingKey) + streamPos % 4;
BYTE masked = *((BYTE*)((BYTE*)buf + streamPos)) ^ *partialMask;
Stream_Write_UINT8(sWS, masked);
}
Stream_SealLength(sWS);
status = tls_write_all(rdg->tlsOut, Stream_Buffer(sWS), Stream_Length(sWS));
Stream_Free(sWS, TRUE);
if (status < 0)
return status;
return isize;
}
static int rdg_write_chunked_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
{ {
int status; int status;
size_t s; size_t len;
wStream* sChunk; wStream* sChunk;
size_t size = (size_t)isize; size_t size = (size_t)isize;
size_t packetSize = size + 10; size_t packetSize = size + 10;
@ -1505,12 +2057,12 @@ static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
Stream_Write(sChunk, buf, size); /* Data */ Stream_Write(sChunk, buf, size); /* Data */
Stream_Write(sChunk, "\r\n", 2); Stream_Write(sChunk, "\r\n", 2);
Stream_SealLength(sChunk); Stream_SealLength(sChunk);
s = Stream_Length(sChunk); len = Stream_Length(sChunk);
if (s > INT_MAX) if (len > INT_MAX)
return -1; return -1;
status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)s); status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)len);
Stream_Free(sChunk, TRUE); Stream_Free(sChunk, TRUE);
if (status < 0) if (status < 0)
@ -1519,18 +2071,26 @@ static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
return (int)size; return (int)size;
} }
static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
{
if (rdg->transferEncoding.isWebsocketTransport)
{
if (rdg->transferEncoding.context.websocket.closeSent == TRUE)
return -1;
return rdg_write_websocket_data_packet(rdg, buf, isize);
}
else
return rdg_write_chunked_data_packet(rdg, buf, isize);
return -1;
}
static BOOL rdg_process_close_packet(rdpRdg* rdg, wStream* s) static BOOL rdg_process_close_packet(rdpRdg* rdg, wStream* s)
{ {
int status = -1; int status = -1;
size_t len; wStream* sClose;
wStream* sChunk;
UINT32 errorCode; UINT32 errorCode;
UINT32 packetSize = 12; UINT32 packetSize = 12;
char chunkSize[11];
int chunkLen = sprintf_s(chunkSize, sizeof(chunkSize), "%" PRIx32 "\r\n", packetSize);
if (chunkLen < 0)
return FALSE;
/* Read error code */ /* Read error code */
if (Stream_GetRemainingLength(s) < 4) if (Stream_GetRemainingLength(s) < 4)
@ -1540,55 +2100,39 @@ static BOOL rdg_process_close_packet(rdpRdg* rdg, wStream* s)
if (errorCode != 0) if (errorCode != 0)
freerdp_set_last_error_log(rdg->context, errorCode); freerdp_set_last_error_log(rdg->context, errorCode);
sChunk = Stream_New(NULL, (size_t)chunkLen + packetSize + 2); sClose = Stream_New(NULL, packetSize);
if (!sChunk) if (!sClose)
return FALSE; return FALSE;
Stream_Write(sChunk, chunkSize, (size_t)chunkLen); Stream_Write_UINT16(sClose, PKT_TYPE_CLOSE_CHANNEL_RESPONSE); /* Type */
Stream_Write_UINT16(sChunk, PKT_TYPE_CLOSE_CHANNEL_RESPONSE); /* Type */ Stream_Write_UINT16(sClose, 0); /* Reserved */
Stream_Write_UINT16(sChunk, 0); /* Reserved */ Stream_Write_UINT32(sClose, packetSize); /* Packet length */
Stream_Write_UINT32(sChunk, packetSize); /* Packet length */ Stream_Write_UINT32(sClose, 0); /* Status code */
Stream_Write_UINT32(sChunk, 0); /* Status code */ Stream_SealLength(sClose);
Stream_Write(sChunk, "\r\n", 2); status = rdg_write_packet(rdg, sClose);
Stream_SealLength(sChunk); Stream_Free(sClose, TRUE);
len = Stream_Length(sChunk);
if (len <= INT_MAX)
status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)len);
Stream_Free(sChunk, TRUE);
return (status < 0 ? FALSE : TRUE); return (status < 0 ? FALSE : TRUE);
} }
static BOOL rdg_process_keep_alive_packet(rdpRdg* rdg) static BOOL rdg_process_keep_alive_packet(rdpRdg* rdg)
{ {
int status = -1; int status = -1;
size_t s; wStream* sKeepAlive;
wStream* sChunk;
size_t packetSize = 8; size_t packetSize = 8;
char chunkSize[11];
int chunkLen = sprintf_s(chunkSize, sizeof(chunkSize), "%" PRIxz "\r\n", packetSize);
if ((chunkLen < 0) || (packetSize > UINT32_MAX)) sKeepAlive = Stream_New(NULL, packetSize);
if (!sKeepAlive)
return FALSE; return FALSE;
sChunk = Stream_New(NULL, (size_t)chunkLen + packetSize + 2); Stream_Write_UINT16(sKeepAlive, PKT_TYPE_KEEPALIVE); /* Type */
Stream_Write_UINT16(sKeepAlive, 0); /* Reserved */
Stream_Write_UINT32(sKeepAlive, (UINT32)packetSize); /* Packet length */
Stream_SealLength(sKeepAlive);
status = rdg_write_packet(rdg, sKeepAlive);
Stream_Free(sKeepAlive, TRUE);
if (!sChunk)
return FALSE;
Stream_Write(sChunk, chunkSize, (size_t)chunkLen);
Stream_Write_UINT16(sChunk, PKT_TYPE_KEEPALIVE); /* Type */
Stream_Write_UINT16(sChunk, 0); /* Reserved */
Stream_Write_UINT32(sChunk, (UINT32)packetSize); /* Packet length */
Stream_Write(sChunk, "\r\n", 2);
Stream_SealLength(sChunk);
s = Stream_Length(sChunk);
if (s <= INT_MAX)
status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)s);
Stream_Free(sChunk, TRUE);
return (status < 0 ? FALSE : TRUE); return (status < 0 ? FALSE : TRUE);
} }
@ -2003,7 +2547,8 @@ rdpRdg* rdg_new(rdpContext* context)
!http_context_set_connection(rdg->http, "Keep-Alive") || !http_context_set_connection(rdg->http, "Keep-Alive") ||
!http_context_set_user_agent(rdg->http, "MS-RDGateway/1.0") || !http_context_set_user_agent(rdg->http, "MS-RDGateway/1.0") ||
!http_context_set_host(rdg->http, rdg->settings->GatewayHostname) || !http_context_set_host(rdg->http, rdg->settings->GatewayHostname) ||
!http_context_set_rdg_connection_id(rdg->http, bracedUuid)) !http_context_set_rdg_connection_id(rdg->http, bracedUuid) ||
!http_context_enable_websocket_upgrade(rdg->http, TRUE))
{ {
goto rdg_alloc_error; goto rdg_alloc_error;
} }
@ -2033,6 +2578,7 @@ rdpRdg* rdg_new(rdpContext* context)
InitializeCriticalSection(&rdg->writeSection); InitializeCriticalSection(&rdg->writeSection);
rdg->transferEncoding.httpTransferEncoding = TransferEncodingIdentity; rdg->transferEncoding.httpTransferEncoding = TransferEncodingIdentity;
rdg->transferEncoding.isWebsocketTransport = FALSE;
} }
return rdg; return rdg;
@ -2056,6 +2602,12 @@ void rdg_free(rdpRdg* rdg)
DeleteCriticalSection(&rdg->writeSection); DeleteCriticalSection(&rdg->writeSection);
if (rdg->transferEncoding.isWebsocketTransport)
{
if (rdg->transferEncoding.context.websocket.responseStreamBuffer != NULL)
Stream_Free(rdg->transferEncoding.context.websocket.responseStreamBuffer, TRUE);
}
free(rdg); free(rdg);
} }