proxy: fix channel shift between front and back

When some channels are filtered, some misalignement of channel ids could happen.
This patch keeps track of the back and front channel ids to correctly identify a
channel and send packets with the correct channel id.
This commit is contained in:
David Fort 2022-11-30 09:09:25 +01:00 committed by Martin Fleisz
parent 9db032f326
commit d59c0a49c3
8 changed files with 74 additions and 23 deletions

View File

@ -65,6 +65,9 @@ extern "C"
{ {
#endif #endif
#define MCS_BASE_CHANNEL_ID 1001
#define MCS_GLOBAL_CHANNEL_ID 1003
/* Flags used by certificate callbacks */ /* Flags used by certificate callbacks */
#define VERIFY_CERT_FLAG_NONE 0x00 #define VERIFY_CERT_FLAG_NONE 0x00
#define VERIFY_CERT_FLAG_LEGACY 0x02 #define VERIFY_CERT_FLAG_LEGACY 0x02

View File

@ -74,7 +74,8 @@ extern "C"
struct p_server_static_channel_context struct p_server_static_channel_context
{ {
char* channel_name; char* channel_name;
UINT32 channel_id; UINT32 front_channel_id;
UINT32 back_channel_id;
pf_utils_channel_mode channelMode; pf_utils_channel_mode channelMode;
proxyChannelDataFn onFrontData; proxyChannelDataFn onFrontData;
proxyChannelDataFn onBackData; proxyChannelDataFn onBackData;
@ -97,7 +98,8 @@ extern "C"
HANDLE dynvcReady; HANDLE dynvcReady;
wHashTable* interceptContextMap; wHashTable* interceptContextMap;
wHashTable* channelsById; wHashTable* channelsByFrontId;
wHashTable* channelsByBackId;
}; };
typedef struct p_server_context pServerContext; typedef struct p_server_context pServerContext;

View File

@ -33,9 +33,6 @@ typedef struct rdp_mcs rdpMcs;
#include <winpr/stream.h> #include <winpr/stream.h>
#include <winpr/wtsapi.h> #include <winpr/wtsapi.h>
#define MCS_BASE_CHANNEL_ID 1001
#define MCS_GLOBAL_CHANNEL_ID 1003
enum MCS_Result enum MCS_Result
{ {
MCS_Result_successful = 0, MCS_Result_successful = 0,

View File

@ -1709,7 +1709,7 @@ static PfChannelResult pf_rdpdr_back_data(proxyData* pdata,
WINPR_ASSERT(pdata); WINPR_ASSERT(pdata);
WINPR_ASSERT(channel); WINPR_ASSERT(channel);
if (!pf_channel_rdpdr_client_handle(pdata->pc, channel->channel_id, channel->channel_name, if (!pf_channel_rdpdr_client_handle(pdata->pc, channel->back_channel_id, channel->channel_name,
xdata, xsize, flags, totalSize)) xdata, xsize, flags, totalSize))
{ {
WLog_ERR(TAG, "error treating client back data"); WLog_ERR(TAG, "error treating client back data");
@ -1731,7 +1731,7 @@ static PfChannelResult pf_rdpdr_front_data(proxyData* pdata,
WINPR_ASSERT(pdata); WINPR_ASSERT(pdata);
WINPR_ASSERT(channel); WINPR_ASSERT(channel);
if (!pf_channel_rdpdr_server_handle(pdata->ps, channel->channel_id, channel->channel_name, if (!pf_channel_rdpdr_server_handle(pdata->ps, channel->front_channel_id, channel->channel_name,
xdata, xsize, flags, totalSize)) xdata, xsize, flags, totalSize))
{ {
WLog_ERR(TAG, "error treating front data"); WLog_ERR(TAG, "error treating front data");

View File

@ -160,7 +160,7 @@ PfChannelResult channelTracker_flushCurrent(ChannelStateTracker* t, BOOL first,
{ {
proxyChannelDataEventInfo ev; proxyChannelDataEventInfo ev;
ev.channel_id = channel->channel_id; ev.channel_id = channel->front_channel_id;
ev.channel_name = channel->channel_name; ev.channel_name = channel->channel_name;
ev.data = Stream_Buffer(t->currentPacket); ev.data = Stream_Buffer(t->currentPacket);
ev.data_len = Stream_GetPosition(t->currentPacket); ev.data_len = Stream_GetPosition(t->currentPacket);
@ -176,7 +176,7 @@ PfChannelResult channelTracker_flushCurrent(ChannelStateTracker* t, BOOL first,
ps = pdata->ps; ps = pdata->ps;
r = ps->context.peer->SendChannelPacket( r = ps->context.peer->SendChannelPacket(
ps->context.peer, channel->channel_id, t->currentPacketSize, flags, ps->context.peer, channel->front_channel_id, t->currentPacketSize, flags,
Stream_Buffer(t->currentPacket), Stream_GetPosition(t->currentPacket)); Stream_Buffer(t->currentPacket), Stream_GetPosition(t->currentPacket));
return r ? PF_CHANNEL_RESULT_DROP : PF_CHANNEL_RESULT_ERROR; return r ? PF_CHANNEL_RESULT_DROP : PF_CHANNEL_RESULT_ERROR;
@ -195,7 +195,7 @@ static PfChannelResult pf_channel_generic_back_data(proxyData* pdata,
switch (channel->channelMode) switch (channel->channelMode)
{ {
case PF_UTILS_CHANNEL_PASSTHROUGH: case PF_UTILS_CHANNEL_PASSTHROUGH:
ev.channel_id = channel->channel_id; ev.channel_id = channel->back_channel_id;
ev.channel_name = channel->channel_name; ev.channel_name = channel->channel_name;
ev.data = xdata; ev.data = xdata;
ev.data_len = xsize; ev.data_len = xsize;
@ -229,7 +229,7 @@ static PfChannelResult pf_channel_generic_front_data(proxyData* pdata,
switch (channel->channelMode) switch (channel->channelMode)
{ {
case PF_UTILS_CHANNEL_PASSTHROUGH: case PF_UTILS_CHANNEL_PASSTHROUGH:
ev.channel_id = channel->channel_id; ev.channel_id = channel->front_channel_id;
ev.channel_name = channel->channel_name; ev.channel_name = channel->channel_name;
ev.data = xdata; ev.data = xdata;
ev.data_len = xsize; ev.data_len = xsize;

View File

@ -309,6 +309,37 @@ static BOOL pf_client_pre_connect(freerdp* instance)
return pf_modules_run_hook(pc->pdata->module, HOOK_TYPE_CLIENT_PRE_CONNECT, pc->pdata, pc); return pf_modules_run_hook(pc->pdata->module, HOOK_TYPE_CLIENT_PRE_CONNECT, pc->pdata, pc);
} }
/** @brief arguments for updateBackIdFn */
typedef struct
{
pServerContext* ps;
const char* name;
UINT32 backId;
} UpdateBackIdArgs;
static BOOL updateBackIdFn(const void* key, void* value, void* arg)
{
pServerStaticChannelContext* current = (pServerStaticChannelContext*)value;
UpdateBackIdArgs* updateArgs = (UpdateBackIdArgs*)arg;
if (strcmp(updateArgs->name, current->channel_name) != 0)
return TRUE;
current->back_channel_id = updateArgs->backId;
if (!HashTable_Insert(updateArgs->ps->channelsByBackId, &current->back_channel_id, current))
{
WLog_ERR(TAG, "error inserting channel in channelsByBackId table");
}
return FALSE;
}
static BOOL pf_client_update_back_id(pServerContext* ps, const char* name, UINT32 backId)
{
UpdateBackIdArgs res = { ps, name, backId };
return HashTable_Foreach(ps->channelsByFrontId, updateBackIdFn, &res) == FALSE;
}
static BOOL pf_client_load_channels(freerdp* instance) static BOOL pf_client_load_channels(freerdp* instance)
{ {
pClientContext* pc; pClientContext* pc;
@ -364,6 +395,7 @@ static BOOL pf_client_load_channels(freerdp* instance)
CHANNEL_DEF* channels = (CHANNEL_DEF*)freerdp_settings_get_pointer_array_writable( CHANNEL_DEF* channels = (CHANNEL_DEF*)freerdp_settings_get_pointer_array_writable(
settings, FreeRDP_ChannelDefArray, 0); settings, FreeRDP_ChannelDefArray, 0);
size_t x, size = freerdp_settings_get_uint32(settings, FreeRDP_ChannelDefArraySize); size_t x, size = freerdp_settings_get_uint32(settings, FreeRDP_ChannelDefArraySize);
UINT32 id = MCS_GLOBAL_CHANNEL_ID + 1;
WINPR_ASSERT(channels || (size == 0)); WINPR_ASSERT(channels || (size == 0));
@ -385,7 +417,14 @@ static BOOL pf_client_load_channels(freerdp* instance)
size--; size--;
} }
else else
{
if (!pf_client_update_back_id(ps, cur->name, id++))
{
WLog_ERR(TAG, "unable to update backid for channel %s", cur->name);
return FALSE;
}
x++; x++;
}
} }
if (!freerdp_settings_set_uint32(settings, FreeRDP_ChannelCount, x)) if (!freerdp_settings_set_uint32(settings, FreeRDP_ChannelCount, x))
@ -419,7 +458,7 @@ static BOOL pf_client_receive_channel_data_hook(freerdp* instance, UINT16 channe
pdata = ps->pdata; pdata = ps->pdata;
WINPR_ASSERT(pdata); WINPR_ASSERT(pdata);
channel = HashTable_GetItemValue(ps->channelsById, &channelId64); channel = HashTable_GetItemValue(ps->channelsByBackId, &channelId64);
if (!channel) if (!channel)
return TRUE; return TRUE;

View File

@ -56,7 +56,7 @@ pServerStaticChannelContext* StaticChannelContext_new(pServerContext* ps, const
return NULL; return NULL;
} }
ret->channel_id = id; ret->front_channel_id = id;
ret->channel_name = _strdup(name); ret->channel_name = _strdup(name);
if (!ret->channel_name) if (!ret->channel_name)
{ {
@ -110,17 +110,27 @@ static BOOL client_to_proxy_context_new(freerdp_peer* client, rdpContext* ctx)
obj->fnObjectFree = intercept_context_entry_free; obj->fnObjectFree = intercept_context_entry_free;
/* channels by ids */ /* channels by ids */
context->channelsById = HashTable_New(FALSE); context->channelsByFrontId = HashTable_New(FALSE);
if (!context->channelsById) if (!context->channelsByFrontId)
goto error; goto error;
if (!HashTable_SetHashFunction(context->channelsById, ChannelId_Hash)) if (!HashTable_SetHashFunction(context->channelsByFrontId, ChannelId_Hash))
goto error; goto error;
obj = HashTable_KeyObject(context->channelsById); obj = HashTable_KeyObject(context->channelsByFrontId);
obj->fnObjectEquals = (OBJECT_EQUALS_FN)ChannelId_Compare; obj->fnObjectEquals = (OBJECT_EQUALS_FN)ChannelId_Compare;
obj = HashTable_ValueObject(context->channelsById); obj = HashTable_ValueObject(context->channelsByFrontId);
obj->fnObjectFree = (OBJECT_FREE_FN)StaticChannelContext_free; obj->fnObjectFree = (OBJECT_FREE_FN)StaticChannelContext_free;
context->channelsByBackId = HashTable_New(FALSE);
if (!context->channelsByBackId)
goto error;
if (!HashTable_SetHashFunction(context->channelsByBackId, ChannelId_Hash))
goto error;
obj = HashTable_KeyObject(context->channelsByBackId);
obj->fnObjectEquals = (OBJECT_EQUALS_FN)ChannelId_Compare;
return TRUE; return TRUE;
error: error:
@ -146,7 +156,8 @@ void client_to_proxy_context_free(freerdp_peer* client, rdpContext* ctx)
} }
HashTable_Free(context->interceptContextMap); HashTable_Free(context->interceptContextMap);
HashTable_Free(context->channelsById); HashTable_Free(context->channelsByFrontId);
HashTable_Free(context->channelsByBackId);
if (context->vcm && (context->vcm != INVALID_HANDLE_VALUE)) if (context->vcm && (context->vcm != INVALID_HANDLE_VALUE))
WTSCloseServer((HANDLE)context->vcm); WTSCloseServer((HANDLE)context->vcm);

View File

@ -194,7 +194,6 @@ static BOOL pf_server_setup_channels(freerdp_peer* peer)
size_t accepted_channels_count; size_t accepted_channels_count;
size_t i; size_t i;
pServerContext* ps = (pServerContext*)peer->context; pServerContext* ps = (pServerContext*)peer->context;
wHashTable* byId = ps->channelsById;
accepted_channels = WTSGetAcceptedChannelNames(peer, &accepted_channels_count); accepted_channels = WTSGetAcceptedChannelNames(peer, &accepted_channels_count);
if (!accepted_channels) if (!accepted_channels)
@ -243,7 +242,8 @@ static BOOL pf_server_setup_channels(freerdp_peer* peer)
} }
} }
if (!HashTable_Insert(byId, &channelContext->channel_id, channelContext)) if (!HashTable_Insert(ps->channelsByFrontId, &channelContext->front_channel_id,
channelContext))
{ {
StaticChannelContext_free(channelContext); StaticChannelContext_free(channelContext);
PROXY_LOG_ERR(TAG, ps, "error inserting channelContext in byId table for '%s'", cname); PROXY_LOG_ERR(TAG, ps, "error inserting channelContext in byId table for '%s'", cname);
@ -339,7 +339,6 @@ static BOOL pf_server_activate(freerdp_peer* peer)
WINPR_ASSERT(pdata); WINPR_ASSERT(pdata);
settings = peer->context->settings; settings = peer->context->settings;
WINPR_ASSERT(settings);
settings->CompressionLevel = PACKET_COMPR_TYPE_RDP8; settings->CompressionLevel = PACKET_COMPR_TYPE_RDP8;
if (!pf_modules_run_hook(pdata->module, HOOK_TYPE_SERVER_ACTIVATE, pdata, peer)) if (!pf_modules_run_hook(pdata->module, HOOK_TYPE_SERVER_ACTIVATE, pdata, peer))
@ -408,7 +407,7 @@ static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 chann
if (!pc) if (!pc)
goto original_cb; goto original_cb;
channel = HashTable_GetItemValue(ps->channelsById, &channelId64); channel = HashTable_GetItemValue(ps->channelsByFrontId, &channelId64);
if (!channel) if (!channel)
{ {
PROXY_LOG_ERR(TAG, ps, "channel id=%" PRIu64 " not registered here, dropping", channelId64); PROXY_LOG_ERR(TAG, ps, "channel id=%" PRIu64 " not registered here, dropping", channelId64);