rdg websocket support

This commit is contained in:
Michael Saxl 2021-01-25 16:20:18 +01:00 committed by akallabeth
parent b05fea134e
commit bc52147fbb
3 changed files with 772 additions and 105 deletions

View File

@ -30,6 +30,9 @@
#include <freerdp/log.h>
/* websocket need sha1 for Sec-Websocket-Accept */
#include <winpr/crypto.h>
#ifdef HAVE_VALGRIND_MEMCHECK_H
#include <valgrind/memcheck.h>
#endif
@ -40,6 +43,8 @@
#define RESPONSE_SIZE_LIMIT 64 * 1024 * 1024
#define WEBSOCKET_MAGIC_GUID "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
struct _http_context
{
char* Method;
@ -52,6 +57,8 @@ struct _http_context
char* Pragma;
char* RdgConnectionId;
char* RdgAuthScheme;
BOOL websocketUpgrade;
char SecWebsocketKey[16];
};
struct _http_request
@ -77,6 +84,8 @@ struct _http_response
size_t ContentLength;
const char* ContentType;
TRANSFER_ENCODING TransferEncoding;
const char* SecWebsocketVersion;
const char* SecWebsocketAccept;
size_t BodyLength;
BYTE* BodyContent;
@ -259,6 +268,30 @@ BOOL http_context_set_rdg_connection_id(HttpContext* context, const char* RdgCon
return TRUE;
}
BOOL http_context_enable_websocket_upgrade(HttpContext* context, BOOL enable)
{
if (!context)
return FALSE;
if (enable)
{
int i;
winpr_RAND((BYTE*)context->SecWebsocketKey, 15);
for (i = 0; i < 16; i++)
context->SecWebsocketKey[i] = (context->SecWebsocketKey[i] | 0x40) & 0x5f;
context->SecWebsocketKey[15] = '\0';
}
else
context->SecWebsocketKey[0] = '\0';
context->websocketUpgrade = enable;
return TRUE;
}
BOOL http_context_is_websocket_upgrade_enabled(HttpContext* context)
{
return context->websocketUpgrade;
}
BOOL http_context_set_rdg_auth_scheme(HttpContext* context, const char* RdgAuthScheme)
{
if (!context || !RdgAuthScheme)
@ -426,13 +459,26 @@ wStream* http_request_write(HttpContext* context, HttpRequest* request)
if (!http_encode_header_line(s, request->Method, request->URI) ||
!http_encode_body_line(s, "Cache-Control", context->CacheControl) ||
!http_encode_body_line(s, "Connection", context->Connection) ||
!http_encode_body_line(s, "Pragma", context->Pragma) ||
!http_encode_body_line(s, "Accept", context->Accept) ||
!http_encode_body_line(s, "User-Agent", context->UserAgent) ||
!http_encode_body_line(s, "Host", context->Host))
goto fail;
if (!context->websocketUpgrade)
{
if (!http_encode_body_line(s, "Connection", context->Connection))
goto fail;
}
else
{
if (!http_encode_body_line(s, "Connection", "Upgrade") ||
!http_encode_body_line(s, "Upgrade", "websocket") ||
!http_encode_body_line(s, "Sec-Websocket-Version", "13") ||
!http_encode_body_line(s, "Sec-Websocket-Key", context->SecWebsocketKey))
goto fail;
}
if (context->RdgConnectionId)
{
if (!http_encode_body_line(s, "RDG-Connection-Id", context->RdgConnectionId))
@ -556,7 +602,6 @@ static BOOL http_response_parse_header_field(HttpResponse* response, const char*
const char* value)
{
BOOL status = TRUE;
if (!response || !name)
return FALSE;
@ -587,6 +632,20 @@ static BOOL http_response_parse_header_field(HttpResponse* response, const char*
else
response->TransferEncoding = TransferEncodingUnknown;
}
else if (_stricmp(name, "Sec-WebSocket-Version") == 0)
{
response->SecWebsocketVersion = value;
if (!response->SecWebsocketVersion)
return FALSE;
}
else if (_stricmp(name, "Sec-WebSocket-Accept") == 0)
{
response->SecWebsocketAccept = value;
if (!response->SecWebsocketAccept)
return FALSE;
}
else if (_stricmp(name, "WWW-Authenticate") == 0)
{
char* separator = NULL;
@ -1041,3 +1100,56 @@ TRANSFER_ENCODING http_response_get_transfer_encoding(HttpResponse* response)
return response->TransferEncoding;
}
BOOL http_response_is_websocket(HttpContext* http, HttpResponse* response)
{
BOOL isWebsocket = FALSE;
WINPR_DIGEST_CTX* sha1 = NULL;
char* base64accept = NULL;
BYTE sha1_digest[WINPR_SHA1_DIGEST_LENGTH];
if (!http || !response)
return FALSE;
if (!http->websocketUpgrade || response->StatusCode != HTTP_STATUS_SWITCH_PROTOCOLS)
return FALSE;
if (response->SecWebsocketVersion && _stricmp(response->SecWebsocketVersion, "13") != 0)
return FALSE;
if (!response->SecWebsocketAccept)
return FALSE;
/* now check if Sec-Websocket-Accept is correct */
sha1 = winpr_Digest_New();
if (!sha1)
goto out;
if (!winpr_Digest_Init(sha1, WINPR_MD_SHA1))
goto out;
if (!winpr_Digest_Update(sha1, (const BYTE*)http->SecWebsocketKey,
strlen(http->SecWebsocketKey)))
goto out;
if (!winpr_Digest_Update(sha1, (const BYTE*)WEBSOCKET_MAGIC_GUID, strlen(WEBSOCKET_MAGIC_GUID)))
goto out;
if (!winpr_Digest_Final(sha1, sha1_digest, sizeof(sha1_digest)))
goto out;
base64accept = crypto_base64_encode(sha1_digest, WINPR_SHA1_DIGEST_LENGTH);
if (!base64accept)
goto out;
if (_stricmp(response->SecWebsocketAccept, base64accept) != 0)
{
WLog_WARN(TAG, "Webserver gave Websocket Upgrade response but sanity check failed");
goto out;
}
isWebsocket = TRUE;
out:
winpr_Digest_Free(sha1);
free(base64accept);
return isWebsocket;
}

View File

@ -52,6 +52,8 @@ FREERDP_LOCAL BOOL http_context_set_rdg_connection_id(HttpContext* context,
const char* RdgConnectionId);
FREERDP_LOCAL BOOL http_context_set_rdg_auth_scheme(HttpContext* context,
const char* RdgAuthScheme);
FREERDP_LOCAL BOOL http_context_enable_websocket_upgrade(HttpContext* context, BOOL enable);
FREERDP_LOCAL BOOL http_context_is_websocket_upgrade_enabled(HttpContext* context);
/* HTTP request */
typedef struct _http_request HttpRequest;
@ -85,5 +87,6 @@ FREERDP_LOCAL long http_response_get_status_code(HttpResponse* response);
FREERDP_LOCAL SSIZE_T http_response_get_body_length(HttpResponse* response);
FREERDP_LOCAL const char* http_response_get_auth_token(HttpResponse* response, const char* method);
FREERDP_LOCAL TRANSFER_ENCODING http_response_get_transfer_encoding(HttpResponse* response);
FREERDP_LOCAL BOOL http_response_is_websocket(HttpContext* http, HttpResponse* response);
#endif /* FREERDP_LIB_CORE_GATEWAY_HTTP_H */

View File

@ -104,6 +104,42 @@
#define HTTP_CAPABILITY_REAUTH 0x10
#define HTTP_CAPABILITY_UDP_TRANSPORT 0x20
#define WEBSOCKET_MASK_BIT 0x80
#define WEBSOCKET_FIN_BIT 0x80
typedef enum _WEBSOCKET_OPCODE
{
WebsocketContinuationOpcode = 0x0,
WebsocketTextOpcode = 0x1,
WebsocketBinaryOpcode = 0x2,
WebsocketCloseOpcode = 0x8,
WebsocketPingOpcode = 0x9,
WebsocketPongOpcode = 0xa,
} WEBSOCKET_OPCODE;
typedef enum _WEBSOCKET_STATE
{
WebsocketStateOpcodeAndFin,
WebsocketStateLengthAndMasking,
WebsocketStateShortLength,
WebsocketStateLongLength,
WebSocketStateMaskingKey,
WebSocketStatePayload,
} WEBSOCKET_STATE;
typedef struct
{
size_t payloadLength;
uint32_t maskingKey;
BOOL masking;
BOOL closeSent;
BYTE opcode;
BYTE fragmentOriginalOpcode;
BYTE lengthAndMaskPosition;
WEBSOCKET_STATE state;
wStream* responseStreamBuffer;
} rdg_http_websocket_context;
typedef enum _CHUNK_STATE
{
ChunkStateLenghHeader,
@ -122,9 +158,11 @@ typedef struct
typedef struct
{
TRANSFER_ENCODING httpTransferEncoding;
BOOL isWebsocketTransport;
union _context
{
rdg_http_encoding_chunked_context chunked;
rdg_http_websocket_context websocket;
} context;
} rdg_http_encoding_context;
@ -293,9 +331,9 @@ static BOOL rdg_read_http_unicode_string(wStream* s, const WCHAR** string, UINT1
return TRUE;
}
static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket)
static BOOL rdg_write_chunked(BIO* bio, wStream* sPacket)
{
size_t s;
size_t len;
int status;
wStream* sChunk;
char chunkSize[11];
@ -309,44 +347,436 @@ static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket)
Stream_Write(sChunk, Stream_Buffer(sPacket), Stream_Length(sPacket));
Stream_Write(sChunk, "\r\n", 2);
Stream_SealLength(sChunk);
s = Stream_Length(sChunk);
len = Stream_Length(sChunk);
if (s > INT_MAX)
if (len > INT_MAX)
return FALSE;
status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)s);
status = BIO_write(bio, Stream_Buffer(sChunk), (int)len);
Stream_Free(sChunk, TRUE);
if (status < 0)
if (status != len)
return FALSE;
return TRUE;
}
static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size,
rdg_http_encoding_context* encodingContext)
static BOOL rdg_write_websocket(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode)
{
size_t len;
size_t fullLen;
int status;
wStream* sWS;
uint32_t maskingKey;
size_t streamPos;
len = Stream_Length(sPacket);
Stream_SetPosition(sPacket, 0);
if (len > INT_MAX)
return FALSE;
if (len < 126)
fullLen = len + 6; /* 2 byte "mini header" + 4 byte masking key */
else if (len < 0x10000)
fullLen = len + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */
else
fullLen = len + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */
sWS = Stream_New(NULL, fullLen);
if (!sWS)
return FALSE;
winpr_RAND((BYTE*)&maskingKey, 4);
Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | opcode);
if (len < 126)
Stream_Write_UINT8(sWS, len | WEBSOCKET_MASK_BIT);
else if (len < 0x10000)
{
Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
Stream_Write_UINT16_BE(sWS, len);
}
else
{
Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
Stream_Write_UINT32_BE(sWS, 0); /* payload is limited to INT_MAX */
Stream_Write_UINT32_BE(sWS, len);
}
Stream_Write_UINT32(sWS, maskingKey);
/* mask as much as possible with 32bit access */
for (streamPos = 0; streamPos + 4 <= len; streamPos += 4)
{
uint32_t data;
Stream_Read_UINT32(sPacket, data);
Stream_Write_UINT32(sWS, data ^ maskingKey);
}
/* mask the rest byte by byte */
for (; streamPos < len; streamPos++)
{
BYTE data;
BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4);
Stream_Read_UINT8(sPacket, data);
Stream_Write_UINT8(sWS, data ^ *partialMask);
}
Stream_SealLength(sWS);
status = BIO_write(bio, Stream_Buffer(sWS), Stream_Length(sWS));
Stream_Free(sWS, TRUE);
if (status != fullLen)
return FALSE;
return TRUE;
}
static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket)
{
if (rdg->transferEncoding.isWebsocketTransport)
{
if (rdg->transferEncoding.context.websocket.closeSent)
return FALSE;
return rdg_write_websocket(rdg->tlsOut->bio, sPacket, WebsocketBinaryOpcode);
}
return rdg_write_chunked(rdg->tlsIn->bio, sPacket);
}
static int rdg_websocket_read_data(BIO* bio, BYTE* pBuffer, size_t size,
rdg_http_websocket_context* encodingContext)
{
int status;
if (encodingContext->payloadLength == 0)
{
encodingContext->state = WebsocketStateOpcodeAndFin;
return 0;
}
status =
BIO_read(bio, pBuffer,
(encodingContext->payloadLength < size ? encodingContext->payloadLength : size));
if (status <= 0)
return status;
encodingContext->payloadLength -= status;
if (encodingContext->payloadLength == 0)
encodingContext->state = WebsocketStateOpcodeAndFin;
return status;
}
static int rdg_websocket_read_discard(BIO* bio, rdg_http_websocket_context* encodingContext)
{
char _dummy[256];
int status;
if (encodingContext->payloadLength == 0)
{
encodingContext->state = WebsocketStateOpcodeAndFin;
return 0;
}
status = BIO_read(bio, _dummy, sizeof(_dummy));
if (status <= 0)
return status;
encodingContext->payloadLength -= status;
if (encodingContext->payloadLength == 0)
encodingContext->state = WebsocketStateOpcodeAndFin;
return status;
}
static int rdg_websocket_read_wstream(BIO* bio, wStream* s,
rdg_http_websocket_context* encodingContext)
{
int status;
if (encodingContext->payloadLength == 0)
{
encodingContext->state = WebsocketStateOpcodeAndFin;
return 0;
}
if (s == NULL || Stream_GetRemainingCapacity(s) != encodingContext->payloadLength)
return -1;
status = BIO_read(bio, Stream_Pointer(s), encodingContext->payloadLength);
if (status <= 0)
return status;
Stream_Seek(s, status);
encodingContext->payloadLength -= status;
if (encodingContext->payloadLength == 0)
{
encodingContext->state = WebsocketStateOpcodeAndFin;
Stream_SealLength(s);
Stream_SetPosition(s, 0);
}
return status;
}
static BOOL rdg_websocket_reply_close(BIO* bio, wStream* s)
{
/* write back close */
wStream* closeFrame;
uint16_t maskingKey1;
uint16_t maskingKey2;
int status;
size_t closeDataLen;
closeDataLen = 0;
if (s != NULL && Stream_Length(s) >= 2)
closeDataLen = 2;
closeFrame = Stream_New(NULL, 6 + closeDataLen);
Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketPongOpcode);
Stream_Write_UINT8(closeFrame, closeDataLen | WEBSOCKET_MASK_BIT); /* no payload */
winpr_RAND((BYTE*)&maskingKey1, 2);
winpr_RAND((BYTE*)&maskingKey2, 2);
Stream_Write_UINT16(closeFrame, maskingKey1);
Stream_Write_UINT16(closeFrame, maskingKey2); /* unused half, max 2 bytes of data */
if (closeDataLen == 2)
{
uint16_t data;
Stream_Read_UINT16(s, data);
Stream_Write_UINT16(s, data ^ maskingKey1);
}
Stream_SealLength(closeFrame);
status = BIO_write(bio, Stream_Buffer(closeFrame), Stream_Length(closeFrame));
/* server MUST close socket now. The server is not allowed anymore to
* send frames but if he does, nothing bad would happen */
if (status < 0)
return FALSE;
return TRUE;
}
static BOOL rdg_websocket_reply_pong(BIO* bio, wStream* s)
{
wStream* closeFrame;
uint32_t maskingKey;
int status;
if (s != NULL)
return rdg_write_websocket(bio, s, WebsocketPongOpcode);
closeFrame = Stream_New(NULL, 6);
Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketPongOpcode);
Stream_Write_UINT8(closeFrame, 0 | WEBSOCKET_MASK_BIT); /* no payload */
winpr_RAND((BYTE*)&maskingKey, 4);
Stream_Write_UINT32(closeFrame, maskingKey); /* dummy masking key. */
Stream_SealLength(closeFrame);
status = BIO_write(bio, Stream_Buffer(closeFrame), Stream_Length(closeFrame));
if (status < 0)
return FALSE;
return TRUE;
}
static int rdg_websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size,
rdg_http_websocket_context* encodingContext)
{
int status;
BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode
? encodingContext->fragmentOriginalOpcode & 0xf
: encodingContext->opcode & 0xf);
switch (effectiveOpcode)
{
case WebsocketBinaryOpcode:
{
status = rdg_websocket_read_data(bio, pBuffer, size, encodingContext);
if (status < 0)
return status;
return status;
}
break;
case WebsocketPingOpcode:
{
if (encodingContext->responseStreamBuffer == NULL)
encodingContext->responseStreamBuffer =
Stream_New(NULL, encodingContext->payloadLength);
status = rdg_websocket_read_wstream(bio, encodingContext->responseStreamBuffer,
encodingContext);
if (status < 0)
return status;
if (encodingContext->payloadLength == 0)
{
if (!encodingContext->closeSent)
rdg_websocket_reply_pong(bio, encodingContext->responseStreamBuffer);
if (encodingContext->responseStreamBuffer)
Stream_Free(encodingContext->responseStreamBuffer, TRUE);
encodingContext->responseStreamBuffer = NULL;
}
}
break;
case WebsocketCloseOpcode:
{
if (encodingContext->responseStreamBuffer == NULL)
encodingContext->responseStreamBuffer =
Stream_New(NULL, encodingContext->payloadLength);
status = rdg_websocket_read_wstream(bio, encodingContext->responseStreamBuffer,
encodingContext);
if (status < 0)
return status;
if (encodingContext->payloadLength == 0)
{
rdg_websocket_reply_close(bio, encodingContext->responseStreamBuffer);
encodingContext->closeSent = TRUE;
if (encodingContext->responseStreamBuffer)
Stream_Free(encodingContext->responseStreamBuffer, TRUE);
encodingContext->responseStreamBuffer = NULL;
}
}
break;
default:
WLog_WARN(TAG, "Unimplemented websocket opcode %x. Dropping", effectiveOpcode & 0xf);
status = rdg_websocket_read_discard(bio, encodingContext);
if (status < 0)
return status;
}
/* return how many bytes have been written to pBuffer.
* Only WebsocketBinaryOpcode writes into it and it returns directly */
return 0;
}
static int rdg_websocket_read(BIO* bio, BYTE* pBuffer, size_t size,
rdg_http_websocket_context* encodingContext)
{
int status;
int effectiveDataLen = 0;
assert(encodingContext != NULL);
while (TRUE)
{
switch (encodingContext->context.chunked.state)
switch (encodingContext->state)
{
case ChunkStateData:
case WebsocketStateOpcodeAndFin:
{
status = BIO_read(bio, pBuffer,
(size > encodingContext->context.chunked.nextOffset
? encodingContext->context.chunked.nextOffset
: size));
BYTE buffer[1];
status = BIO_read(bio, (char*)buffer, 1);
if (status <= 0)
return (effectiveDataLen > 0 ? effectiveDataLen : status);
encodingContext->context.chunked.nextOffset -= status;
if (encodingContext->context.chunked.nextOffset == 0)
encodingContext->opcode = buffer[0];
if (((encodingContext->opcode & 0xf) != WebsocketContinuationOpcode) &&
(encodingContext->opcode & 0xf) < 0x08)
encodingContext->fragmentOriginalOpcode = encodingContext->opcode;
encodingContext->state = WebsocketStateLengthAndMasking;
}
break;
case WebsocketStateLengthAndMasking:
{
BYTE buffer[1];
BYTE len;
status = BIO_read(bio, (char*)buffer, 1);
if (status <= 0)
return (effectiveDataLen > 0 ? effectiveDataLen : status);
encodingContext->masking = ((buffer[0] & WEBSOCKET_MASK_BIT) == WEBSOCKET_MASK_BIT);
encodingContext->lengthAndMaskPosition = 0;
encodingContext->payloadLength = 0;
len = buffer[0] & 0x7f;
if (len < 126)
{
encodingContext->context.chunked.state = ChunkStateFooter;
encodingContext->context.chunked.headerFooterPos = 0;
encodingContext->payloadLength = len;
encodingContext->state = (encodingContext->masking ? WebSocketStateMaskingKey
: WebSocketStatePayload);
}
else if (len == 126)
encodingContext->state = WebsocketStateShortLength;
else
encodingContext->state = WebsocketStateLongLength;
}
break;
case WebsocketStateShortLength:
case WebsocketStateLongLength:
{
BYTE buffer[1];
BYTE lenLength = (encodingContext->state == WebsocketStateShortLength ? 2 : 8);
while (encodingContext->lengthAndMaskPosition < lenLength)
{
status = BIO_read(bio, (char*)buffer, 1);
if (status <= 0)
return (effectiveDataLen > 0 ? effectiveDataLen : status);
encodingContext->payloadLength =
(encodingContext->payloadLength) << 8 | buffer[0];
encodingContext->lengthAndMaskPosition += status;
}
encodingContext->state =
(encodingContext->masking ? WebSocketStateMaskingKey : WebSocketStatePayload);
}
break;
case WebSocketStateMaskingKey:
{
WLog_WARN(
TAG, "Websocket Server sends data with masking key. This is against RFC 6455.");
return -1;
}
break;
case WebSocketStatePayload:
{
status = rdg_websocket_handle_payload(bio, pBuffer, size, encodingContext);
if (status < 0)
return (effectiveDataLen > 0 ? effectiveDataLen : status);
effectiveDataLen += status;
if ((size_t)status == size)
return effectiveDataLen;
pBuffer += status;
size -= status;
}
}
}
/* should be unreachable */
return -1;
}
static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size,
rdg_http_encoding_chunked_context* encodingContext)
{
int status;
int effectiveDataLen = 0;
assert(encodingContext != NULL);
while (TRUE)
{
switch (encodingContext->state)
{
case ChunkStateData:
{
status = BIO_read(
bio, pBuffer,
(size > encodingContext->nextOffset ? encodingContext->nextOffset : size));
if (status <= 0)
return (effectiveDataLen > 0 ? effectiveDataLen : status);
encodingContext->nextOffset -= status;
if (encodingContext->nextOffset == 0)
{
encodingContext->state = ChunkStateFooter;
encodingContext->headerFooterPos = 0;
}
effectiveDataLen += status;
@ -360,17 +790,16 @@ static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size,
case ChunkStateFooter:
{
char _dummy[2];
assert(encodingContext->context.chunked.nextOffset == 0);
assert(encodingContext->context.chunked.headerFooterPos < 2);
status =
BIO_read(bio, _dummy, 2 - encodingContext->context.chunked.headerFooterPos);
assert(encodingContext->nextOffset == 0);
assert(encodingContext->headerFooterPos < 2);
status = BIO_read(bio, _dummy, 2 - encodingContext->headerFooterPos);
if (status >= 0)
{
encodingContext->context.chunked.headerFooterPos += status;
if (encodingContext->context.chunked.headerFooterPos == 2)
encodingContext->headerFooterPos += status;
if (encodingContext->headerFooterPos == 2)
{
encodingContext->context.chunked.state = ChunkStateLenghHeader;
encodingContext->context.chunked.headerFooterPos = 0;
encodingContext->state = ChunkStateLenghHeader;
encodingContext->headerFooterPos = 0;
}
}
else
@ -381,43 +810,40 @@ static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size,
{
BOOL _haveNewLine = FALSE;
size_t tmp;
char* dst = &encodingContext->context.chunked
.lenBuffer[encodingContext->context.chunked.headerFooterPos];
assert(encodingContext->context.chunked.nextOffset == 0);
while (encodingContext->context.chunked.headerFooterPos < 10 && !_haveNewLine)
char* dst = &encodingContext->lenBuffer[encodingContext->headerFooterPos];
assert(encodingContext->nextOffset == 0);
while (encodingContext->headerFooterPos < 10 && !_haveNewLine)
{
status = BIO_read(bio, dst, 1);
if (status >= 0)
{
if (*dst == '\n')
_haveNewLine = TRUE;
encodingContext->context.chunked.headerFooterPos += status;
encodingContext->headerFooterPos += status;
dst += status;
}
else
return (effectiveDataLen > 0 ? effectiveDataLen : status);
}
*dst = '\0';
/* strtoul is tricky, error are reported via errno, we also need
* to ensure the result does not overflow */
errno = 0;
tmp = strtoul(encodingContext->context.chunked.lenBuffer, NULL, 16);
tmp = strtoul(encodingContext->lenBuffer, NULL, 16);
if ((errno != 0) || (tmp > SIZE_MAX))
return -1;
encodingContext->context.chunked.nextOffset = tmp;
encodingContext->context.chunked.state = ChunkStateData;
encodingContext->nextOffset = tmp;
encodingContext->state = ChunkStateData;
if (encodingContext->context.chunked.nextOffset == 0)
{ // end of stream
if (encodingContext->nextOffset == 0)
{ /* end of stream */
int fd = BIO_get_fd(bio, NULL);
if (fd >= 0)
close(fd);
closesocket((SOCKET)fd);
WLog_WARN(TAG, "cunked encoding end of stream received");
encodingContext->context.chunked.headerFooterPos = 0;
encodingContext->context.chunked.state = ChunkStateFooter;
encodingContext->headerFooterPos = 0;
encodingContext->state = ChunkStateFooter;
}
}
break;
@ -433,12 +859,18 @@ static int rdg_socket_read(BIO* bio, BYTE* pBuffer, size_t size,
rdg_http_encoding_context* encodingContext)
{
assert(encodingContext != NULL);
if (encodingContext->isWebsocketTransport)
{
return rdg_websocket_read(bio, pBuffer, size, &encodingContext->context.websocket);
}
switch (encodingContext->httpTransferEncoding)
{
case TransferEncodingIdentity:
return BIO_read(bio, pBuffer, size);
case TransferEncodingChunked:
return rdg_chuncked_read(bio, pBuffer, size, encodingContext);
return rdg_chuncked_read(bio, pBuffer, size, &encodingContext->context.chunked);
default:
return -1;
}
@ -454,7 +886,6 @@ static BOOL rdg_read_all(rdpTls* tls, BYTE* buffer, size_t size,
while (readCount < size)
{
int status = rdg_socket_read(tls->bio, pBuffer, size - readCount, transferEncoding);
if (status <= 0)
{
if (!BIO_should_retry(tls->bio))
@ -730,10 +1161,7 @@ static wStream* rdg_build_http_request(rdpRdg* rdg, const char* method,
goto out;
}
if (transferEncoding)
{
http_request_set_transfer_encoding(request, transferEncoding);
}
http_request_set_transfer_encoding(request, transferEncoding);
s = http_request_write(rdg->http, request);
out:
@ -1320,6 +1748,7 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char*
SSIZE_T bodyLength;
long StatusCode;
TRANSFER_ENCODING encoding;
BOOL isWebsocket;
if (!rdg_tls_connect(rdg, tls, peerAddress, timeout))
return FALSE;
@ -1344,6 +1773,7 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char*
case HTTP_STATUS_NOT_FOUND:
{
WLog_INFO(TAG, "RD Gateway does not support HTTP transport.");
http_context_enable_websocket_upgrade(rdg->http, FALSE);
if (rpcFallback)
*rpcFallback = TRUE;
@ -1377,16 +1807,44 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char*
statusCode = http_response_get_status_code(response);
bodyLength = http_response_get_body_length(response);
encoding = http_response_get_transfer_encoding(response);
isWebsocket = http_response_is_websocket(rdg->http, response);
http_response_free(response);
WLog_DBG(TAG, "%s authorization result: %d", method, statusCode);
switch (statusCode)
{
case HTTP_STATUS_OK:
/* old rdg endpoint without websocket support, don't request websocket for RDG_IN_DATA
*/
http_context_enable_websocket_upgrade(rdg->http, FALSE);
break;
case HTTP_STATUS_DENIED:
freerdp_set_last_error_log(rdg->context, FREERDP_ERROR_CONNECT_ACCESS_DENIED);
return FALSE;
case HTTP_STATUS_SWITCH_PROTOCOLS:
if (!isWebsocket)
{
/*
* webserver is broken, a fallback may be possible here
* but only if already tested with oppurtonistic upgrade
*/
if (http_context_is_websocket_upgrade_enabled(rdg->http))
{
int fd = BIO_get_fd(tls->bio, NULL);
if (fd >= 0)
closesocket((SOCKET)fd);
http_context_enable_websocket_upgrade(rdg->http, FALSE);
return rdg_establish_data_connection(rdg, tls, method, peerAddress, timeout,
rpcFallback);
}
return FALSE;
}
rdg->transferEncoding.isWebsocketTransport = TRUE;
rdg->transferEncoding.context.websocket.state = WebsocketStateOpcodeAndFin;
rdg->transferEncoding.context.websocket.responseStreamBuffer = NULL;
return TRUE;
break;
default:
return FALSE;
}
@ -1452,14 +1910,21 @@ BOOL rdg_connect(rdpRdg* rdg, DWORD timeout, BOOL* rpcFallback)
if (status)
{
/* Establish IN connection with the same peer/server as OUT connection,
* even when server hostname resolves to different IP addresses.
*/
BIO_get_socket(rdg->tlsOut->underlying, &outConnSocket);
peerAddress = freerdp_tcp_get_peer_address(outConnSocket);
status = rdg_establish_data_connection(rdg, rdg->tlsIn, "RDG_IN_DATA", peerAddress, timeout,
NULL);
free(peerAddress);
if (rdg->transferEncoding.isWebsocketTransport)
{
WLog_DBG(TAG, "Upgraded to websocket. RDG_IN_DATA not required");
}
else
{
/* Establish IN connection with the same peer/server as OUT connection,
* even when server hostname resolves to different IP addresses.
*/
BIO_get_socket(rdg->tlsOut->underlying, &outConnSocket);
peerAddress = freerdp_tcp_get_peer_address(outConnSocket);
status = rdg_establish_data_connection(rdg, rdg->tlsIn, "RDG_IN_DATA", peerAddress,
timeout, NULL);
free(peerAddress);
}
}
if (!status)
@ -1476,10 +1941,97 @@ BOOL rdg_connect(rdpRdg* rdg, DWORD timeout, BOOL* rpcFallback)
return TRUE;
}
static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
static int rdg_write_websocket_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
{
size_t payloadSize;
size_t fullLen;
int status;
wStream* sWS;
uint32_t maskingKey;
BYTE* maskingKeyByte1 = (BYTE*)&maskingKey;
BYTE* maskingKeyByte2 = maskingKeyByte1 + 1;
BYTE* maskingKeyByte3 = maskingKeyByte1 + 2;
BYTE* maskingKeyByte4 = maskingKeyByte1 + 3;
int streamPos;
winpr_RAND((BYTE*)&maskingKey, 4);
payloadSize = isize + 10;
if ((isize < 0) || (isize > UINT16_MAX))
return -1;
if (payloadSize < 1)
return 0;
if (payloadSize < 126)
fullLen = payloadSize + 6; /* 2 byte "mini header" + 4 byte masking key */
else if (payloadSize < 0x10000)
fullLen = payloadSize + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */
else
fullLen = payloadSize + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */
sWS = Stream_New(NULL, fullLen);
if (!sWS)
return FALSE;
Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | WebsocketBinaryOpcode);
if (payloadSize < 126)
Stream_Write_UINT8(sWS, payloadSize | WEBSOCKET_MASK_BIT);
else if (payloadSize < 0x10000)
{
Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
Stream_Write_UINT16_BE(sWS, payloadSize);
}
else
{
Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
/* biggest packet possible is 0xffff + 0xa, so 32bit is always enough */
Stream_Write_UINT32_BE(sWS, 0);
Stream_Write_UINT32_BE(sWS, payloadSize);
}
Stream_Write_UINT32(sWS, maskingKey);
Stream_Write_UINT16(sWS, PKT_TYPE_DATA ^ (*maskingKeyByte1 | *maskingKeyByte2 << 8)); /* Type */
Stream_Write_UINT16(sWS, 0 ^ (*maskingKeyByte3 | *maskingKeyByte4 << 8)); /* Reserved */
Stream_Write_UINT32(sWS, (UINT32)payloadSize ^ maskingKey); /* Packet length */
Stream_Write_UINT16(sWS,
(UINT16)isize ^ (*maskingKeyByte1 | *maskingKeyByte2 << 8)); /* Data size */
/* masking key is now off by 2 bytes. fix that */
maskingKey = (maskingKey & 0xffff) << 16 | (maskingKey >> 16);
/* mask as much as possible with 32bit access */
for (streamPos = 0; streamPos + 4 <= isize; streamPos += 4)
{
uint32_t masked = *((uint32_t*)((BYTE*)buf + streamPos)) ^ maskingKey;
Stream_Write_UINT32(sWS, masked);
}
/* mask the rest byte by byte */
for (; streamPos < isize; streamPos++)
{
BYTE* partialMask = (BYTE*)(&maskingKey) + streamPos % 4;
BYTE masked = *((BYTE*)((BYTE*)buf + streamPos)) ^ *partialMask;
Stream_Write_UINT8(sWS, masked);
}
Stream_SealLength(sWS);
status = tls_write_all(rdg->tlsOut, Stream_Buffer(sWS), Stream_Length(sWS));
Stream_Free(sWS, TRUE);
if (status < 0)
return status;
return isize;
}
static int rdg_write_chunked_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
{
int status;
size_t s;
size_t len;
wStream* sChunk;
size_t size = (size_t)isize;
size_t packetSize = size + 10;
@ -1505,12 +2057,12 @@ static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
Stream_Write(sChunk, buf, size); /* Data */
Stream_Write(sChunk, "\r\n", 2);
Stream_SealLength(sChunk);
s = Stream_Length(sChunk);
len = Stream_Length(sChunk);
if (s > INT_MAX)
if (len > INT_MAX)
return -1;
status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)s);
status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)len);
Stream_Free(sChunk, TRUE);
if (status < 0)
@ -1519,18 +2071,26 @@ static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
return (int)size;
}
static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
{
if (rdg->transferEncoding.isWebsocketTransport)
{
if (rdg->transferEncoding.context.websocket.closeSent == TRUE)
return -1;
return rdg_write_websocket_data_packet(rdg, buf, isize);
}
else
return rdg_write_chunked_data_packet(rdg, buf, isize);
return -1;
}
static BOOL rdg_process_close_packet(rdpRdg* rdg, wStream* s)
{
int status = -1;
size_t len;
wStream* sChunk;
wStream* sClose;
UINT32 errorCode;
UINT32 packetSize = 12;
char chunkSize[11];
int chunkLen = sprintf_s(chunkSize, sizeof(chunkSize), "%" PRIx32 "\r\n", packetSize);
if (chunkLen < 0)
return FALSE;
/* Read error code */
if (Stream_GetRemainingLength(s) < 4)
@ -1540,55 +2100,39 @@ static BOOL rdg_process_close_packet(rdpRdg* rdg, wStream* s)
if (errorCode != 0)
freerdp_set_last_error_log(rdg->context, errorCode);
sChunk = Stream_New(NULL, (size_t)chunkLen + packetSize + 2);
if (!sChunk)
sClose = Stream_New(NULL, packetSize);
if (!sClose)
return FALSE;
Stream_Write(sChunk, chunkSize, (size_t)chunkLen);
Stream_Write_UINT16(sChunk, PKT_TYPE_CLOSE_CHANNEL_RESPONSE); /* Type */
Stream_Write_UINT16(sChunk, 0); /* Reserved */
Stream_Write_UINT32(sChunk, packetSize); /* Packet length */
Stream_Write_UINT32(sChunk, 0); /* Status code */
Stream_Write(sChunk, "\r\n", 2);
Stream_SealLength(sChunk);
len = Stream_Length(sChunk);
Stream_Write_UINT16(sClose, PKT_TYPE_CLOSE_CHANNEL_RESPONSE); /* Type */
Stream_Write_UINT16(sClose, 0); /* Reserved */
Stream_Write_UINT32(sClose, packetSize); /* Packet length */
Stream_Write_UINT32(sClose, 0); /* Status code */
Stream_SealLength(sClose);
status = rdg_write_packet(rdg, sClose);
Stream_Free(sClose, TRUE);
if (len <= INT_MAX)
status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)len);
Stream_Free(sChunk, TRUE);
return (status < 0 ? FALSE : TRUE);
}
static BOOL rdg_process_keep_alive_packet(rdpRdg* rdg)
{
int status = -1;
size_t s;
wStream* sChunk;
wStream* sKeepAlive;
size_t packetSize = 8;
char chunkSize[11];
int chunkLen = sprintf_s(chunkSize, sizeof(chunkSize), "%" PRIxz "\r\n", packetSize);
if ((chunkLen < 0) || (packetSize > UINT32_MAX))
sKeepAlive = Stream_New(NULL, packetSize);
if (!sKeepAlive)
return FALSE;
sChunk = Stream_New(NULL, (size_t)chunkLen + packetSize + 2);
Stream_Write_UINT16(sKeepAlive, PKT_TYPE_KEEPALIVE); /* Type */
Stream_Write_UINT16(sKeepAlive, 0); /* Reserved */
Stream_Write_UINT32(sKeepAlive, (UINT32)packetSize); /* Packet length */
Stream_SealLength(sKeepAlive);
status = rdg_write_packet(rdg, sKeepAlive);
Stream_Free(sKeepAlive, TRUE);
if (!sChunk)
return FALSE;
Stream_Write(sChunk, chunkSize, (size_t)chunkLen);
Stream_Write_UINT16(sChunk, PKT_TYPE_KEEPALIVE); /* Type */
Stream_Write_UINT16(sChunk, 0); /* Reserved */
Stream_Write_UINT32(sChunk, (UINT32)packetSize); /* Packet length */
Stream_Write(sChunk, "\r\n", 2);
Stream_SealLength(sChunk);
s = Stream_Length(sChunk);
if (s <= INT_MAX)
status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)s);
Stream_Free(sChunk, TRUE);
return (status < 0 ? FALSE : TRUE);
}
@ -2003,7 +2547,8 @@ rdpRdg* rdg_new(rdpContext* context)
!http_context_set_connection(rdg->http, "Keep-Alive") ||
!http_context_set_user_agent(rdg->http, "MS-RDGateway/1.0") ||
!http_context_set_host(rdg->http, rdg->settings->GatewayHostname) ||
!http_context_set_rdg_connection_id(rdg->http, bracedUuid))
!http_context_set_rdg_connection_id(rdg->http, bracedUuid) ||
!http_context_enable_websocket_upgrade(rdg->http, TRUE))
{
goto rdg_alloc_error;
}
@ -2033,6 +2578,7 @@ rdpRdg* rdg_new(rdpContext* context)
InitializeCriticalSection(&rdg->writeSection);
rdg->transferEncoding.httpTransferEncoding = TransferEncodingIdentity;
rdg->transferEncoding.isWebsocketTransport = FALSE;
}
return rdg;
@ -2056,6 +2602,12 @@ void rdg_free(rdpRdg* rdg)
DeleteCriticalSection(&rdg->writeSection);
if (rdg->transferEncoding.isWebsocketTransport)
{
if (rdg->transferEncoding.context.websocket.responseStreamBuffer != NULL)
Stream_Free(rdg->transferEncoding.context.websocket.responseStreamBuffer, TRUE);
}
free(rdg);
}