From 46eb50df2c38be1b3e343ae03361f0b2df5a2f8e Mon Sep 17 00:00:00 2001 From: David Fort Date: Thu, 3 Feb 2022 17:36:53 +0100 Subject: [PATCH] proxy: improve channel treatment This PR introduces per channel context so that we can speed up operations like retrieving the channel name from its id, or knowing what shall be done for a packet (no config ACL recomputation at each packet). --- include/freerdp/server/proxy/proxy_context.h | 25 +++++++ libfreerdp/core/gcc.c | 2 +- server/proxy/pf_client.c | 37 +++++----- server/proxy/pf_context.c | 59 ++++++++++++++++ server/proxy/pf_server.c | 72 +++++++++++++++----- server/proxy/pf_utils.h | 9 +-- 6 files changed, 160 insertions(+), 44 deletions(-) diff --git a/include/freerdp/server/proxy/proxy_context.h b/include/freerdp/server/proxy/proxy_context.h index 93a365655..cfaae08bb 100644 --- a/include/freerdp/server/proxy/proxy_context.h +++ b/include/freerdp/server/proxy/proxy_context.h @@ -47,6 +47,28 @@ extern "C" * and set their cleanup function accordingly. */ FREERDP_API void intercept_context_entry_free(void* obj); + /** @brief how is handled a channel */ + enum pf_utils_channel_mode + { + PF_UTILS_CHANNEL_NOT_HANDLED, + PF_UTILS_CHANNEL_BLOCK, + PF_UTILS_CHANNEL_PASSTHROUGH, + PF_UTILS_CHANNEL_INTERCEPT, + }; + typedef enum pf_utils_channel_mode pf_utils_channel_mode; + + /** @brief per channel configuration */ + struct p_server_channel_context + { + char* channel_name; + UINT32 channel_id; + BOOL isDynamic; + pf_utils_channel_mode channelMode; + }; + typedef struct p_server_channel_context pServerChannelContext; + + void ChannelContext_free(pServerChannelContext* ctx); + /** * Wraps rdpContext and holds the state for the proxy's server. */ @@ -60,9 +82,12 @@ extern "C" HANDLE dynvcReady; wHashTable* interceptContextMap; + wHashTable* channelsById; }; typedef struct p_server_context pServerContext; + pServerChannelContext* ChannelContext_new(pServerContext* ps, const char* name, UINT32 id); + /** * Wraps rdpContext and holds the state for the proxy's client. */ diff --git a/libfreerdp/core/gcc.c b/libfreerdp/core/gcc.c index 8d52a887e..9333bcafb 100644 --- a/libfreerdp/core/gcc.c +++ b/libfreerdp/core/gcc.c @@ -2006,7 +2006,7 @@ BOOL gcc_write_client_monitor_data(wStream* s, const rdpMcs* mcs) Stream_Write_UINT32(s, flags); /* flags */ } } - WLog_DBG(TAG, "[%s] FINISHED" PRIu32, __FUNCTION__); + WLog_DBG(TAG, "[%s] FINISHED", __FUNCTION__); return TRUE; } diff --git a/server/proxy/pf_client.c b/server/proxy/pf_client.c index 841c98014..678f3004f 100644 --- a/server/proxy/pf_client.c +++ b/server/proxy/pf_client.c @@ -376,9 +376,10 @@ static BOOL pf_client_pre_connect(freerdp* instance) return pf_modules_run_hook(pc->pdata->module, HOOK_TYPE_CLIENT_PRE_CONNECT, pc->pdata, pc); } -static BOOL pf_client_receive_channel_passthrough(proxyData* pdata, UINT16 channelId, - const char* channel_name, const BYTE* xdata, - size_t xsize, UINT32 flags, size_t totalSize) +static BOOL pf_client_receive_channel_passthrough(proxyData* pdata, + const pServerChannelContext* channel, + const BYTE* xdata, size_t xsize, UINT32 flags, + size_t totalSize) { proxyChannelDataEventInfo ev; UINT16 server_channel_id; @@ -389,8 +390,8 @@ static BOOL pf_client_receive_channel_passthrough(proxyData* pdata, UINT16 chann ps = pdata->ps; WINPR_ASSERT(ps); - ev.channel_id = channelId; - ev.channel_name = channel_name; + ev.channel_id = channel->channel_id; + ev.channel_name = channel->channel_name; ev.data = xdata; ev.data_len = xsize; ev.flags = flags; @@ -411,7 +412,7 @@ static BOOL pf_client_receive_channel_passthrough(proxyData* pdata, UINT16 chann * CREATE_REQUEST_PDU (0x01) packets as invalid. */ if ((flags & CHANNEL_FLAG_FIRST) && - (strncmp(channel_name, DRDYNVC_SVC_CHANNEL_NAME, CHANNEL_NAME_LEN + 1) == 0)) + (strncmp(channel->channel_name, DRDYNVC_SVC_CHANNEL_NAME, CHANNEL_NAME_LEN + 1) == 0)) { BYTE cmd, first; wStream *s, sbuffer; @@ -468,7 +469,7 @@ static BOOL pf_client_receive_channel_passthrough(proxyData* pdata, UINT16 chann return TRUE; /* Silently drop */ } } - server_channel_id = WTSChannelGetId(ps->context.peer, channel_name); + server_channel_id = WTSChannelGetId(ps->context.peer, channel->channel_name); /* Ignore messages for channels that can not be mapped. * The client might not have enabled support for this specific channel, @@ -499,12 +500,11 @@ static BOOL pf_client_receive_channel_data_hook(freerdp* instance, UINT16 channe const BYTE* xdata, size_t xsize, UINT32 flags, size_t totalSize) { - const char* channel_name = freerdp_channels_get_name_by_id(instance, channelId); pClientContext* pc; pServerContext* ps; proxyData* pdata; - const proxyConfig* config; - pf_utils_channel_mode pass; + pServerChannelContext* channel; + UINT32 channelId32 = channelId; WINPR_ASSERT(instance); WINPR_ASSERT(xdata || (xsize == 0)); @@ -519,21 +519,20 @@ static BOOL pf_client_receive_channel_data_hook(freerdp* instance, UINT16 channe pdata = ps->pdata; WINPR_ASSERT(pdata); - config = pdata->config; - WINPR_ASSERT(config); + channel = HashTable_GetItemValue(ps->channelsById, &channelId32); + if (!channel) + return TRUE; - pass = pf_utils_get_channel_mode(config, channel_name); - - switch (pass) + switch (channel->channelMode) { case PF_UTILS_CHANNEL_BLOCK: return TRUE; /* Silently drop */ case PF_UTILS_CHANNEL_PASSTHROUGH: - return pf_client_receive_channel_passthrough(pdata, channelId, channel_name, xdata, - xsize, flags, totalSize); + return pf_client_receive_channel_passthrough(pdata, channel, xdata, xsize, flags, + totalSize); case PF_UTILS_CHANNEL_INTERCEPT: - return pf_client_receive_channel_intercept(pdata, channelId, channel_name, xdata, xsize, - flags, totalSize); + return pf_client_receive_channel_intercept(pdata, channelId, channel->channel_name, + xdata, xsize, flags, totalSize); case PF_UTILS_CHANNEL_NOT_HANDLED: default: WINPR_ASSERT(pc->client_receive_channel_data_original); diff --git a/server/proxy/pf_context.c b/server/proxy/pf_context.c index d41cc9301..f1be7a6b0 100644 --- a/server/proxy/pf_context.c +++ b/server/proxy/pf_context.c @@ -24,13 +24,59 @@ #include #include +#include #include #include "pf_client.h" +#include "pf_utils.h" #include #include "channels/pf_channel_rdpdr.h" +#define TAG PROXY_TAG("server") + +static UINT32 ChannelId_Hash(const void* key) +{ + const UINT32* v = (const UINT32*)key; + return *v; +} + +static BOOL ChannelId_Compare(const UINT32* v1, const UINT32* v2) +{ + return (*v1 == *v2); +} + +pServerChannelContext* ChannelContext_new(pServerContext* ps, const char* name, UINT32 id) +{ + pServerChannelContext* ret = calloc(1, sizeof(*ret)); + if (!ret) + { + PROXY_LOG_ERR(TAG, ps, "error allocating channel context for '%s'", name); + return NULL; + } + + ret->channel_id = id; + ret->channel_name = _strdup(name); + if (!ret->channel_name) + { + PROXY_LOG_ERR(TAG, ps, "error allocating name in channel context for '%s'", ret); + free(ret); + return NULL; + } + + ret->channelMode = pf_utils_get_channel_mode(ps->pdata->config, name); + return ret; +} + +void ChannelContext_free(pServerChannelContext* ctx) +{ + if (!ctx) + return; + + free(ctx->channel_name); + free(ctx); +} + /* Proxy context initialization callback */ static void client_to_proxy_context_free(freerdp_peer* client, rdpContext* ctx); static BOOL client_to_proxy_context_new(freerdp_peer* client, rdpContext* ctx) @@ -60,6 +106,18 @@ static BOOL client_to_proxy_context_new(freerdp_peer* client, rdpContext* ctx) WINPR_ASSERT(obj); obj->fnObjectFree = intercept_context_entry_free; + /* channels by ids */ + context->channelsById = HashTable_New(FALSE); + if (!context->channelsById) + goto error; + if (!HashTable_SetHashFunction(context->channelsById, ChannelId_Hash)) + goto error; + + obj = HashTable_KeyObject(context->channelsById); + obj->fnObjectEquals = (OBJECT_EQUALS_FN)ChannelId_Compare; + + obj = HashTable_ValueObject(context->channelsById); + obj->fnObjectFree = (OBJECT_FREE_FN)ChannelContext_free; return TRUE; error: @@ -85,6 +143,7 @@ void client_to_proxy_context_free(freerdp_peer* client, rdpContext* ctx) } HashTable_Free(context->interceptContextMap); + HashTable_Free(context->channelsById); if (context->vcm && (context->vcm != INVALID_HANDLE_VALUE)) WTSCloseServer((HANDLE)context->vcm); diff --git a/server/proxy/pf_server.c b/server/proxy/pf_server.c index b59ff001a..88947c6c4 100644 --- a/server/proxy/pf_server.c +++ b/server/proxy/pf_server.c @@ -177,7 +177,46 @@ static BOOL pf_server_get_target_info(rdpContext* context, rdpSettings* settings return TRUE; } +static BOOL pf_server_setup_channels(freerdp_peer* peer) +{ + char** accepted_channels = NULL; + size_t accepted_channels_count; + size_t i; + pServerContext* ps = (pServerContext*)peer->context; + wHashTable* byId = ps->channelsById; + + accepted_channels = WTSGetAcceptedChannelNames(peer, &accepted_channels_count); + if (!accepted_channels) + return TRUE; + + for (i = 0; i < accepted_channels_count; i++) + { + pServerChannelContext* channelContext; + const char* cname = accepted_channels[i]; + UINT16 channelId = WTSChannelGetId(peer, cname); + + PROXY_LOG_INFO(TAG, ps, "Accepted channel: %s", cname); + channelContext = ChannelContext_new(ps, cname, channelId); + if (!channelContext) + { + PROXY_LOG_ERR(TAG, ps, "error seting up channelContext for '%s'", cname); + return FALSE; + } + + if (!HashTable_Insert(byId, &channelContext->channel_id, channelContext)) + { + ChannelContext_free(channelContext); + PROXY_LOG_ERR(TAG, ps, "error inserting channelContext in byId table for '%s'", cname); + return FALSE; + } + } + + free(accepted_channels); + return TRUE; +} + /* Event callbacks */ + /** * This callback is called when the entire connection sequence is done (as * described in MS-RDPBCGR section 1.3) @@ -191,9 +230,6 @@ static BOOL pf_server_post_connect(freerdp_peer* peer) pClientContext* pc; rdpSettings* client_settings; proxyData* pdata; - char** accepted_channels = NULL; - size_t accepted_channels_count; - size_t i; WINPR_ASSERT(peer); @@ -204,13 +240,10 @@ static BOOL pf_server_post_connect(freerdp_peer* peer) WINPR_ASSERT(pdata); PROXY_LOG_INFO(TAG, ps, "Accepted client: %s", peer->settings->ClientHostname); - accepted_channels = WTSGetAcceptedChannelNames(peer, &accepted_channels_count); - if (accepted_channels) + if (!pf_server_setup_channels(peer)) { - for (i = 0; i < accepted_channels_count; i++) - PROXY_LOG_INFO(TAG, ps, "Accepted channel: %s", accepted_channels[i]); - - free(accepted_channels); + PROXY_LOG_ERR(TAG, ps, "error setting up channels"); + return FALSE; } pc = pf_context_create_client_context(peer->settings); @@ -276,6 +309,7 @@ static BOOL pf_server_activate(freerdp_peer* peer) peer->settings->CompressionLevel = PACKET_COMPR_TYPE_RDP8; if (!pf_modules_run_hook(pdata->module, HOOK_TYPE_SERVER_ACTIVATE, pdata, peer)) return FALSE; + return TRUE; } @@ -334,8 +368,8 @@ static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 chann pClientContext* pc; proxyData* pdata; const proxyConfig* config; - pf_utils_channel_mode pass; - const char* channel_name = WTSChannelGetName(peer, channelId); + const pServerChannelContext* channel; + UINT32 channelId32 = channelId; WINPR_ASSERT(peer); @@ -356,8 +390,14 @@ static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 chann if (!pc) goto original_cb; - pass = pf_utils_get_channel_mode(config, channel_name); - switch (pass) + channel = HashTable_GetItemValue(ps->channelsById, &channelId32); + if (!channel) + { + PROXY_LOG_ERR(TAG, ps, "channel id=%d not registered here, dropping", channelId32); + return TRUE; + } + + switch (channel->channelMode) { case PF_UTILS_CHANNEL_BLOCK: return TRUE; @@ -366,7 +406,7 @@ static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 chann proxyChannelDataEventInfo ev; ev.channel_id = channelId; - ev.channel_name = channel_name; + ev.channel_name = channel->channel_name; ev.data = data; ev.data_len = size; ev.flags = flags; @@ -379,8 +419,8 @@ static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 chann return IFCALLRESULT(TRUE, pc->sendChannelData, pc, &ev); } case PF_UTILS_CHANNEL_INTERCEPT: - return pf_server_receive_channel_intercept(pdata, channelId, channel_name, data, size, - flags, totalSize); + return pf_server_receive_channel_intercept(pdata, channelId, channel->channel_name, + data, size, flags, totalSize); default: break; } diff --git a/server/proxy/pf_utils.h b/server/proxy/pf_utils.h index 0d03b3c91..0e899e94a 100644 --- a/server/proxy/pf_utils.h +++ b/server/proxy/pf_utils.h @@ -22,6 +22,7 @@ #define FREERDP_SERVER_PROXY_PFUTILS_H #include +#include /** * @brief pf_utils_channel_is_passthrough Checks of a channel identified by 'name' @@ -34,14 +35,6 @@ * e.g. proxy client and server are termination points and data passed * between. */ -typedef enum -{ - PF_UTILS_CHANNEL_NOT_HANDLED, - PF_UTILS_CHANNEL_BLOCK, - PF_UTILS_CHANNEL_PASSTHROUGH, - PF_UTILS_CHANNEL_INTERCEPT, -} pf_utils_channel_mode; - pf_utils_channel_mode pf_utils_get_channel_mode(const proxyConfig* config, const char* name); const char* pf_utils_channel_mode_string(pf_utils_channel_mode mode);