proxy: improve channel treatment

This PR introduces per channel context so that we can speed up operations like
retrieving the channel name from its id, or knowing what shall be done for a
packet (no config ACL recomputation at each packet).
This commit is contained in:
David Fort 2022-02-03 17:36:53 +01:00 committed by akallabeth
parent cefb4e1237
commit 46eb50df2c
6 changed files with 160 additions and 44 deletions

View File

@ -47,6 +47,28 @@ extern "C"
* and set their cleanup function accordingly. */
FREERDP_API void intercept_context_entry_free(void* obj);
/** @brief how is handled a channel */
enum pf_utils_channel_mode
{
PF_UTILS_CHANNEL_NOT_HANDLED,
PF_UTILS_CHANNEL_BLOCK,
PF_UTILS_CHANNEL_PASSTHROUGH,
PF_UTILS_CHANNEL_INTERCEPT,
};
typedef enum pf_utils_channel_mode pf_utils_channel_mode;
/** @brief per channel configuration */
struct p_server_channel_context
{
char* channel_name;
UINT32 channel_id;
BOOL isDynamic;
pf_utils_channel_mode channelMode;
};
typedef struct p_server_channel_context pServerChannelContext;
void ChannelContext_free(pServerChannelContext* ctx);
/**
* Wraps rdpContext and holds the state for the proxy's server.
*/
@ -60,9 +82,12 @@ extern "C"
HANDLE dynvcReady;
wHashTable* interceptContextMap;
wHashTable* channelsById;
};
typedef struct p_server_context pServerContext;
pServerChannelContext* ChannelContext_new(pServerContext* ps, const char* name, UINT32 id);
/**
* Wraps rdpContext and holds the state for the proxy's client.
*/

View File

@ -2006,7 +2006,7 @@ BOOL gcc_write_client_monitor_data(wStream* s, const rdpMcs* mcs)
Stream_Write_UINT32(s, flags); /* flags */
}
}
WLog_DBG(TAG, "[%s] FINISHED" PRIu32, __FUNCTION__);
WLog_DBG(TAG, "[%s] FINISHED", __FUNCTION__);
return TRUE;
}

View File

@ -376,9 +376,10 @@ static BOOL pf_client_pre_connect(freerdp* instance)
return pf_modules_run_hook(pc->pdata->module, HOOK_TYPE_CLIENT_PRE_CONNECT, pc->pdata, pc);
}
static BOOL pf_client_receive_channel_passthrough(proxyData* pdata, UINT16 channelId,
const char* channel_name, const BYTE* xdata,
size_t xsize, UINT32 flags, size_t totalSize)
static BOOL pf_client_receive_channel_passthrough(proxyData* pdata,
const pServerChannelContext* channel,
const BYTE* xdata, size_t xsize, UINT32 flags,
size_t totalSize)
{
proxyChannelDataEventInfo ev;
UINT16 server_channel_id;
@ -389,8 +390,8 @@ static BOOL pf_client_receive_channel_passthrough(proxyData* pdata, UINT16 chann
ps = pdata->ps;
WINPR_ASSERT(ps);
ev.channel_id = channelId;
ev.channel_name = channel_name;
ev.channel_id = channel->channel_id;
ev.channel_name = channel->channel_name;
ev.data = xdata;
ev.data_len = xsize;
ev.flags = flags;
@ -411,7 +412,7 @@ static BOOL pf_client_receive_channel_passthrough(proxyData* pdata, UINT16 chann
* CREATE_REQUEST_PDU (0x01) packets as invalid.
*/
if ((flags & CHANNEL_FLAG_FIRST) &&
(strncmp(channel_name, DRDYNVC_SVC_CHANNEL_NAME, CHANNEL_NAME_LEN + 1) == 0))
(strncmp(channel->channel_name, DRDYNVC_SVC_CHANNEL_NAME, CHANNEL_NAME_LEN + 1) == 0))
{
BYTE cmd, first;
wStream *s, sbuffer;
@ -468,7 +469,7 @@ static BOOL pf_client_receive_channel_passthrough(proxyData* pdata, UINT16 chann
return TRUE; /* Silently drop */
}
}
server_channel_id = WTSChannelGetId(ps->context.peer, channel_name);
server_channel_id = WTSChannelGetId(ps->context.peer, channel->channel_name);
/* Ignore messages for channels that can not be mapped.
* The client might not have enabled support for this specific channel,
@ -499,12 +500,11 @@ static BOOL pf_client_receive_channel_data_hook(freerdp* instance, UINT16 channe
const BYTE* xdata, size_t xsize, UINT32 flags,
size_t totalSize)
{
const char* channel_name = freerdp_channels_get_name_by_id(instance, channelId);
pClientContext* pc;
pServerContext* ps;
proxyData* pdata;
const proxyConfig* config;
pf_utils_channel_mode pass;
pServerChannelContext* channel;
UINT32 channelId32 = channelId;
WINPR_ASSERT(instance);
WINPR_ASSERT(xdata || (xsize == 0));
@ -519,21 +519,20 @@ static BOOL pf_client_receive_channel_data_hook(freerdp* instance, UINT16 channe
pdata = ps->pdata;
WINPR_ASSERT(pdata);
config = pdata->config;
WINPR_ASSERT(config);
channel = HashTable_GetItemValue(ps->channelsById, &channelId32);
if (!channel)
return TRUE;
pass = pf_utils_get_channel_mode(config, channel_name);
switch (pass)
switch (channel->channelMode)
{
case PF_UTILS_CHANNEL_BLOCK:
return TRUE; /* Silently drop */
case PF_UTILS_CHANNEL_PASSTHROUGH:
return pf_client_receive_channel_passthrough(pdata, channelId, channel_name, xdata,
xsize, flags, totalSize);
return pf_client_receive_channel_passthrough(pdata, channel, xdata, xsize, flags,
totalSize);
case PF_UTILS_CHANNEL_INTERCEPT:
return pf_client_receive_channel_intercept(pdata, channelId, channel_name, xdata, xsize,
flags, totalSize);
return pf_client_receive_channel_intercept(pdata, channelId, channel->channel_name,
xdata, xsize, flags, totalSize);
case PF_UTILS_CHANNEL_NOT_HANDLED:
default:
WINPR_ASSERT(pc->client_receive_channel_data_original);

View File

@ -24,13 +24,59 @@
#include <winpr/crypto.h>
#include <winpr/print.h>
#include <freerdp/server/proxy/proxy_log.h>
#include <freerdp/server/proxy/proxy_server.h>
#include "pf_client.h"
#include "pf_utils.h"
#include <freerdp/server/proxy/proxy_context.h>
#include "channels/pf_channel_rdpdr.h"
#define TAG PROXY_TAG("server")
static UINT32 ChannelId_Hash(const void* key)
{
const UINT32* v = (const UINT32*)key;
return *v;
}
static BOOL ChannelId_Compare(const UINT32* v1, const UINT32* v2)
{
return (*v1 == *v2);
}
pServerChannelContext* ChannelContext_new(pServerContext* ps, const char* name, UINT32 id)
{
pServerChannelContext* ret = calloc(1, sizeof(*ret));
if (!ret)
{
PROXY_LOG_ERR(TAG, ps, "error allocating channel context for '%s'", name);
return NULL;
}
ret->channel_id = id;
ret->channel_name = _strdup(name);
if (!ret->channel_name)
{
PROXY_LOG_ERR(TAG, ps, "error allocating name in channel context for '%s'", ret);
free(ret);
return NULL;
}
ret->channelMode = pf_utils_get_channel_mode(ps->pdata->config, name);
return ret;
}
void ChannelContext_free(pServerChannelContext* ctx)
{
if (!ctx)
return;
free(ctx->channel_name);
free(ctx);
}
/* Proxy context initialization callback */
static void client_to_proxy_context_free(freerdp_peer* client, rdpContext* ctx);
static BOOL client_to_proxy_context_new(freerdp_peer* client, rdpContext* ctx)
@ -60,6 +106,18 @@ static BOOL client_to_proxy_context_new(freerdp_peer* client, rdpContext* ctx)
WINPR_ASSERT(obj);
obj->fnObjectFree = intercept_context_entry_free;
/* channels by ids */
context->channelsById = HashTable_New(FALSE);
if (!context->channelsById)
goto error;
if (!HashTable_SetHashFunction(context->channelsById, ChannelId_Hash))
goto error;
obj = HashTable_KeyObject(context->channelsById);
obj->fnObjectEquals = (OBJECT_EQUALS_FN)ChannelId_Compare;
obj = HashTable_ValueObject(context->channelsById);
obj->fnObjectFree = (OBJECT_FREE_FN)ChannelContext_free;
return TRUE;
error:
@ -85,6 +143,7 @@ void client_to_proxy_context_free(freerdp_peer* client, rdpContext* ctx)
}
HashTable_Free(context->interceptContextMap);
HashTable_Free(context->channelsById);
if (context->vcm && (context->vcm != INVALID_HANDLE_VALUE))
WTSCloseServer((HANDLE)context->vcm);

View File

@ -177,7 +177,46 @@ static BOOL pf_server_get_target_info(rdpContext* context, rdpSettings* settings
return TRUE;
}
static BOOL pf_server_setup_channels(freerdp_peer* peer)
{
char** accepted_channels = NULL;
size_t accepted_channels_count;
size_t i;
pServerContext* ps = (pServerContext*)peer->context;
wHashTable* byId = ps->channelsById;
accepted_channels = WTSGetAcceptedChannelNames(peer, &accepted_channels_count);
if (!accepted_channels)
return TRUE;
for (i = 0; i < accepted_channels_count; i++)
{
pServerChannelContext* channelContext;
const char* cname = accepted_channels[i];
UINT16 channelId = WTSChannelGetId(peer, cname);
PROXY_LOG_INFO(TAG, ps, "Accepted channel: %s", cname);
channelContext = ChannelContext_new(ps, cname, channelId);
if (!channelContext)
{
PROXY_LOG_ERR(TAG, ps, "error seting up channelContext for '%s'", cname);
return FALSE;
}
if (!HashTable_Insert(byId, &channelContext->channel_id, channelContext))
{
ChannelContext_free(channelContext);
PROXY_LOG_ERR(TAG, ps, "error inserting channelContext in byId table for '%s'", cname);
return FALSE;
}
}
free(accepted_channels);
return TRUE;
}
/* Event callbacks */
/**
* This callback is called when the entire connection sequence is done (as
* described in MS-RDPBCGR section 1.3)
@ -191,9 +230,6 @@ static BOOL pf_server_post_connect(freerdp_peer* peer)
pClientContext* pc;
rdpSettings* client_settings;
proxyData* pdata;
char** accepted_channels = NULL;
size_t accepted_channels_count;
size_t i;
WINPR_ASSERT(peer);
@ -204,13 +240,10 @@ static BOOL pf_server_post_connect(freerdp_peer* peer)
WINPR_ASSERT(pdata);
PROXY_LOG_INFO(TAG, ps, "Accepted client: %s", peer->settings->ClientHostname);
accepted_channels = WTSGetAcceptedChannelNames(peer, &accepted_channels_count);
if (accepted_channels)
if (!pf_server_setup_channels(peer))
{
for (i = 0; i < accepted_channels_count; i++)
PROXY_LOG_INFO(TAG, ps, "Accepted channel: %s", accepted_channels[i]);
free(accepted_channels);
PROXY_LOG_ERR(TAG, ps, "error setting up channels");
return FALSE;
}
pc = pf_context_create_client_context(peer->settings);
@ -276,6 +309,7 @@ static BOOL pf_server_activate(freerdp_peer* peer)
peer->settings->CompressionLevel = PACKET_COMPR_TYPE_RDP8;
if (!pf_modules_run_hook(pdata->module, HOOK_TYPE_SERVER_ACTIVATE, pdata, peer))
return FALSE;
return TRUE;
}
@ -334,8 +368,8 @@ static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 chann
pClientContext* pc;
proxyData* pdata;
const proxyConfig* config;
pf_utils_channel_mode pass;
const char* channel_name = WTSChannelGetName(peer, channelId);
const pServerChannelContext* channel;
UINT32 channelId32 = channelId;
WINPR_ASSERT(peer);
@ -356,8 +390,14 @@ static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 chann
if (!pc)
goto original_cb;
pass = pf_utils_get_channel_mode(config, channel_name);
switch (pass)
channel = HashTable_GetItemValue(ps->channelsById, &channelId32);
if (!channel)
{
PROXY_LOG_ERR(TAG, ps, "channel id=%d not registered here, dropping", channelId32);
return TRUE;
}
switch (channel->channelMode)
{
case PF_UTILS_CHANNEL_BLOCK:
return TRUE;
@ -366,7 +406,7 @@ static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 chann
proxyChannelDataEventInfo ev;
ev.channel_id = channelId;
ev.channel_name = channel_name;
ev.channel_name = channel->channel_name;
ev.data = data;
ev.data_len = size;
ev.flags = flags;
@ -379,8 +419,8 @@ static BOOL pf_server_receive_channel_data_hook(freerdp_peer* peer, UINT16 chann
return IFCALLRESULT(TRUE, pc->sendChannelData, pc, &ev);
}
case PF_UTILS_CHANNEL_INTERCEPT:
return pf_server_receive_channel_intercept(pdata, channelId, channel_name, data, size,
flags, totalSize);
return pf_server_receive_channel_intercept(pdata, channelId, channel->channel_name,
data, size, flags, totalSize);
default:
break;
}

View File

@ -22,6 +22,7 @@
#define FREERDP_SERVER_PROXY_PFUTILS_H
#include <freerdp/server/proxy/proxy_config.h>
#include <freerdp/server/proxy/proxy_context.h>
/**
* @brief pf_utils_channel_is_passthrough Checks of a channel identified by 'name'
@ -34,14 +35,6 @@
* e.g. proxy client and server are termination points and data passed
* between.
*/
typedef enum
{
PF_UTILS_CHANNEL_NOT_HANDLED,
PF_UTILS_CHANNEL_BLOCK,
PF_UTILS_CHANNEL_PASSTHROUGH,
PF_UTILS_CHANNEL_INTERCEPT,
} pf_utils_channel_mode;
pf_utils_channel_mode pf_utils_get_channel_mode(const proxyConfig* config, const char* name);
const char* pf_utils_channel_mode_string(pf_utils_channel_mode mode);