proxy: improve channel treatment

This PR introduces per channel context so that we can speed up operations like
retrieving the channel name from its id, or knowing what shall be done for a
packet (no config ACL recomputation at each packet).
This commit is contained in:
David Fort 2022-02-03 17:36:53 +01:00 committed by akallabeth
parent cefb4e1237
commit 46eb50df2c
6 changed files with 160 additions and 44 deletions

View File

@ -47,6 +47,28 @@ extern "C"
* and set their cleanup function accordingly. */ * and set their cleanup function accordingly. */
FREERDP_API void intercept_context_entry_free(void* obj); FREERDP_API void intercept_context_entry_free(void* obj);
/** @brief how is handled a channel */
enum pf_utils_channel_mode
{
PF_UTILS_CHANNEL_NOT_HANDLED,
PF_UTILS_CHANNEL_BLOCK,
PF_UTILS_CHANNEL_PASSTHROUGH,
PF_UTILS_CHANNEL_INTERCEPT,
};
typedef enum pf_utils_channel_mode pf_utils_channel_mode;
/** @brief per channel configuration */
struct p_server_channel_context
{
char* channel_name;
UINT32 channel_id;
BOOL isDynamic;
pf_utils_channel_mode channelMode;
};
typedef struct p_server_channel_context pServerChannelContext;
void ChannelContext_free(pServerChannelContext* ctx);
/** /**
* Wraps rdpContext and holds the state for the proxy's server. * Wraps rdpContext and holds the state for the proxy's server.
*/ */
@ -60,9 +82,12 @@ extern "C"
HANDLE dynvcReady; HANDLE dynvcReady;
wHashTable* interceptContextMap; wHashTable* interceptContextMap;
wHashTable* channelsById;
}; };
typedef struct p_server_context pServerContext; typedef struct p_server_context pServerContext;
pServerChannelContext* ChannelContext_new(pServerContext* ps, const char* name, UINT32 id);
/** /**
* Wraps rdpContext and holds the state for the proxy's client. * Wraps rdpContext and holds the state for the proxy's client.
*/ */

View File

@ -2006,7 +2006,7 @@ BOOL gcc_write_client_monitor_data(wStream* s, const rdpMcs* mcs)
Stream_Write_UINT32(s, flags); /* flags */ Stream_Write_UINT32(s, flags); /* flags */
} }
} }
WLog_DBG(TAG, "[%s] FINISHED" PRIu32, __FUNCTION__); WLog_DBG(TAG, "[%s] FINISHED", __FUNCTION__);
return TRUE; return TRUE;
} }

View File

@ -376,9 +376,10 @@ static BOOL pf_client_pre_connect(freerdp* instance)
return pf_modules_run_hook(pc->pdata->module, HOOK_TYPE_CLIENT_PRE_CONNECT, pc->pdata, pc); return pf_modules_run_hook(pc->pdata->module, HOOK_TYPE_CLIENT_PRE_CONNECT, pc->pdata, pc);
} }
static BOOL pf_client_receive_channel_passthrough(proxyData* pdata, UINT16 channelId, static BOOL pf_client_receive_channel_passthrough(proxyData* pdata,
const char* channel_name, const BYTE* xdata, const pServerChannelContext* channel,
size_t xsize, UINT32 flags, size_t totalSize) const BYTE* xdata, size_t xsize, UINT32 flags,
size_t totalSize)
{ {
proxyChannelDataEventInfo ev; proxyChannelDataEventInfo ev;
UINT16 server_channel_id; UINT16 server_channel_id;
@ -389,8 +390,8 @@ static BOOL pf_client_receive_channel_passthrough(proxyData* pdata, UINT16 chann
ps = pdata->ps; ps = pdata->ps;
WINPR_ASSERT(ps); WINPR_ASSERT(ps);
ev.channel_id = channelId; ev.channel_id = channel->channel_id;
ev.channel_name = channel_name; ev.channel_name = channel->channel_name;
ev.data = xdata; ev.data = xdata;
ev.data_len = xsize; ev.data_len = xsize;
ev.flags = flags; ev.flags = flags;
@ -411,7 +412,7 @@ static BOOL pf_client_receive_channel_passthrough(proxyData* pdata, UINT16 chann
* CREATE_REQUEST_PDU (0x01) packets as invalid. * CREATE_REQUEST_PDU (0x01) packets as invalid.
*/ */
if ((flags & CHANNEL_FLAG_FIRST) && if ((flags & CHANNEL_FLAG_FIRST) &&
(strncmp(channel_name, DRDYNVC_SVC_CHANNEL_NAME, CHANNEL_NAME_LEN + 1) == 0)) (strncmp(channel->channel_name, DRDYNVC_SVC_CHANNEL_NAME, CHANNEL_NAME_LEN + 1) == 0))
{ {
BYTE cmd, first; BYTE cmd, first;
wStream *s, sbuffer; wStream *s, sbuffer;
@ -468,7 +469,7 @@ static BOOL pf_client_receive_channel_passthrough(proxyData* pdata, UINT16 chann
return TRUE; /* Silently drop */ return TRUE; /* Silently drop */
} }
} }
server_channel_id = WTSChannelGetId(ps->context.peer, channel_name); server_channel_id = WTSChannelGetId(ps->context.peer, channel->channel_name);
/* Ignore messages for channels that can not be mapped. /* Ignore messages for channels that can not be mapped.
* The client might not have enabled support for this specific channel, * The client might not have enabled support for this specific channel,
@ -499,12 +500,11 @@ static BOOL pf_client_receive_channel_data_hook(freerdp* instance, UINT16 channe
const BYTE* xdata, size_t xsize, UINT32 flags, const BYTE* xdata, size_t xsize, UINT32 flags,
size_t totalSize) size_t totalSize)
{ {
const char* channel_name = freerdp_channels_get_name_by_id(instance, channelId);
pClientContext* pc; pClientContext* pc;
pServerContext* ps; pServerContext* ps;
proxyData* pdata; proxyData* pdata;
const proxyConfig* config; pServerChannelContext* channel;
pf_utils_channel_mode pass; UINT32 channelId32 = channelId;
WINPR_ASSERT(instance); WINPR_ASSERT(instance);
WINPR_ASSERT(xdata || (xsize == 0)); WINPR_ASSERT(xdata || (xsize == 0));
@ -519,21 +519,20 @@ static BOOL pf_client_receive_channel_data_hook(freerdp* instance, UINT16 channe
pdata = ps->pdata; pdata = ps->pdata;
WINPR_ASSERT(pdata); WINPR_ASSERT(pdata);
config = pdata->config; channel = HashTable_GetItemValue(ps->channelsById, &channelId32);
WINPR_ASSERT(config); if (!channel)
return TRUE;
pass = pf_utils_get_channel_mode(config, channel_name); switch (channel->channelMode)
switch (pass)
{ {
case PF_UTILS_CHANNEL_BLOCK: case PF_UTILS_CHANNEL_BLOCK:
return TRUE; /* Silently drop */ return TRUE; /* Silently drop */
case PF_UTILS_CHANNEL_PASSTHROUGH: case PF_UTILS_CHANNEL_PASSTHROUGH:
return pf_client_receive_channel_passthrough(pdata, channelId, channel_name, xdata, return pf_client_receive_channel_passthrough(pdata, channel, xdata, xsize, flags,
xsize, flags, totalSize); totalSize);
case PF_UTILS_CHANNEL_INTERCEPT: case PF_UTILS_CHANNEL_INTERCEPT:
return pf_client_receive_channel_intercept(pdata, channelId, channel_name, xdata, xsize, return pf_client_receive_channel_intercept(pdata, channelId, channel->channel_name,
flags, totalSize); xdata, xsize, flags, totalSize);
case PF_UTILS_CHANNEL_NOT_HANDLED: case PF_UTILS_CHANNEL_NOT_HANDLED:
default: default:
WINPR_ASSERT(pc->client_receive_channel_data_original); WINPR_ASSERT(pc->client_receive_channel_data_original);

View File

@ -24,13 +24,59 @@
#include <winpr/crypto.h> #include <winpr/crypto.h>
#include <winpr/print.h> #include <winpr/print.h>
#include <freerdp/server/proxy/proxy_log.h>
#include <freerdp/server/proxy/proxy_server.h> #include <freerdp/server/proxy/proxy_server.h>
#include "pf_client.h" #include "pf_client.h"
#include "pf_utils.h"
#include <freerdp/server/proxy/proxy_context.h> #include <freerdp/server/proxy/proxy_context.h>
#include "channels/pf_channel_rdpdr.h" #include "channels/pf_channel_rdpdr.h"
#define TAG PROXY_TAG("server")
static UINT32 ChannelId_Hash(const void* key)
{
const UINT32* v = (const UINT32*)key;
return *v;
}
static BOOL ChannelId_Compare(const UINT32* v1, const UINT32* v2)
{
return (*v1 == *v2);
}
pServerChannelContext* ChannelContext_new(pServerContext* ps, const char* name, UINT32 id)
{
pServerChannelContext* ret = calloc(1, sizeof(*ret));
if (!ret)
{
PROXY_LOG_ERR(TAG, ps, "error allocating channel context for '%s'", name);
return NULL;
}
ret->channel_id = id;
ret->channel_name = _strdup(name);
if (!ret->channel_name)
{
PROXY_LOG_ERR(TAG, ps, "error allocating name in channel context for '%s'", ret);
free(ret);
return NULL;
}
ret->channelMode = pf_utils_get_channel_mode(ps->pdata->config, name);
return ret;
}
void ChannelContext_free(pServerChannelContext* ctx)
{
if (!ctx)
return;
free(ctx->channel_name);
free(ctx);
}
/* Proxy context initialization callback */ /* Proxy context initialization callback */
static void client_to_proxy_context_free(freerdp_peer* client, rdpContext* ctx); static void client_to_proxy_context_free(freerdp_peer* client, rdpContext* ctx);
static BOOL client_to_proxy_context_new(freerdp_peer* client, rdpContext* ctx) static BOOL client_to_proxy_context_new(freerdp_peer* client, rdpContext* ctx)
@ -60,6 +106,18 @@ static BOOL client_to_proxy_context_new(freerdp_peer* client, rdpContext* ctx)
WINPR_ASSERT(obj); WINPR_ASSERT(obj);
obj->fnObjectFree = intercept_context_entry_free; obj->fnObjectFree = intercept_context_entry_free;
/* channels by ids */
context->channelsById = HashTable_New(FALSE);
if (!context->channelsById)
goto error;
if (!HashTable_SetHashFunction(context->channelsById, ChannelId_Hash))
goto error;
obj = HashTable_KeyObject(context->channelsById);
obj->fnObjectEquals = (OBJECT_EQUALS_FN)ChannelId_Compare;
obj = HashTable_ValueObject(context->channelsById);
obj->fnObjectFree = (OBJECT_FREE_FN)ChannelContext_free;
return TRUE; return TRUE;
error: error:
@ -85,6 +143,7 @@ void client_to_proxy_context_free(freerdp_peer* client, rdpContext* ctx)
} }
HashTable_Free(context->interceptContextMap); HashTable_Free(context->interceptContextMap);
HashTable_Free(context->channelsById);
if (context->vcm && (context->vcm != INVALID_HANDLE_VALUE)) if (context->vcm && (context->vcm != INVALID_HANDLE_VALUE))
WTSCloseServer((HANDLE)context->vcm); WTSCloseServer((HANDLE)context->vcm);

View File

@ -177,7 +177,46 @@ static BOOL pf_server_get_target_info(rdpContext* context, rdpSettings* settings
return TRUE; return TRUE;
} }
static BOOL pf_server_setup_channels(freerdp_peer* peer)
{
char** accepted_channels = NULL;
size_t accepted_channels_count;
size_t i;
pServerContext* ps = (pServerContext*)peer->context;
wHashTable* byId = ps->channelsById;
accepted_channels = WTSGetAcceptedChannelNames(peer, &accepted_channels_count);
if (!accepted_channels)
return TRUE;
for (i = 0; i < accepted_channels_count; i++)
{
pServerChannelContext* channelContext;
const char* cname = accepted_channels[i];
UINT16 channelId = WTSChannelGetId(peer, cname);
PROXY_LOG_INFO(TAG, ps, "Accepted channel: %s", cname);
channelContext = ChannelContext_new(ps, cname, channelId);
if (!channelContext)
{
PROXY_LOG_ERR(TAG, ps, "error seting up channelContext for '%s'", cname);
return FALSE;
}
if (!HashTable_Insert(byId, &channelContext->channel_id, channelContext))
{
ChannelContext_free(channelContext);
PROXY_LOG_ERR(TAG, ps, "error inserting channelContext in byId table for '%s'", cname);
return FALSE;
}
}
free(accepted_channels);
return TRUE;
}
/* Event callbacks */ /* Event callbacks */
/** /**
* This callback is called when the entire connection sequence is done (as * This callback is called when the entire connection sequence is done (as
* described in MS-RDPBCGR section 1.3) * described in MS-RDPBCGR section 1.3)
@ -191,9 +230,6 @@ static BOOL pf_server_post_connect(freerdp_peer* peer)
pClientContext* pc; pClientContext* pc;
rdpSettings* client_settings; rdpSettings* client_settings;
proxyData* pdata; proxyData* pdata;
char** accepted_channels = NULL;
size_t accepted_channels_count;
size_t i;
WINPR_ASSERT(peer); WINPR_ASSERT(peer);
@ -204,13 +240,10 @@ static BOOL pf_server_post_connect(freerdp_peer* peer)
WINPR_ASSERT(pdata); WINPR_ASSERT(pdata);
PROXY_LOG_INFO(TAG, ps, "Accepted client: %s", peer->settings->ClientHostname); PROXY_LOG_INFO(TAG, ps, "Accepted client: %s", peer->settings->ClientHostname);
accepted_channels = WTSGetAcceptedChannelNames(peer, &accepted_channels_count); if (!pf_server_setup_channels(peer))
if (accepted_channels)
{ {
for (i = 0; i < accepted_channels_count; i++) PROXY_LOG_ERR(TAG, ps, "error setting up channels");
PROXY_LOG_INFO(TAG, ps, "Accepted channel: %s", accepted_channels[i]); return FALSE;
free(accepted_channels);
} }
pc = pf_context_create_client_context(peer->settings); pc = pf_context_create_client_context(peer->settings);
@ -276,6 +309,7 @@ static BOOL pf_server_activate(freerdp_peer* peer)
peer->settings->CompressionLevel = PACKET_COMPR_TYPE_RDP8; peer->settings->CompressionLevel = PACKET_COMPR_TYPE_RDP8;
if (!pf_modules_run_hook(pdata->module, HOOK_TYPE_SERVER_ACTIVATE, pdata, peer)) if (!pf_modules_run_hook(pdata->module, HOOK_TYPE_SERVER_ACTIVATE, pdata, peer))
return FALSE; return FALSE;
return TRUE; return TRUE;
} }
@ -334,8 +368,8 @@ static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 chann
pClientContext* pc; pClientContext* pc;
proxyData* pdata; proxyData* pdata;
const proxyConfig* config; const proxyConfig* config;
pf_utils_channel_mode pass; const pServerChannelContext* channel;
const char* channel_name = WTSChannelGetName(peer, channelId); UINT32 channelId32 = channelId;
WINPR_ASSERT(peer); WINPR_ASSERT(peer);
@ -356,8 +390,14 @@ static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 chann
if (!pc) if (!pc)
goto original_cb; goto original_cb;
pass = pf_utils_get_channel_mode(config, channel_name); channel = HashTable_GetItemValue(ps->channelsById, &channelId32);
switch (pass) if (!channel)
{
PROXY_LOG_ERR(TAG, ps, "channel id=%d not registered here, dropping", channelId32);
return TRUE;
}
switch (channel->channelMode)
{ {
case PF_UTILS_CHANNEL_BLOCK: case PF_UTILS_CHANNEL_BLOCK:
return TRUE; return TRUE;
@ -366,7 +406,7 @@ static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 chann
proxyChannelDataEventInfo ev; proxyChannelDataEventInfo ev;
ev.channel_id = channelId; ev.channel_id = channelId;
ev.channel_name = channel_name; ev.channel_name = channel->channel_name;
ev.data = data; ev.data = data;
ev.data_len = size; ev.data_len = size;
ev.flags = flags; ev.flags = flags;
@ -379,8 +419,8 @@ static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 chann
return IFCALLRESULT(TRUE, pc->sendChannelData, pc, &ev); return IFCALLRESULT(TRUE, pc->sendChannelData, pc, &ev);
} }
case PF_UTILS_CHANNEL_INTERCEPT: case PF_UTILS_CHANNEL_INTERCEPT:
return pf_server_receive_channel_intercept(pdata, channelId, channel_name, data, size, return pf_server_receive_channel_intercept(pdata, channelId, channel->channel_name,
flags, totalSize); data, size, flags, totalSize);
default: default:
break; break;
} }

View File

@ -22,6 +22,7 @@
#define FREERDP_SERVER_PROXY_PFUTILS_H #define FREERDP_SERVER_PROXY_PFUTILS_H
#include <freerdp/server/proxy/proxy_config.h> #include <freerdp/server/proxy/proxy_config.h>
#include <freerdp/server/proxy/proxy_context.h>
/** /**
* @brief pf_utils_channel_is_passthrough Checks of a channel identified by 'name' * @brief pf_utils_channel_is_passthrough Checks of a channel identified by 'name'
@ -34,14 +35,6 @@
* e.g. proxy client and server are termination points and data passed * e.g. proxy client and server are termination points and data passed
* between. * between.
*/ */
typedef enum
{
PF_UTILS_CHANNEL_NOT_HANDLED,
PF_UTILS_CHANNEL_BLOCK,
PF_UTILS_CHANNEL_PASSTHROUGH,
PF_UTILS_CHANNEL_INTERCEPT,
} pf_utils_channel_mode;
pf_utils_channel_mode pf_utils_get_channel_mode(const proxyConfig* config, const char* name); pf_utils_channel_mode pf_utils_get_channel_mode(const proxyConfig* config, const char* name);
const char* pf_utils_channel_mode_string(pf_utils_channel_mode mode); const char* pf_utils_channel_mode_string(pf_utils_channel_mode mode);