Transport add getter, add checks

* Added transport_get_context to get rdpContext in IO callbacks.
* Added WINPR_ASSERT where possible.
* Fixed handle count mismatch in transport_get_event_handles
This commit is contained in:
akallabeth 2021-09-03 08:00:24 +02:00 committed by akallabeth
parent 732a4d3839
commit 595a40a1e0
3 changed files with 94 additions and 18 deletions

View File

@ -66,6 +66,8 @@ extern "C"
*/
FREERDP_API SSIZE_T transport_parse_pdu(rdpTransport* transport, wStream* s, BOOL* incomplete);
FREERDP_API rdpContext* transport_get_context(rdpTransport* transport);
#ifdef __cplusplus
}
#endif

View File

@ -64,6 +64,7 @@ static void transport_ssl_cb(SSL* ssl, int where, int ret)
if (where & SSL_CB_ALERT)
{
rdpTransport* transport = (rdpTransport*)SSL_get_app_data(ssl);
WINPR_ASSERT(transport);
switch (ret)
{
@ -106,10 +107,12 @@ static void transport_ssl_cb(SSL* ssl, int where, int ret)
}
}
wStream* transport_send_stream_init(rdpTransport* transport, int size)
wStream* transport_send_stream_init(rdpTransport* transport, size_t size)
{
wStream* s;
WINPR_ASSERT(transport);
if (!(s = StreamPool_Take(transport->ReceivePool, size)))
return NULL;
@ -134,6 +137,9 @@ static BOOL transport_default_attach(rdpTransport* transport, int sockfd)
{
BIO* socketBio = NULL;
BIO* bufferedBio;
WINPR_ASSERT(transport);
socketBio = BIO_new(BIO_s_simple_socket());
if (!socketBio)
@ -179,6 +185,8 @@ BOOL transport_connect_tls(rdpTransport* transport)
if (!transport)
return FALSE;
WINPR_ASSERT(transport->settings);
/* Only prompt for password if we use TLS (NLA also calls this function) */
if (transport->settings->SelectedProtocol == PROTOCOL_SSL)
{
@ -200,8 +208,16 @@ static BOOL transport_default_connect_tls(rdpTransport* transport)
{
int tlsStatus;
rdpTls* tls = NULL;
rdpContext* context = transport->context;
rdpSettings* settings = transport->settings;
rdpContext* context;
rdpSettings* settings;
WINPR_ASSERT(transport);
context = transport->context;
WINPR_ASSERT(context);
settings = transport->settings;
WINPR_ASSERT(settings);
if (!(tls = tls_new(settings)))
return FALSE;
@ -263,6 +279,11 @@ BOOL transport_connect_nla(rdpTransport* transport)
instance = context->instance;
rdp = context->rdp;
WINPR_ASSERT(context);
WINPR_ASSERT(settings);
WINPR_ASSERT(instance);
WINPR_ASSERT(rdp);
if (!transport_connect_tls(transport))
return FALSE;
@ -302,9 +323,18 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por
{
int sockfd;
BOOL status = FALSE;
rdpSettings* settings = transport->settings;
rdpContext* context = transport->context;
BOOL rpcFallback = !settings->GatewayHttpTransport;
rdpSettings* settings;
rdpContext* context;
BOOL rpcFallback;
WINPR_ASSERT(transport);
WINPR_ASSERT(hostname);
settings = transport->settings;
context = transport->context;
rpcFallback = !settings->GatewayHttpTransport;
WINPR_ASSERT(settings);
WINPR_ASSERT(context);
if (transport->GatewayEnabled)
{
@ -403,7 +433,12 @@ BOOL transport_accept_tls(rdpTransport* transport)
static BOOL transport_default_accept_tls(rdpTransport* transport)
{
rdpSettings* settings = transport->settings;
rdpSettings* settings;
WINPR_ASSERT(transport);
settings = transport->settings;
WINPR_ASSERT(settings);
if (!transport->tls)
transport->tls = tls_new(transport->settings);
@ -468,6 +503,9 @@ static void transport_bio_error_log(rdpTransport* transport, LPCSTR biofunc, BIO
char* buf;
int saveerrno;
DWORD level;
WINPR_ASSERT(transport);
saveerrno = errno;
level = WLOG_ERROR;
@ -502,7 +540,13 @@ static void transport_bio_error_log(rdpTransport* transport, LPCSTR biofunc, BIO
static SSIZE_T transport_read_layer(rdpTransport* transport, BYTE* data, size_t bytes)
{
SSIZE_T read = 0;
rdpRdp* rdp = transport->context->rdp;
rdpRdp* rdp;
WINPR_ASSERT(transport);
WINPR_ASSERT(transport->context);
rdp = transport->context->rdp;
WINPR_ASSERT(rdp);
if (!transport->frontBio || (bytes > SSIZE_MAX))
{
@ -744,6 +788,9 @@ static int transport_default_read_pdu(rdpTransport* transport, wStream* s)
size_t pduLength;
size_t position;
WINPR_ASSERT(transport);
WINPR_ASSERT(s);
/* Read in pdu length */
status = transport_parse_pdu(transport, s, &incomplete);
while ((status == 0) && incomplete)
@ -904,6 +951,10 @@ DWORD transport_get_event_handles(rdpTransport* transport, HANDLE* events, DWORD
DWORD nCount = 1; /* always the reread Event */
DWORD tmp;
WINPR_ASSERT(transport);
WINPR_ASSERT(events);
WINPR_ASSERT(count > 0);
if (events)
{
if (count < 1)
@ -918,11 +969,9 @@ DWORD transport_get_event_handles(rdpTransport* transport, HANDLE* events, DWORD
if (!transport->GatewayEnabled)
{
nCount++;
if (events)
{
if (nCount > count)
if (nCount >= count)
{
WLog_Print(transport->log, WLOG_ERROR,
"%s: provided handles array is too small (count=%" PRIu32
@ -931,11 +980,15 @@ DWORD transport_get_event_handles(rdpTransport* transport, HANDLE* events, DWORD
return 0;
}
if (BIO_get_event(transport->frontBio, &events[1]) != 1)
if (transport->frontBio)
{
WLog_Print(transport->log, WLOG_ERROR, "%s: error getting the frontBio handle",
__FUNCTION__);
return 0;
if (BIO_get_event(transport->frontBio, &events[1]) != 1)
{
WLog_Print(transport->log, WLOG_ERROR, "%s: error getting the frontBio handle",
__FUNCTION__);
return 0;
}
nCount++;
}
}
}
@ -969,6 +1022,11 @@ void transport_get_fds(rdpTransport* transport, void** rfds, int* rcount)
DWORD index;
DWORD nCount;
HANDLE events[MAXIMUM_WAIT_OBJECTS] = { 0 };
WINPR_ASSERT(transport);
WINPR_ASSERT(rfds);
WINPR_ASSERT(rcount);
nCount = transport_get_event_handles(transport, events, ARRAYSIZE(events));
*rcount = nCount + 1;
@ -982,6 +1040,7 @@ void transport_get_fds(rdpTransport* transport, void** rfds, int* rcount)
BOOL transport_is_write_blocked(rdpTransport* transport)
{
WINPR_ASSERT(transport);
return BIO_write_blocked(transport->frontBio);
}
@ -989,6 +1048,7 @@ int transport_drain_output_buffer(rdpTransport* transport)
{
BOOL status = FALSE;
WINPR_ASSERT(transport);
if (BIO_write_blocked(transport->frontBio))
{
if (BIO_flush(transport->frontBio) < 1)
@ -1008,8 +1068,7 @@ int transport_check_fds(rdpTransport* transport)
UINT64 now = GetTickCount64();
UINT64 dueDate = 0;
if (!transport)
return -1;
WINPR_ASSERT(transport);
if (transport->layer == TRANSPORT_LAYER_CLOSED)
{
@ -1018,6 +1077,7 @@ int transport_check_fds(rdpTransport* transport)
return -1;
}
WINPR_ASSERT(transport->settings);
dueDate = now + transport->settings->MaxTimeInCheckLoop;
if (transport->haveMoreBytesToRead)
@ -1028,6 +1088,7 @@ int transport_check_fds(rdpTransport* transport)
while (now < dueDate)
{
WINPR_ASSERT(transport->context);
if (freerdp_shall_disconnect(transport->context->instance))
{
return -1;
@ -1062,6 +1123,7 @@ int transport_check_fds(rdpTransport* transport)
* 0: success
* 1: redirection
*/
WINPR_ASSERT(transport->ReceiveCallback);
recv_status = transport->ReceiveCallback(transport, received, transport->ReceiveExtra);
Stream_Release(received);
@ -1092,6 +1154,8 @@ int transport_check_fds(rdpTransport* transport)
BOOL transport_set_blocking_mode(rdpTransport* transport, BOOL blocking)
{
WINPR_ASSERT(transport);
transport->blocking = blocking;
if (!BIO_set_nonblock(transport->frontBio, blocking ? FALSE : TRUE))
@ -1102,11 +1166,13 @@ BOOL transport_set_blocking_mode(rdpTransport* transport, BOOL blocking)
void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled)
{
WINPR_ASSERT(transport);
transport->GatewayEnabled = GatewayEnabled;
}
void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode)
{
WINPR_ASSERT(transport);
transport->NlaMode = NlaMode;
}
@ -1156,6 +1222,8 @@ rdpTransport* transport_new(rdpContext* context)
{
rdpTransport* transport = (rdpTransport*)calloc(1, sizeof(rdpTransport));
WINPR_ASSERT(context);
if (!transport)
return NULL;
@ -1248,3 +1316,9 @@ const rdpTransportIo* transport_get_io_callbacks(rdpTransport* transport)
return NULL;
return &transport->io;
}
rdpContext* transport_get_context(rdpTransport* transport)
{
WINPR_ASSERT(transport);
return transport->context;
}

View File

@ -79,7 +79,7 @@ struct rdp_transport
rdpTransportIo io;
};
FREERDP_LOCAL wStream* transport_send_stream_init(rdpTransport* transport, int size);
FREERDP_LOCAL wStream* transport_send_stream_init(rdpTransport* transport, size_t size);
FREERDP_LOCAL BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 port,
DWORD timeout);
FREERDP_LOCAL BOOL transport_attach(rdpTransport* transport, int sockfd);