Drdynvc needs love (#8059)

* winpr: add lock operation on HashTables

* drdynvc: change the listeners array for a hashtable and other micro cleanups

* logonInfo: drop warning that is shown at every connection

Let's avoid this log, we can't do anything if at Microsoft they don't respect
their own specs.

* rdpei: fix terminate of rdpei

* drdynvc: implement the channel list with a hashtable by channelId
This commit is contained in:
David Fort 2022-07-26 12:53:41 +02:00 committed by GitHub
parent 48abc64a6d
commit 1f08cb9a7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 145 additions and 115 deletions

View File

@ -106,6 +106,9 @@ static UINT generic_plugin_terminated(IWTSPlugin* pPlugin)
WLog_Print(plugin->log, WLOG_TRACE, "...");
/* some channels (namely rdpei), look at initialized to see if they should continue to run */
plugin->initialized = FALSE;
if (plugin->terminatePluginFn)
plugin->terminatePluginFn(plugin);

View File

@ -70,7 +70,7 @@ static UINT dvcman_create_listener(IWTSVirtualChannelManager* pChannelMgr,
DVCMAN* dvcman = (DVCMAN*)pChannelMgr;
DVCMAN_LISTENER* listener;
WLog_DBG(TAG, "create_listener: %d.%s.", ArrayList_Count(dvcman->listeners) + 1,
WLog_DBG(TAG, "create_listener: %d.%s.", HashTable_Count(dvcman->listeners) + 1,
pszChannelName);
listener = (DVCMAN_LISTENER*)calloc(1, sizeof(DVCMAN_LISTENER));
@ -98,7 +98,7 @@ static UINT dvcman_create_listener(IWTSVirtualChannelManager* pChannelMgr,
if (ppListener)
*ppListener = (IWTSListener*)listener;
if (!ArrayList_Append(dvcman->listeners, listener))
if (!HashTable_Insert(dvcman->listeners, listener->channel_name, listener))
return ERROR_INTERNAL_ERROR;
return CHANNEL_RC_OK;
}
@ -113,7 +113,7 @@ static UINT dvcman_destroy_listener(IWTSVirtualChannelManager* pChannelMgr, IWTS
{
DVCMAN* dvcman = listener->dvcman;
if (dvcman)
ArrayList_Remove(dvcman->listeners, listener);
HashTable_Remove(dvcman->listeners, listener->channel_name);
}
return CHANNEL_RC_OK;
@ -202,21 +202,16 @@ static const char* dvcman_get_channel_name(IWTSVirtualChannel* channel)
static IWTSVirtualChannel* dvcman_find_channel_by_id(IWTSVirtualChannelManager* pChannelMgr,
UINT32 ChannelId)
{
size_t index;
IWTSVirtualChannel* channel = NULL;
DVCMAN* dvcman = (DVCMAN*)pChannelMgr;
ArrayList_Lock(dvcman->channels);
for (index = 0; index < ArrayList_Count(dvcman->channels); index++)
{
DVCMAN_CHANNEL* cur = (DVCMAN_CHANNEL*)ArrayList_GetItem(dvcman->channels, index);
if (cur->channel_id == ChannelId)
{
channel = &cur->iface;
break;
}
}
DVCMAN_CHANNEL* dvcChannel;
ArrayList_Unlock(dvcman->channels);
HashTable_Lock(dvcman->channelsById);
dvcChannel = HashTable_GetItemValue(dvcman->channelsById, &ChannelId);
if (dvcChannel)
channel = &dvcChannel->iface;
HashTable_Unlock(dvcman->channelsById);
return channel;
}
@ -234,6 +229,17 @@ static void wts_listener_free(void* arg)
DVCMAN_LISTENER* listener = (DVCMAN_LISTENER*)arg;
dvcman_wtslistener_free(listener);
}
static BOOL channelIdMatch(const void* k1, const void* k2)
{
return *((UINT32*)k1) == *((UINT32*)k2);
}
static UINT32 channelIdHash(const void* id)
{
return *((UINT32*)id);
}
static IWTSVirtualChannelManager* dvcman_new(drdynvcPlugin* plugin)
{
wObject* obj;
@ -249,22 +255,31 @@ static IWTSVirtualChannelManager* dvcman_new(drdynvcPlugin* plugin)
dvcman->iface.GetChannelId = dvcman_get_channel_id;
dvcman->iface.GetChannelName = dvcman_get_channel_name;
dvcman->drdynvc = plugin;
dvcman->channels = ArrayList_New(TRUE);
dvcman->channelsById = HashTable_New(TRUE);
if (!dvcman->channels)
if (!dvcman->channelsById)
goto fail;
obj = ArrayList_Object(dvcman->channels);
HashTable_SetHashFunction(dvcman->channelsById, channelIdHash);
obj = HashTable_KeyObject(dvcman->channelsById);
obj->fnObjectEquals = channelIdMatch;
obj = HashTable_ValueObject(dvcman->channelsById);
obj->fnObjectFree = dvcman_channel_free;
dvcman->pool = StreamPool_New(TRUE, 10);
if (!dvcman->pool)
goto fail;
dvcman->listeners = ArrayList_New(TRUE);
dvcman->listeners = HashTable_New(TRUE);
if (!dvcman->listeners)
goto fail;
obj = ArrayList_Object(dvcman->listeners);
HashTable_SetHashFunction(dvcman->listeners, HashTable_StringHash);
obj = HashTable_KeyObject(dvcman->listeners);
obj->fnObjectEquals = HashTable_StringCompare;
obj = HashTable_ValueObject(dvcman->listeners);
obj->fnObjectFree = wts_listener_free;
dvcman->plugin_names = ArrayList_New(TRUE);
@ -406,10 +421,10 @@ static void dvcman_clear(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pCha
WINPR_UNUSED(drdynvc);
ArrayList_Clear(dvcman->channels);
HashTable_Clear(dvcman->channelsById);
ArrayList_Clear(dvcman->plugins);
ArrayList_Clear(dvcman->plugin_names);
ArrayList_Clear(dvcman->listeners);
HashTable_Clear(dvcman->listeners);
}
static void dvcman_free(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pChannelMgr)
{
@ -418,9 +433,9 @@ static void dvcman_free(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pChan
WINPR_UNUSED(drdynvc);
ArrayList_Free(dvcman->plugins);
ArrayList_Free(dvcman->channels);
HashTable_Free(dvcman->channelsById);
ArrayList_Free(dvcman->plugin_names);
ArrayList_Free(dvcman->listeners);
HashTable_Free(dvcman->listeners);
StreamPool_Free(dvcman->pool);
free(dvcman);
@ -506,76 +521,75 @@ static UINT dvcman_close_channel_iface(IWTSVirtualChannel* pChannel)
static UINT dvcman_create_channel(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pChannelMgr,
UINT32 ChannelId, const char* ChannelName)
{
size_t i;
BOOL bAccept;
DVCMAN_CHANNEL* channel;
DrdynvcClientContext* context;
DVCMAN* dvcman = (DVCMAN*)pChannelMgr;
DVCMAN_LISTENER* listener;
IWTSVirtualChannelCallback* pCallback = NULL;
UINT error;
HashTable_Lock(dvcman->listeners);
listener = (DVCMAN_LISTENER*)HashTable_GetItemValue(dvcman->listeners, ChannelName);
if (!listener)
{
error = ERROR_NOT_FOUND;
goto out;
}
if (!(channel = dvcman_channel_new(drdynvc, pChannelMgr, ChannelId, ChannelName)))
{
WLog_Print(drdynvc->log, WLOG_ERROR, "dvcman_channel_new failed!");
return CHANNEL_RC_NO_MEMORY;
error = CHANNEL_RC_NO_MEMORY;
goto out;
}
channel->status = ERROR_NOT_CONNECTED;
if (!ArrayList_Append(dvcman->channels, channel))
return ERROR_INTERNAL_ERROR;
ArrayList_Lock(dvcman->listeners);
for (i = 0; i < ArrayList_Count(dvcman->listeners); i++)
if (!HashTable_Insert(dvcman->channelsById, &channel->channel_id, channel))
{
DVCMAN_LISTENER* listener = (DVCMAN_LISTENER*)ArrayList_GetItem(dvcman->listeners, i);
if (strcmp(listener->channel_name, ChannelName) == 0)
{
IWTSVirtualChannelCallback* pCallback = NULL;
channel->iface.Write = dvcman_write_channel;
channel->iface.Close = dvcman_close_channel_iface;
bAccept = TRUE;
if ((error = listener->listener_callback->OnNewChannelConnection(
listener->listener_callback, &channel->iface, NULL, &bAccept, &pCallback)) ==
CHANNEL_RC_OK &&
bAccept)
{
WLog_Print(drdynvc->log, WLOG_DEBUG, "listener %s created new channel %" PRIu32 "",
listener->channel_name, channel->channel_id);
channel->status = CHANNEL_RC_OK;
channel->channel_callback = pCallback;
channel->pInterface = listener->iface.pInterface;
context = dvcman->drdynvc->context;
IFCALLRET(context->OnChannelConnected, error, context, ChannelName,
listener->iface.pInterface);
if (error)
WLog_Print(drdynvc->log, WLOG_ERROR,
"context.OnChannelConnected failed with error %" PRIu32 "", error);
goto fail;
}
else
{
if (error)
{
WLog_Print(drdynvc->log, WLOG_ERROR,
"OnNewChannelConnection failed with error %" PRIu32 "!", error);
goto fail;
}
else
{
WLog_Print(drdynvc->log, WLOG_ERROR,
"OnNewChannelConnection returned with bAccept FALSE!");
error = ERROR_INTERNAL_ERROR;
goto fail;
}
}
}
WLog_Print(drdynvc->log, WLOG_ERROR, "unable to register channel in our channel list");
error = ERROR_INTERNAL_ERROR;
goto out;
}
error = ERROR_INTERNAL_ERROR;
fail:
ArrayList_Unlock(dvcman->listeners);
channel->iface.Write = dvcman_write_channel;
channel->iface.Close = dvcman_close_channel_iface;
bAccept = TRUE;
error = listener->listener_callback->OnNewChannelConnection(
listener->listener_callback, &channel->iface, NULL, &bAccept, &pCallback);
if (error != CHANNEL_RC_OK)
{
WLog_Print(drdynvc->log, WLOG_ERROR,
"OnNewChannelConnection failed with error %" PRIu32 "!", error);
error = ERROR_INTERNAL_ERROR;
goto out;
}
if (!bAccept)
{
WLog_Print(drdynvc->log, WLOG_ERROR, "OnNewChannelConnection returned with bAccept FALSE!");
error = ERROR_INTERNAL_ERROR;
goto out;
}
WLog_Print(drdynvc->log, WLOG_DEBUG, "listener %s created new channel %" PRIu32 "",
listener->channel_name, channel->channel_id);
channel->status = CHANNEL_RC_OK;
channel->channel_callback = pCallback;
channel->pInterface = listener->iface.pInterface;
context = dvcman->drdynvc->context;
IFCALLRET(context->OnChannelConnected, error, context, ChannelName, listener->iface.pInterface);
if (error != CHANNEL_RC_OK)
{
WLog_Print(drdynvc->log, WLOG_ERROR,
"context.OnChannelConnected failed with error %" PRIu32 "", error);
}
out:
HashTable_Unlock(dvcman->listeners);
return error;
}
@ -632,8 +646,8 @@ UINT dvcman_close_channel(IWTSVirtualChannelManager* pChannelMgr, UINT32 Channel
UINT error = CHANNEL_RC_OK;
DVCMAN* dvcman = (DVCMAN*)pChannelMgr;
drdynvcPlugin* drdynvc = dvcman->drdynvc;
channel = (DVCMAN_CHANNEL*)dvcman_find_channel_by_id(pChannelMgr, ChannelId);
channel = (DVCMAN_CHANNEL*)dvcman_find_channel_by_id(pChannelMgr, ChannelId);
if (!channel)
{
// WLog_Print(drdynvc->log, WLOG_ERROR, "ChannelId %"PRIu32" not found!", ChannelId);
@ -660,7 +674,7 @@ UINT dvcman_close_channel(IWTSVirtualChannelManager* pChannelMgr, UINT32 Channel
}
}
ArrayList_Remove(dvcman->channels, channel);
HashTable_Remove(dvcman->channelsById, &ChannelId);
return error;
}
@ -1048,6 +1062,7 @@ static UINT drdynvc_process_create_request(drdynvcPlugin* drdynvc, int Sp, int c
char* name;
size_t length;
DVCMAN* dvcman;
UINT32 retStatus;
WINPR_UNUSED(Sp);
if (!drdynvc)
@ -1086,8 +1101,8 @@ static UINT drdynvc_process_create_request(drdynvcPlugin* drdynvc, int Sp, int c
WLog_Print(drdynvc->log, WLOG_DEBUG,
"process_create_request: ChannelId=%" PRIu32 " ChannelName=%s", ChannelId, name);
channel_status = dvcman_create_channel(drdynvc, drdynvc->channel_mgr, ChannelId, name);
data_out = StreamPool_Take(dvcman->pool, pos + 4);
data_out = StreamPool_Take(dvcman->pool, pos + 4);
if (!data_out)
{
WLog_Print(drdynvc->log, WLOG_ERROR, "StreamPool_Take failed!");
@ -1098,16 +1113,26 @@ static UINT drdynvc_process_create_request(drdynvcPlugin* drdynvc, int Sp, int c
Stream_SetPosition(s, 1);
Stream_Copy(s, data_out, pos - 1);
if (channel_status == CHANNEL_RC_OK)
switch (channel_status)
{
WLog_Print(drdynvc->log, WLOG_DEBUG, "channel created");
Stream_Write_UINT32(data_out, 0);
}
else
{
WLog_Print(drdynvc->log, WLOG_DEBUG, "no listener");
Stream_Write_UINT32(data_out, (UINT32)0xC0000001); /* same code used by mstsc */
case CHANNEL_RC_OK:
WLog_Print(drdynvc->log, WLOG_DEBUG, "channel created");
retStatus = 0;
break;
case CHANNEL_RC_NO_MEMORY:
WLog_Print(drdynvc->log, WLOG_DEBUG, "not enough memory for channel creation");
retStatus = STATUS_NO_MEMORY;
break;
case ERROR_NOT_FOUND:
WLog_Print(drdynvc->log, WLOG_DEBUG, "no listener for '%s'", name);
retStatus = (UINT32)0xC0000001; /* same code used by mstsc, STATUS_UNSUCCESSFUL */
break;
default:
WLog_Print(drdynvc->log, WLOG_DEBUG, "channel creation error");
retStatus = (UINT32)0xC0000001; /* same code used by mstsc, STATUS_UNSUCCESSFUL */
break;
}
Stream_Write_UINT32(data_out, retStatus);
status = drdynvc_send(drdynvc, data_out);
@ -1386,6 +1411,15 @@ static void VCAPITYPE drdynvc_virtual_channel_open_event_ex(LPVOID lpUserParam,
"drdynvc_virtual_channel_open_event reported an error");
}
static BOOL channelByIdCleanerFn(const void* key, void* value, void* arg)
{
drdynvcPlugin* drdynvc = (drdynvcPlugin*)arg;
DVCMAN_CHANNEL* channel = (DVCMAN_CHANNEL*)value;
dvcman_close_channel(drdynvc->channel_mgr, channel->channel_id, FALSE);
return TRUE;
}
static DWORD WINAPI drdynvc_virtual_channel_client_thread(LPVOID arg)
{
/* TODO: rewrite this */
@ -1438,23 +1472,9 @@ static DWORD WINAPI drdynvc_virtual_channel_client_thread(LPVOID arg)
/* Disconnect remaining dynamic channels that the server did not.
* This is required to properly shut down channels by calling the appropriate
* event handlers. */
size_t count = 0;
DVCMAN* drdynvcMgr = (DVCMAN*)drdynvc->channel_mgr;
do
{
ArrayList_Lock(drdynvcMgr->channels);
count = ArrayList_Count(drdynvcMgr->channels);
if (count > 0)
{
IWTSVirtualChannel* channel =
(IWTSVirtualChannel*)ArrayList_GetItem(drdynvcMgr->channels, 0);
const UINT32 ChannelId = drdynvc->channel_mgr->GetChannelId(channel);
dvcman_close_channel(drdynvc->channel_mgr, ChannelId, FALSE);
count--;
}
ArrayList_Unlock(drdynvcMgr->channels);
} while (count > 0);
HashTable_Foreach(drdynvcMgr->channelsById, channelByIdCleanerFn, drdynvc);
}
if (error && drdynvc->rdpcontext)

View File

@ -46,8 +46,8 @@ typedef struct
wArrayList* plugin_names;
wArrayList* plugins;
wArrayList* listeners;
wArrayList* channels;
wHashTable* listeners;
wHashTable* channelsById;
wStreamPool* pool;
} DVCMAN;

View File

@ -1098,13 +1098,6 @@ static BOOL rdp_recv_logon_info_v2(rdpRdp* rdp, wStream* s, logon_info* info)
logonInfoV2TotalSize, Size);
return FALSE;
}
else
{
WLog_WARN(TAG,
"[SERVER-BUG] 2.2.10.1.1.2 Logon Info Version 2 (TS_LOGON_INFO_VERSION_2) "
"Size expected %" PRIu32 " bytes, got %" PRIu32 ", ignoring",
logonInfoV2TotalSize, Size);
}
}
Stream_Read_UINT32(s, info->sessionId); /* SessionId (4 bytes) */

View File

@ -356,6 +356,8 @@ extern "C"
WINPR_API wHashTable* HashTable_New(BOOL synchronized);
WINPR_API void HashTable_Free(wHashTable* table);
WINPR_API void HashTable_Lock(wHashTable* table);
WINPR_API void HashTable_Unlock(wHashTable* table);
WINPR_API wObject* HashTable_KeyObject(wHashTable* table);
WINPR_API wObject* HashTable_ValueObject(wHashTable* table);

View File

@ -829,6 +829,18 @@ void HashTable_Free(wHashTable* table)
free(table);
}
void HashTable_Lock(wHashTable* table)
{
WINPR_ASSERT(table);
EnterCriticalSection(&table->lock);
}
void HashTable_Unlock(wHashTable* table)
{
WINPR_ASSERT(table);
LeaveCriticalSection(&table->lock);
}
wObject* HashTable_KeyObject(wHashTable* table)
{
WINPR_ASSERT(table);