Refactored rdpsnd server

* Assert all input arguments
* Unify stream buffer handling
This commit is contained in:
akallabeth 2022-06-03 08:05:54 +02:00 committed by akallabeth
parent b69499c060
commit 35f575a753

View File

@ -35,6 +35,17 @@
#include "rdpsnd_common.h"
#include "rdpsnd_main.h"
static wStream* rdpsnd_server_get_buffer(RdpsndServerContext* context)
{
wStream* s;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
s = context->priv->rdpsnd_pdu;
Stream_SetPosition(s, 0);
return s;
}
/**
* Send Server Audio Formats and Version PDU (2.2.2.1)
*
@ -42,11 +53,15 @@
*/
static UINT rdpsnd_server_send_formats(RdpsndServerContext* context)
{
wStream* s = context->priv->rdpsnd_pdu;
wStream* s = rdpsnd_server_get_buffer(context);
size_t pos;
UINT16 i;
BOOL status = FALSE;
ULONG written;
if (!Stream_EnsureRemainingCapacity(s, 24))
return ERROR_OUTOFMEMORY;
Stream_Write_UINT8(s, SNDC_FORMATS);
Stream_Write_UINT8(s, 0);
Stream_Seek_UINT16(s);
@ -61,9 +76,9 @@ static UINT rdpsnd_server_send_formats(RdpsndServerContext* context)
for (i = 0; i < context->num_server_formats; i++)
{
AUDIO_FORMAT format = context->server_formats[i];
const AUDIO_FORMAT* format = &context->server_formats[i];
if (!audio_format_write(s, &format))
if (!audio_format_write(s, format))
goto fail;
}
@ -71,6 +86,8 @@ static UINT rdpsnd_server_send_formats(RdpsndServerContext* context)
Stream_SetPosition(s, 2);
Stream_Write_UINT16(s, pos - 4);
Stream_SetPosition(s, pos);
WINPR_ASSERT(context->priv);
status = WTSVirtualChannelWrite(context->priv->ChannelHandle, (PCHAR)Stream_Buffer(s),
Stream_GetPosition(s), &written);
Stream_SetPosition(s, 0);
@ -89,6 +106,8 @@ static UINT rdpsnd_server_recv_waveconfirm(RdpsndServerContext* context, wStream
BYTE confirmBlockNum;
UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context);
if (!Stream_CheckAndLogRequiredLength(TAG, s, 4))
return ERROR_INVALID_DATA;
@ -114,6 +133,8 @@ static UINT rdpsnd_server_recv_trainingconfirm(RdpsndServerContext* context, wSt
UINT16 packsize;
UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context);
if (!Stream_CheckAndLogRequiredLength(TAG, s, 4))
return ERROR_INVALID_DATA;
@ -134,6 +155,8 @@ static UINT rdpsnd_server_recv_trainingconfirm(RdpsndServerContext* context, wSt
*/
static UINT rdpsnd_server_recv_quality_mode(RdpsndServerContext* context, wStream* s)
{
WINPR_ASSERT(context);
if (Stream_GetRemainingLength(s) < 4)
{
WLog_ERR(TAG, "not enough data in stream!");
@ -160,6 +183,8 @@ static UINT rdpsnd_server_recv_formats(RdpsndServerContext* context, wStream* s)
BYTE lastblock;
UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context);
if (!Stream_CheckAndLogRequiredLength(TAG, s, 20))
return ERROR_INVALID_DATA;
@ -192,6 +217,8 @@ static UINT rdpsnd_server_recv_formats(RdpsndServerContext* context, wStream* s)
for (i = 0; i < context->num_client_formats; i++)
{
AUDIO_FORMAT* format = &context->client_formats[i];
if (!Stream_CheckAndLogRequiredLength(TAG, s, 18))
{
WLog_ERR(TAG, "not enough data in stream!");
@ -199,17 +226,17 @@ static UINT rdpsnd_server_recv_formats(RdpsndServerContext* context, wStream* s)
goto out_free;
}
Stream_Read_UINT16(s, context->client_formats[i].wFormatTag);
Stream_Read_UINT16(s, context->client_formats[i].nChannels);
Stream_Read_UINT32(s, context->client_formats[i].nSamplesPerSec);
Stream_Read_UINT32(s, context->client_formats[i].nAvgBytesPerSec);
Stream_Read_UINT16(s, context->client_formats[i].nBlockAlign);
Stream_Read_UINT16(s, context->client_formats[i].wBitsPerSample);
Stream_Read_UINT16(s, context->client_formats[i].cbSize);
Stream_Read_UINT16(s, format->wFormatTag);
Stream_Read_UINT16(s, format->nChannels);
Stream_Read_UINT32(s, format->nSamplesPerSec);
Stream_Read_UINT32(s, format->nAvgBytesPerSec);
Stream_Read_UINT16(s, format->nBlockAlign);
Stream_Read_UINT16(s, format->wBitsPerSample);
Stream_Read_UINT16(s, format->cbSize);
if (context->client_formats[i].cbSize > 0)
if (format->cbSize > 0)
{
if (!Stream_SafeSeek(s, context->client_formats[i].cbSize))
if (!Stream_SafeSeek(s, format->cbSize))
{
WLog_ERR(TAG, "Stream_SafeSeek failed!");
error = ERROR_INTERNAL_ERROR;
@ -217,7 +244,7 @@ static UINT rdpsnd_server_recv_formats(RdpsndServerContext* context, wStream* s)
}
}
if (context->client_formats[i].wFormatTag != 0)
if (format->wFormatTag != 0)
{
// lets call this a known format
// TODO: actually look through our own list of known formats
@ -243,7 +270,10 @@ static DWORD WINAPI rdpsnd_server_thread(LPVOID arg)
HANDLE events[2] = { 0 };
RdpsndServerContext* context = (RdpsndServerContext*)arg;
UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
events[nCount++] = context->priv->channelEvent;
events[nCount++] = context->priv->StopEvent;
@ -293,6 +323,9 @@ static DWORD WINAPI rdpsnd_server_thread(LPVOID arg)
*/
static UINT rdpsnd_server_initialize(RdpsndServerContext* context, BOOL ownThread)
{
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
context->priv->ownThread = ownThread;
return context->Start(context);
}
@ -309,6 +342,9 @@ static UINT rdpsnd_server_select_format(RdpsndServerContext* context, UINT16 cli
AUDIO_FORMAT* format;
UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
if ((client_format_index >= context->num_client_formats) || (!context->src_format))
{
WLog_ERR(TAG, "index %d is not correct.", client_format_index);
@ -391,10 +427,10 @@ out:
static UINT rdpsnd_server_training(RdpsndServerContext* context, UINT16 timestamp, UINT16 packsize,
BYTE* data)
{
wStream* s = context->priv->rdpsnd_pdu;
size_t end = 0;
ULONG written;
BOOL status;
wStream* s = rdpsnd_server_get_buffer(context);
if (!Stream_EnsureRemainingCapacity(s, 8))
return ERROR_INTERNAL_ERROR;
@ -461,12 +497,21 @@ static UINT rdpsnd_server_send_wave_pdu(RdpsndServerContext* context, UINT16 wTi
const BYTE* src;
AUDIO_FORMAT* format;
ULONG written;
wStream* s = context->priv->rdpsnd_pdu;
UINT error = CHANNEL_RC_OK;
wStream* s = rdpsnd_server_get_buffer(context);
if (context->selected_client_format > context->num_client_formats)
return ERROR_INTERNAL_ERROR;
WINPR_ASSERT(context->client_formats);
format = &context->client_formats[context->selected_client_format];
/* WaveInfo PDU */
Stream_SetPosition(s, 0);
if (!Stream_EnsureRemainingCapacity(s, 16))
return ERROR_OUTOFMEMORY;
Stream_Write_UINT8(s, SNDC_WAVE); /* msgType */
Stream_Write_UINT8(s, 0); /* bPad */
Stream_Write_UINT16(s, 0); /* BodySize */
@ -535,11 +580,11 @@ static UINT rdpsnd_server_send_wave2_pdu(RdpsndServerContext* context, UINT16 fo
const BYTE* data, size_t size, BOOL encoded,
UINT16 timestamp, UINT32 audioTimeStamp)
{
wStream* s = context->priv->rdpsnd_pdu;
size_t end = 0;
ULONG written;
UINT error = CHANNEL_RC_OK;
BOOL status;
wStream* s = rdpsnd_server_get_buffer(context);
if (!Stream_EnsureRemainingCapacity(s, 16))
{
@ -548,7 +593,6 @@ static UINT rdpsnd_server_send_wave2_pdu(RdpsndServerContext* context, UINT16 fo
}
/* Wave2 PDU */
Stream_SetPosition(s, 0);
Stream_Write_UINT8(s, SNDC_WAVE2); /* msgType */
Stream_Write_UINT8(s, 0); /* bPad */
Stream_Write_UINT16(s, 0); /* BodySize */
@ -616,6 +660,9 @@ static UINT rdpsnd_server_send_audio_pdu(RdpsndServerContext* context, UINT16 wT
const BYTE* src;
size_t length;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
if (context->selected_client_format >= context->num_client_formats)
return ERROR_INTERNAL_ERROR;
@ -638,6 +685,10 @@ static UINT rdpsnd_server_send_samples(RdpsndServerContext* context, const void*
size_t nframes, UINT16 wTimestamp)
{
UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
EnterCriticalSection(&context->priv->lock);
if (context->selected_client_format >= context->num_client_formats)
@ -686,6 +737,9 @@ static UINT rdpsnd_server_send_samples2(RdpsndServerContext* context, UINT16 for
{
UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
if (context->clientVersion < CHANNEL_VERSION_WIN_8)
return ERROR_INTERNAL_ERROR;
@ -709,15 +763,8 @@ static UINT rdpsnd_server_set_volume(RdpsndServerContext* context, UINT16 left,
size_t len;
BOOL status;
ULONG written;
wStream* s;
wStream* s = rdpsnd_server_get_buffer(context);
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
s = context->priv->rdpsnd_pdu;
WINPR_ASSERT(s);
Stream_SetPosition(s, 0);
if (!Stream_EnsureRemainingCapacity(s, 8))
return ERROR_NOT_ENOUGH_MEMORY;
@ -744,8 +791,9 @@ static UINT rdpsnd_server_close(RdpsndServerContext* context)
size_t pos;
BOOL status;
ULONG written;
wStream* s = context->priv->rdpsnd_pdu;
UINT error = CHANNEL_RC_OK;
wStream* s = rdpsnd_server_get_buffer(context);
EnterCriticalSection(&context->priv->lock);
if (context->priv->out_pending_frames > 0)
@ -767,6 +815,10 @@ static UINT rdpsnd_server_close(RdpsndServerContext* context)
return error;
context->selected_client_format = 0xFFFF;
if (!Stream_EnsureRemainingCapacity(s, 4))
return ERROR_OUTOFMEMORY;
Stream_Write_UINT8(s, SNDC_CLOSE);
Stream_Write_UINT8(s, 0);
Stream_Seek_UINT16(s);
@ -789,10 +841,14 @@ static UINT rdpsnd_server_start(RdpsndServerContext* context)
{
void* buffer = NULL;
DWORD bytesReturned;
RdpsndServerPrivate* priv = context->priv;
RdpsndServerPrivate* priv;
UINT error = ERROR_INTERNAL_ERROR;
PULONG pSessionId = NULL;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
priv = context->priv;
priv->SessionId = WTS_CURRENT_SESSION;
if (context->use_dynamic_virtual_channel)
@ -909,6 +965,10 @@ out_close:
static UINT rdpsnd_server_stop(RdpsndServerContext* context)
{
UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
if (!context->priv->StopEvent)
return error;
@ -951,15 +1011,11 @@ static UINT rdpsnd_server_stop(RdpsndServerContext* context)
RdpsndServerContext* rdpsnd_server_context_new(HANDLE vcm)
{
RdpsndServerContext* context;
RdpsndServerPrivate* priv;
context = (RdpsndServerContext*)calloc(1, sizeof(RdpsndServerContext));
RdpsndServerContext* context = (RdpsndServerContext*)calloc(1, sizeof(RdpsndServerContext));
if (!context)
{
WLog_ERR(TAG, "calloc failed!");
return NULL;
}
goto fail;
context->vcm = vcm;
context->Start = rdpsnd_server_start;
@ -978,7 +1034,7 @@ RdpsndServerContext* rdpsnd_server_context_new(HANDLE vcm)
if (!priv)
{
WLog_ERR(TAG, "calloc failed!");
goto out_free;
goto fail;
}
priv->dsp_context = freerdp_dsp_context_new(TRUE);
@ -986,7 +1042,7 @@ RdpsndServerContext* rdpsnd_server_context_new(HANDLE vcm)
if (!priv->dsp_context)
{
WLog_ERR(TAG, "freerdp_dsp_context_new failed!");
goto out_free_priv;
goto fail;
}
priv->input_stream = Stream_New(NULL, 4);
@ -994,24 +1050,23 @@ RdpsndServerContext* rdpsnd_server_context_new(HANDLE vcm)
if (!priv->input_stream)
{
WLog_ERR(TAG, "Stream_New failed!");
goto out_free_dsp;
goto fail;
}
priv->expectedBytes = 4;
priv->waitingHeader = TRUE;
priv->ownThread = TRUE;
return context;
out_free_dsp:
freerdp_dsp_context_free(priv->dsp_context);
out_free_priv:
free(context->priv);
out_free:
free(context);
fail:
rdpsnd_server_context_free(context);
return NULL;
}
void rdpsnd_server_context_reset(RdpsndServerContext* context)
{
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
context->priv->expectedBytes = 4;
context->priv->waitingHeader = TRUE;
Stream_SetPosition(context->priv->input_stream, 0);
@ -1022,6 +1077,8 @@ void rdpsnd_server_context_free(RdpsndServerContext* context)
if (!context)
return;
if (context->priv)
{
rdpsnd_server_stop(context);
free(context->priv->out_buffer);
@ -1031,6 +1088,7 @@ void rdpsnd_server_context_free(RdpsndServerContext* context)
if (context->priv->input_stream)
Stream_Free(context->priv->input_stream, TRUE);
}
free(context->server_formats);
free(context->client_formats);
@ -1040,6 +1098,9 @@ void rdpsnd_server_context_free(RdpsndServerContext* context)
HANDLE rdpsnd_server_get_event_handle(RdpsndServerContext* context)
{
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
return context->priv->channelEvent;
}
@ -1061,8 +1122,14 @@ UINT rdpsnd_server_handle_messages(RdpsndServerContext* context)
{
DWORD bytesReturned;
UINT ret = CHANNEL_RC_OK;
RdpsndServerPrivate* priv = context->priv;
wStream* s = priv->input_stream;
RdpsndServerPrivate* priv;
wStream* s;
WINPR_ASSERT(context);
WINPR_ASSERT(context->priv);
priv = context->priv;
s = priv->input_stream;
if (!WTSVirtualChannelRead(priv->ChannelHandle, 0, (PCHAR)Stream_Pointer(s),
priv->expectedBytes, &bytesReturned))