diff --git a/include/freerdp/freerdp.h b/include/freerdp/freerdp.h index 922f91346..77b4b2732 100644 --- a/include/freerdp/freerdp.h +++ b/include/freerdp/freerdp.h @@ -65,6 +65,9 @@ extern "C" { #endif +#define MCS_BASE_CHANNEL_ID 1001 +#define MCS_GLOBAL_CHANNEL_ID 1003 + /* Flags used by certificate callbacks */ #define VERIFY_CERT_FLAG_NONE 0x00 #define VERIFY_CERT_FLAG_LEGACY 0x02 diff --git a/include/freerdp/server/proxy/proxy_context.h b/include/freerdp/server/proxy/proxy_context.h index 0c6288380..6e377dff4 100644 --- a/include/freerdp/server/proxy/proxy_context.h +++ b/include/freerdp/server/proxy/proxy_context.h @@ -74,7 +74,8 @@ extern "C" struct p_server_static_channel_context { char* channel_name; - UINT32 channel_id; + UINT32 front_channel_id; + UINT32 back_channel_id; pf_utils_channel_mode channelMode; proxyChannelDataFn onFrontData; proxyChannelDataFn onBackData; @@ -97,7 +98,8 @@ extern "C" HANDLE dynvcReady; wHashTable* interceptContextMap; - wHashTable* channelsById; + wHashTable* channelsByFrontId; + wHashTable* channelsByBackId; }; typedef struct p_server_context pServerContext; diff --git a/libfreerdp/core/mcs.h b/libfreerdp/core/mcs.h index 3a1dab34a..b7365d3f0 100644 --- a/libfreerdp/core/mcs.h +++ b/libfreerdp/core/mcs.h @@ -33,9 +33,6 @@ typedef struct rdp_mcs rdpMcs; #include #include -#define MCS_BASE_CHANNEL_ID 1001 -#define MCS_GLOBAL_CHANNEL_ID 1003 - enum MCS_Result { MCS_Result_successful = 0, diff --git a/server/proxy/channels/pf_channel_rdpdr.c b/server/proxy/channels/pf_channel_rdpdr.c index 742bb85a5..8b7832e95 100644 --- a/server/proxy/channels/pf_channel_rdpdr.c +++ b/server/proxy/channels/pf_channel_rdpdr.c @@ -1709,7 +1709,7 @@ static PfChannelResult pf_rdpdr_back_data(proxyData* pdata, WINPR_ASSERT(pdata); WINPR_ASSERT(channel); - if (!pf_channel_rdpdr_client_handle(pdata->pc, channel->channel_id, channel->channel_name, + if (!pf_channel_rdpdr_client_handle(pdata->pc, channel->back_channel_id, channel->channel_name, xdata, xsize, flags, totalSize)) { WLog_ERR(TAG, "error treating client back data"); @@ -1731,7 +1731,7 @@ static PfChannelResult pf_rdpdr_front_data(proxyData* pdata, WINPR_ASSERT(pdata); WINPR_ASSERT(channel); - if (!pf_channel_rdpdr_server_handle(pdata->ps, channel->channel_id, channel->channel_name, + if (!pf_channel_rdpdr_server_handle(pdata->ps, channel->front_channel_id, channel->channel_name, xdata, xsize, flags, totalSize)) { WLog_ERR(TAG, "error treating front data"); diff --git a/server/proxy/pf_channel.c b/server/proxy/pf_channel.c index 5b84d8e2f..2289a400a 100644 --- a/server/proxy/pf_channel.c +++ b/server/proxy/pf_channel.c @@ -160,7 +160,7 @@ PfChannelResult channelTracker_flushCurrent(ChannelStateTracker* t, BOOL first, { proxyChannelDataEventInfo ev; - ev.channel_id = channel->channel_id; + ev.channel_id = channel->front_channel_id; ev.channel_name = channel->channel_name; ev.data = Stream_Buffer(t->currentPacket); ev.data_len = Stream_GetPosition(t->currentPacket); @@ -176,7 +176,7 @@ PfChannelResult channelTracker_flushCurrent(ChannelStateTracker* t, BOOL first, ps = pdata->ps; r = ps->context.peer->SendChannelPacket( - ps->context.peer, channel->channel_id, t->currentPacketSize, flags, + ps->context.peer, channel->front_channel_id, t->currentPacketSize, flags, Stream_Buffer(t->currentPacket), Stream_GetPosition(t->currentPacket)); return r ? PF_CHANNEL_RESULT_DROP : PF_CHANNEL_RESULT_ERROR; @@ -195,7 +195,7 @@ static PfChannelResult pf_channel_generic_back_data(proxyData* pdata, switch (channel->channelMode) { case PF_UTILS_CHANNEL_PASSTHROUGH: - ev.channel_id = channel->channel_id; + ev.channel_id = channel->back_channel_id; ev.channel_name = channel->channel_name; ev.data = xdata; ev.data_len = xsize; @@ -229,7 +229,7 @@ static PfChannelResult pf_channel_generic_front_data(proxyData* pdata, switch (channel->channelMode) { case PF_UTILS_CHANNEL_PASSTHROUGH: - ev.channel_id = channel->channel_id; + ev.channel_id = channel->front_channel_id; ev.channel_name = channel->channel_name; ev.data = xdata; ev.data_len = xsize; diff --git a/server/proxy/pf_client.c b/server/proxy/pf_client.c index cb41f6140..6d3ca0443 100644 --- a/server/proxy/pf_client.c +++ b/server/proxy/pf_client.c @@ -309,6 +309,37 @@ static BOOL pf_client_pre_connect(freerdp* instance) return pf_modules_run_hook(pc->pdata->module, HOOK_TYPE_CLIENT_PRE_CONNECT, pc->pdata, pc); } +/** @brief arguments for updateBackIdFn */ +typedef struct +{ + pServerContext* ps; + const char* name; + UINT32 backId; +} UpdateBackIdArgs; + +static BOOL updateBackIdFn(const void* key, void* value, void* arg) +{ + pServerStaticChannelContext* current = (pServerStaticChannelContext*)value; + UpdateBackIdArgs* updateArgs = (UpdateBackIdArgs*)arg; + + if (strcmp(updateArgs->name, current->channel_name) != 0) + return TRUE; + + current->back_channel_id = updateArgs->backId; + if (!HashTable_Insert(updateArgs->ps->channelsByBackId, ¤t->back_channel_id, current)) + { + WLog_ERR(TAG, "error inserting channel in channelsByBackId table"); + } + return FALSE; +} + +static BOOL pf_client_update_back_id(pServerContext* ps, const char* name, UINT32 backId) +{ + UpdateBackIdArgs res = { ps, name, backId }; + + return HashTable_Foreach(ps->channelsByFrontId, updateBackIdFn, &res) == FALSE; +} + static BOOL pf_client_load_channels(freerdp* instance) { pClientContext* pc; @@ -364,6 +395,7 @@ static BOOL pf_client_load_channels(freerdp* instance) CHANNEL_DEF* channels = (CHANNEL_DEF*)freerdp_settings_get_pointer_array_writable( settings, FreeRDP_ChannelDefArray, 0); size_t x, size = freerdp_settings_get_uint32(settings, FreeRDP_ChannelDefArraySize); + UINT32 id = MCS_GLOBAL_CHANNEL_ID + 1; WINPR_ASSERT(channels || (size == 0)); @@ -385,7 +417,14 @@ static BOOL pf_client_load_channels(freerdp* instance) size--; } else + { + if (!pf_client_update_back_id(ps, cur->name, id++)) + { + WLog_ERR(TAG, "unable to update backid for channel %s", cur->name); + return FALSE; + } x++; + } } if (!freerdp_settings_set_uint32(settings, FreeRDP_ChannelCount, x)) @@ -419,7 +458,7 @@ static BOOL pf_client_receive_channel_data_hook(freerdp* instance, UINT16 channe pdata = ps->pdata; WINPR_ASSERT(pdata); - channel = HashTable_GetItemValue(ps->channelsById, &channelId64); + channel = HashTable_GetItemValue(ps->channelsByBackId, &channelId64); if (!channel) return TRUE; diff --git a/server/proxy/pf_context.c b/server/proxy/pf_context.c index 22c1ebf25..7dab6ce17 100644 --- a/server/proxy/pf_context.c +++ b/server/proxy/pf_context.c @@ -56,7 +56,7 @@ pServerStaticChannelContext* StaticChannelContext_new(pServerContext* ps, const return NULL; } - ret->channel_id = id; + ret->front_channel_id = id; ret->channel_name = _strdup(name); if (!ret->channel_name) { @@ -110,17 +110,27 @@ static BOOL client_to_proxy_context_new(freerdp_peer* client, rdpContext* ctx) obj->fnObjectFree = intercept_context_entry_free; /* channels by ids */ - context->channelsById = HashTable_New(FALSE); - if (!context->channelsById) + context->channelsByFrontId = HashTable_New(FALSE); + if (!context->channelsByFrontId) goto error; - if (!HashTable_SetHashFunction(context->channelsById, ChannelId_Hash)) + if (!HashTable_SetHashFunction(context->channelsByFrontId, ChannelId_Hash)) goto error; - obj = HashTable_KeyObject(context->channelsById); + obj = HashTable_KeyObject(context->channelsByFrontId); obj->fnObjectEquals = (OBJECT_EQUALS_FN)ChannelId_Compare; - obj = HashTable_ValueObject(context->channelsById); + obj = HashTable_ValueObject(context->channelsByFrontId); obj->fnObjectFree = (OBJECT_FREE_FN)StaticChannelContext_free; + + context->channelsByBackId = HashTable_New(FALSE); + if (!context->channelsByBackId) + goto error; + if (!HashTable_SetHashFunction(context->channelsByBackId, ChannelId_Hash)) + goto error; + + obj = HashTable_KeyObject(context->channelsByBackId); + obj->fnObjectEquals = (OBJECT_EQUALS_FN)ChannelId_Compare; + return TRUE; error: @@ -146,7 +156,8 @@ void client_to_proxy_context_free(freerdp_peer* client, rdpContext* ctx) } HashTable_Free(context->interceptContextMap); - HashTable_Free(context->channelsById); + HashTable_Free(context->channelsByFrontId); + HashTable_Free(context->channelsByBackId); if (context->vcm && (context->vcm != INVALID_HANDLE_VALUE)) WTSCloseServer((HANDLE)context->vcm); diff --git a/server/proxy/pf_server.c b/server/proxy/pf_server.c index 80c821c15..f623e8107 100644 --- a/server/proxy/pf_server.c +++ b/server/proxy/pf_server.c @@ -194,7 +194,6 @@ static BOOL pf_server_setup_channels(freerdp_peer* peer) size_t accepted_channels_count; size_t i; pServerContext* ps = (pServerContext*)peer->context; - wHashTable* byId = ps->channelsById; accepted_channels = WTSGetAcceptedChannelNames(peer, &accepted_channels_count); if (!accepted_channels) @@ -243,7 +242,8 @@ static BOOL pf_server_setup_channels(freerdp_peer* peer) } } - if (!HashTable_Insert(byId, &channelContext->channel_id, channelContext)) + if (!HashTable_Insert(ps->channelsByFrontId, &channelContext->front_channel_id, + channelContext)) { StaticChannelContext_free(channelContext); PROXY_LOG_ERR(TAG, ps, "error inserting channelContext in byId table for '%s'", cname); @@ -339,7 +339,6 @@ static BOOL pf_server_activate(freerdp_peer* peer) WINPR_ASSERT(pdata); settings = peer->context->settings; - WINPR_ASSERT(settings); settings->CompressionLevel = PACKET_COMPR_TYPE_RDP8; if (!pf_modules_run_hook(pdata->module, HOOK_TYPE_SERVER_ACTIVATE, pdata, peer)) @@ -408,7 +407,7 @@ static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 chann if (!pc) goto original_cb; - channel = HashTable_GetItemValue(ps->channelsById, &channelId64); + channel = HashTable_GetItemValue(ps->channelsByFrontId, &channelId64); if (!channel) { PROXY_LOG_ERR(TAG, ps, "channel id=%" PRIu64 " not registered here, dropping", channelId64);