From cf2daeb01d3325a9de97348047caf4b8974f2b76 Mon Sep 17 00:00:00 2001 From: Alexandr Date: Wed, 2 Sep 2020 11:37:04 +0000 Subject: [PATCH] cleanup of https://github.com/FreeRDP/FreeRDP/pull/6448 (#6455) * Implemented switchable transport layer Co-authored-by: akallabeth --- include/freerdp/freerdp.h | 1 + include/freerdp/transport_io.h | 73 +++++++++++ libfreerdp/core/freerdp.c | 14 +++ libfreerdp/core/rdp.c | 32 +++++ libfreerdp/core/rdp.h | 4 + libfreerdp/core/tcp.c | 14 +++ libfreerdp/core/tcp.h | 3 + libfreerdp/core/transport.c | 215 ++++++++++++++++++++++++--------- libfreerdp/core/transport.h | 7 +- 9 files changed, 302 insertions(+), 61 deletions(-) create mode 100644 include/freerdp/transport_io.h diff --git a/include/freerdp/freerdp.h b/include/freerdp/freerdp.h index ac0678ad6..aa1214f98 100644 --- a/include/freerdp/freerdp.h +++ b/include/freerdp/freerdp.h @@ -30,6 +30,7 @@ typedef struct rdp_channels rdpChannels; typedef struct rdp_graphics rdpGraphics; typedef struct rdp_metrics rdpMetrics; typedef struct rdp_codecs rdpCodecs; +typedef struct rdp_transport rdpTransport; /* Opaque */ typedef struct rdp_freerdp freerdp; typedef struct rdp_context rdpContext; diff --git a/include/freerdp/transport_io.h b/include/freerdp/transport_io.h new file mode 100644 index 000000000..9d91a7336 --- /dev/null +++ b/include/freerdp/transport_io.h @@ -0,0 +1,73 @@ +/** + * FreeRDP: A Remote Desktop Protocol Implementation + * FreeRDP Interface + * + * Copyright 2009-2011 Jay Sorg + * Copyright 2015 Thincast Technologies GmbH + * Copyright 2015 DI (FH) Martin Haimberger + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef FREERDP_TRANSPORT_IO_H +#define FREERDP_TRANSPORT_IO_H + +typedef struct rdp_transport_io rdpTransportIo; + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" +{ +#endif + + typedef int (*pTCPConnect)(rdpContext* context, rdpSettings* settings, const char* hostname, + int port, DWORD timeout); + typedef BOOL (*pTransportFkt)(rdpTransport* transport); + typedef BOOL (*pTransportAttach)(rdpTransport* transport, int sockfd); + typedef int (*pTransportRWFkt)(rdpTransport* transport, wStream* s); + typedef SSIZE_T (*pTransportRead)(rdpTransport* transport, BYTE* data, size_t bytes); + + struct rdp_transport_io + { + pTCPConnect TCPConnect; + pTransportFkt TLSConnect; + pTransportFkt TLSAccept; + pTransportAttach TransportAttach; + pTransportFkt TransportDisconnect; + pTransportRWFkt ReadPdu; /* Reads a whole PDU from the transport */ + pTransportRWFkt WritePdu; /* Writes a whole PDU to the transport */ + pTransportRead ReadBytes; /* Reads up to a requested amount of bytes from the transport */ + }; + typedef struct rdp_transport_io rdpTransportIo; + + FREERDP_API const rdpTransportIo* freerdp_get_io_callbacks(freerdp* instance); + FREERDP_API BOOL freerdp_set_io_callbacks(freerdp* instance, + const rdpTransportIo* io_callbacks); + /* PDU parser. + * incomplete: FALSE if the whole PDU is available, TRUE otherwise + * Return: 0 -> PDU header incomplete + * >0 -> PDU header complete, length of PDU. + * <0 -> Abort, an error occured + */ + FREERDP_API SSIZE_T transport_parse_pdu(rdpTransport* transport, wStream* s, BOOL* incomplete); + +#ifdef __cplusplus +} +#endif + +#endif /* FREERDP_TRANSPORT_IO_H */ diff --git a/libfreerdp/core/freerdp.c b/libfreerdp/core/freerdp.c index 9094bc346..2ec0b8cb9 100644 --- a/libfreerdp/core/freerdp.c +++ b/libfreerdp/core/freerdp.c @@ -1121,3 +1121,17 @@ const char* freerdp_nego_get_routing_token(rdpContext* context, DWORD* length) return (const char*)nego_get_routing_token(context->rdp->nego, length); } + +const rdpTransportIo* freerdp_get_io_callbacks(freerdp* instance) +{ + if (!instance || !instance->context) + return NULL; + return rdp_get_io_callbacks(instance->context->rdp); +} + +BOOL freerdp_set_io_callbacks(freerdp* instance, const rdpTransportIo* io_callbacks) +{ + if (!instance || !instance->context) + return FALSE; + return rdp_set_io_callbacks(instance->context->rdp, io_callbacks); +} diff --git a/libfreerdp/core/rdp.c b/libfreerdp/core/rdp.c index 1c8990ec8..23ff3518e 100644 --- a/libfreerdp/core/rdp.c +++ b/libfreerdp/core/rdp.c @@ -1787,6 +1787,12 @@ rdpRdp* rdp_new(rdpContext* context) if (!rdp->transport) goto fail; + if (rdp->io && rdp->transport) + { + if (!transport_set_io_callbacks(rdp->transport, rdp->io)) + goto fail; + } + rdp->license = license_new(rdp); if (!rdp->license) @@ -1906,6 +1912,8 @@ void rdp_reset(rdpRdp* rdp) transport_free(rdp->transport); fastpath_free(rdp->fastpath); rdp->transport = transport_new(context); + if (rdp->io && rdp->transport) + transport_set_io_callbacks(rdp->transport, rdp->io); rdp->license = license_new(rdp); rdp->nego = nego_new(rdp->transport); rdp->mcs = mcs_new(rdp->transport); @@ -1944,6 +1952,30 @@ void rdp_free(rdpRdp* rdp) heartbeat_free(rdp->heartbeat); multitransport_free(rdp->multitransport); bulk_free(rdp->bulk); + free(rdp->io); free(rdp); } } + +const rdpTransportIo* rdp_get_io_callbacks(rdpRdp* rdp) +{ + if (!rdp) + return NULL; + return rdp->io; +} + +BOOL rdp_set_io_callbacks(rdpRdp* rdp, const rdpTransportIo* io_callbacks) +{ + if (!rdp) + return FALSE; + free(rdp->io); + rdp->io = NULL; + if (io_callbacks) + { + rdp->io = malloc(sizeof(rdpTransportIo)); + if (!rdp->io) + return FALSE; + *rdp->io = *io_callbacks; + } + return TRUE; +} diff --git a/libfreerdp/core/rdp.h b/libfreerdp/core/rdp.h index 47281d611..a87b613a2 100644 --- a/libfreerdp/core/rdp.h +++ b/libfreerdp/core/rdp.h @@ -180,6 +180,7 @@ struct rdp_rdp UINT64 outBytes; UINT64 outPackets; CRITICAL_SECTION critical; + rdpTransportIo* io; }; FREERDP_LOCAL BOOL rdp_read_security_header(wStream* s, UINT16* flags, UINT16* length); @@ -224,6 +225,9 @@ FREERDP_LOCAL rdpRdp* rdp_new(rdpContext* context); FREERDP_LOCAL void rdp_reset(rdpRdp* rdp); FREERDP_LOCAL void rdp_free(rdpRdp* rdp); +FREERDP_LOCAL const rdpTransportIo* rdp_get_io_callbacks(rdpRdp* rdp); +FREERDP_LOCAL BOOL rdp_set_io_callbacks(rdpRdp* rdp, const rdpTransportIo* io_callbacks); + #define RDP_TAG FREERDP_TAG("core.rdp") #ifdef WITH_DEBUG_RDP #define DEBUG_RDP(...) WLog_DBG(RDP_TAG, __VA_ARGS__) diff --git a/libfreerdp/core/tcp.c b/libfreerdp/core/tcp.c index ef5f7b193..f23da2941 100644 --- a/libfreerdp/core/tcp.c +++ b/libfreerdp/core/tcp.c @@ -30,6 +30,8 @@ #include #include +#include "rdp.h" + #if !defined(_WIN32) #include @@ -1061,6 +1063,18 @@ static BOOL freerdp_tcp_set_keep_alive_mode(const rdpSettings* settings, int soc int freerdp_tcp_connect(rdpContext* context, rdpSettings* settings, const char* hostname, int port, DWORD timeout) +{ + rdpTransport* transport; + if (!context || !context->rdp) + return -1; + transport = context->rdp->transport; + if (!transport) + return -1; + return IFCALLRESULT(-1, transport->io.TCPConnect, context, settings, hostname, port, timeout); +} + +int freerdp_tcp_default_connect(rdpContext* context, rdpSettings* settings, const char* hostname, + int port, DWORD timeout) { int sockfd; UINT32 optval; diff --git a/libfreerdp/core/tcp.h b/libfreerdp/core/tcp.h index 74a23114c..d342ca990 100644 --- a/libfreerdp/core/tcp.h +++ b/libfreerdp/core/tcp.h @@ -66,6 +66,9 @@ FREERDP_LOCAL BIO_METHOD* BIO_s_buffered_socket(void); FREERDP_LOCAL int freerdp_tcp_connect(rdpContext* context, rdpSettings* settings, const char* hostname, int port, DWORD timeout); +FREERDP_LOCAL int freerdp_tcp_default_connect(rdpContext* context, rdpSettings* settings, + const char* hostname, int port, DWORD timeout); + FREERDP_LOCAL char* freerdp_tcp_get_peer_address(SOCKET sockfd); FREERDP_LOCAL struct addrinfo* freerdp_tcp_resolve_host(const char* hostname, int port, diff --git a/libfreerdp/core/transport.c b/libfreerdp/core/transport.c index 40795e941..6d2699d3e 100644 --- a/libfreerdp/core/transport.c +++ b/libfreerdp/core/transport.c @@ -226,6 +226,13 @@ wStream* transport_send_stream_init(rdpTransport* transport, int size) } BOOL transport_attach(rdpTransport* transport, int sockfd) +{ + if (!transport) + return FALSE; + return IFCALLRESULT(FALSE, transport->io.TransportAttach, transport, sockfd); +} + +static BOOL transport_default_attach(rdpTransport* transport, int sockfd) { BIO* socketBio = NULL; BIO* bufferedBio; @@ -255,11 +262,20 @@ fail: BOOL transport_connect_rdp(rdpTransport* transport) { + if (!transport) + return FALSE; /* RDP encryption */ return TRUE; } BOOL transport_connect_tls(rdpTransport* transport) +{ + if (!transport) + return FALSE; + return IFCALLRESULT(FALSE, transport->io.TLSConnect, transport); +} + +static BOOL transport_default_connect_tls(rdpTransport* transport) { int tlsStatus; rdpTls* tls = NULL; @@ -314,10 +330,17 @@ BOOL transport_connect_tls(rdpTransport* transport) BOOL transport_connect_nla(rdpTransport* transport) { - rdpContext* context = transport->context; - rdpSettings* settings = context->settings; - freerdp* instance = context->instance; - rdpRdp* rdp = context->rdp; + rdpContext* context = NULL; + rdpSettings* settings = NULL; + freerdp* instance = NULL; + rdpRdp* rdp = NULL; + if (!transport) + return FALSE; + + context = transport->context; + settings = context->settings; + instance = context->instance; + rdp = context->rdp; if (!transport_connect_tls(transport)) return FALSE; @@ -444,11 +467,20 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por BOOL transport_accept_rdp(rdpTransport* transport) { + if (!transport) + return FALSE; /* RDP encryption */ return TRUE; } BOOL transport_accept_tls(rdpTransport* transport) +{ + if (!transport) + return FALSE; + return IFCALLRESULT(FALSE, transport->io.TLSAccept, transport); +} + +static BOOL transport_default_accept_tls(rdpTransport* transport) { rdpSettings* settings = transport->settings; @@ -466,18 +498,13 @@ BOOL transport_accept_tls(rdpTransport* transport) BOOL transport_accept_nla(rdpTransport* transport) { + rdpSettings* settings = transport->settings; freerdp* instance = (freerdp*)settings->instance; - - if (!transport->tls) - transport->tls = tls_new(transport->settings); - - transport->layer = TRANSPORT_LAYER_TLS; - - if (!tls_accept(transport->tls, transport->frontBio, settings)) + if (!transport) + return FALSE; + if (!IFCALLRESULT(FALSE, transport->io.TLSAccept, transport)) return FALSE; - - transport->frontBio = transport->tls->bio; /* Network Level Authentication */ @@ -626,10 +653,13 @@ static SSIZE_T transport_read_layer(rdpTransport* transport, BYTE* data, size_t static SSIZE_T transport_read_layer_bytes(rdpTransport* transport, wStream* s, size_t toRead) { SSIZE_T status; + if (!transport) + return -1; + if (toRead > SSIZE_MAX) return 0; - status = transport_read_layer(transport, Stream_Pointer(s), toRead); + status = IFCALLRESULT(-1, transport->io.ReadBytes, transport, Stream_Pointer(s), toRead); if (status <= 0) return status; @@ -652,7 +682,13 @@ static SSIZE_T transport_read_layer_bytes(rdpTransport* transport, wStream* s, s */ int transport_read_pdu(rdpTransport* transport, wStream* s) { - int status; + if (!transport) + return -1; + return IFCALLRESULT(-1, transport->io.ReadPdu, transport, s); +} + +SSIZE_T transport_parse_pdu(rdpTransport* transport, wStream* s, BOOL* incomplete) +{ size_t position; size_t pduLength; BYTE* header; @@ -664,23 +700,19 @@ int transport_read_pdu(rdpTransport* transport, wStream* s) if (!s) return -1; + header = Stream_Buffer(s); position = Stream_GetPosition(s); - /* Make sure there is enough space for the longest header within the stream */ - if (!Stream_EnsureCapacity(s, 4)) - return -1; + if (incomplete) + *incomplete = TRUE; /* Make sure at least two bytes are read for further processing */ - if (position < 2 && (status = transport_read_layer_bytes(transport, s, 2 - position)) != 1) + if (position < 2) { /* No data available at the moment */ - return status; + return 0; } - /* update position value for further checks */ - position = Stream_GetPosition(s); - header = Stream_Buffer(s); - if (transport->NlaMode) { /* @@ -697,9 +729,8 @@ int transport_read_pdu(rdpTransport* transport, wStream* s) if ((header[1] & ~(0x80)) == 1) { /* check for header bytes already was readed in previous calls */ - if (position < 3 && - (status = transport_read_layer_bytes(transport, s, 3 - position)) != 1) - return status; + if (position < 3) + return 0; pduLength = header[2]; pduLength += 3; @@ -707,9 +738,8 @@ int transport_read_pdu(rdpTransport* transport, wStream* s) else if ((header[1] & ~(0x80)) == 2) { /* check for header bytes already was readed in previous calls */ - if (position < 4 && - (status = transport_read_layer_bytes(transport, s, 4 - position)) != 1) - return status; + if (position < 4) + return 0; pduLength = (header[2] << 8) | header[3]; pduLength += 4; @@ -733,14 +763,13 @@ int transport_read_pdu(rdpTransport* transport, wStream* s) { /* TPKT header */ /* check for header bytes already was readed in previous calls */ - if (position < 4 && - (status = transport_read_layer_bytes(transport, s, 4 - position)) != 1) - return status; + if (position < 4) + return 0; pduLength = (header[2] << 8) | header[3]; /* min and max values according to ITU-T Rec. T.123 (01/2007) section 8 */ - if (pduLength < 7 || pduLength > 0xFFFF) + if ((pduLength < 7) || (pduLength > 0xFFFF)) { WLog_Print(transport->log, WLOG_ERROR, "tpkt - invalid pduLength: %" PRIdz, pduLength); @@ -753,9 +782,8 @@ int transport_read_pdu(rdpTransport* transport, wStream* s) if (header[1] & 0x80) { /* check for header bytes already was readed in previous calls */ - if (position < 3 && - (status = transport_read_layer_bytes(transport, s, 3 - position)) != 1) - return status; + if (position < 3) + return 0; pduLength = ((header[1] & 0x7F) << 8) | header[2]; } @@ -776,7 +804,46 @@ int transport_read_pdu(rdpTransport* transport, wStream* s) } } - if (!Stream_EnsureCapacity(s, Stream_GetPosition(s) + pduLength)) + if (position > pduLength) + return -1; + + if (incomplete) + *incomplete = position >= pduLength; + + return pduLength; +} + +static int transport_default_read_pdu(rdpTransport* transport, wStream* s) +{ + BOOL incomplete; + SSIZE_T status; + size_t pduLength; + size_t position; + + /* Read in pdu length */ + status = transport_parse_pdu(transport, s, &incomplete); + while ((status == 0) && incomplete) + { + int rc; + if (!Stream_EnsureRemainingCapacity(s, 1)) + return -1; + rc = transport_read_layer_bytes(transport, s, 1); + if (rc != 1) + return rc; + status = transport_parse_pdu(transport, s, &incomplete); + } + + if (status < 0) + return -1; + + pduLength = (size_t)status; + + /* Read in rest of the PDU */ + if (!Stream_EnsureCapacity(s, pduLength)) + return -1; + + position = Stream_GetPosition(s); + if (position > pduLength) return -1; status = transport_read_layer_bytes(transport, s, pduLength - Stream_GetPosition(s)); @@ -793,6 +860,14 @@ int transport_read_pdu(rdpTransport* transport, wStream* s) } int transport_write(rdpTransport* transport, wStream* s) +{ + if (!transport) + return -1; + + return IFCALLRESULT(-1, transport->io.WritePdu, transport, s); +} + +static int transport_default_write(rdpTransport* transport, wStream* s) { size_t length; int status = -1; @@ -1108,6 +1183,13 @@ void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode) } BOOL transport_disconnect(rdpTransport* transport) +{ + if (!transport) + return FALSE; + return IFCALLRESULT(FALSE, transport->io.TransportDisconnect, transport); +} + +static BOOL transport_default_disconnect(rdpTransport* transport) { BOOL status = TRUE; @@ -1144,8 +1226,7 @@ BOOL transport_disconnect(rdpTransport* transport) rdpTransport* transport_new(rdpContext* context) { - rdpTransport* transport; - transport = (rdpTransport*)calloc(1, sizeof(rdpTransport)); + rdpTransport* transport = (rdpTransport*)calloc(1, sizeof(rdpTransport)); if (!transport) return NULL; @@ -1153,30 +1234,40 @@ rdpTransport* transport_new(rdpContext* context) transport->log = WLog_Get(TAG); if (!transport->log) - goto out_free_transport; + goto fail; + + // transport->io.DataHandler = transport_data_handler; + transport->io.TCPConnect = freerdp_tcp_default_connect; + transport->io.TLSConnect = transport_default_connect_tls; + transport->io.TLSAccept = transport_default_accept_tls; + transport->io.TransportAttach = transport_default_attach; + transport->io.TransportDisconnect = transport_default_disconnect; + transport->io.ReadPdu = transport_default_read_pdu; + transport->io.WritePdu = transport_default_write; + transport->io.ReadBytes = transport_read_layer; transport->context = context; transport->settings = context->settings; transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE); if (!transport->ReceivePool) - goto out_free_transport; + goto fail; /* receive buffer for non-blocking read. */ transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0); if (!transport->ReceiveBuffer) - goto out_free_receivepool; + goto fail; transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL); if (!transport->connectedEvent || transport->connectedEvent == INVALID_HANDLE_VALUE) - goto out_free_receivebuffer; + goto fail; transport->rereadEvent = CreateEvent(NULL, TRUE, FALSE, NULL); if (!transport->rereadEvent || transport->rereadEvent == INVALID_HANDLE_VALUE) - goto out_free_connectedEvent; + goto fail; transport->haveMoreBytesToRead = FALSE; transport->blocking = TRUE; @@ -1184,24 +1275,14 @@ rdpTransport* transport_new(rdpContext* context) transport->layer = TRANSPORT_LAYER_TCP; if (!InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000)) - goto out_free_rereadEvent; + goto fail; if (!InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000)) - goto out_free_readlock; + goto fail; return transport; -out_free_readlock: - DeleteCriticalSection(&(transport->ReadLock)); -out_free_rereadEvent: - CloseHandle(transport->rereadEvent); -out_free_connectedEvent: - CloseHandle(transport->connectedEvent); -out_free_receivebuffer: - StreamPool_Return(transport->ReceivePool, transport->ReceiveBuffer); -out_free_receivepool: - StreamPool_Free(transport->ReceivePool); -out_free_transport: - free(transport); +fail: + transport_free(transport); return NULL; } @@ -1223,3 +1304,19 @@ void transport_free(rdpTransport* transport) DeleteCriticalSection(&(transport->WriteLock)); free(transport); } + +BOOL transport_set_io_callbacks(rdpTransport* transport, const rdpTransportIo* io_callbacks) +{ + if (!transport || !io_callbacks) + return FALSE; + + transport->io = *io_callbacks; + return TRUE; +} + +const rdpTransportIo* transport_get_io_callbacks(rdpTransport* transport) +{ + if (!transport) + return NULL; + return &transport->io; +} diff --git a/libfreerdp/core/transport.h b/libfreerdp/core/transport.h index 944f5ce30..694240971 100644 --- a/libfreerdp/core/transport.h +++ b/libfreerdp/core/transport.h @@ -29,8 +29,6 @@ typedef enum TRANSPORT_LAYER_CLOSED } TRANSPORT_LAYER; -typedef struct rdp_transport rdpTransport; - #include "tcp.h" #include "nla.h" @@ -50,6 +48,7 @@ typedef struct rdp_transport rdpTransport; #include #include #include +#include typedef int (*TransportRecv)(rdpTransport* transport, wStream* stream, void* extra); @@ -77,6 +76,7 @@ struct rdp_transport HANDLE rereadEvent; BOOL haveMoreBytesToRead; wLog* log; + rdpTransportIo io; }; FREERDP_LOCAL wStream* transport_send_stream_init(rdpTransport* transport, int size); @@ -109,6 +109,9 @@ FREERDP_LOCAL int transport_drain_output_buffer(rdpTransport* transport); FREERDP_LOCAL wStream* transport_receive_pool_take(rdpTransport* transport); FREERDP_LOCAL int transport_receive_pool_return(rdpTransport* transport, wStream* pdu); +FREERDP_LOCAL const rdpTransportIo* transport_get_io_callbacks(rdpTransport* transport); +FREERDP_LOCAL BOOL transport_set_io_callbacks(rdpTransport* transport, + const rdpTransportIo* io_callbacks); FREERDP_LOCAL rdpTransport* transport_new(rdpContext* context); FREERDP_LOCAL void transport_free(rdpTransport* transport);