Refactored rdpsnd server

* Assert all input arguments
* Unify stream buffer handling
This commit is contained in:
akallabeth 2022-06-03 08:05:54 +02:00 committed by akallabeth
parent b69499c060
commit 35f575a753

View File

@ -35,6 +35,17 @@
#include "rdpsnd_common.h" #include "rdpsnd_common.h"
#include "rdpsnd_main.h" #include "rdpsnd_main.h"
static wStream* rdpsnd_server_get_buffer(RdpsndServerContext* context)
{
wStream* s;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
s = context->priv->rdpsnd_pdu;
Stream_SetPosition(s, 0);
return s;
}
/** /**
* Send Server Audio Formats and Version PDU (2.2.2.1) * Send Server Audio Formats and Version PDU (2.2.2.1)
* *
@ -42,11 +53,15 @@
*/ */
static UINT rdpsnd_server_send_formats(RdpsndServerContext* context) static UINT rdpsnd_server_send_formats(RdpsndServerContext* context)
{ {
wStream* s = context->priv->rdpsnd_pdu; wStream* s = rdpsnd_server_get_buffer(context);
size_t pos; size_t pos;
UINT16 i; UINT16 i;
BOOL status = FALSE; BOOL status = FALSE;
ULONG written; ULONG written;
if (!Stream_EnsureRemainingCapacity(s, 24))
return ERROR_OUTOFMEMORY;
Stream_Write_UINT8(s, SNDC_FORMATS); Stream_Write_UINT8(s, SNDC_FORMATS);
Stream_Write_UINT8(s, 0); Stream_Write_UINT8(s, 0);
Stream_Seek_UINT16(s); Stream_Seek_UINT16(s);
@ -61,9 +76,9 @@ static UINT rdpsnd_server_send_formats(RdpsndServerContext* context)
for (i = 0; i < context->num_server_formats; i++) for (i = 0; i < context->num_server_formats; i++)
{ {
AUDIO_FORMAT format = context->server_formats[i]; const AUDIO_FORMAT* format = &context->server_formats[i];
if (!audio_format_write(s, &format)) if (!audio_format_write(s, format))
goto fail; goto fail;
} }
@ -71,6 +86,8 @@ static UINT rdpsnd_server_send_formats(RdpsndServerContext* context)
Stream_SetPosition(s, 2); Stream_SetPosition(s, 2);
Stream_Write_UINT16(s, pos - 4); Stream_Write_UINT16(s, pos - 4);
Stream_SetPosition(s, pos); Stream_SetPosition(s, pos);
WINPR_ASSERT(context->priv);
status = WTSVirtualChannelWrite(context->priv->ChannelHandle, (PCHAR)Stream_Buffer(s), status = WTSVirtualChannelWrite(context->priv->ChannelHandle, (PCHAR)Stream_Buffer(s),
Stream_GetPosition(s), &written); Stream_GetPosition(s), &written);
Stream_SetPosition(s, 0); Stream_SetPosition(s, 0);
@ -89,6 +106,8 @@ static UINT rdpsnd_server_recv_waveconfirm(RdpsndServerContext* context, wStream
BYTE confirmBlockNum; BYTE confirmBlockNum;
UINT error = CHANNEL_RC_OK; UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context);
if (!Stream_CheckAndLogRequiredLength(TAG, s, 4)) if (!Stream_CheckAndLogRequiredLength(TAG, s, 4))
return ERROR_INVALID_DATA; return ERROR_INVALID_DATA;
@ -114,6 +133,8 @@ static UINT rdpsnd_server_recv_trainingconfirm(RdpsndServerContext* context, wSt
UINT16 packsize; UINT16 packsize;
UINT error = CHANNEL_RC_OK; UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context);
if (!Stream_CheckAndLogRequiredLength(TAG, s, 4)) if (!Stream_CheckAndLogRequiredLength(TAG, s, 4))
return ERROR_INVALID_DATA; return ERROR_INVALID_DATA;
@ -134,6 +155,8 @@ static UINT rdpsnd_server_recv_trainingconfirm(RdpsndServerContext* context, wSt
*/ */
static UINT rdpsnd_server_recv_quality_mode(RdpsndServerContext* context, wStream* s) static UINT rdpsnd_server_recv_quality_mode(RdpsndServerContext* context, wStream* s)
{ {
WINPR_ASSERT(context);
if (Stream_GetRemainingLength(s) < 4) if (Stream_GetRemainingLength(s) < 4)
{ {
WLog_ERR(TAG, "not enough data in stream!"); WLog_ERR(TAG, "not enough data in stream!");
@ -160,6 +183,8 @@ static UINT rdpsnd_server_recv_formats(RdpsndServerContext* context, wStream* s)
BYTE lastblock; BYTE lastblock;
UINT error = CHANNEL_RC_OK; UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context);
if (!Stream_CheckAndLogRequiredLength(TAG, s, 20)) if (!Stream_CheckAndLogRequiredLength(TAG, s, 20))
return ERROR_INVALID_DATA; return ERROR_INVALID_DATA;
@ -192,6 +217,8 @@ static UINT rdpsnd_server_recv_formats(RdpsndServerContext* context, wStream* s)
for (i = 0; i < context->num_client_formats; i++) for (i = 0; i < context->num_client_formats; i++)
{ {
AUDIO_FORMAT* format = &context->client_formats[i];
if (!Stream_CheckAndLogRequiredLength(TAG, s, 18)) if (!Stream_CheckAndLogRequiredLength(TAG, s, 18))
{ {
WLog_ERR(TAG, "not enough data in stream!"); WLog_ERR(TAG, "not enough data in stream!");
@ -199,17 +226,17 @@ static UINT rdpsnd_server_recv_formats(RdpsndServerContext* context, wStream* s)
goto out_free; goto out_free;
} }
Stream_Read_UINT16(s, context->client_formats[i].wFormatTag); Stream_Read_UINT16(s, format->wFormatTag);
Stream_Read_UINT16(s, context->client_formats[i].nChannels); Stream_Read_UINT16(s, format->nChannels);
Stream_Read_UINT32(s, context->client_formats[i].nSamplesPerSec); Stream_Read_UINT32(s, format->nSamplesPerSec);
Stream_Read_UINT32(s, context->client_formats[i].nAvgBytesPerSec); Stream_Read_UINT32(s, format->nAvgBytesPerSec);
Stream_Read_UINT16(s, context->client_formats[i].nBlockAlign); Stream_Read_UINT16(s, format->nBlockAlign);
Stream_Read_UINT16(s, context->client_formats[i].wBitsPerSample); Stream_Read_UINT16(s, format->wBitsPerSample);
Stream_Read_UINT16(s, context->client_formats[i].cbSize); Stream_Read_UINT16(s, format->cbSize);
if (context->client_formats[i].cbSize > 0) if (format->cbSize > 0)
{ {
if (!Stream_SafeSeek(s, context->client_formats[i].cbSize)) if (!Stream_SafeSeek(s, format->cbSize))
{ {
WLog_ERR(TAG, "Stream_SafeSeek failed!"); WLog_ERR(TAG, "Stream_SafeSeek failed!");
error = ERROR_INTERNAL_ERROR; error = ERROR_INTERNAL_ERROR;
@ -217,7 +244,7 @@ static UINT rdpsnd_server_recv_formats(RdpsndServerContext* context, wStream* s)
} }
} }
if (context->client_formats[i].wFormatTag != 0) if (format->wFormatTag != 0)
{ {
// lets call this a known format // lets call this a known format
// TODO: actually look through our own list of known formats // TODO: actually look through our own list of known formats
@ -243,7 +270,10 @@ static DWORD WINAPI rdpsnd_server_thread(LPVOID arg)
HANDLE events[2] = { 0 }; HANDLE events[2] = { 0 };
RdpsndServerContext* context = (RdpsndServerContext*)arg; RdpsndServerContext* context = (RdpsndServerContext*)arg;
UINT error = CHANNEL_RC_OK; UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context); WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
events[nCount++] = context->priv->channelEvent; events[nCount++] = context->priv->channelEvent;
events[nCount++] = context->priv->StopEvent; events[nCount++] = context->priv->StopEvent;
@ -293,6 +323,9 @@ static DWORD WINAPI rdpsnd_server_thread(LPVOID arg)
*/ */
static UINT rdpsnd_server_initialize(RdpsndServerContext* context, BOOL ownThread) static UINT rdpsnd_server_initialize(RdpsndServerContext* context, BOOL ownThread)
{ {
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
context->priv->ownThread = ownThread; context->priv->ownThread = ownThread;
return context->Start(context); return context->Start(context);
} }
@ -309,6 +342,9 @@ static UINT rdpsnd_server_select_format(RdpsndServerContext* context, UINT16 cli
AUDIO_FORMAT* format; AUDIO_FORMAT* format;
UINT error = CHANNEL_RC_OK; UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
if ((client_format_index >= context->num_client_formats) || (!context->src_format)) if ((client_format_index >= context->num_client_formats) || (!context->src_format))
{ {
WLog_ERR(TAG, "index %d is not correct.", client_format_index); WLog_ERR(TAG, "index %d is not correct.", client_format_index);
@ -391,10 +427,10 @@ out:
static UINT rdpsnd_server_training(RdpsndServerContext* context, UINT16 timestamp, UINT16 packsize, static UINT rdpsnd_server_training(RdpsndServerContext* context, UINT16 timestamp, UINT16 packsize,
BYTE* data) BYTE* data)
{ {
wStream* s = context->priv->rdpsnd_pdu;
size_t end = 0; size_t end = 0;
ULONG written; ULONG written;
BOOL status; BOOL status;
wStream* s = rdpsnd_server_get_buffer(context);
if (!Stream_EnsureRemainingCapacity(s, 8)) if (!Stream_EnsureRemainingCapacity(s, 8))
return ERROR_INTERNAL_ERROR; return ERROR_INTERNAL_ERROR;
@ -461,12 +497,21 @@ static UINT rdpsnd_server_send_wave_pdu(RdpsndServerContext* context, UINT16 wTi
const BYTE* src; const BYTE* src;
AUDIO_FORMAT* format; AUDIO_FORMAT* format;
ULONG written; ULONG written;
wStream* s = context->priv->rdpsnd_pdu;
UINT error = CHANNEL_RC_OK; UINT error = CHANNEL_RC_OK;
wStream* s = rdpsnd_server_get_buffer(context);
if (context->selected_client_format > context->num_client_formats)
return ERROR_INTERNAL_ERROR;
WINPR_ASSERT(context->client_formats);
format = &context->client_formats[context->selected_client_format]; format = &context->client_formats[context->selected_client_format];
/* WaveInfo PDU */ /* WaveInfo PDU */
Stream_SetPosition(s, 0); Stream_SetPosition(s, 0);
if (!Stream_EnsureRemainingCapacity(s, 16))
return ERROR_OUTOFMEMORY;
Stream_Write_UINT8(s, SNDC_WAVE); /* msgType */ Stream_Write_UINT8(s, SNDC_WAVE); /* msgType */
Stream_Write_UINT8(s, 0); /* bPad */ Stream_Write_UINT8(s, 0); /* bPad */
Stream_Write_UINT16(s, 0); /* BodySize */ Stream_Write_UINT16(s, 0); /* BodySize */
@ -535,11 +580,11 @@ static UINT rdpsnd_server_send_wave2_pdu(RdpsndServerContext* context, UINT16 fo
const BYTE* data, size_t size, BOOL encoded, const BYTE* data, size_t size, BOOL encoded,
UINT16 timestamp, UINT32 audioTimeStamp) UINT16 timestamp, UINT32 audioTimeStamp)
{ {
wStream* s = context->priv->rdpsnd_pdu;
size_t end = 0; size_t end = 0;
ULONG written; ULONG written;
UINT error = CHANNEL_RC_OK; UINT error = CHANNEL_RC_OK;
BOOL status; BOOL status;
wStream* s = rdpsnd_server_get_buffer(context);
if (!Stream_EnsureRemainingCapacity(s, 16)) if (!Stream_EnsureRemainingCapacity(s, 16))
{ {
@ -548,7 +593,6 @@ static UINT rdpsnd_server_send_wave2_pdu(RdpsndServerContext* context, UINT16 fo
} }
/* Wave2 PDU */ /* Wave2 PDU */
Stream_SetPosition(s, 0);
Stream_Write_UINT8(s, SNDC_WAVE2); /* msgType */ Stream_Write_UINT8(s, SNDC_WAVE2); /* msgType */
Stream_Write_UINT8(s, 0); /* bPad */ Stream_Write_UINT8(s, 0); /* bPad */
Stream_Write_UINT16(s, 0); /* BodySize */ Stream_Write_UINT16(s, 0); /* BodySize */
@ -616,6 +660,9 @@ static UINT rdpsnd_server_send_audio_pdu(RdpsndServerContext* context, UINT16 wT
const BYTE* src; const BYTE* src;
size_t length; size_t length;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
if (context->selected_client_format >= context->num_client_formats) if (context->selected_client_format >= context->num_client_formats)
return ERROR_INTERNAL_ERROR; return ERROR_INTERNAL_ERROR;
@ -638,6 +685,10 @@ static UINT rdpsnd_server_send_samples(RdpsndServerContext* context, const void*
size_t nframes, UINT16 wTimestamp) size_t nframes, UINT16 wTimestamp)
{ {
UINT error = CHANNEL_RC_OK; UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
EnterCriticalSection(&context->priv->lock); EnterCriticalSection(&context->priv->lock);
if (context->selected_client_format >= context->num_client_formats) if (context->selected_client_format >= context->num_client_formats)
@ -686,6 +737,9 @@ static UINT rdpsnd_server_send_samples2(RdpsndServerContext* context, UINT16 for
{ {
UINT error = CHANNEL_RC_OK; UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
if (context->clientVersion < CHANNEL_VERSION_WIN_8) if (context->clientVersion < CHANNEL_VERSION_WIN_8)
return ERROR_INTERNAL_ERROR; return ERROR_INTERNAL_ERROR;
@ -709,15 +763,8 @@ static UINT rdpsnd_server_set_volume(RdpsndServerContext* context, UINT16 left,
size_t len; size_t len;
BOOL status; BOOL status;
ULONG written; ULONG written;
wStream* s; wStream* s = rdpsnd_server_get_buffer(context);
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
s = context->priv->rdpsnd_pdu;
WINPR_ASSERT(s);
Stream_SetPosition(s, 0);
if (!Stream_EnsureRemainingCapacity(s, 8)) if (!Stream_EnsureRemainingCapacity(s, 8))
return ERROR_NOT_ENOUGH_MEMORY; return ERROR_NOT_ENOUGH_MEMORY;
@ -744,8 +791,9 @@ static UINT rdpsnd_server_close(RdpsndServerContext* context)
size_t pos; size_t pos;
BOOL status; BOOL status;
ULONG written; ULONG written;
wStream* s = context->priv->rdpsnd_pdu;
UINT error = CHANNEL_RC_OK; UINT error = CHANNEL_RC_OK;
wStream* s = rdpsnd_server_get_buffer(context);
EnterCriticalSection(&context->priv->lock); EnterCriticalSection(&context->priv->lock);
if (context->priv->out_pending_frames > 0) if (context->priv->out_pending_frames > 0)
@ -767,6 +815,10 @@ static UINT rdpsnd_server_close(RdpsndServerContext* context)
return error; return error;
context->selected_client_format = 0xFFFF; context->selected_client_format = 0xFFFF;
if (!Stream_EnsureRemainingCapacity(s, 4))
return ERROR_OUTOFMEMORY;
Stream_Write_UINT8(s, SNDC_CLOSE); Stream_Write_UINT8(s, SNDC_CLOSE);
Stream_Write_UINT8(s, 0); Stream_Write_UINT8(s, 0);
Stream_Seek_UINT16(s); Stream_Seek_UINT16(s);
@ -789,10 +841,14 @@ static UINT rdpsnd_server_start(RdpsndServerContext* context)
{ {
void* buffer = NULL; void* buffer = NULL;
DWORD bytesReturned; DWORD bytesReturned;
RdpsndServerPrivate* priv = context->priv; RdpsndServerPrivate* priv;
UINT error = ERROR_INTERNAL_ERROR; UINT error = ERROR_INTERNAL_ERROR;
PULONG pSessionId = NULL; PULONG pSessionId = NULL;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
priv = context->priv;
priv->SessionId = WTS_CURRENT_SESSION; priv->SessionId = WTS_CURRENT_SESSION;
if (context->use_dynamic_virtual_channel) if (context->use_dynamic_virtual_channel)
@ -909,6 +965,10 @@ out_close:
static UINT rdpsnd_server_stop(RdpsndServerContext* context) static UINT rdpsnd_server_stop(RdpsndServerContext* context)
{ {
UINT error = CHANNEL_RC_OK; UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
if (!context->priv->StopEvent) if (!context->priv->StopEvent)
return error; return error;
@ -951,15 +1011,11 @@ static UINT rdpsnd_server_stop(RdpsndServerContext* context)
RdpsndServerContext* rdpsnd_server_context_new(HANDLE vcm) RdpsndServerContext* rdpsnd_server_context_new(HANDLE vcm)
{ {
RdpsndServerContext* context;
RdpsndServerPrivate* priv; RdpsndServerPrivate* priv;
context = (RdpsndServerContext*)calloc(1, sizeof(RdpsndServerContext)); RdpsndServerContext* context = (RdpsndServerContext*)calloc(1, sizeof(RdpsndServerContext));
if (!context) if (!context)
{ goto fail;
WLog_ERR(TAG, "calloc failed!");
return NULL;
}
context->vcm = vcm; context->vcm = vcm;
context->Start = rdpsnd_server_start; context->Start = rdpsnd_server_start;
@ -978,7 +1034,7 @@ RdpsndServerContext* rdpsnd_server_context_new(HANDLE vcm)
if (!priv) if (!priv)
{ {
WLog_ERR(TAG, "calloc failed!"); WLog_ERR(TAG, "calloc failed!");
goto out_free; goto fail;
} }
priv->dsp_context = freerdp_dsp_context_new(TRUE); priv->dsp_context = freerdp_dsp_context_new(TRUE);
@ -986,7 +1042,7 @@ RdpsndServerContext* rdpsnd_server_context_new(HANDLE vcm)
if (!priv->dsp_context) if (!priv->dsp_context)
{ {
WLog_ERR(TAG, "freerdp_dsp_context_new failed!"); WLog_ERR(TAG, "freerdp_dsp_context_new failed!");
goto out_free_priv; goto fail;
} }
priv->input_stream = Stream_New(NULL, 4); priv->input_stream = Stream_New(NULL, 4);
@ -994,24 +1050,23 @@ RdpsndServerContext* rdpsnd_server_context_new(HANDLE vcm)
if (!priv->input_stream) if (!priv->input_stream)
{ {
WLog_ERR(TAG, "Stream_New failed!"); WLog_ERR(TAG, "Stream_New failed!");
goto out_free_dsp; goto fail;
} }
priv->expectedBytes = 4; priv->expectedBytes = 4;
priv->waitingHeader = TRUE; priv->waitingHeader = TRUE;
priv->ownThread = TRUE; priv->ownThread = TRUE;
return context; return context;
out_free_dsp: fail:
freerdp_dsp_context_free(priv->dsp_context); rdpsnd_server_context_free(context);
out_free_priv:
free(context->priv);
out_free:
free(context);
return NULL; return NULL;
} }
void rdpsnd_server_context_reset(RdpsndServerContext* context) void rdpsnd_server_context_reset(RdpsndServerContext* context)
{ {
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
context->priv->expectedBytes = 4; context->priv->expectedBytes = 4;
context->priv->waitingHeader = TRUE; context->priv->waitingHeader = TRUE;
Stream_SetPosition(context->priv->input_stream, 0); Stream_SetPosition(context->priv->input_stream, 0);
@ -1022,6 +1077,8 @@ void rdpsnd_server_context_free(RdpsndServerContext* context)
if (!context) if (!context)
return; return;
if (context->priv)
{
rdpsnd_server_stop(context); rdpsnd_server_stop(context);
free(context->priv->out_buffer); free(context->priv->out_buffer);
@ -1031,6 +1088,7 @@ void rdpsnd_server_context_free(RdpsndServerContext* context)
if (context->priv->input_stream) if (context->priv->input_stream)
Stream_Free(context->priv->input_stream, TRUE); Stream_Free(context->priv->input_stream, TRUE);
}
free(context->server_formats); free(context->server_formats);
free(context->client_formats); free(context->client_formats);
@ -1040,6 +1098,9 @@ void rdpsnd_server_context_free(RdpsndServerContext* context)
HANDLE rdpsnd_server_get_event_handle(RdpsndServerContext* context) HANDLE rdpsnd_server_get_event_handle(RdpsndServerContext* context)
{ {
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
return context->priv->channelEvent; return context->priv->channelEvent;
} }
@ -1061,8 +1122,14 @@ UINT rdpsnd_server_handle_messages(RdpsndServerContext* context)
{ {
DWORD bytesReturned; DWORD bytesReturned;
UINT ret = CHANNEL_RC_OK; UINT ret = CHANNEL_RC_OK;
RdpsndServerPrivate* priv = context->priv; RdpsndServerPrivate* priv;
wStream* s = priv->input_stream; wStream* s;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
priv = context->priv;
s = priv->input_stream;
if (!WTSVirtualChannelRead(priv->ChannelHandle, 0, (PCHAR)Stream_Pointer(s), if (!WTSVirtualChannelRead(priv->ChannelHandle, 0, (PCHAR)Stream_Pointer(s),
priv->expectedBytes, &bytesReturned)) priv->expectedBytes, &bytesReturned))