Merge pull request #5944 from akallabeth/fragmented_packet_length_check_fix

Fragmented packet length check fix
This commit is contained in:
Martin Fleisz 2020-03-05 08:26:34 +01:00 committed by GitHub
commit 5facf708dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 89 additions and 40 deletions

View File

@ -176,7 +176,8 @@ extern "C"
typedef int (*pLogonErrorInfo)(freerdp* instance, UINT32 data, UINT32 type);
typedef int (*pSendChannelData)(freerdp* instance, UINT16 channelId, BYTE* data, int size);
typedef int (*pSendChannelData)(freerdp* instance, UINT16 channelId, const BYTE* data,
int size);
typedef int (*pReceiveChannelData)(freerdp* instance, UINT16 channelId, BYTE* data, int size,
int flags, int totalSize);

View File

@ -54,13 +54,13 @@ typedef BOOL (*psPeerClientCapabilities)(freerdp_peer* peer);
typedef int (*psPeerSendChannelData)(freerdp_peer* peer, UINT16 channelId, const BYTE* data,
int size);
typedef int (*psPeerReceiveChannelData)(freerdp_peer* peer, UINT16 channelId, const BYTE* data,
int size, int flags, int totalSize);
size_t size, UINT32 flags, size_t totalSize);
typedef HANDLE (*psPeerVirtualChannelOpen)(freerdp_peer* peer, const char* name, UINT32 flags);
typedef BOOL (*psPeerVirtualChannelClose)(freerdp_peer* peer, HANDLE hChannel);
typedef int (*psPeerVirtualChannelRead)(freerdp_peer* peer, HANDLE hChannel, BYTE* buffer,
UINT32 length);
typedef int (*psPeerVirtualChannelWrite)(freerdp_peer* peer, HANDLE hChannel, BYTE* buffer,
typedef int (*psPeerVirtualChannelWrite)(freerdp_peer* peer, HANDLE hChannel, const BYTE* buffer,
UINT32 length);
typedef void* (*psPeerVirtualChannelGetData)(freerdp_peer* peer, HANDLE hChannel);
typedef int (*psPeerVirtualChannelSetData)(freerdp_peer* peer, HANDLE hChannel, void* data);

View File

@ -49,10 +49,10 @@
#define TAG FREERDP_TAG("core.channels")
BOOL freerdp_channel_send(rdpRdp* rdp, UINT16 channelId, const BYTE* data, int size)
BOOL freerdp_channel_send(rdpRdp* rdp, UINT16 channelId, const BYTE* data, size_t size)
{
DWORD i;
int left;
size_t left;
wStream* s;
UINT32 flags;
size_t chunkSize;
@ -84,7 +84,7 @@ BOOL freerdp_channel_send(rdpRdp* rdp, UINT16 channelId, const BYTE* data, int s
if (!s)
return FALSE;
if (left > (int)rdp->settings->VirtualChannelChunkSize)
if (left > rdp->settings->VirtualChannelChunkSize)
{
chunkSize = rdp->settings->VirtualChannelChunkSize;
}
@ -122,23 +122,51 @@ BOOL freerdp_channel_send(rdpRdp* rdp, UINT16 channelId, const BYTE* data, int s
return TRUE;
}
BOOL freerdp_channel_process(freerdp* instance, wStream* s, UINT16 channelId)
BOOL freerdp_channel_process(freerdp* instance, wStream* s, UINT16 channelId, size_t packetLength)
{
int rc = 0;
UINT32 length;
UINT32 flags;
size_t chunkLength;
if (packetLength < 8)
{
WLog_ERR(TAG, "Header length %" PRIdz " bytes promised, none available", packetLength);
return FALSE;
}
packetLength -= 8;
if (Stream_GetRemainingLength(s) < 8)
return FALSE;
/* [MS-RDPBCGR] 3.1.5.2.2 Processing of Virtual Channel PDU
* chunked data. Length is the total size of the combined data,
* chunkLength is the actual data received.
* check chunkLength against packetLength, which is the TPKT header size.
*/
Stream_Read_UINT32(s, length);
Stream_Read_UINT32(s, flags);
chunkLength = Stream_GetRemainingLength(s);
if (length > chunkLength)
if (packetLength != chunkLength)
{
WLog_ERR(TAG, "Header length %" PRIdz " != actual length %" PRIdz, packetLength,
chunkLength);
return FALSE;
IFCALL(instance->ReceiveChannelData, instance, channelId, Stream_Pointer(s), chunkLength, flags,
length);
return Stream_SafeSeek(s, length);
}
if (length < chunkLength)
{
WLog_ERR(TAG, "Expected %" PRIu32 " bytes, but have %" PRIdz, length, chunkLength);
return FALSE;
}
IFCALLRET(instance->ReceiveChannelData, rc, instance, channelId, Stream_Pointer(s), chunkLength,
flags, length);
if (rc != CHANNEL_RC_OK)
{
WLog_WARN(TAG, "ReceiveChannelData returned %d", rc);
return FALSE;
}
return Stream_SafeSeek(s, chunkLength);
}
BOOL freerdp_channel_peer_process(freerdp_peer* client, wStream* s, UINT16 channelId)

View File

@ -23,8 +23,10 @@
#include <freerdp/api.h>
#include "client.h"
FREERDP_LOCAL BOOL freerdp_channel_send(rdpRdp* rdp, UINT16 channelId, const BYTE* data, int size);
FREERDP_LOCAL BOOL freerdp_channel_process(freerdp* instance, wStream* s, UINT16 channelId);
FREERDP_LOCAL BOOL freerdp_channel_send(rdpRdp* rdp, UINT16 channelId, const BYTE* data,
size_t size);
FREERDP_LOCAL BOOL freerdp_channel_process(freerdp* instance, wStream* s, UINT16 channelId,
size_t packetLength);
FREERDP_LOCAL BOOL freerdp_channel_peer_process(freerdp_peer* client, wStream* s, UINT16 channelId);
#endif /* FREERDP_LIB_CORE_CHANNELS_H */

View File

@ -427,6 +427,14 @@ int freerdp_channels_data(freerdp* instance, UINT16 channelId, BYTE* data, int d
rdpChannels* channels;
rdpMcsChannel* channel = NULL;
CHANNEL_OPEN_DATA* pChannelOpenData;
if (!instance || !data || (dataSize < 0) || (totalSize < 0))
{
WLog_ERR(TAG, "%s(%p, %" PRIu16 ", %p, %d, 0x%08x, %d): Invalid arguments", __FUNCTION__,
instance, channelId, data, dataSize, flags, totalSize);
return -1;
}
mcs = instance->context->rdp->mcs;
channels = instance->context->channels;

View File

@ -488,9 +488,18 @@ int freerdp_message_queue_process_pending_messages(freerdp* instance, DWORD id)
return status;
}
static int freerdp_send_channel_data(freerdp* instance, UINT16 channelId, BYTE* data, int size)
static int freerdp_send_channel_data(freerdp* instance, UINT16 channelId, const BYTE* data,
int size)
{
return rdp_send_channel_data(instance->context->rdp, channelId, data, size);
if (size < 0)
{
WLog_ERR(TAG, "%s: size has invalid value %d", __FUNCTION__, size);
return -1;
}
if (!rdp_send_channel_data(instance->context->rdp, channelId, data, (size_t)size))
return -2;
return 0;
}
BOOL freerdp_disconnect(freerdp* instance)

View File

@ -108,14 +108,8 @@ static BOOL freerdp_peer_virtual_channel_close(freerdp_peer* client, HANDLE hCha
return TRUE;
}
static int freerdp_peer_virtual_channel_read(freerdp_peer* client, HANDLE hChannel, BYTE* buffer,
UINT32 length)
{
return 0; /* this needs to be implemented by the server application */
}
static int freerdp_peer_virtual_channel_write(freerdp_peer* client, HANDLE hChannel, BYTE* buffer,
UINT32 length)
static int freerdp_peer_virtual_channel_write(freerdp_peer* client, HANDLE hChannel,
const BYTE* buffer, UINT32 length)
{
wStream* s;
UINT32 flags;
@ -745,7 +739,14 @@ static void freerdp_peer_disconnect(freerdp_peer* client)
static int freerdp_peer_send_channel_data(freerdp_peer* client, UINT16 channelId, const BYTE* data,
int size)
{
return rdp_send_channel_data(client->context->rdp, channelId, data, size);
if (size < 0)
{
WLog_ERR(TAG, "%s: invalid size %d", __FUNCTION__, size);
return -1;
}
if (!rdp_send_channel_data(client->context->rdp, channelId, data, (size_t)size))
return -1;
return 0;
}
static BOOL freerdp_peer_is_write_blocked(freerdp_peer* peer)

View File

@ -1350,11 +1350,8 @@ static int rdp_recv_tpkt_pdu(rdpRdp* rdp, wStream* s)
{
rdp->inPackets++;
if (!freerdp_channel_process(rdp->instance, s, channelId))
{
WLog_ERR(TAG, "rdp_recv_tpkt_pdu: freerdp_channel_process() fail");
if (!freerdp_channel_process(rdp->instance, s, channelId, length))
return -1;
}
}
out:
@ -1601,7 +1598,7 @@ int rdp_recv_callback(rdpTransport* transport, wStream* s, void* extra)
return status;
}
int rdp_send_channel_data(rdpRdp* rdp, UINT16 channelId, const BYTE* data, int size)
BOOL rdp_send_channel_data(rdpRdp* rdp, UINT16 channelId, const BYTE* data, size_t size)
{
return freerdp_channel_send(rdp, channelId, data, size);
}

View File

@ -205,7 +205,8 @@ FREERDP_LOCAL int rdp_recv_data_pdu(rdpRdp* rdp, wStream* s);
FREERDP_LOCAL BOOL rdp_send(rdpRdp* rdp, wStream* s, UINT16 channelId);
FREERDP_LOCAL int rdp_send_channel_data(rdpRdp* rdp, UINT16 channelId, const BYTE* data, int size);
FREERDP_LOCAL BOOL rdp_send_channel_data(rdpRdp* rdp, UINT16 channelId, const BYTE* data,
size_t size);
FREERDP_LOCAL wStream* rdp_message_channel_pdu_init(rdpRdp* rdp);
FREERDP_LOCAL BOOL rdp_send_message_channel_pdu(rdpRdp* rdp, wStream* s, UINT16 sec_flags);

View File

@ -359,13 +359,13 @@ static BOOL wts_write_drdynvc_create_request(wStream* s, UINT32 ChannelId, const
}
static BOOL WTSProcessChannelData(rdpPeerChannel* channel, UINT16 channelId, const BYTE* data,
int s, int flags, int t)
size_t s, UINT32 flags, size_t t)
{
BOOL ret = TRUE;
const size_t size = (size_t)s;
const size_t totalSize = (size_t)t;
if ((s < 0) || (t < 0))
return FALSE;
WINPR_UNUSED(channelId);
if (flags & CHANNEL_FLAG_FIRST)
{
@ -400,8 +400,8 @@ static BOOL WTSProcessChannelData(rdpPeerChannel* channel, UINT16 channelId, con
return ret;
}
static int WTSReceiveChannelData(freerdp_peer* client, UINT16 channelId, const BYTE* data, int size,
int flags, int totalSize)
static int WTSReceiveChannelData(freerdp_peer* client, UINT16 channelId, const BYTE* data,
size_t size, UINT32 flags, size_t totalSize)
{
UINT32 i;
BOOL status = FALSE;
@ -478,6 +478,7 @@ BOOL WTSVirtualChannelManagerCheckFileDescriptor(HANDLE hServer)
while (MessageQueue_Peek(vcm->queue, &message, TRUE))
{
int rc;
BYTE* buffer;
UINT32 length;
UINT16 channelId;
@ -485,7 +486,8 @@ BOOL WTSVirtualChannelManagerCheckFileDescriptor(HANDLE hServer)
buffer = (BYTE*)message.wParam;
length = (UINT32)(UINT_PTR)message.lParam;
if (vcm->client->SendChannelData(vcm->client, channelId, buffer, length) == FALSE)
rc = vcm->client->SendChannelData(vcm->client, channelId, buffer, length);
if (rc < 0)
{
status = FALSE;
}

View File

@ -198,9 +198,9 @@ static BOOL pf_server_adjust_monitor_layout(freerdp_peer* peer)
return TRUE;
}
static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 channelId,
const BYTE* data, int size, int flags,
int totalSize)
static int pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 channelId,
const BYTE* data, size_t size, UINT32 flags,
size_t totalSize)
{
pServerContext* ps = (pServerContext*)peer->context;
pClientContext* pc = ps->pdata->pc;
@ -222,7 +222,7 @@ static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 chann
ev.data_len = size;
if (!pf_modules_run_filter(FILTER_TYPE_SERVER_PASSTHROUGH_CHANNEL_DATA, pdata, &ev))
return FALSE;
return -1;
client_channel_id = (UINT64)HashTable_GetItemValue(pc->vc_ids, (void*)channel_name);