Add option for external thread to ainput server channel

This commit is contained in:
akallabeth 2022-02-18 09:53:50 +01:00 committed by akallabeth
parent 42abdb056f
commit c5f7f9fba8
2 changed files with 281 additions and 136 deletions

View File

@ -41,7 +41,14 @@
#define TAG CHANNELS_TAG("ainput.server") #define TAG CHANNELS_TAG("ainput.server")
typedef struct ainput_server_ typedef enum
{
AINPUT_INITIAL,
AINPUT_OPENED,
AINPUT_VERSION_SENT,
} eAInputChannelState;
typedef struct
{ {
ainput_server_context context; ainput_server_context context;
@ -54,14 +61,25 @@ typedef struct ainput_server_
DWORD SessionId; DWORD SessionId;
BOOL isOpened;
BOOL externalThread;
/* Channel state */
eAInputChannelState state;
wStream* buffer;
} ainput_server; } ainput_server;
static UINT ainput_server_context_poll(ainput_server_context* context);
static BOOL ainput_server_context_handle(ainput_server_context* context, HANDLE* handle);
static UINT ainput_server_context_poll_int(ainput_server_context* context);
static BOOL ainput_server_is_open(ainput_server_context* context) static BOOL ainput_server_is_open(ainput_server_context* context)
{ {
ainput_server* ainput = (ainput_server*)context; ainput_server* ainput = (ainput_server*)context;
WINPR_ASSERT(ainput); WINPR_ASSERT(ainput);
return ainput->thread != NULL; return ainput->isOpened;
} }
/** /**
@ -118,10 +136,14 @@ static UINT ainput_server_open_channel(ainput_server* ainput)
return ainput->ainput_channel ? CHANNEL_RC_OK : ERROR_INTERNAL_ERROR; return ainput->ainput_channel ? CHANNEL_RC_OK : ERROR_INTERNAL_ERROR;
} }
static UINT ainput_server_send_version(ainput_server* ainput, wStream* s) static UINT ainput_server_send_version(ainput_server* ainput)
{ {
ULONG written; ULONG written;
wStream* s;
WINPR_ASSERT(ainput); WINPR_ASSERT(ainput);
s = ainput->buffer;
WINPR_ASSERT(s); WINPR_ASSERT(s);
Stream_SetPosition(s, 0); Stream_SetPosition(s, 0);
@ -167,31 +189,15 @@ static UINT ainput_server_recv_mouse_event(ainput_server* ainput, wStream* s)
return error; return error;
} }
static DWORD WINAPI ainput_server_thread_func(LPVOID arg)
static HANDLE ainput_server_get_channel_handle(ainput_server* ainput)
{ {
wStream* s; BYTE* buffer = NULL;
void* buffer;
DWORD nCount;
HANDLE events[8] = { 0 };
BOOL ready = FALSE;
HANDLE ChannelEvent;
DWORD BytesReturned = 0; DWORD BytesReturned = 0;
ainput_server* ainput = (ainput_server*)arg; HANDLE ChannelEvent = NULL;
UINT error;
DWORD status;
WINPR_ASSERT(ainput); WINPR_ASSERT(ainput);
if ((error = ainput_server_open_channel(ainput)))
{
WLog_ERR(TAG, "ainput_server_open_channel failed with error %" PRIu32 "!", error);
goto out;
}
buffer = NULL;
BytesReturned = 0;
ChannelEvent = NULL;
if (WTSVirtualChannelQuery(ainput->ainput_channel, WTSVirtualEventHandle, &buffer, if (WTSVirtualChannelQuery(ainput->ainput_channel, WTSVirtualEventHandle, &buffer,
&BytesReturned) == TRUE) &BytesReturned) == TRUE)
{ {
@ -201,130 +207,64 @@ static DWORD WINAPI ainput_server_thread_func(LPVOID arg)
WTSFreeMemory(buffer); WTSFreeMemory(buffer);
} }
return ChannelEvent;
}
static DWORD WINAPI ainput_server_thread_func(LPVOID arg)
{
DWORD nCount;
HANDLE events[2] = { 0 };
ainput_server* ainput = (ainput_server*)arg;
UINT error = CHANNEL_RC_OK;
DWORD status;
WINPR_ASSERT(ainput);
nCount = 0; nCount = 0;
events[nCount++] = ainput->stopEvent; events[nCount++] = ainput->stopEvent;
events[nCount++] = ChannelEvent;
while (1) while ((error == CHANNEL_RC_OK) && (WaitForSingleObject(events[0], 0) != WAIT_OBJECT_0))
{ {
status = WaitForMultipleObjects(nCount, events, FALSE, 100); switch (ainput->state)
if (status == WAIT_FAILED)
{ {
error = GetLastError(); case AINPUT_OPENED:
WLog_ERR(TAG, "WaitForMultipleObjects failed with error %" PRIu32 "", error); events[1] = ainput_server_get_channel_handle(ainput);
break; nCount = 2;
} status = WaitForMultipleObjects(nCount, events, FALSE, 100);
switch (status)
{
case WAIT_TIMEOUT:
case WAIT_OBJECT_0 + 1:
case WAIT_OBJECT_0:
error = ainput_server_context_poll_int(&ainput->context);
if (status == WAIT_OBJECT_0) case WAIT_FAILED:
{ default:
if (error) error = ERROR_INTERNAL_ERROR;
WLog_ERR(TAG, "OpenResult failed with error %" PRIu32 "!", error); break;
}
break;
}
if (WTSVirtualChannelQuery(ainput->ainput_channel, WTSVirtualChannelReady, &buffer,
&BytesReturned) == FALSE)
{
if (error)
WLog_ERR(TAG, "OpenResult failed with error %" PRIu32 "!", error);
break;
}
ready = *((BOOL*)buffer);
WTSFreeMemory(buffer);
if (ready)
{
if (error)
WLog_ERR(TAG, "OpenResult failed with error %" PRIu32 "!", error);
break;
}
}
s = Stream_New(NULL, 4096);
if (!s)
{
WLog_ERR(TAG, "Stream_New failed!");
WTSVirtualChannelClose(ainput->ainput_channel);
ExitThread(ERROR_NOT_ENOUGH_MEMORY);
return ERROR_NOT_ENOUGH_MEMORY;
}
if (ready)
{
if ((error = ainput_server_send_version(ainput, s)))
{
WLog_ERR(TAG, "audin_server_send_version failed with error %" PRIu32 "!", error);
goto out_capacity;
}
}
while (ready)
{
UINT16 MessageId;
status = WaitForMultipleObjects(nCount, events, FALSE, INFINITE);
if (status == WAIT_FAILED)
{
error = GetLastError();
WLog_ERR(TAG, "WaitForMultipleObjects failed with error %" PRIu32 "", error);
break;
}
if (status == WAIT_OBJECT_0)
break;
Stream_SetPosition(s, 0);
WTSVirtualChannelRead(ainput->ainput_channel, 0, NULL, 0, &BytesReturned);
if (BytesReturned < 2)
continue;
if (!Stream_EnsureRemainingCapacity(s, BytesReturned))
{
WLog_ERR(TAG, "Stream_EnsureRemainingCapacity failed!");
error = CHANNEL_RC_NO_MEMORY;
break;
}
if (WTSVirtualChannelRead(ainput->ainput_channel, 0, (PCHAR)Stream_Buffer(s),
(ULONG)Stream_Capacity(s), &BytesReturned) == FALSE)
{
WLog_ERR(TAG, "WTSVirtualChannelRead failed!");
error = ERROR_INTERNAL_ERROR;
break;
}
Stream_SetLength(s, BytesReturned);
Stream_Read_UINT16(s, MessageId);
switch (MessageId)
{
case MSG_AINPUT_MOUSE:
error = ainput_server_recv_mouse_event(ainput, s);
break; break;
case AINPUT_VERSION_SENT:
status = WaitForMultipleObjects(nCount, events, FALSE, INFINITE);
switch (status)
{
case WAIT_TIMEOUT:
case WAIT_OBJECT_0 + 1:
case WAIT_OBJECT_0:
error = ainput_server_context_poll_int(&ainput->context);
case WAIT_FAILED:
default:
error = ERROR_INTERNAL_ERROR;
break;
}
break;
default: default:
WLog_ERR(TAG, "audin_server_thread_func: unknown MessageId %" PRIu8 "", MessageId); error = ainput_server_context_poll_int(&ainput->context);
break; break;
} }
if (error)
{
WLog_ERR(TAG, "Response failed with error %" PRIu32 "!", error);
break;
}
} }
out_capacity:
Stream_Free(s, TRUE);
out:
WTSVirtualChannelClose(ainput->ainput_channel); WTSVirtualChannelClose(ainput->ainput_channel);
ainput->ainput_channel = NULL; ainput->ainput_channel = NULL;
@ -347,7 +287,7 @@ static UINT ainput_server_open(ainput_server_context* context)
WINPR_ASSERT(ainput); WINPR_ASSERT(ainput);
if (ainput->thread == NULL) if (!ainput->externalThread && (ainput->thread == NULL))
{ {
ainput->stopEvent = CreateEvent(NULL, TRUE, FALSE, NULL); ainput->stopEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
if (!ainput->stopEvent) if (!ainput->stopEvent)
@ -365,6 +305,7 @@ static UINT ainput_server_open(ainput_server_context* context)
return ERROR_INTERNAL_ERROR; return ERROR_INTERNAL_ERROR;
} }
} }
ainput->isOpened = TRUE;
return CHANNEL_RC_OK; return CHANNEL_RC_OK;
} }
@ -381,7 +322,7 @@ static UINT ainput_server_close(ainput_server_context* context)
WINPR_ASSERT(ainput); WINPR_ASSERT(ainput);
if (ainput->thread) if (!ainput->externalThread && ainput->thread)
{ {
SetEvent(ainput->stopEvent); SetEvent(ainput->stopEvent);
@ -397,10 +338,37 @@ static UINT ainput_server_close(ainput_server_context* context)
ainput->thread = NULL; ainput->thread = NULL;
ainput->stopEvent = NULL; ainput->stopEvent = NULL;
} }
if (ainput->externalThread)
{
if (ainput->state != AINPUT_INITIAL)
{
WTSVirtualChannelClose(ainput->ainput_channel);
ainput->ainput_channel = NULL;
ainput->state = AINPUT_INITIAL;
}
}
ainput->isOpened = FALSE;
return error; return error;
} }
static UINT ainput_server_initialize(ainput_server_context* context, BOOL externalThread)
{
UINT error = CHANNEL_RC_OK;
ainput_server* ainput = (ainput_server*)context;
WINPR_ASSERT(ainput);
if (ainput->isOpened)
{
WLog_WARN(TAG, "Application error: AINPUT channel already initialized, calling in this "
"state is not possible!");
return ERROR_INVALID_STATE;
}
ainput->externalThread = externalThread;
return error;
}
ainput_server_context* ainput_server_context_new(HANDLE vcm) ainput_server_context* ainput_server_context_new(HANDLE vcm)
{ {
ainput_server* ainput = (ainput_server*)calloc(1, sizeof(ainput_server)); ainput_server* ainput = (ainput_server*)calloc(1, sizeof(ainput_server));
@ -412,14 +380,165 @@ ainput_server_context* ainput_server_context_new(HANDLE vcm)
ainput->context.Open = ainput_server_open; ainput->context.Open = ainput_server_open;
ainput->context.IsOpen = ainput_server_is_open; ainput->context.IsOpen = ainput_server_is_open;
ainput->context.Close = ainput_server_close; ainput->context.Close = ainput_server_close;
ainput->context.Initialize = ainput_server_initialize;
ainput->context.Poll = ainput_server_context_poll;
ainput->context.ChannelHandle = ainput_server_context_handle;
ainput->buffer = Stream_New(NULL, 4096);
if (!ainput->buffer)
goto fail;
return &ainput->context; return &ainput->context;
fail:
ainput_server_context_free(ainput);
return NULL;
} }
void ainput_server_context_free(ainput_server_context* context) void ainput_server_context_free(ainput_server_context* context)
{ {
ainput_server* ainput = (ainput_server*)context; ainput_server* ainput = (ainput_server*)context;
if (ainput) if (ainput)
{
ainput_server_close(context); ainput_server_close(context);
Stream_Free(ainput->buffer, TRUE);
}
free(ainput); free(ainput);
} }
static UINT ainput_process_message(ainput_server* ainput)
{
BOOL rc;
UINT error = ERROR_INTERNAL_ERROR;
ULONG BytesReturned;
UINT16 MessageId;
wStream* s;
WINPR_ASSERT(ainput);
WINPR_ASSERT(ainput->ainput_channel);
s = ainput->buffer;
WINPR_ASSERT(s);
Stream_SetPosition(s, 0);
rc = WTSVirtualChannelRead(ainput->ainput_channel, 0, NULL, 0, &BytesReturned);
if (!rc)
goto out;
if (BytesReturned < 2)
{
error = CHANNEL_RC_OK;
goto out;
}
if (!Stream_EnsureRemainingCapacity(s, BytesReturned))
{
WLog_ERR(TAG, "Stream_EnsureRemainingCapacity failed!");
error = CHANNEL_RC_NO_MEMORY;
goto out;
}
if (WTSVirtualChannelRead(ainput->ainput_channel, 0, (PCHAR)Stream_Buffer(s),
(ULONG)Stream_Capacity(s), &BytesReturned) == FALSE)
{
WLog_ERR(TAG, "WTSVirtualChannelRead failed!");
goto out;
}
Stream_SetLength(s, BytesReturned);
Stream_Read_UINT16(s, MessageId);
switch (MessageId)
{
case MSG_AINPUT_MOUSE:
error = ainput_server_recv_mouse_event(ainput, s);
break;
default:
WLog_ERR(TAG, "audin_server_thread_func: unknown MessageId %" PRIu8 "", MessageId);
break;
}
out:
if (error)
WLog_ERR(TAG, "Response failed with error %" PRIu32 "!", error);
return error;
}
BOOL ainput_server_context_handle(ainput_server_context* context, HANDLE* handle)
{
ainput_server* ainput = (ainput_server*)context;
WINPR_ASSERT(ainput);
WINPR_ASSERT(handle);
if (!ainput->externalThread)
return FALSE;
if (ainput->state == AINPUT_INITIAL)
return FALSE;
*handle = ainput_server_get_channel_handle(ainput);
return TRUE;
}
UINT ainput_server_context_poll_int(ainput_server_context* context)
{
ainput_server* ainput = (ainput_server*)context;
UINT error = ERROR_INTERNAL_ERROR;
WINPR_ASSERT(ainput);
switch (ainput->state)
{
case AINPUT_INITIAL:
error = ainput_server_open_channel(ainput);
if (error)
WLog_ERR(TAG, "ainput_server_open_channel failed with error %" PRIu32 "!", error);
else
ainput->state = AINPUT_OPENED;
break;
case AINPUT_OPENED:
{
BYTE* buffer = NULL;
DWORD BytesReturned = 0;
if (WTSVirtualChannelQuery(ainput->ainput_channel, WTSVirtualChannelReady, &buffer,
&BytesReturned) != TRUE)
{
WLog_ERR(TAG, "WTSVirtualChannelReady failed,");
}
else
{
if (*buffer != 0)
{
error = ainput_server_send_version(ainput);
if (error)
WLog_ERR(TAG, "audin_server_send_version failed with error %" PRIu32 "!",
error);
else
ainput->state = AINPUT_VERSION_SENT;
}
else
error = CHANNEL_RC_OK;
}
WTSFreeMemory(buffer);
}
break;
case AINPUT_VERSION_SENT:
error = ainput_process_message(ainput);
break;
default:
WLog_ERR(TAG, "AINPUT chanel is in invalid state %d", ainput->state);
break;
}
return error;
}
UINT ainput_server_context_poll(ainput_server_context* context)
{
ainput_server* ainput = (ainput_server*)context;
WINPR_ASSERT(ainput);
if (!ainput->externalThread)
return ERROR_INTERNAL_ERROR;
return ainput_server_context_poll_int(context);
}

View File

@ -34,6 +34,10 @@ typedef enum AINPUT_SERVER_OPEN_RESULT
typedef struct _ainput_server_context ainput_server_context; typedef struct _ainput_server_context ainput_server_context;
typedef UINT (*psAInputServerInitialize)(ainput_server_context* context, BOOL externalThread);
typedef UINT (*psAInputServerPoll)(ainput_server_context* context);
typedef BOOL (*psAInputServerChannelHandle)(ainput_server_context* context, HANDLE* handle);
typedef UINT (*psAInputServerOpen)(ainput_server_context* context); typedef UINT (*psAInputServerOpen)(ainput_server_context* context);
typedef UINT (*psAInputServerClose)(ainput_server_context* context); typedef UINT (*psAInputServerClose)(ainput_server_context* context);
typedef BOOL (*psAInputServerIsOpen)(ainput_server_context* context); typedef BOOL (*psAInputServerIsOpen)(ainput_server_context* context);
@ -55,6 +59,28 @@ struct _ainput_server_context
* Open the ainput channel. * Open the ainput channel.
*/ */
psAInputServerOpen Open; psAInputServerOpen Open;
/**
* Optional: Set thread handling.
* When externalThread=TRUE the application is responsible to call
* ainput_server_context_poll periodically to process input events.
*
* Defaults to externalThread=FALSE
*/
psAInputServerInitialize Initialize;
/**
* @brief Poll When externalThread=TRUE call periodically from your main loop.
* if externalThread=FALSE do not call.
*/
psAInputServerPoll Poll;
/**
* @brief Poll When externalThread=TRUE call to get a handle to wait for events.
* Will return FALSE until the handle is available.
*/
psAInputServerChannelHandle ChannelHandle;
/** /**
* Close the ainput channel. * Close the ainput channel.
*/ */