diff --git a/include/freerdp/transport_io.h b/include/freerdp/transport_io.h index ab7f39a59..fbbbeaef1 100644 --- a/include/freerdp/transport_io.h +++ b/include/freerdp/transport_io.h @@ -41,6 +41,8 @@ extern "C" typedef BOOL (*pTransportAttach)(rdpTransport* transport, int sockfd); typedef int (*pTransportRWFkt)(rdpTransport* transport, wStream* s); typedef SSIZE_T (*pTransportRead)(rdpTransport* transport, BYTE* data, size_t bytes); + typedef BOOL (*pTransportGetPublicKey)(rdpTransport* transport, const BYTE** data, + DWORD* length); struct rdp_transport_io { @@ -52,6 +54,7 @@ extern "C" pTransportRWFkt ReadPdu; /* Reads a whole PDU from the transport */ pTransportRWFkt WritePdu; /* Writes a whole PDU to the transport */ pTransportRead ReadBytes; /* Reads up to a requested amount of bytes from the transport */ + pTransportGetPublicKey GetPublicKey; }; typedef struct rdp_transport_io rdpTransportIo; diff --git a/libfreerdp/core/nla.c b/libfreerdp/core/nla.c index 7b1e10992..ad3583c52 100644 --- a/libfreerdp/core/nla.c +++ b/libfreerdp/core/nla.c @@ -451,15 +451,15 @@ static int nla_client_init(rdpNla* nla) if (!credssp_auth_setup_client(nla->auth, "TERMSRV", hostname, nla->identity, nla->pkinitArgs)) return -1; - rdpTls* tls = transport_get_tls(nla->transport); - - if (!tls) + const BYTE* data = NULL; + DWORD length = 0; + if (!transport_get_public_key(nla->transport, &data, &length)) { - WLog_ERR(TAG, "Unknown NLA transport layer"); + WLog_ERR(TAG, "Failed to get public key"); return -1; } - if (!nla_sec_buffer_alloc_from_data(&nla->PublicKey, tls->PublicKey, 0, tls->PublicKeyLength)) + if (!nla_sec_buffer_alloc_from_data(&nla->PublicKey, data, 0, length)) { WLog_ERR(TAG, "Failed to allocate sspi secBuffer"); return -1; @@ -662,10 +662,15 @@ static int nla_server_init(rdpNla* nla) { WINPR_ASSERT(nla); - rdpTls* tls = transport_get_tls(nla->transport); - WINPR_ASSERT(tls); + const BYTE* data = NULL; + DWORD length = 0; + if (!transport_get_public_key(nla->transport, &data, &length)) + { + WLog_ERR(TAG, "Failed to get public key"); + return -1; + } - if (!nla_sec_buffer_alloc_from_data(&nla->PublicKey, tls->PublicKey, 0, tls->PublicKeyLength)) + if (!nla_sec_buffer_alloc_from_data(&nla->PublicKey, data, 0, length)) { WLog_ERR(TAG, "Failed to allocate SecBuffer for public key"); return -1; diff --git a/libfreerdp/core/transport.c b/libfreerdp/core/transport.c index 83f4a172f..d756be352 100644 --- a/libfreerdp/core/transport.c +++ b/libfreerdp/core/transport.c @@ -1218,6 +1218,24 @@ fail: return status; } +BOOL transport_get_public_key(rdpTransport* transport, const BYTE** data, DWORD* length) +{ + return IFCALLRESULT(FALSE, transport->io.GetPublicKey, transport, data, length); +} + +static BOOL transport_default_get_public_key(rdpTransport* transport, const BYTE** data, + DWORD* length) +{ + rdpTls* tls = transport_get_tls(transport); + if (!tls) + return FALSE; + + *data = tls->PublicKey; + *length = tls->PublicKeyLength; + + return TRUE; +} + DWORD transport_get_event_handles(rdpTransport* transport, HANDLE* events, DWORD count) { DWORD nCount = 0; /* always the reread Event */ @@ -1535,6 +1553,7 @@ rdpTransport* transport_new(rdpContext* context) transport->io.ReadPdu = transport_default_read_pdu; transport->io.WritePdu = transport_default_write; transport->io.ReadBytes = transport_read_layer; + transport->io.GetPublicKey = transport_default_get_public_key; transport->context = context; transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE); diff --git a/libfreerdp/core/transport.h b/libfreerdp/core/transport.h index aff067bcd..490d94416 100644 --- a/libfreerdp/core/transport.h +++ b/libfreerdp/core/transport.h @@ -73,6 +73,9 @@ FREERDP_LOCAL BOOL transport_accept_rdstls(rdpTransport* transport); FREERDP_LOCAL int transport_read_pdu(rdpTransport* transport, wStream* s); FREERDP_LOCAL int transport_write(rdpTransport* transport, wStream* s); +FREERDP_LOCAL BOOL transport_get_public_key(rdpTransport* transport, const BYTE** data, + DWORD* length); + #if defined(WITH_FREERDP_DEPRECATED) FREERDP_LOCAL void transport_get_fds(rdpTransport* transport, void** rfds, int* rcount); #endif