Allow transport_write calls to be non-blocking

This big patch allows to have non-blocking writes. To achieve
this, it slightly changes the way transport is handled. The misc transport
layers are handled with OpenSSL BIOs. In the chain we insert a
bufferedBIO that will bufferize write calls that couldn't be honored.

For an access with Tls security the BIO chain would look like this:
  FreeRdp Code ===> SSL bio ===> buffered BIO ===> socket BIO

The buffered BIO will store bytes that couldn't be send because of
blocking write calls.

This patch also rework TSG so that it would look like this in the
case of SSL security with TSG:
                                         (TSG in)
                              > SSL BIO => buffered BIO ==> socket BIO
                             /
FreeRdp => SSL BIO => TSG BIO
                             \
                              > SSL BIO => buffered BIO ==> socket BIO
                                        (TSG out)

So from the FreeRDP point of view sending something is only BIO_writing
on the frontBio (last BIO on the left).
This commit is contained in:
Hardening 2014-05-21 17:32:14 +02:00
parent 0376dcd065
commit dd6d829550
20 changed files with 1478 additions and 1113 deletions

View File

@ -70,7 +70,6 @@ struct rdp_tls
SSL* ssl;
BIO* bio;
void* tsg;
int sockfd;
SSL_CTX* ctx;
BYTE* PublicKey;
BIO_METHOD* methods;
@ -84,17 +83,11 @@ struct rdp_tls
int alertDescription;
};
FREERDP_API int tls_connect(rdpTls* tls);
FREERDP_API BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file);
FREERDP_API int tls_connect(rdpTls* tls, BIO *underlying);
FREERDP_API BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file);
FREERDP_API BOOL tls_disconnect(rdpTls* tls);
FREERDP_API int tls_read(rdpTls* tls, BYTE* data, int length);
FREERDP_API int tls_write(rdpTls* tls, BYTE* data, int length);
FREERDP_API int tls_write_all(rdpTls* tls, BYTE* data, int length);
FREERDP_API int tls_wait_read(rdpTls* tls);
FREERDP_API int tls_wait_write(rdpTls* tls);
FREERDP_API int tls_write_all(rdpTls* tls, const BYTE* data, int length);
FREERDP_API int tls_set_alert_code(rdpTls* tls, int level, int description);

View File

@ -34,7 +34,10 @@ typedef void (*psPeerContextFree)(freerdp_peer* client, rdpContext* context);
typedef BOOL (*psPeerInitialize)(freerdp_peer* client);
typedef BOOL (*psPeerGetFileDescriptor)(freerdp_peer* client, void** rfds, int* rcount);
typedef HANDLE (*psPeerGetEventHandle)(freerdp_peer* client);
typedef HANDLE (*psPeerGetReceiveEventHandle)(freerdp_peer* client);
typedef BOOL (*psPeerCheckFileDescriptor)(freerdp_peer* client);
typedef BOOL (*psPeerIsWriteBlocked)(freerdp_peer* client);
typedef int (*psPeerDrainOutputBuffer)(freerdp_peer* client);
typedef BOOL (*psPeerClose)(freerdp_peer* client);
typedef void (*psPeerDisconnect)(freerdp_peer* client);
typedef BOOL (*psPeerCapabilities)(freerdp_peer* client);
@ -62,6 +65,7 @@ struct rdp_freerdp_peer
psPeerInitialize Initialize;
psPeerGetFileDescriptor GetFileDescriptor;
psPeerGetEventHandle GetEventHandle;
psPeerGetReceiveEventHandle GetReceiveEventHandle;
psPeerCheckFileDescriptor CheckFileDescriptor;
psPeerClose Close;
psPeerDisconnect Disconnect;
@ -81,6 +85,9 @@ struct rdp_freerdp_peer
BOOL activated;
BOOL authenticated;
SEC_WINNT_AUTH_IDENTITY identity;
psPeerIsWriteBlocked IsWriteBlocked;
psPeerDrainOutputBuffer DrainOutputBuffer;
};
#ifdef __cplusplus

View File

@ -798,7 +798,8 @@ struct rdp_settings
ALIGN64 char* Password; /* 22 */
ALIGN64 char* Domain; /* 23 */
ALIGN64 char* PasswordHash; /* 24 */
UINT64 padding0064[64 - 25]; /* 25 */
ALIGN64 BOOL WaitForOutputBufferFlush; /* 25 */
UINT64 padding0064[64 - 26]; /* 26 */
UINT64 padding0128[128 - 64]; /* 64 */
/**

View File

@ -26,6 +26,10 @@
#include <winpr/stream.h>
#include <winpr/string.h>
#ifdef HAVE_VALGRIND_MEMCHECK_H
#include <valgrind/memcheck.h>
#endif
#include "http.h"
HttpContext* http_context_new()
@ -472,7 +476,7 @@ HttpResponse* http_response_recv(rdpTls* tls)
nbytes = 0;
length = 10000;
content = NULL;
buffer = malloc(length);
buffer = calloc(length, 1);
if (!buffer)
return NULL;
@ -487,14 +491,20 @@ HttpResponse* http_response_recv(rdpTls* tls)
{
while (nbytes < 5)
{
status = tls_read(tls, p, length - nbytes);
status = BIO_read(tls->bio, p, length - nbytes);
if (status < 0)
goto out_error;
if (status <= 0)
{
if (!BIO_should_retry(tls->bio))
goto out_error;
if (!status)
USleep(100);
continue;
}
#ifdef HAVE_VALGRIND_MEMCHECK_H
VALGRIND_MAKE_MEM_DEFINED(p, status);
#endif
nbytes += status;
p = (BYTE*) &buffer[nbytes];
}
@ -503,7 +513,7 @@ HttpResponse* http_response_recv(rdpTls* tls)
if (!header_end)
{
fprintf(stderr, "http_response_recv: invalid response:\n");
fprintf(stderr, "%s: invalid response:\n", __FUNCTION__);
winpr_HexDump(buffer, status);
goto out_error;
}
@ -517,7 +527,7 @@ HttpResponse* http_response_recv(rdpTls* tls)
header_end[0] = '\0';
header_end[1] = '\0';
content = &header_end[2];
content = header_end + 2;
count = 0;
line = (char*) buffer;
@ -552,11 +562,14 @@ HttpResponse* http_response_recv(rdpTls* tls)
if (!http_response_parse_header(http_response))
goto out_error;
if (http_response->ContentLength > 0)
http_response->bodyLen = nbytes - (content - (char *)buffer);
if (http_response->bodyLen > 0)
{
http_response->Content = _strdup(content);
if (!http_response->Content)
http_response->BodyContent = (BYTE *)malloc(http_response->bodyLen);
if (!http_response->BodyContent)
goto out_error;
CopyMemory(http_response->BodyContent, content, http_response->bodyLen);
}
break;
@ -627,7 +640,7 @@ void http_response_free(HttpResponse* http_response)
ListDictionary_Free(http_response->Authenticates);
if (http_response->ContentLength > 0)
free(http_response->Content);
free(http_response->BodyContent);
free(http_response);
}

View File

@ -84,7 +84,8 @@ struct _http_response
wListDictionary *Authenticates;
int ContentLength;
char* Content;
BYTE *BodyContent;
int bodyLen;
};
void http_response_print(HttpResponse* http_response);

View File

@ -98,6 +98,8 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc)
rdpNtlm* ntlm = rpc->NtlmHttpIn->ntlm;
http_response = http_response_recv(rpc->TlsIn);
if (!http_response)
return -1;
if (ListDictionary_Contains(http_response->Authenticates, "NTLM"))
{
@ -105,14 +107,12 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc)
if (!token64)
goto out;
ntlm_token_data = NULL;
crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length);
}
out:
ntlm->inputBuffer[0].pvBuffer = ntlm_token_data;
ntlm->inputBuffer[0].cbBuffer = ntlm_token_length;
out:
http_response_free(http_response);
return 0;
@ -123,25 +123,19 @@ int rpc_ncacn_http_ntlm_init(rdpRpc* rpc, TSG_CHANNEL channel)
rdpNtlm* ntlm = NULL;
rdpSettings* settings = rpc->settings;
freerdp* instance = (freerdp*) rpc->settings->instance;
BOOL promptPassword = FALSE;
if (channel == TSG_CHANNEL_IN)
ntlm = rpc->NtlmHttpIn->ntlm;
else if (channel == TSG_CHANNEL_OUT)
ntlm = rpc->NtlmHttpOut->ntlm;
if ((!settings->GatewayPassword) || (!settings->GatewayUsername)
|| (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername)))
{
promptPassword = TRUE;
}
if (promptPassword)
if (!settings->GatewayPassword || !settings->GatewayUsername ||
!strlen(settings->GatewayPassword) || !strlen(settings->GatewayUsername))
{
if (instance->GatewayAuthenticate)
{
BOOL proceed = instance->GatewayAuthenticate(instance,
&settings->GatewayUsername, &settings->GatewayPassword, &settings->GatewayDomain);
BOOL proceed = instance->GatewayAuthenticate(instance, &settings->GatewayUsername,
&settings->GatewayPassword, &settings->GatewayDomain);
if (!proceed)
{
@ -240,12 +234,10 @@ int rpc_ncacn_http_recv_out_channel_response(rdpRpc* rpc)
char *token64 = ListDictionary_GetItemValue(http_response->Authenticates, "NTLM");
crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length);
}
ntlm->inputBuffer[0].pvBuffer = ntlm_token_data;
ntlm->inputBuffer[0].cbBuffer = ntlm_token_length;
http_response_free(http_response);
return 0;
}
@ -259,15 +251,12 @@ BOOL rpc_ntlm_http_out_connect(rdpRpc* rpc)
success = TRUE;
/* Send OUT Channel Request */
rpc_ncacn_http_send_out_channel_request(rpc);
/* Receive OUT Channel Response */
rpc_ncacn_http_recv_out_channel_response(rpc);
/* Send OUT Channel Request */
rpc_ncacn_http_send_out_channel_request(rpc);
ntlm_client_uninit(ntlm);
@ -296,13 +285,11 @@ void rpc_ntlm_http_init_channel(rdpRpc* rpc, rdpNtlmHttp* ntlm_http, TSG_CHANNEL
if (channel == TSG_CHANNEL_IN)
{
http_context_set_pragma(ntlm_http->context,
"ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729");
http_context_set_pragma(ntlm_http->context, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729");
}
else if (channel == TSG_CHANNEL_OUT)
{
http_context_set_pragma(ntlm_http->context,
"ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729" ", "
http_context_set_pragma(ntlm_http->context, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729, "
"SessionId=fbd9c34f-397d-471d-a109-1b08cc554624");
}
}

View File

@ -33,6 +33,11 @@
#include <winpr/dsparse.h>
#include <openssl/rand.h>
#include <openssl/bio.h>
#ifdef HAVE_VALGRIND_MEMCHECK_H
#include <valgrind/memcheck.h>
#endif
#include "http.h"
#include "ntlm.h"
@ -235,80 +240,77 @@ BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* buffer, UINT32* offset, UINT32* l
{
UINT32 alloc_hint = 0;
rpcconn_hdr_t* header;
UINT32 frag_length;
UINT32 auth_length;
UINT32 auth_pad_length;
UINT32 sec_trailer_offset;
rpc_sec_trailer* sec_trailer;
*offset = RPC_COMMON_FIELDS_LENGTH;
header = ((rpcconn_hdr_t*) buffer);
if (header->common.ptype == PTYPE_RESPONSE)
switch (header->common.ptype)
{
*offset += 8;
rpc_offset_align(offset, 8);
alloc_hint = header->response.alloc_hint;
}
else if (header->common.ptype == PTYPE_REQUEST)
{
*offset += 4;
rpc_offset_align(offset, 8);
alloc_hint = header->request.alloc_hint;
}
else if (header->common.ptype == PTYPE_RTS)
{
*offset += 4;
}
else
{
return FALSE;
case PTYPE_RESPONSE:
*offset += 8;
rpc_offset_align(offset, 8);
alloc_hint = header->response.alloc_hint;
break;
case PTYPE_REQUEST:
*offset += 4;
rpc_offset_align(offset, 8);
alloc_hint = header->request.alloc_hint;
break;
case PTYPE_RTS:
*offset += 4;
break;
default:
fprintf(stderr, "%s: unknown ptype=0x%x\n", __FUNCTION__, header->common.ptype);
return FALSE;
}
if (length)
if (!length)
return TRUE;
if (header->common.ptype == PTYPE_REQUEST)
{
if (header->common.ptype == PTYPE_REQUEST)
{
UINT32 sec_trailer_offset;
UINT32 sec_trailer_offset;
sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8;
*length = sec_trailer_offset - *offset;
}
else
{
UINT32 frag_length;
UINT32 auth_length;
UINT32 auth_pad_length;
UINT32 sec_trailer_offset;
rpc_sec_trailer* sec_trailer;
sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8;
*length = sec_trailer_offset - *offset;
return TRUE;
}
frag_length = header->common.frag_length;
auth_length = header->common.auth_length;
sec_trailer_offset = frag_length - auth_length - 8;
sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset];
auth_pad_length = sec_trailer->auth_pad_length;
frag_length = header->common.frag_length;
auth_length = header->common.auth_length;
sec_trailer_offset = frag_length - auth_length - 8;
sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset];
auth_pad_length = sec_trailer->auth_pad_length;
#if 0
fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n",
sec_trailer->auth_type,
sec_trailer->auth_level,
sec_trailer->auth_pad_length,
sec_trailer->auth_reserved,
sec_trailer->auth_context_id);
fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n",
sec_trailer->auth_type,
sec_trailer->auth_level,
sec_trailer->auth_pad_length,
sec_trailer->auth_reserved,
sec_trailer->auth_context_id);
#endif
/**
* According to [MS-RPCE], auth_pad_length is the number of padding
* octets used to 4-byte align the security trailer, but in practice
* we get values up to 15, which indicates 16-byte alignment.
*/
/**
* According to [MS-RPCE], auth_pad_length is the number of padding
* octets used to 4-byte align the security trailer, but in practice
* we get values up to 15, which indicates 16-byte alignment.
*/
if ((frag_length - (sec_trailer_offset + 8)) != auth_length)
{
fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length,
(frag_length - (sec_trailer_offset + 8)));
}
*length = frag_length - auth_length - 24 - 8 - auth_pad_length;
}
if ((frag_length - (sec_trailer_offset + 8)) != auth_length)
{
fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length,
(frag_length - (sec_trailer_offset + 8)));
}
*length = frag_length - auth_length - 24 - 8 - auth_pad_length;
return TRUE;
}
@ -316,12 +318,23 @@ int rpc_out_read(rdpRpc* rpc, BYTE* data, int length)
{
int status;
status = tls_read(rpc->TlsOut, data, length);
status = BIO_read(rpc->TlsOut->bio, data, length);
/* fprintf(stderr, "%s: length=%d => status=%d shouldRetry=%d\n", __FUNCTION__, length,
* status, BIO_should_retry(rpc->TlsOut->bio)); */
if (status > 0) {
#ifdef HAVE_VALGRIND_MEMCHECK_H
VALGRIND_MAKE_MEM_DEFINED(data, status);
#endif
return status;
}
return status;
if (BIO_should_retry(rpc->TlsOut->bio))
return 0;
return -1;
}
int rpc_out_write(rdpRpc* rpc, BYTE* data, int length)
int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length)
{
int status;
@ -330,7 +343,7 @@ int rpc_out_write(rdpRpc* rpc, BYTE* data, int length)
return status;
}
int rpc_in_write(rdpRpc* rpc, BYTE* data, int length)
int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length)
{
int status;
@ -360,20 +373,21 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
ntlm = rpc->ntlm;
if ((!ntlm) || (!ntlm->table))
if (!ntlm || !ntlm->table)
{
fprintf(stderr, "rpc_write: invalid ntlm context\n");
fprintf(stderr, "%s: invalid ntlm context\n", __FUNCTION__);
return -1;
}
if (ntlm->table->QueryContextAttributes(&ntlm->context, SECPKG_ATTR_SIZES, &ntlm->ContextSizes) != SEC_E_OK)
{
fprintf(stderr, "QueryContextAttributes SECPKG_ATTR_SIZES failure\n");
fprintf(stderr, "%s: QueryContextAttributes SECPKG_ATTR_SIZES failure\n", __FUNCTION__);
return -1;
}
request_pdu = (rpcconn_request_hdr_t*) malloc(sizeof(rpcconn_request_hdr_t));
ZeroMemory(request_pdu, sizeof(rpcconn_request_hdr_t));
request_pdu = (rpcconn_request_hdr_t*) calloc(1, sizeof(rpcconn_request_hdr_t));
if (!request_pdu)
return -1;
rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) request_pdu);
@ -386,7 +400,11 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
request_pdu->opnum = opnum;
clientCall = rpc_client_call_new(request_pdu->call_id, request_pdu->opnum);
ArrayList_Add(rpc->client->ClientCallList, clientCall);
if (!clientCall)
goto out_free_pdu;
if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0)
goto out_free_clientCall;
if (request_pdu->opnum == TsProxySetupReceivePipeOpnum)
rpc->PipeCallId = request_pdu->call_id;
@ -407,8 +425,9 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
request_pdu->frag_length = offset;
buffer = (BYTE*) malloc(request_pdu->frag_length);
buffer = (BYTE*) calloc(1, request_pdu->frag_length);
if (!buffer)
goto out_free_pdu;
CopyMemory(buffer, request_pdu, 24);
offset = 24;
@ -427,15 +446,15 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
Buffers[0].cbBuffer = offset;
Buffers[1].cbBuffer = ntlm->ContextSizes.cbMaxSignature;
Buffers[1].pvBuffer = malloc(Buffers[1].cbBuffer);
ZeroMemory(Buffers[1].pvBuffer, Buffers[1].cbBuffer);
Buffers[1].pvBuffer = calloc(1, Buffers[1].cbBuffer);
if (!Buffers[1].pvBuffer)
return -1;
Message.cBuffers = 2;
Message.ulVersion = SECBUFFER_VERSION;
Message.pBuffers = (PSecBuffer) &Buffers;
encrypt_status = ntlm->table->EncryptMessage(&ntlm->context, 0, &Message, rpc->SendSeqNum++);
if (encrypt_status != SEC_E_OK)
{
fprintf(stderr, "EncryptMessage status: 0x%08X\n", encrypt_status);
@ -447,12 +466,18 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
offset += Buffers[1].cbBuffer;
free(Buffers[1].pvBuffer);
if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) != 0)
if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) < 0)
length = -1;
free(request_pdu);
return length;
out_free_clientCall:
rpc_client_call_free(clientCall);
out_free_pdu:
free(request_pdu);
return -1;
}
BOOL rpc_connect(rdpRpc* rpc)
@ -592,13 +617,17 @@ rdpRpc* rpc_new(rdpTransport* transport)
rpc->CallId = 2;
rpc_client_new(rpc);
if (rpc_client_new(rpc) < 0)
goto out_free_virtualConnectionCookieTable;
rpc->client->SynchronousSend = TRUE;
rpc->client->SynchronousReceive = TRUE;
return rpc;
out_free_virtualConnectionCookieTable:
rpc_client_free(rpc);
ArrayList_Free(rpc->VirtualConnectionCookieTable);
out_free_virtual_connection:
rpc_client_virtual_connection_free(rpc->VirtualConnection);
out_free_ntlm_http_out:

View File

@ -772,8 +772,8 @@ UINT32 rpc_offset_pad(UINT32* offset, UINT32 pad);
int rpc_out_read(rdpRpc* rpc, BYTE* data, int length);
int rpc_out_write(rdpRpc* rpc, BYTE* data, int length);
int rpc_in_write(rdpRpc* rpc, BYTE* data, int length);
int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length);
int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length);
BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* header, UINT32* offset, UINT32* length);

View File

@ -103,6 +103,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
DEBUG_RPC("Sending bind PDU");
rpc->ntlm = ntlm_new();
if (!rpc->ntlm)
return -1;
if ((!settings->GatewayPassword) || (!settings->GatewayUsername)
|| (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername)))
@ -129,17 +131,22 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
settings->Username = _strdup(settings->GatewayUsername);
settings->Domain = _strdup(settings->GatewayDomain);
settings->Password = _strdup(settings->GatewayPassword);
if (!settings->Username || !settings->Domain || settings->Password)
return -1;
}
}
}
ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL);
ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname);
if (!ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL) ||
!ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname) ||
!ntlm_authenticate(rpc->ntlm)
)
return -1;
ntlm_authenticate(rpc->ntlm);
bind_pdu = (rpcconn_bind_hdr_t*) malloc(sizeof(rpcconn_bind_hdr_t));
ZeroMemory(bind_pdu, sizeof(rpcconn_bind_hdr_t));
bind_pdu = (rpcconn_bind_hdr_t*) calloc(1, sizeof(rpcconn_bind_hdr_t));
if (!bind_pdu)
return -1;
rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) bind_pdu);
@ -159,6 +166,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
bind_pdu->p_context_elem.reserved2 = 0;
bind_pdu->p_context_elem.p_cont_elem = malloc(sizeof(p_cont_elem_t) * bind_pdu->p_context_elem.n_context_elem);
if (!bind_pdu->p_context_elem.p_cont_elem)
return -1;
p_cont_elem = &bind_pdu->p_context_elem.p_cont_elem[0];
@ -196,6 +205,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
bind_pdu->frag_length = offset;
buffer = (BYTE*) malloc(bind_pdu->frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, bind_pdu, 24);
CopyMemory(&buffer[24], &bind_pdu->p_context_elem, 4);
@ -214,7 +225,10 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
length = bind_pdu->frag_length;
clientCall = rpc_client_call_new(bind_pdu->call_id, 0);
ArrayList_Add(rpc->client->ClientCallList, clientCall);
if (!clientCall)
return -1;
if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0)
return -1;
if (rpc_send_enqueue_pdu(rpc, buffer, length) != 0)
length = -1;

View File

@ -34,9 +34,7 @@
#include <winpr/stream.h>
#include "rpc_fault.h"
#include "rpc_client.h"
#include "../rdp.h"
#define SYNCHRONOUS_TIMEOUT 5000
@ -69,8 +67,15 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc)
if (!pdu)
{
pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU));
pdu = (RPC_PDU *)malloc(sizeof(RPC_PDU));
if (!pdu)
return NULL;
pdu->s = Stream_New(NULL, rpc->max_recv_frag);
if (!pdu->s)
{
free(pdu);
return NULL;
}
}
pdu->CallId = 0;
@ -84,8 +89,7 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc)
int rpc_client_receive_pool_return(rdpRpc* rpc, RPC_PDU* pdu)
{
Queue_Enqueue(rpc->client->ReceivePool, pdu);
return 0;
return Queue_Enqueue(rpc->client->ReceivePool, pdu) == TRUE ? 0 : -1;
}
int rpc_client_on_fragment_received_event(rdpRpc* rpc)
@ -97,7 +101,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
rpcconn_hdr_t* header;
freerdp* instance;
instance = (freerdp*) rpc->transport->settings->instance;
instance = (freerdp *)rpc->transport->settings->instance;
if (!rpc->client->pdu)
rpc->client->pdu = rpc_client_receive_pool_take(rpc);
@ -125,34 +129,29 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
return 0;
}
if (header->common.ptype == PTYPE_RTS)
switch (header->common.ptype)
{
if (rpc->VirtualConnection->State >= VIRTUAL_CONNECTION_STATE_OPENED)
{
//fprintf(stderr, "Receiving Out-of-Sequence RTS PDU\n");
case PTYPE_RTS:
if (rpc->VirtualConnection->State < VIRTUAL_CONNECTION_STATE_OPENED)
{
fprintf(stderr, "%s: warning: unhandled RTS PDU\n", __FUNCTION__);
return 0;
}
fprintf(stderr, "%s: Receiving Out-of-Sequence RTS PDU\n", __FUNCTION__);
rts_recv_out_of_sequence_pdu(rpc, buffer, header->common.frag_length);
rpc_client_fragment_pool_return(rpc, fragment);
}
else
{
fprintf(stderr, "warning: unhandled RTS PDU\n");
}
return 0;
return 0;
}
else if (header->common.ptype == PTYPE_FAULT)
{
rpc_recv_fault_pdu(header);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1;
}
if (header->common.ptype != PTYPE_RESPONSE)
{
fprintf(stderr, "Unexpected RPC PDU type: %d\n", header->common.ptype);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1;
case PTYPE_FAULT:
rpc_recv_fault_pdu(header);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1;
case PTYPE_RESPONSE:
break;
default:
fprintf(stderr, "%s: unexpected RPC PDU type %d\n", __FUNCTION__, header->common.ptype);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1;
}
rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header->common.frag_length;
@ -160,7 +159,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
if (!rpc_get_stub_data_info(rpc, buffer, &StubOffset, &StubLength))
{
fprintf(stderr, "rpc_recv_pdu_fragment: expected stub\n");
fprintf(stderr, "%s: expected stub\n", __FUNCTION__);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1;
}
@ -196,7 +195,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
if (rpc->StubCallId != header->common.call_id)
{
fprintf(stderr, "invalid call_id: actual: %d, expected: %d, frag_count: %d\n",
fprintf(stderr, "%s: invalid call_id: actual: %d, expected: %d, frag_count: %d\n", __FUNCTION__,
rpc->StubCallId, header->common.call_id, rpc->StubFragCount);
}
@ -243,27 +242,34 @@ int rpc_client_on_read_event(rdpRpc* rpc)
int status = -1;
rpcconn_common_hdr_t* header;
if (!rpc->client->RecvFrag)
rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
position = Stream_GetPosition(rpc->client->RecvFrag);
if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
while (1)
{
status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag));
if (!rpc->client->RecvFrag)
rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
if (status < 0)
position = Stream_GetPosition(rpc->client->RecvFrag);
while (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
{
fprintf(stderr, "rpc_client_frag_read: error reading header\n");
return -1;
status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag));
if (status < 0)
{
fprintf(stderr, "rpc_client_frag_read: error reading header\n");
return -1;
}
if (!status)
return 0;
Stream_Seek(rpc->client->RecvFrag, status);
}
Stream_Seek(rpc->client->RecvFrag, status);
}
if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
return status;
if (Stream_GetPosition(rpc->client->RecvFrag) >= RPC_COMMON_FIELDS_LENGTH)
{
header = (rpcconn_common_hdr_t*) Stream_Buffer(rpc->client->RecvFrag);
if (header->frag_length > rpc->max_recv_frag)
@ -274,45 +280,44 @@ int rpc_client_on_read_event(rdpRpc* rpc)
return -1;
}
if (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length)
while (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length)
{
status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
header->frag_length - Stream_GetPosition(rpc->client->RecvFrag));
if (status < 0)
{
fprintf(stderr, "rpc_client_frag_read: error reading fragment body\n");
fprintf(stderr, "%s: error reading fragment body\n", __FUNCTION__);
return -1;
}
if (!status)
return 0;
Stream_Seek(rpc->client->RecvFrag, status);
}
}
else
{
return status;
}
if (status < 0)
return -1;
status = Stream_GetPosition(rpc->client->RecvFrag) - position;
if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length)
{
/* complete fragment received */
Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag);
Stream_SetPosition(rpc->client->RecvFrag, 0);
Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag);
rpc->client->RecvFrag = NULL;
if (rpc_client_on_fragment_received_event(rpc) < 0)
if (status < 0)
return -1;
status = Stream_GetPosition(rpc->client->RecvFrag) - position;
if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length)
{
/* complete fragment received */
Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag);
Stream_SetPosition(rpc->client->RecvFrag, 0);
Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag);
rpc->client->RecvFrag = NULL;
if (rpc_client_on_fragment_received_event(rpc) < 0)
return -1;
}
}
return status;
return 0;
}
/**
@ -349,13 +354,12 @@ RpcClientCall* rpc_client_call_new(UINT32 CallId, UINT32 OpNum)
RpcClientCall* clientCall;
clientCall = (RpcClientCall*) malloc(sizeof(RpcClientCall));
if (!clientCall)
return NULL;
if (clientCall)
{
clientCall->CallId = CallId;
clientCall->OpNum = OpNum;
clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS;
}
clientCall->CallId = CallId;
clientCall->OpNum = OpNum;
clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS;
return clientCall;
}
@ -371,16 +375,22 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
int status;
pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU));
pdu->s = Stream_New(buffer, length);
if (!pdu)
return -1;
Queue_Enqueue(rpc->client->SendQueue, pdu);
pdu->s = Stream_New(buffer, length);
if (!pdu->s)
goto out_free;
if (!Queue_Enqueue(rpc->client->SendQueue, pdu))
goto out_free_stream;
if (rpc->client->SynchronousSend)
{
status = WaitForSingleObject(rpc->client->PduSentEvent, SYNCHRONOUS_TIMEOUT);
if (status == WAIT_TIMEOUT)
{
fprintf(stderr, "rpc_send_enqueue_pdu: timed out waiting for pdu sent event\n");
fprintf(stderr, "%s: timed out waiting for pdu sent event %p\n", __FUNCTION__, rpc->client->PduSentEvent);
return -1;
}
@ -388,6 +398,12 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
}
return 0;
out_free_stream:
Stream_Free(pdu->s, TRUE);
out_free:
free(pdu);
return -1;
}
int rpc_send_dequeue_pdu(rdpRpc* rpc)
@ -396,13 +412,14 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
RPC_PDU* pdu;
RpcClientCall* clientCall;
rpcconn_common_hdr_t* header;
RpcInChannel *inChannel;
pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->SendQueue);
if (!pdu)
return 0;
WaitForSingleObject(rpc->VirtualConnection->DefaultInChannel->Mutex, INFINITE);
inChannel = rpc->VirtualConnection->DefaultInChannel;
WaitForSingleObject(inChannel->Mutex, INFINITE);
status = rpc_in_write(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s));
@ -410,7 +427,7 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
clientCall = rpc_client_call_find_by_id(rpc, header->call_id);
clientCall->State = RPC_CLIENT_CALL_STATE_DISPATCHED;
ReleaseMutex(rpc->VirtualConnection->DefaultInChannel->Mutex);
ReleaseMutex(inChannel->Mutex);
/*
* This protocol specifies that only RPC PDUs are subject to the flow control abstract
@ -421,8 +438,8 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
if (header->ptype == PTYPE_REQUEST)
{
rpc->VirtualConnection->DefaultInChannel->BytesSent += status;
rpc->VirtualConnection->DefaultInChannel->SenderAvailableWindow -= status;
inChannel->BytesSent += status;
inChannel->SenderAvailableWindow -= status;
}
Stream_Free(pdu->s, TRUE);
@ -440,57 +457,48 @@ RPC_PDU* rpc_recv_dequeue_pdu(rdpRpc* rpc)
DWORD dwMilliseconds;
DWORD result;
pdu = NULL;
dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0;
dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT * 4 : 0;
result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds);
if (result == WAIT_TIMEOUT)
{
fprintf(stderr, "rpc_recv_dequeue_pdu: timed out waiting for receive event\n");
fprintf(stderr, "%s: timed out waiting for receive event\n", __FUNCTION__);
return NULL;
}
if (result == WAIT_OBJECT_0)
{
pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->ReceiveQueue);
if (result != WAIT_OBJECT_0)
return NULL;
pdu = (RPC_PDU *)Queue_Dequeue(rpc->client->ReceiveQueue);
#ifdef WITH_DEBUG_TSG
if (pdu)
{
fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId);
winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s));
fprintf(stderr, "\n");
}
#endif
return pdu;
if (pdu)
{
fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId);
winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s));
fprintf(stderr, "\n");
}
else
{
fprintf(stderr, "Receiving a NULL PDU\n");
}
#endif
return pdu;
}
RPC_PDU* rpc_recv_peek_pdu(rdpRpc* rpc)
{
RPC_PDU* pdu;
DWORD dwMilliseconds;
DWORD result;
pdu = NULL;
dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0;
result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds);
if (result == WAIT_TIMEOUT)
{
if (result != WAIT_OBJECT_0)
return NULL;
}
if (result == WAIT_OBJECT_0)
{
pdu = (RPC_PDU*) Queue_Peek(rpc->client->ReceiveQueue);
return pdu;
}
return pdu;
return (RPC_PDU *)Queue_Peek(rpc->client->ReceiveQueue);
}
static void* rpc_client_thread(void* arg)
@ -500,40 +508,52 @@ static void* rpc_client_thread(void* arg)
DWORD nCount;
HANDLE events[3];
HANDLE ReadEvent;
int fd;
rpc = (rdpRpc*) arg;
fd = BIO_get_fd(rpc->TlsOut->bio, NULL);
ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, rpc->TlsOut->sockfd);
ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, fd);
nCount = 0;
events[nCount++] = rpc->client->StopEvent;
events[nCount++] = Queue_Event(rpc->client->SendQueue);
events[nCount++] = ReadEvent;
/* Do a first free run in case some bytes were set from the HTTP headers.
* We also have to do it because most of the time the underlying socket has notified,
* and the ssl layer has eaten all bytes, so we won't be notified any more even if the
* bytes are buffered locally
*/
if (rpc_client_on_read_event(rpc) < 0)
{
fprintf(stderr, "%s: an error occured when treating first packet\n", __FUNCTION__);
goto out;
}
while (rpc->transport->layer != TRANSPORT_LAYER_CLOSED)
{
status = WaitForMultipleObjects(nCount, events, FALSE, 100);
if (status != WAIT_TIMEOUT)
if (status == WAIT_TIMEOUT)
continue;
if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0)
break;
if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0)
{
if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0)
{
if (rpc_client_on_read_event(rpc) < 0)
break;
}
}
if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0)
{
if (rpc_client_on_read_event(rpc) < 0)
break;
}
if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0)
{
rpc_send_dequeue_pdu(rpc);
}
if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0)
{
rpc_send_dequeue_pdu(rpc);
}
}
out:
CloseHandle(ReadEvent);
return NULL;
@ -541,6 +561,9 @@ static void* rpc_client_thread(void* arg)
static void rpc_pdu_free(RPC_PDU* pdu)
{
if (!pdu)
return;
Stream_Free(pdu->s, TRUE);
free(pdu);
}
@ -554,35 +577,55 @@ int rpc_client_new(rdpRpc* rpc)
{
RpcClient* client = NULL;
client = (RpcClient*) calloc(1, sizeof(RpcClient));
if (client)
{
client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
client->SendQueue = Queue_New(TRUE, -1, -1);
Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->pdu = NULL;
client->ReceivePool = Queue_New(TRUE, -1, -1);
client->ReceiveQueue = Queue_New(TRUE, -1, -1);
Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->RecvFrag = NULL;
client->FragmentPool = Queue_New(TRUE, -1, -1);
client->FragmentQueue = Queue_New(TRUE, -1, -1);
Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
client->ClientCallList = ArrayList_New(TRUE);
ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free;
}
client = (RpcClient *)calloc(1, sizeof(RpcClient));
rpc->client = client;
if (!client)
return -1;
client->Thread = CreateThread(NULL, 0,
(LPTHREAD_START_ROUTINE) rpc_client_thread,
rpc, CREATE_SUSPENDED, NULL);
if (!client->Thread)
return -1;
client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
if (!client->StopEvent)
return -1;
client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
if (!client->PduSentEvent)
return -1;
client->SendQueue = Queue_New(TRUE, -1, -1);
if (!client->SendQueue)
return -1;
Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->pdu = NULL;
client->ReceivePool = Queue_New(TRUE, -1, -1);
if (!client->ReceivePool)
return -1;
Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->ReceiveQueue = Queue_New(TRUE, -1, -1);
if (!client->ReceiveQueue)
return -1;
Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->RecvFrag = NULL;
client->FragmentPool = Queue_New(TRUE, -1, -1);
if (!client->FragmentPool)
return -1;
Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
client->FragmentQueue = Queue_New(TRUE, -1, -1);
if (!client->FragmentQueue)
return -1;
Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
client->ClientCallList = ArrayList_New(TRUE);
if (!client->ClientCallList)
return -1;
ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free;
return 0;
}
@ -604,9 +647,7 @@ int rpc_client_stop(rdpRpc* rpc)
rpc->client->Thread = NULL;
}
rpc_client_free(rpc);
return 0;
return rpc_client_free(rpc);
}
int rpc_client_free(rdpRpc* rpc)
@ -615,31 +656,39 @@ int rpc_client_free(rdpRpc* rpc)
client = rpc->client;
if (client)
{
if (!client)
return 0;
if (client->SendQueue)
Queue_Free(client->SendQueue);
if (client->RecvFrag)
rpc_fragment_free(client->RecvFrag);
if (client->RecvFrag)
rpc_fragment_free(client->RecvFrag);
if (client->FragmentPool)
Queue_Free(client->FragmentPool);
if (client->FragmentQueue)
Queue_Free(client->FragmentQueue);
if (client->pdu)
rpc_pdu_free(client->pdu);
if (client->pdu)
rpc_pdu_free(client->pdu);
if (client->ReceivePool)
Queue_Free(client->ReceivePool);
if (client->ReceiveQueue)
Queue_Free(client->ReceiveQueue);
if (client->ClientCallList)
ArrayList_Free(client->ClientCallList);
if (client->StopEvent)
CloseHandle(client->StopEvent);
if (client->PduSentEvent)
CloseHandle(client->PduSentEvent);
if (client->Thread)
CloseHandle(client->Thread);
free(client);
}
free(client);
return 0;
}

View File

@ -93,25 +93,25 @@ BOOL rts_connect(rdpRpc* rpc)
if (!rpc_ntlm_http_out_connect(rpc))
{
fprintf(stderr, "rpc_out_connect_http error!\n");
fprintf(stderr, "%s: rpc_out_connect_http error!\n", __FUNCTION__);
return FALSE;
}
if (rts_send_CONN_A1_pdu(rpc) != 0)
{
fprintf(stderr, "rpc_send_CONN_A1_pdu error!\n");
fprintf(stderr, "%s: rpc_send_CONN_A1_pdu error!\n", __FUNCTION__);
return FALSE;
}
if (!rpc_ntlm_http_in_connect(rpc))
{
fprintf(stderr, "rpc_in_connect_http error!\n");
fprintf(stderr, "%s: rpc_in_connect_http error!\n", __FUNCTION__);
return FALSE;
}
if (rts_send_CONN_B1_pdu(rpc) != 0)
if (rts_send_CONN_B1_pdu(rpc) < 0)
{
fprintf(stderr, "rpc_send_CONN_B1_pdu error!\n");
fprintf(stderr, "%s: rpc_send_CONN_B1_pdu error!\n", __FUNCTION__);
return FALSE;
}
@ -147,10 +147,15 @@ BOOL rts_connect(rdpRpc* rpc)
*/
http_response = http_response_recv(rpc->TlsOut);
if (!http_response)
{
fprintf(stderr, "%s: unable to retrieve OUT Channel Response!\n", __FUNCTION__);
return FALSE;
}
if (http_response->StatusCode != HTTP_STATUS_OK)
{
fprintf(stderr, "rts_connect error! Status Code: %d\n", http_response->StatusCode);
fprintf(stderr, "%s: error! Status Code: %d\n", __FUNCTION__, http_response->StatusCode);
http_response_print(http_response);
http_response_free(http_response);
@ -170,6 +175,14 @@ BOOL rts_connect(rdpRpc* rpc)
return FALSE;
}
if (http_response->bodyLen)
{
/* inject bytes we have read in the body as a received packet for the RPC client */
rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
Stream_EnsureCapacity(rpc->client->RecvFrag, http_response->bodyLen);
CopyMemory(rpc->client->RecvFrag, http_response->BodyContent, http_response->bodyLen);
}
//http_response_print(http_response);
http_response_free(http_response);
@ -195,7 +208,6 @@ BOOL rts_connect(rdpRpc* rpc)
rpc_client_start(rpc);
pdu = rpc_recv_dequeue_pdu(rpc);
if (!pdu)
return FALSE;
@ -203,7 +215,7 @@ BOOL rts_connect(rdpRpc* rpc)
if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_A3_SIGNATURE, rts))
{
fprintf(stderr, "Unexpected RTS PDU: Expected CONN/A3\n");
fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/A3\n", __FUNCTION__);
return FALSE;
}
@ -236,7 +248,6 @@ BOOL rts_connect(rdpRpc* rpc)
*/
pdu = rpc_recv_dequeue_pdu(rpc);
if (!pdu)
return FALSE;
@ -244,7 +255,7 @@ BOOL rts_connect(rdpRpc* rpc)
if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_C2_SIGNATURE, rts))
{
fprintf(stderr, "Unexpected RTS PDU: Expected CONN/C2\n");
fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/C2\n", __FUNCTION__);
return FALSE;
}
@ -261,7 +272,7 @@ BOOL rts_connect(rdpRpc* rpc)
return TRUE;
}
#if defined WITH_DEBUG_RTS && 0
#ifdef WITH_DEBUG_RTS
static const char* const RTS_CMD_STRINGS[] =
{
@ -317,6 +328,7 @@ static const char* const RTS_CMD_STRINGS[] =
void rts_pdu_header_init(rpcconn_rts_hdr_t* header)
{
ZeroMemory(header, sizeof(*header));
header->rpc_vers = 5;
header->rpc_vers_minor = 0;
header->ptype = PTYPE_RTS;
@ -681,6 +693,8 @@ int rts_send_CONN_A1_pdu(rdpRpc* rpc)
ReceiveWindowSize = rpc->VirtualConnection->DefaultOutChannel->ReceiveWindow;
buffer = (BYTE*) malloc(header.frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_version_command_write(&buffer[20]); /* Version (8 bytes) */
@ -718,6 +732,7 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc)
BYTE* INChannelCookie;
BYTE* AssociationGroupId;
BYTE* VirtualConnectionCookie;
int status;
rts_pdu_header_init(&header);
header.frag_length = 104;
@ -734,6 +749,8 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc)
AssociationGroupId = (BYTE*) &(rpc->VirtualConnection->AssociationGroupId);
buffer = (BYTE*) malloc(header.frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_version_command_write(&buffer[20]); /* Version (8 bytes) */
@ -745,11 +762,11 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc)
length = header.frag_length;
rpc_in_write(rpc, buffer, length);
status = rpc_in_write(rpc, buffer, length);
free(buffer);
return 0;
return status;
}
/* CONN/C Sequence */
@ -795,12 +812,15 @@ int rts_send_keep_alive_pdu(rdpRpc* rpc)
DEBUG_RPC("Sending Keep-Alive RTS PDU");
buffer = (BYTE*) malloc(header.frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_client_keepalive_command_write(&buffer[20], rpc->CurrentKeepAliveInterval); /* ClientKeepAlive (8 bytes) */
length = header.frag_length;
rpc_in_write(rpc, buffer, length);
if (rpc_in_write(rpc, buffer, length) < 0)
return -1;
free(buffer);
return length;
@ -830,6 +850,8 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc)
rpc->VirtualConnection->DefaultOutChannel->AvailableWindowAdvertised;
buffer = (BYTE*) malloc(header.frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_destination_command_write(&buffer[20], FDOutProxy); /* Destination Command (8 bytes) */
@ -839,7 +861,8 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc)
length = header.frag_length;
rpc_in_write(rpc, buffer, length);
if (rpc_in_write(rpc, buffer, length) < 0)
return -1;
free(buffer);
return 0;
@ -923,12 +946,15 @@ int rts_send_ping_pdu(rdpRpc* rpc)
DEBUG_RPC("Sending Ping RTS PDU");
buffer = (BYTE*) malloc(header.frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
length = header.frag_length;
rpc_in_write(rpc, buffer, length);
if (rpc_in_write(rpc, buffer, length) < 0)
return -1;
free(buffer);
return length;
@ -1020,22 +1046,18 @@ int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
rts_extract_pdu_signature(rpc, &signature, rts);
SignatureId = rts_identify_pdu_signature(rpc, &signature, NULL);
if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK)
switch (SignatureId)
{
return rts_recv_flow_control_ack_pdu(rpc, buffer, length);
}
else if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION)
{
return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length);
}
else if (SignatureId == RTS_PDU_PING)
{
rts_send_ping_pdu(rpc);
}
else
{
fprintf(stderr, "Unimplemented signature id: 0x%08X\n", SignatureId);
rts_print_pdu_signature(rpc, &signature);
case RTS_PDU_FLOW_CONTROL_ACK:
return rts_recv_flow_control_ack_pdu(rpc, buffer, length);
case RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION:
return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length);
case RTS_PDU_PING:
return rts_send_ping_pdu(rpc);
default:
fprintf(stderr, "%s: unimplemented signature id: 0x%08X\n", __FUNCTION__, SignatureId);
rts_print_pdu_signature(rpc, &signature);
break;
}
return 0;

View File

@ -234,7 +234,6 @@ BOOL rts_match_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_rt
return FALSE;
status = rts_command_length(rpc, CommandType, &buffer[offset], length);
if (status < 0)
return FALSE;
@ -272,7 +271,6 @@ int rts_extract_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_r
signature->CommandTypes[i] = CommandType;
status = rts_command_length(rpc, CommandType, &buffer[offset], length);
if (status < 0)
return FALSE;
@ -294,22 +292,22 @@ UINT32 rts_identify_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, RTS_P
{
pSignature = RTS_PDU_SIGNATURE_TABLE[i].Signature;
if (signature->Flags == pSignature->Flags)
if (signature->Flags != pSignature->Flags)
continue;
if (signature->NumberOfCommands != pSignature->NumberOfCommands)
continue;
for (j = 0; j < signature->NumberOfCommands; j++)
{
if (signature->NumberOfCommands == pSignature->NumberOfCommands)
{
for (j = 0; j < signature->NumberOfCommands; j++)
{
if (signature->CommandTypes[j] != pSignature->CommandTypes[j])
continue;
}
if (entry)
*entry = &RTS_PDU_SIGNATURE_TABLE[i];
return RTS_PDU_SIGNATURE_TABLE[i].SignatureId;
}
if (signature->CommandTypes[j] != pSignature->CommandTypes[j])
continue;
}
if (entry)
*entry = &RTS_PDU_SIGNATURE_TABLE[i];
return RTS_PDU_SIGNATURE_TABLE[i].SignatureId;
}
return 0;

View File

@ -33,9 +33,9 @@
#include <winpr/stream.h>
#include "rpc_client.h"
#include "tsg.h"
/**
* RPC Functions: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378623/
* Remote Procedure Call: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378651/
@ -96,7 +96,9 @@ DWORD TsProxySendToServer(handle_t IDL_handle, byte pRpcMessage[], UINT32 count,
}
length = 28 + totalDataBytes;
buffer = (BYTE*) malloc(length);
buffer = (BYTE*) calloc(1, length);
if (!buffer)
return -1;
s = Stream_New(buffer, length);
@ -228,8 +230,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
if (!(pdu->Flags & RPC_PDU_FLAG_STUB))
buffer = &buffer[24];
packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET));
ZeroMemory(packet, sizeof(TSG_PACKET));
packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET));
if (!packet)
return FALSE;
offset = 4; // Skip Packet Pointer
packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */
@ -237,8 +240,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
if ((packet->packetId == TSG_PACKET_TYPE_CAPS_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_CAPS_RESPONSE))
{
packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) malloc(sizeof(TSG_PACKET_CAPS_RESPONSE));
ZeroMemory(packetCapsResponse, sizeof(TSG_PACKET_CAPS_RESPONSE));
packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) calloc(1, sizeof(TSG_PACKET_CAPS_RESPONSE));
if (!packetCapsResponse) // TODO: correct cleanup
return FALSE;
packet->tsgPacket.packetCapsResponse = packetCapsResponse;
/* PacketQuarResponsePtr (4 bytes) */
@ -258,8 +262,7 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
IsMessagePresent = *((UINT32*) &buffer[offset]);
offset += 4;
MessageSwitchValue = *((UINT32*) &buffer[offset]);
DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d",
IsMessagePresent, MessageSwitchValue);
DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d", IsMessagePresent, MessageSwitchValue);
offset += 4;
}
@ -289,8 +292,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
offset += 4;
}
versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS));
ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS));
versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS));
if (!versionCaps) // TODO: correct cleanup
return FALSE;
packetCapsResponse->pktQuarEncResponse.versionCaps = versionCaps;
versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */
@ -317,8 +321,10 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
/* 4-byte alignment */
rpc_offset_align(&offset, 4);
tsgCaps = (PTSG_PACKET_CAPABILITIES) malloc(sizeof(TSG_PACKET_CAPABILITIES));
ZeroMemory(tsgCaps, sizeof(TSG_PACKET_CAPABILITIES));
tsgCaps = (PTSG_PACKET_CAPABILITIES) calloc(1, sizeof(TSG_PACKET_CAPABILITIES));
if (!tsgCaps)
return FALSE;
versionCaps->tsgCaps = tsgCaps;
offset += 4; /* MaxCount (4 bytes) */
@ -406,8 +412,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
}
else if ((packet->packetId == TSG_PACKET_TYPE_QUARENC_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_QUARENC_RESPONSE))
{
packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) malloc(sizeof(TSG_PACKET_QUARENC_RESPONSE));
ZeroMemory(packetQuarEncResponse, sizeof(TSG_PACKET_QUARENC_RESPONSE));
packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) calloc(1, sizeof(TSG_PACKET_QUARENC_RESPONSE));
if (!packetQuarEncResponse) // TODO: handle cleanup
return FALSE;
packet->tsgPacket.packetQuarEncResponse = packetQuarEncResponse;
/* PacketQuarResponsePtr (4 bytes) */
@ -443,8 +450,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
offset += 4;
}
versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS));
ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS));
versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS));
if (!versionCaps) // TODO: handle cleanup
return FALSE;
packetQuarEncResponse->versionCaps = versionCaps;
versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */
@ -779,8 +787,9 @@ BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
if (!(pdu->Flags & RPC_PDU_FLAG_STUB))
buffer = &buffer[24];
packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET));
ZeroMemory(packet, sizeof(TSG_PACKET));
packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET));
if (!packet)
return FALSE;
offset = 4;
packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */
@ -923,6 +932,8 @@ BOOL TsProxyCreateChannelWriteRequest(rdpTsg* tsg, PTUNNEL_CONTEXT_HANDLE_NOSERI
length = 60 + (count * 2);
buffer = (BYTE*) malloc(length);
if (!buffer)
return FALSE;
/* TunnelContext */
handle = (CONTEXT_HANDLE*) tunnelContext;
@ -1526,48 +1537,53 @@ int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length)
return CopyLength;
}
else
tsg->pdu = rpc_recv_peek_pdu(rpc);
if (!tsg->pdu)
{
tsg->pdu = rpc_recv_peek_pdu(rpc);
if (!tsg->rpc->client->SynchronousReceive)
return 0;
if (!tsg->pdu)
{
if (tsg->rpc->client->SynchronousReceive)
return tsg_read(tsg, data, length);
else
return 0;
}
tsg->PendingPdu = TRUE;
tsg->BytesAvailable = Stream_Length(tsg->pdu->s);
tsg->BytesRead = 0;
CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable;
CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength);
tsg->BytesAvailable -= CopyLength;
tsg->BytesRead += CopyLength;
if (tsg->BytesAvailable < 1)
{
tsg->PendingPdu = FALSE;
rpc_recv_dequeue_pdu(rpc);
rpc_client_receive_pool_return(rpc, tsg->pdu);
}
return CopyLength;
// weird !!!!
return tsg_read(tsg, data, length);
}
tsg->PendingPdu = TRUE;
tsg->BytesAvailable = Stream_Length(tsg->pdu->s);
tsg->BytesRead = 0;
CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable;
CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength);
tsg->BytesAvailable -= CopyLength;
tsg->BytesRead += CopyLength;
if (tsg->BytesAvailable < 1)
{
tsg->PendingPdu = FALSE;
rpc_recv_dequeue_pdu(rpc);
rpc_client_receive_pool_return(rpc, tsg->pdu);
}
return CopyLength;
}
int tsg_write(rdpTsg* tsg, BYTE* data, UINT32 length)
{
int status;
if (tsg->rpc->transport->layer == TRANSPORT_LAYER_CLOSED)
{
fprintf(stderr, "tsg_write error: connection lost\n");
fprintf(stderr, "%s: error, connection lost\n", __FUNCTION__);
return -1;
}
return TsProxySendToServer((handle_t) tsg, data, 1, &length);
status = TsProxySendToServer((handle_t) tsg, data, 1, &length);
if (status < 0)
return -1;
return length;
}
BOOL tsg_set_blocking_mode(rdpTsg* tsg, BOOL blocking)
@ -1584,18 +1600,21 @@ rdpTsg* tsg_new(rdpTransport* transport)
{
rdpTsg* tsg;
tsg = (rdpTsg*) malloc(sizeof(rdpTsg));
ZeroMemory(tsg, sizeof(rdpTsg));
if (tsg != NULL)
{
tsg->transport = transport;
tsg->settings = transport->settings;
tsg->rpc = rpc_new(tsg->transport);
tsg->PendingPdu = FALSE;
}
tsg = (rdpTsg*) calloc(1, sizeof(rdpTsg));
if (!tsg)
return NULL;
tsg->transport = transport;
tsg->settings = transport->settings;
tsg->rpc = rpc_new(tsg->transport);
if (!tsg->rpc)
goto out_free;
tsg->PendingPdu = FALSE;
return tsg;
out_free:
free(tsg);
return NULL;
}
void tsg_free(rdpTsg* tsg)

View File

@ -52,13 +52,13 @@ static BOOL freerdp_peer_initialize(freerdp_peer* client)
fprintf(stderr, "%s: inavlid RDP key file %s\n", __FUNCTION__, settings->RdpKeyFile);
return FALSE;
}
if (settings->RdpServerRsaKey->ModulusLength > 256)
{
fprintf(stderr, "%s: Key sizes > 2048 are currently not supported for RDP security.\n", __FUNCTION__);
fprintf(stderr, "%s: Set a different key file than %s\n", __FUNCTION__, settings->RdpKeyFile);
exit(1);
}
}
return TRUE;
@ -77,12 +77,13 @@ static HANDLE freerdp_peer_get_event_handle(freerdp_peer* client)
return client->context->rdp->transport->TcpIn->event;
}
static BOOL freerdp_peer_check_fds(freerdp_peer* client)
static BOOL freerdp_peer_check_fds(freerdp_peer* peer)
{
int status;
rdpRdp* rdp;
rdp = client->context->rdp;
rdp = peer->context->rdp;
status = rdp_check_fds(rdp);
@ -413,6 +414,19 @@ static int freerdp_peer_send_channel_data(freerdp_peer* client, UINT16 channelId
return rdp_send_channel_data(client->context->rdp, channelId, data, size);
}
static BOOL freerdp_peer_is_write_blocked(freerdp_peer* peer)
{
return tranport_is_write_blocked(peer->context->rdp->transport);
}
static int freerdp_peer_drain_output_buffer(freerdp_peer* peer)
{
rdpTransport *transport = peer->context->rdp->transport;
return tranport_drain_output_buffer(transport);
}
void freerdp_peer_context_new(freerdp_peer* client)
{
rdpRdp* rdp;
@ -445,6 +459,9 @@ void freerdp_peer_context_new(freerdp_peer* client)
rdp->transport->ReceiveExtra = client;
transport_set_blocking_mode(rdp->transport, FALSE);
client->IsWriteBlocked = freerdp_peer_is_write_blocked;
client->DrainOutputBuffer = freerdp_peer_drain_output_buffer;
IFCALL(client->ContextNew, client, client->context);
}
@ -473,6 +490,8 @@ freerdp_peer* freerdp_peer_new(int sockfd)
client->Close = freerdp_peer_close;
client->Disconnect = freerdp_peer_disconnect;
client->SendChannelData = freerdp_peer_send_channel_data;
client->IsWriteBlocked = freerdp_peer_is_write_blocked;
client->DrainOutputBuffer = freerdp_peer_drain_output_buffer;
}
return client;
@ -480,10 +499,10 @@ freerdp_peer* freerdp_peer_new(int sockfd)
void freerdp_peer_free(freerdp_peer* client)
{
if (client)
{
rdp_free(client->context->rdp);
free(client->context);
free(client);
}
if (!client)
return;
rdp_free(client->context->rdp);
free(client->context);
free(client);
}

View File

@ -209,6 +209,7 @@ rdpSettings* freerdp_settings_new(DWORD flags)
ZeroMemory(settings, sizeof(rdpSettings));
settings->ServerMode = (flags & FREERDP_SETTINGS_SERVER_MODE) ? TRUE : FALSE;
settings->WaitForOutputBufferFlush = TRUE;
settings->DesktopWidth = 1024;
settings->DesktopHeight = 768;
@ -579,6 +580,7 @@ rdpSettings* freerdp_settings_clone(rdpSettings* settings)
/* BOOL values */
_settings->ServerMode = settings->ServerMode; /* 16 */
_settings->WaitForOutputBufferFlush = settings->WaitForOutputBufferFlush; /* 25 */
_settings->NetworkAutoDetect = settings->NetworkAutoDetect; /* 137 */
_settings->SupportAsymetricKeys = settings->SupportAsymetricKeys; /* 138 */
_settings->SupportErrorInfoPdu = settings->SupportErrorInfoPdu; /* 139 */

View File

@ -66,6 +66,165 @@
#include "tcp.h"
long transport_bio_buffered_callback(BIO* bio, int mode, const char* argp, int argi, long argl, long ret)
{
return 1;
}
static int transport_bio_buffered_write(BIO* bio, const char* buf, int num)
{
int status, ret;
rdpTcp *tcp = (rdpTcp *)bio->ptr;
int nchunks, committedBytes, i;
DataChunk chunks[2];
ret = num;
BIO_clear_retry_flags(bio);
tcp->writeBlocked = FALSE;
/* we directly append extra bytes in the xmit buffer, this could be prevented
* but for now it makes the code more simple.
*/
if (buf && num && !ringbuffer_write(&tcp->xmitBuffer, buf, num))
{
fprintf(stderr, "%s: an error occured when writing(toWrite=%d)\n", __FUNCTION__, num);
return -1;
}
committedBytes = 0;
nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer));
for (i = 0; i < nchunks; i++)
{
while (chunks[i].size)
{
status = BIO_write(bio->next_bio, chunks[i].data, chunks[i].size);
/*fprintf(stderr, "%s: i=%d/%d size=%d/%d status=%d retry=%d\n", __FUNCTION__, i, nchunks,
chunks[i].size, ringbuffer_used(&tcp->xmitBuffer), status,
BIO_should_retry(bio->next_bio)
);*/
if (status <= 0)
{
if (BIO_should_retry(bio->next_bio))
{
tcp->writeBlocked = TRUE;
goto out; /* EWOULDBLOCK */
}
/* any other is an error, but we still have to commit written bytes */
ret = -1;
goto out;
}
committedBytes += status;
chunks[i].size -= status;
chunks[i].data += status;
}
}
out:
ringbuffer_commit_read_bytes(&tcp->xmitBuffer, committedBytes);
return ret;
}
static int transport_bio_buffered_read(BIO* bio, char* buf, int size)
{
int status;
rdpTcp *tcp = (rdpTcp *)bio->ptr;
tcp->readBlocked = FALSE;
BIO_clear_retry_flags(bio);
status = BIO_read(bio->next_bio, buf, size);
/*fprintf(stderr, "%s: size=%d status=%d shouldRetry=%d\n", __FUNCTION__, size, status, BIO_should_retry(bio->next_bio)); */
if (status <= 0 && BIO_should_retry(bio->next_bio))
{
BIO_set_retry_read(bio);
tcp->readBlocked = TRUE;
}
return status;
}
static int transport_bio_buffered_puts(BIO* bio, const char* str)
{
return 1;
}
static int transport_bio_buffered_gets(BIO* bio, char* str, int size)
{
return 1;
}
static long transport_bio_buffered_ctrl(BIO* bio, int cmd, long arg1, void* arg2)
{
rdpTcp *tcp = (rdpTcp *)bio->ptr;
switch (cmd)
{
case BIO_CTRL_FLUSH:
return 1;
case BIO_CTRL_WPENDING:
return ringbuffer_used(&tcp->xmitBuffer);
case BIO_CTRL_PENDING:
return 0;
default:
/*fprintf(stderr, "%s: passing to next BIO, bio=%p cmd=%d arg1=%d arg2=%p\n", __FUNCTION__, bio, cmd, arg1, arg2); */
return BIO_ctrl(bio->next_bio, cmd, arg1, arg2);
}
return 0;
}
static int transport_bio_buffered_new(BIO* bio)
{
bio->init = 1;
bio->num = 0;
bio->ptr = NULL;
bio->flags = 0;
return 1;
}
static int transport_bio_buffered_free(BIO* bio)
{
return 1;
}
static BIO_METHOD transport_bio_buffered_socket_methods =
{
BIO_TYPE_BUFFERED,
"BufferedSocket",
transport_bio_buffered_write,
transport_bio_buffered_read,
transport_bio_buffered_puts,
transport_bio_buffered_gets,
transport_bio_buffered_ctrl,
transport_bio_buffered_new,
transport_bio_buffered_free,
NULL,
};
BIO_METHOD* BIO_s_buffered_socket(void)
{
return &transport_bio_buffered_socket_methods;
}
BOOL transport_bio_buffered_drain(BIO *bio)
{
rdpTcp *tcp = (rdpTcp *)bio->ptr;
int status;
if (!ringbuffer_used(&tcp->xmitBuffer))
return 1;
status = transport_bio_buffered_write(bio, NULL, 0);
return status >= 0;
}
void tcp_get_ip_address(rdpTcp* tcp)
{
BYTE* ip;
@ -136,62 +295,65 @@ BOOL tcp_connect(rdpTcp* tcp, const char* hostname, int port)
if (hostname[0] == '/')
{
tcp->sockfd = freerdp_uds_connect(hostname);
if (tcp->sockfd < 0)
return FALSE;
tcp->socketBio = BIO_new_fd(tcp->sockfd, 1);
if (!tcp->socketBio)
return FALSE;
}
else
{
tcp->sockfd = freerdp_tcp_connect(hostname, port);
if (tcp->sockfd < 0)
tcp->socketBio = BIO_new(BIO_s_connect());
if (!tcp->socketBio)
return FALSE;
SetEventFileDescriptor(tcp->event, tcp->sockfd);
if (BIO_set_conn_hostname(tcp->socketBio, hostname) < 0 || BIO_set_conn_int_port(tcp->socketBio, &port) < 0)
return FALSE;
tcp_get_ip_address(tcp);
tcp_get_mac_address(tcp);
if (BIO_do_connect(tcp->socketBio) <= 0)
return FALSE;
option_value = 1;
option_len = sizeof(option_value);
setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len);
/* receive buffer must be a least 32 K */
if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0)
{
if (option_value < (1024 * 32))
{
option_value = 1024 * 32;
option_len = sizeof(option_value);
setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len);
}
}
tcp_set_keep_alive_mode(tcp);
tcp->sockfd = BIO_get_fd(tcp->socketBio, NULL);
}
SetEventFileDescriptor(tcp->event, tcp->sockfd);
tcp_get_ip_address(tcp);
tcp_get_mac_address(tcp);
option_value = 1;
option_len = sizeof(option_value);
if (setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len) < 0)
fprintf(stderr, "%s: unable to set TCP_NODELAY\n", __FUNCTION__);
/* receive buffer must be a least 32 K */
if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0)
{
if (option_value < (1024 * 32))
{
option_value = 1024 * 32;
option_len = sizeof(option_value);
if (setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len) < 0)
{
fprintf(stderr, "%s: unable to set receive buffer len\n", __FUNCTION__);
return FALSE;
}
}
}
if (!tcp_set_keep_alive_mode(tcp))
return FALSE;
tcp->bufferedBio = BIO_new(BIO_s_buffered_socket());
if (!tcp->bufferedBio)
return FALSE;
tcp->bufferedBio->ptr = tcp;
tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio);
return TRUE;
}
int tcp_read(rdpTcp* tcp, BYTE* data, int length)
{
return freerdp_tcp_read(tcp->sockfd, data, length);
}
int tcp_write(rdpTcp* tcp, BYTE* data, int length)
{
return freerdp_tcp_write(tcp->sockfd, data, length);
}
int tcp_wait_read(rdpTcp* tcp)
{
return freerdp_tcp_wait_read(tcp->sockfd);
}
int tcp_wait_write(rdpTcp* tcp)
{
return freerdp_tcp_wait_write(tcp->sockfd);
}
BOOL tcp_disconnect(rdpTcp* tcp)
{
@ -209,7 +371,7 @@ BOOL tcp_set_blocking_mode(rdpTcp* tcp, BOOL blocking)
if (flags == -1)
{
fprintf(stderr, "tcp_set_blocking_mode: fcntl failed.\n");
fprintf(stderr, "%s: fcntl failed, %s.\n", __FUNCTION__, strerror(errno));
return FALSE;
}
@ -297,6 +459,31 @@ int tcp_attach(rdpTcp* tcp, int sockfd)
{
tcp->sockfd = sockfd;
SetEventFileDescriptor(tcp->event, tcp->sockfd);
ringbuffer_commit_read_bytes(&tcp->xmitBuffer, ringbuffer_used(&tcp->xmitBuffer));
if (tcp->socketBio)
{
if (BIO_set_fd(tcp->socketBio, sockfd, 1) < 0)
return -1;
}
else
{
tcp->socketBio = BIO_new_socket(sockfd, 1);
if (!tcp->socketBio)
return -1;
}
if (!tcp->bufferedBio)
{
tcp->bufferedBio = BIO_new(BIO_s_buffered_socket());
if (!tcp->bufferedBio)
return FALSE;
tcp->bufferedBio->ptr = tcp;
tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio);
}
return 0;
}
@ -316,25 +503,34 @@ rdpTcp* tcp_new(rdpSettings* settings)
{
rdpTcp* tcp;
tcp = (rdpTcp*) malloc(sizeof(rdpTcp));
tcp = (rdpTcp *)calloc(1, sizeof(rdpTcp));
if (!tcp)
return NULL;
if (tcp)
{
ZeroMemory(tcp, sizeof(rdpTcp));
if (!ringbuffer_init(&tcp->xmitBuffer, 0x10000))
goto out_free;
tcp->sockfd = -1;
tcp->settings = settings;
tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd);
}
tcp->sockfd = -1;
tcp->settings = settings;
tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd);
if (!tcp->event || tcp->event == INVALID_HANDLE_VALUE)
goto out_ringbuffer;
return tcp;
out_ringbuffer:
ringbuffer_destroy(&tcp->xmitBuffer);
out_free:
free(tcp);
return NULL;
}
void tcp_free(rdpTcp* tcp)
{
if (tcp)
{
CloseHandle(tcp->event);
free(tcp);
}
if (!tcp)
return;
ringbuffer_destroy(&tcp->xmitBuffer);
CloseHandle(tcp->event);
free(tcp);
}

View File

@ -31,10 +31,15 @@
#include <winpr/stream.h>
#include <winpr/winsock.h>
#include <freerdp/utils/ringbuffer.h>
#include <openssl/bio.h>
#ifndef MSG_NOSIGNAL
#define MSG_NOSIGNAL 0
#endif
#define BIO_TYPE_BUFFERED 66
typedef struct rdp_tcp rdpTcp;
struct rdp_tcp
@ -46,6 +51,12 @@ struct rdp_tcp
#ifdef _WIN32
WSAEVENT wsa_event;
#endif
BIO *socketBio;
BIO *bufferedBio;
RingBuffer xmitBuffer;
BOOL writeBlocked;
BOOL readBlocked;
HANDLE event;
};

View File

@ -33,7 +33,9 @@
#include <freerdp/error.h>
#include <freerdp/utils/tcp.h>
#include <freerdp/utils/ringbuffer.h>
#include <openssl/bio.h>
#include <time.h>
#include <errno.h>
#include <fcntl.h>
@ -41,6 +43,12 @@
#ifndef _WIN32
#include <netdb.h>
#include <sys/socket.h>
#include <sys/select.h>
#include <sys/time.h>
#endif
#ifdef HAVE_VALGRIND_MEMCHECK_H
#include <valgrind/memcheck.h>
#endif
#include "tpkt.h"
@ -48,6 +56,7 @@
#include "transport.h"
#include "rdp.h"
#define BUFFER_SIZE 16384
static void* transport_client_thread(void* arg);
@ -69,6 +78,7 @@ void transport_attach(rdpTransport* transport, int sockfd)
tcp_attach(transport->TcpIn, sockfd);
transport->SplitInputOutput = FALSE;
transport->TcpOut = transport->TcpIn;
transport->frontBio = transport->TcpIn->bufferedBio;
}
void transport_stop(rdpTransport* transport)
@ -98,18 +108,9 @@ BOOL transport_disconnect(rdpTransport* transport)
transport_stop(transport);
if (transport->layer == TRANSPORT_LAYER_TLS)
status &= tls_disconnect(transport->TlsIn);
if ((transport->layer == TRANSPORT_LAYER_TSG) || (transport->layer == TRANSPORT_LAYER_TSG_TLS))
{
status &= tsg_disconnect(transport->tsg);
}
else
{
status &= tcp_disconnect(transport->TcpIn);
}
BIO_free_all(transport->frontBio);
transport->frontBio = 0;
return status;
}
@ -131,16 +132,16 @@ static int transport_bio_tsg_write(BIO* bio, const char* buf, int num)
rdpTsg* tsg;
tsg = (rdpTsg*) bio->ptr;
status = tsg_write(tsg, (BYTE*) buf, num);
BIO_clear_retry_flags(bio);
status = tsg_write(tsg, (BYTE*) buf, num);
if (status > 0)
return status;
if (status == 0)
{
BIO_set_retry_write(bio);
}
return status < 0 ? 0 : num;
return -1;
}
static int transport_bio_tsg_read(BIO* bio, char* buf, int size)
@ -222,8 +223,13 @@ BIO_METHOD* BIO_s_tsg(void)
return &transport_bio_tsg_methods;
}
BOOL transport_connect_tls(rdpTransport* transport)
{
rdpSettings *settings = transport->settings;
rdpTls *targetTls;
BIO *targetBio;
int tls_status;
freerdp* instance;
rdpContext* context;
@ -234,61 +240,33 @@ BOOL transport_connect_tls(rdpTransport* transport)
if (transport->layer == TRANSPORT_LAYER_TSG)
{
transport->TsgTls = tls_new(transport->settings);
transport->TsgTls->methods = BIO_s_tsg();
transport->TsgTls->tsg = (void*) transport->tsg;
transport->layer = TRANSPORT_LAYER_TSG_TLS;
transport->TsgTls->hostname = transport->settings->ServerHostname;
transport->TsgTls->port = transport->settings->ServerPort;
targetTls = transport->TsgTls;
targetBio = transport->frontBio;
}
else
{
if (!transport->TlsIn)
transport->TlsIn = tls_new(settings);
if (transport->TsgTls->port == 0)
transport->TsgTls->port = 3389;
if (!transport->TlsOut)
transport->TlsOut = transport->TlsIn;
tls_status = tls_connect(transport->TsgTls);
targetTls = transport->TlsIn;
targetBio = transport->TcpIn->bufferedBio;
if (tls_status < 1)
{
if (tls_status < 0)
{
if (!connectErrorCode)
connectErrorCode = TLSCONNECTERROR;
if (!freerdp_get_last_error(context))
freerdp_set_last_error(context, FREERDP_ERROR_TLS_CONNECT_FAILED);
}
else
{
if (!freerdp_get_last_error(context))
freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED);
}
tls_free(transport->TsgTls);
transport->TsgTls = NULL;
return FALSE;
}
return TRUE;
transport->layer = TRANSPORT_LAYER_TLS;
}
if (!transport->TlsIn)
transport->TlsIn = tls_new(transport->settings);
if (!transport->TlsOut)
transport->TlsOut = transport->TlsIn;
targetTls->hostname = settings->ServerHostname;
targetTls->port = settings->ServerPort;
transport->layer = TRANSPORT_LAYER_TLS;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (targetTls->port == 0)
targetTls->port = 3389;
transport->TlsIn->hostname = transport->settings->ServerHostname;
transport->TlsIn->port = transport->settings->ServerPort;
if (transport->TlsIn->port == 0)
transport->TlsIn->port = 3389;
tls_status = tls_connect(transport->TlsIn);
tls_status = tls_connect(targetTls, targetBio);
if (tls_status < 1)
{
@ -306,13 +284,13 @@ BOOL transport_connect_tls(rdpTransport* transport)
freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED);
}
tls_free(transport->TlsIn);
if (transport->TlsIn == transport->TlsOut)
transport->TlsIn = transport->TlsOut = NULL;
else
transport->TlsIn = NULL;
return FALSE;
}
transport->frontBio = targetTls->bio;
if (!transport->frontBio)
{
fprintf(stderr, "%s: unable to prepend a filtering TLS bio");
return FALSE;
}
@ -323,6 +301,7 @@ BOOL transport_connect_nla(rdpTransport* transport)
{
freerdp* instance;
rdpSettings* settings;
rdpCredssp *credSsp;
settings = transport->settings;
instance = (freerdp*) settings->instance;
@ -338,16 +317,22 @@ BOOL transport_connect_nla(rdpTransport* transport)
if (!transport->credssp)
{
transport->credssp = credssp_new(instance, transport, settings);
if (!transport->credssp)
return FALSE;
transport_set_nla_mode(transport, TRUE);
if (settings->AuthenticationServiceClass)
{
transport->credssp->ServicePrincipalName =
credssp_make_spn(settings->AuthenticationServiceClass, settings->ServerHostname);
if (!transport->credssp->ServicePrincipalName)
return FALSE;
}
}
if (credssp_authenticate(transport->credssp) < 0)
credSsp = transport->credssp;
if (credssp_authenticate(credSsp) < 0)
{
if (!connectErrorCode)
connectErrorCode = AUTHENTICATIONERROR;
@ -361,14 +346,14 @@ BOOL transport_connect_nla(rdpTransport* transport)
"If credentials are valid, the NTLMSSP implementation may be to blame.\n");
transport_set_nla_mode(transport, FALSE);
credssp_free(transport->credssp);
credssp_free(credSsp);
transport->credssp = NULL;
return FALSE;
}
transport_set_nla_mode(transport, FALSE);
credssp_free(transport->credssp);
credssp_free(credSsp);
transport->credssp = NULL;
return TRUE;
@ -380,38 +365,41 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
int tls_status;
freerdp* instance;
rdpContext* context;
rdpSettings *settings = transport->settings;
instance = (freerdp*) transport->settings->instance;
context = instance->context;
tsg = tsg_new(transport);
if (!tsg)
return FALSE;
tsg->transport = transport;
transport->tsg = tsg;
transport->SplitInputOutput = TRUE;
if (!transport->TlsIn)
transport->TlsIn = tls_new(transport->settings);
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
transport->TlsIn->hostname = transport->settings->GatewayHostname;
transport->TlsIn->port = transport->settings->GatewayPort;
if (transport->TlsIn->port == 0)
transport->TlsIn->port = 443;
{
transport->TlsIn = tls_new(settings);
if (!transport->TlsIn)
return FALSE;
}
if (!transport->TlsOut)
transport->TlsOut = tls_new(transport->settings);
{
transport->TlsOut = tls_new(settings);
if (!transport->TlsOut)
return FALSE;
}
transport->TlsOut->sockfd = transport->TcpOut->sockfd;
transport->TlsOut->hostname = transport->settings->GatewayHostname;
transport->TlsOut->port = transport->settings->GatewayPort;
/* put a decent default value for gateway port */
if (!settings->GatewayPort)
settings->GatewayPort = 443;
if (transport->TlsOut->port == 0)
transport->TlsOut->port = 443;
transport->TlsIn->hostname = transport->TlsOut->hostname = settings->GatewayHostname;
transport->TlsIn->port = transport->TlsOut->port = settings->GatewayPort;
tls_status = tls_connect(transport->TlsIn);
tls_status = tls_connect(transport->TlsIn, transport->TcpIn->bufferedBio);
if (tls_status < 1)
{
if (tls_status < 0)
@ -428,8 +416,7 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
return FALSE;
}
tls_status = tls_connect(transport->TlsOut);
tls_status = tls_connect(transport->TlsOut, transport->TcpOut->bufferedBio);
if (tls_status < 1)
{
if (tls_status < 0)
@ -449,6 +436,8 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
if (!tsg_connect(tsg, hostname, port))
return FALSE;
transport->frontBio = BIO_new(BIO_s_tsg());
transport->frontBio->ptr = tsg;
return TRUE;
}
@ -462,15 +451,20 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por
if (transport->GatewayEnabled)
{
transport->layer = TRANSPORT_LAYER_TSG;
transport->SplitInputOutput = TRUE;
transport->TcpOut = tcp_new(settings);
status = tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort);
if (!tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort) ||
!tcp_set_blocking_mode(transport->TcpIn, FALSE))
return FALSE;
if (status)
status = tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort);
if (!tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort) ||
!tcp_set_blocking_mode(transport->TcpOut, FALSE))
return FALSE;
if (status)
status = transport_tsg_connect(transport, hostname, port);
if (!transport_tsg_connect(transport, hostname, port))
return FALSE;
status = TRUE;
}
else
{
@ -478,6 +472,7 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por
transport->SplitInputOutput = FALSE;
transport->TcpOut = transport->TcpIn;
transport->frontBio = transport->TcpIn->bufferedBio;
}
if (status)
@ -510,11 +505,11 @@ BOOL transport_accept_tls(rdpTransport* transport)
transport->TlsOut = transport->TlsIn;
transport->layer = TRANSPORT_LAYER_TLS;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile))
if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, transport->settings->CertificateFile, transport->settings->PrivateKeyFile))
return FALSE;
transport->frontBio = transport->TlsIn->bio;
return TRUE;
}
@ -533,10 +528,10 @@ BOOL transport_accept_nla(rdpTransport* transport)
transport->TlsOut = transport->TlsIn;
transport->layer = TRANSPORT_LAYER_TLS;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile))
if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, settings->CertificateFile, settings->PrivateKeyFile))
return FALSE;
transport->frontBio = transport->TlsIn->bio;
/* Network Level Authentication */
@ -630,56 +625,131 @@ UINT32 nla_header_length(wStream* s)
return length;
}
static int transport_wait_for_read(rdpTransport* transport)
{
struct timeval tv;
fd_set rset, wset;
fd_set *rsetPtr = NULL, *wsetPtr = NULL;
rdpTcp *tcpIn;
tcpIn = transport->TcpIn;
if (tcpIn->readBlocked)
{
rsetPtr = &rset;
FD_ZERO(rsetPtr);
FD_SET(tcpIn->sockfd, rsetPtr);
}
else if (tcpIn->writeBlocked)
{
wsetPtr = &wset;
FD_ZERO(wsetPtr);
FD_SET(tcpIn->sockfd, wsetPtr);
}
if (!wsetPtr && !rsetPtr)
{
USleep(1000);
return 0;
}
tv.tv_sec = 0;
tv.tv_usec = 1000;
return select(tcpIn->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv);
}
static int transport_wait_for_write(rdpTransport* transport)
{
struct timeval tv;
fd_set rset, wset;
fd_set *rsetPtr = NULL, *wsetPtr = NULL;
rdpTcp *tcpOut;
tcpOut = transport->SplitInputOutput ? transport->TcpOut : transport->TcpIn;
if (tcpOut->writeBlocked)
{
wsetPtr = &wset;
FD_ZERO(wsetPtr);
FD_SET(tcpOut->sockfd, wsetPtr);
}
else if (tcpOut->readBlocked)
{
rsetPtr = &rset;
FD_ZERO(rsetPtr);
FD_SET(tcpOut->sockfd, rsetPtr);
}
if (!wsetPtr && !rsetPtr)
{
USleep(1000);
return 0;
}
tv.tv_sec = 0;
tv.tv_usec = 1000;
return select(tcpOut->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv);
}
int transport_read_layer(rdpTransport* transport, BYTE* data, int bytes)
{
int read = 0;
int status = -1;
while (read < bytes)
{
if (transport->layer == TRANSPORT_LAYER_TLS)
status = tls_read(transport->TlsIn, data + read, bytes - read);
else if (transport->layer == TRANSPORT_LAYER_TCP)
status = tcp_read(transport->TcpIn, data + read, bytes - read);
else if (transport->layer == TRANSPORT_LAYER_TSG)
status = tsg_read(transport->tsg, data + read, bytes - read);
else if (transport->layer == TRANSPORT_LAYER_TSG_TLS) {
status = tls_read(transport->TsgTls, data + read, bytes - read);
status = BIO_read(transport->frontBio, data + read, bytes - read);
if (!status)
{
transport->layer = TRANSPORT_LAYER_CLOSED;
return -1;
}
/* blocking means that we can't continue until this is read */
if (!transport->blocking)
return status;
if (status < 0)
{
/* A read error indicates that the peer has dropped the connection */
transport->layer = TRANSPORT_LAYER_CLOSED;
return status;
if (!BIO_should_retry(transport->frontBio))
{
/* something unexpected happened, let's close */
transport->layer = TRANSPORT_LAYER_CLOSED;
return -1;
}
/* non blocking will survive a partial read */
if (!transport->blocking)
return read;
/* blocking means that we can't continue until we have read the number of
* requested bytes */
if (transport_wait_for_read(transport) < 0)
{
fprintf(stderr, "%s: error when selecting for read\n", __FUNCTION__);
return -1;
}
continue;
}
#ifdef HAVE_VALGRIND_MEMCHECK_H
VALGRIND_MAKE_MEM_DEFINED(data + read, bytes - read);
#endif
read += status;
if (status == 0)
{
/*
* instead of sleeping, we should wait timeout on the
* socket but this only happens on initial connection
*/
USleep(transport->SleepInterval);
}
}
return read;
}
int transport_read(rdpTransport* transport, wStream* s)
{
int status;
int position;
int pduLength;
BYTE header[4];
BYTE *header;
int transport_status;
position = 0;
@ -710,7 +780,7 @@ int transport_read(rdpTransport* transport, wStream* s)
position += status;
}
CopyMemory(header, Stream_Buffer(s), 4); /* peek at first 4 bytes */
header = Stream_Buffer(s);
/* if header is present, read exactly one PDU */
@ -802,6 +872,8 @@ static int transport_read_nonblocking(rdpTransport* transport)
return status;
}
BOOL transport_bio_buffered_drain(BIO *bio);
int transport_write(rdpTransport* transport, wStream* s)
{
int length;
@ -827,36 +899,48 @@ int transport_write(rdpTransport* transport, wStream* s)
while (length > 0)
{
if (transport->layer == TRANSPORT_LAYER_TLS)
status = tls_write(transport->TlsOut, Stream_Pointer(s), length);
else if (transport->layer == TRANSPORT_LAYER_TCP)
status = tcp_write(transport->TcpOut, Stream_Pointer(s), length);
else if (transport->layer == TRANSPORT_LAYER_TSG)
status = tsg_write(transport->tsg, Stream_Pointer(s), length);
else if (transport->layer == TRANSPORT_LAYER_TSG_TLS)
status = tls_write(transport->TsgTls, Stream_Pointer(s), length);
status = BIO_write(transport->frontBio, Stream_Pointer(s), length);
if (status < 0)
break; /* error occurred */
if (status == 0)
if (status <= 0)
{
/* when sending is blocked in nonblocking mode, the receiving buffer should be checked */
if (!transport->blocking)
{
/* and in case we do have buffered some data, we set the event so next loop will get it */
if (transport_read_nonblocking(transport) > 0)
SetEvent(transport->ReceiveEvent);
}
/* the buffered BIO that is at the end of the chain always says OK for writing,
* so a retry means that for any reason we need to read. The most probable
* is a SSL or TSG BIO in the chain.
*/
if (!BIO_should_retry(transport->frontBio))
return status;
if (transport->layer == TRANSPORT_LAYER_TLS)
tls_wait_write(transport->TlsOut);
else if (transport->layer == TRANSPORT_LAYER_TCP)
tcp_wait_write(transport->TcpOut);
else if (transport->layer == TRANSPORT_LAYER_TSG_TLS)
tls_wait_write(transport->TsgTls);
else
USleep(transport->SleepInterval);
/* non-blocking can live with blocked IOs */
if (!transport->blocking)
return status;
if (transport_wait_for_write(transport) < 0)
{
fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__);
return -1;
}
continue;
}
if (transport->blocking || transport->settings->WaitForOutputBufferFlush)
{
/* blocking transport, we must ensure the write buffer is really empty */
rdpTcp *out = transport->TcpOut;
while (out->writeBlocked)
{
if (transport_wait_for_write(transport) < 0)
{
fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__);
return -1;
}
if (!transport_bio_buffered_drain(out->bufferedBio))
{
fprintf(stderr, "%s: error when draining outputBuffer\n", __FUNCTION__);
return -1;
}
}
}
length -= status;
@ -945,6 +1029,38 @@ void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD*
}
}
BOOL tranport_is_write_blocked(rdpTransport* transport)
{
if (transport->TcpIn->writeBlocked)
return TRUE;
return transport->SplitInputOutput &&
transport->TcpOut &&
transport->TcpOut->writeBlocked;
}
int tranport_drain_output_buffer(rdpTransport* transport)
{
BOOL ret = FALSE;
/* First try to send some accumulated bytes in the send buffer */
if (transport->TcpIn->writeBlocked)
{
if (!transport_bio_buffered_drain(transport->TcpIn->bufferedBio))
return -1;
ret |= transport->TcpIn->writeBlocked;
}
if (transport->SplitInputOutput && transport->TcpOut && transport->TcpOut->writeBlocked)
{
if (!transport_bio_buffered_drain(transport->TcpOut->bufferedBio))
return -1;
ret |= transport->TcpOut->writeBlocked;
}
return ret;
}
int transport_check_fds(rdpTransport* transport)
{
int pos;
@ -1079,15 +1195,14 @@ int transport_check_fds(rdpTransport* transport)
recv_status = transport->ReceiveCallback(transport, received, transport->ReceiveExtra);
Stream_Release(received);
if (recv_status < 0)
return -1;
if (recv_status == 1)
{
return 1; /* session redirection */
}
Stream_Release(received);
if (recv_status < 0)
return -1;
}
return 0;
@ -1198,80 +1313,107 @@ rdpTransport* transport_new(rdpSettings* settings)
{
rdpTransport* transport;
transport = (rdpTransport*) malloc(sizeof(rdpTransport));
transport = (rdpTransport *)calloc(1, sizeof(rdpTransport));
if (!transport)
return NULL;
if (transport)
{
ZeroMemory(transport, sizeof(rdpTransport));
WLog_Init();
transport->log = WLog_Get("com.freerdp.core.transport");
if (!transport->log)
goto out_free;
WLog_Init();
transport->log = WLog_Get("com.freerdp.core.transport");
transport->TcpIn = tcp_new(settings);
if (!transport->TcpIn)
goto out_free;
transport->TcpIn = tcp_new(settings);
transport->settings = settings;
transport->settings = settings;
/* a small 0.1ms delay when transport is blocking. */
transport->SleepInterval = 100;
/* a small 0.1ms delay when transport is blocking. */
transport->SleepInterval = 100;
transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE);
if (!transport->ReceivePool)
goto out_free_tcpin;
transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE);
/* receive buffer for non-blocking read. */
transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0);
if (!transport->ReceiveBuffer)
goto out_free_receivepool;
/* receive buffer for non-blocking read. */
transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0);
transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
if (!transport->ReceiveEvent || transport->ReceiveEvent == INVALID_HANDLE_VALUE)
goto out_free_receivebuffer;
transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
if (!transport->connectedEvent || transport->connectedEvent == INVALID_HANDLE_VALUE)
goto out_free_receiveEvent;
transport->blocking = TRUE;
transport->GatewayEnabled = FALSE;
transport->blocking = TRUE;
transport->GatewayEnabled = FALSE;
transport->layer = TRANSPORT_LAYER_TCP;
InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000);
InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000);
transport->layer = TRANSPORT_LAYER_TCP;
}
if (!InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000))
goto out_free_connectedEvent;
if (!InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000))
goto out_free_readlock;
return transport;
out_free_readlock:
DeleteCriticalSection(&(transport->ReadLock));
out_free_connectedEvent:
CloseHandle(transport->connectedEvent);
out_free_receiveEvent:
CloseHandle(transport->ReceiveEvent);
out_free_receivebuffer:
StreamPool_Return(transport->ReceivePool, transport->ReceiveBuffer);
out_free_receivepool:
StreamPool_Free(transport->ReceivePool);
out_free_tcpin:
tcp_free(transport->TcpIn);
out_free:
free(transport);
return NULL;
}
void transport_free(rdpTransport* transport)
{
if (transport)
{
transport_stop(transport);
if (!transport)
return;
if (transport->ReceiveBuffer)
Stream_Release(transport->ReceiveBuffer);
transport_stop(transport);
StreamPool_Free(transport->ReceivePool);
if (transport->ReceiveBuffer)
Stream_Release(transport->ReceiveBuffer);
CloseHandle(transport->ReceiveEvent);
CloseHandle(transport->connectedEvent);
StreamPool_Free(transport->ReceivePool);
if (transport->TlsIn)
tls_free(transport->TlsIn);
CloseHandle(transport->ReceiveEvent);
CloseHandle(transport->connectedEvent);
if (transport->TlsOut != transport->TlsIn)
tls_free(transport->TlsOut);
if (transport->TlsIn)
tls_free(transport->TlsIn);
transport->TlsIn = NULL;
transport->TlsOut = NULL;
if (transport->TlsOut != transport->TlsIn)
tls_free(transport->TlsOut);
if (transport->TcpIn)
tcp_free(transport->TcpIn);
transport->TlsIn = NULL;
transport->TlsOut = NULL;
if (transport->TcpOut != transport->TcpIn)
tcp_free(transport->TcpOut);
if (transport->TcpIn)
tcp_free(transport->TcpIn);
transport->TcpIn = NULL;
transport->TcpOut = NULL;
if (transport->TcpOut != transport->TcpIn)
tcp_free(transport->TcpOut);
tsg_free(transport->tsg);
transport->tsg = NULL;
transport->TcpIn = NULL;
transport->TcpOut = NULL;
DeleteCriticalSection(&(transport->ReadLock));
DeleteCriticalSection(&(transport->WriteLock));
tsg_free(transport->tsg);
transport->tsg = NULL;
free(transport);
}
DeleteCriticalSection(&(transport->ReadLock));
DeleteCriticalSection(&(transport->WriteLock));
free(transport);
}

View File

@ -49,11 +49,13 @@ typedef struct rdp_transport rdpTransport;
#include <freerdp/types.h>
#include <freerdp/settings.h>
typedef int (*TransportRecv) (rdpTransport* transport, wStream* stream, void* extra);
struct rdp_transport
{
TRANSPORT_LAYER layer;
BIO *frontBio;
rdpTsg* tsg;
rdpTcp* TcpIn;
rdpTcp* TcpOut;
@ -102,6 +104,8 @@ BOOL transport_set_blocking_mode(rdpTransport* transport, BOOL blocking);
void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled);
void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode);
void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD* count);
BOOL tranport_is_write_blocked(rdpTransport* transport);
BOOL tranport_drain_output_buffer(rdpTransport* transport);
wStream* transport_receive_pool_take(rdpTransport* transport);
int transport_receive_pool_return(rdpTransport* transport, wStream* pdu);

View File

@ -28,34 +28,35 @@
#include <winpr/stream.h>
#include <freerdp/utils/tcp.h>
#include <freerdp/utils/ringbuffer.h>
#include <freerdp/crypto/tls.h>
#ifdef HAVE_VALGRIND_MEMCHECK_H
#include <valgrind/memcheck.h>
#endif
#include "../core/tcp.h"
static CryptoCert tls_get_certificate(rdpTls* tls, BOOL peer)
{
CryptoCert cert;
X509* server_cert;
X509* remote_cert;
if (peer)
server_cert = SSL_get_peer_certificate(tls->ssl);
remote_cert = SSL_get_peer_certificate(tls->ssl);
else
server_cert = SSL_get_certificate(tls->ssl);
remote_cert = SSL_get_certificate(tls->ssl);
if (!server_cert)
if (!remote_cert)
{
fprintf(stderr, "tls_get_certificate: failed to get the server TLS certificate\n");
cert = NULL;
}
else
{
cert = malloc(sizeof(*cert));
cert->px509 = server_cert;
fprintf(stderr, "%s: failed to get the server TLS certificate\n", __FUNCTION__);
return NULL;
}
cert = malloc(sizeof(*cert));
if (!cert)
{
X509_free(remote_cert);
return NULL;
}
cert->px509 = remote_cert;
return cert;
}
@ -83,12 +84,14 @@ SecPkgContext_Bindings* tls_get_channel_bindings(X509* cert)
PrefixLength = strlen(TLS_SERVER_END_POINT);
ChannelBindingTokenLength = PrefixLength + CertificateHashLength;
ContextBindings = (SecPkgContext_Bindings*) malloc(sizeof(SecPkgContext_Bindings));
ZeroMemory(ContextBindings, sizeof(SecPkgContext_Bindings));
ContextBindings = (SecPkgContext_Bindings*) calloc(1, sizeof(SecPkgContext_Bindings));
if (!ContextBindings)
return NULL;
ContextBindings->BindingsLength = sizeof(SEC_CHANNEL_BINDINGS) + ChannelBindingTokenLength;
ChannelBindings = (SEC_CHANNEL_BINDINGS*) malloc(ContextBindings->BindingsLength);
ZeroMemory(ChannelBindings, ContextBindings->BindingsLength);
ChannelBindings = (SEC_CHANNEL_BINDINGS*) calloc(1, ContextBindings->BindingsLength);
if (!ChannelBindings)
goto out_free;
ContextBindings->Bindings = ChannelBindings;
ChannelBindings->cbApplicationDataLength = ChannelBindingTokenLength;
@ -99,32 +102,121 @@ SecPkgContext_Bindings* tls_get_channel_bindings(X509* cert)
CopyMemory(&ChannelBindingToken[PrefixLength], CertificateHash, CertificateHashLength);
return ContextBindings;
out_free:
free(ContextBindings);
return NULL;
}
static void tls_ssl_info_callback(const SSL* ssl, int type, int val)
BOOL tls_prepare(rdpTls* tls, BIO *underlying, const SSL_METHOD *method, int options, BOOL clientMode)
{
if (type & SSL_CB_HANDSHAKE_START)
{
}
}
int tls_connect(rdpTls* tls)
{
CryptoCert cert;
long options = 0;
int verify_status;
int connection_status;
tls->ctx = SSL_CTX_new(TLSv1_client_method());
tls->ctx = SSL_CTX_new(method);
if (!tls->ctx)
{
fprintf(stderr, "SSL_CTX_new failed\n");
fprintf(stderr, "%s: SSL_CTX_new failed\n", __FUNCTION__);
return FALSE;
}
SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE);
SSL_CTX_set_options(tls->ctx, options);
SSL_CTX_set_read_ahead(tls->ctx, 1);
tls->bio = BIO_new_ssl(tls->ctx, clientMode);
if (BIO_get_ssl(tls->bio, &tls->ssl) < 0)
{
fprintf(stderr, "%s: unable to retrieve the SSL of the connection\n", __FUNCTION__);
return FALSE;
}
BIO_push(tls->bio, underlying);
return TRUE;
}
int tls_do_handshake(rdpTls* tls, BOOL clientMode)
{
CryptoCert cert;
int verify_status, status;
do
{
struct timeval tv;
fd_set rset;
int fd;
status = BIO_do_handshake(tls->bio);
if (status == 1)
break;
if (!BIO_should_retry(tls->bio))
return -1;
/* we select() only for read even if we should test both read and write
* depending of what have blocked */
FD_ZERO(&rset);
fd = BIO_get_fd(tls->bio, NULL);
if (fd < 0)
{
fprintf(stderr, "%s: unable to retrieve BIO fd\n", __FUNCTION__);
return -1;
}
FD_SET(fd, &rset);
tv.tv_sec = 0;
tv.tv_usec = 10 * 1000; /* 10ms */
status = select(fd + 1, &rset, NULL, NULL, &tv);
if (status < 0)
{
fprintf(stderr, "%s: error during select()\n", __FUNCTION__);
return -1;
}
}
while (TRUE);
if (!clientMode)
return 1;
cert = tls_get_certificate(tls, clientMode);
if (!cert)
{
fprintf(stderr, "%s: tls_get_certificate failed to return the server certificate.\n", __FUNCTION__);
return -1;
}
//SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE);
tls->Bindings = tls_get_channel_bindings(cert->px509);
if (!tls->Bindings)
{
fprintf(stderr, "%s: unable to retrieve bindings\n", __FUNCTION__);
return -1;
}
if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength))
{
fprintf(stderr, "%s: crypto_cert_get_public_key failed to return the server public key.\n", __FUNCTION__);
tls_free_certificate(cert);
return -1;
}
verify_status = tls_verify_certificate(tls, cert, tls->hostname, tls->port);
if (verify_status < 1)
{
fprintf(stderr, "%s: certificate not trusted, aborting.\n", __FUNCTION__);
tls_disconnect(tls);
tls_free_certificate(cert);
return 0;
}
tls_free_certificate(cert);
return verify_status;
}
int tls_connect(rdpTls* tls, BIO *underlying)
{
int options = 0;
/**
* SSL_OP_NO_COMPRESSION:
@ -138,7 +230,7 @@ int tls_connect(rdpTls* tls)
#ifdef SSL_OP_NO_COMPRESSION
options |= SSL_OP_NO_COMPRESSION;
#endif
/**
* SSL_OP_TLS_BLOCK_PADDING_BUG:
*
@ -155,96 +247,19 @@ int tls_connect(rdpTls* tls)
*/
options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
SSL_CTX_set_options(tls->ctx, options);
if (!tls_prepare(tls, underlying, TLSv1_client_method(), options, TRUE))
return FALSE;
tls->ssl = SSL_new(tls->ctx);
if (!tls->ssl)
{
fprintf(stderr, "SSL_new failed\n");
return -1;
}
if (tls->tsg)
{
tls->bio = BIO_new(tls->methods);
if (!tls->bio)
{
fprintf(stderr, "BIO_new failed\n");
return -1;
}
tls->bio->ptr = tls->tsg;
SSL_set_bio(tls->ssl, tls->bio, tls->bio);
SSL_CTX_set_info_callback(tls->ctx, tls_ssl_info_callback);
}
else
{
if (SSL_set_fd(tls->ssl, tls->sockfd) < 1)
{
fprintf(stderr, "SSL_set_fd failed\n");
return -1;
}
}
connection_status = SSL_connect(tls->ssl);
if (connection_status <= 0)
{
if (tls_print_error("SSL_connect", tls->ssl, connection_status))
{
return -1;
}
}
cert = tls_get_certificate(tls, TRUE);
if (!cert)
{
fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n");
return -1;
}
tls->Bindings = tls_get_channel_bindings(cert->px509);
if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength))
{
fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n");
tls_free_certificate(cert);
return -1;
}
verify_status = tls_verify_certificate(tls, cert, tls->hostname, tls->port);
if (verify_status < 1)
{
fprintf(stderr, "tls_connect: certificate not trusted, aborting.\n");
tls_disconnect(tls);
}
tls_free_certificate(cert);
return verify_status;
return tls_do_handshake(tls, TRUE);
}
BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file)
BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file)
{
CryptoCert cert;
long options = 0;
int connection_status;
tls->ctx = SSL_CTX_new(SSLv23_server_method());
if (tls->ctx == NULL)
{
fprintf(stderr, "SSL_CTX_new failed\n");
return FALSE;
}
/*
/**
* SSL_OP_NO_SSLv2:
*
* We only want SSLv3 and TLSv1, so disable SSLv2.
@ -281,80 +296,23 @@ BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file)
*/
options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
SSL_CTX_set_options(tls->ctx, options);
if (SSL_CTX_use_RSAPrivateKey_file(tls->ctx, privatekey_file, SSL_FILETYPE_PEM) <= 0)
{
fprintf(stderr, "SSL_CTX_use_RSAPrivateKey_file failed\n");
fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file);
if (!tls_prepare(tls, underlying, SSLv23_server_method(), options, FALSE))
return FALSE;
}
tls->ssl = SSL_new(tls->ctx);
if (!tls->ssl)
if (SSL_use_RSAPrivateKey_file(tls->ssl, privatekey_file, SSL_FILETYPE_PEM) <= 0)
{
fprintf(stderr, "SSL_new failed\n");
fprintf(stderr, "%s: SSL_CTX_use_RSAPrivateKey_file failed\n", __FUNCTION__);
fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file);
return FALSE;
}
if (SSL_use_certificate_file(tls->ssl, cert_file, SSL_FILETYPE_PEM) <= 0)
{
fprintf(stderr, "SSL_use_certificate_file failed\n");
fprintf(stderr, "%s: SSL_use_certificate_file failed\n", __FUNCTION__);
return FALSE;
}
if (SSL_set_fd(tls->ssl, tls->sockfd) < 1)
{
fprintf(stderr, "SSL_set_fd failed\n");
return FALSE;
}
while (1)
{
connection_status = SSL_accept(tls->ssl);
if (connection_status <= 0)
{
switch (SSL_get_error(tls->ssl, connection_status))
{
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
break;
default:
if (tls_print_error("SSL_accept", tls->ssl, connection_status))
return FALSE;
break;
}
}
else
{
break;
}
}
cert = tls_get_certificate(tls, FALSE);
if (!cert)
{
fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n");
return FALSE;
}
if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength))
{
fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n");
tls_free_certificate(cert);
return FALSE;
}
free(cert);
fprintf(stderr, "TLS connection accepted\n");
return TRUE;
return tls_do_handshake(tls, FALSE) > 0;
}
BOOL tls_disconnect(rdpTls* tls)
@ -362,256 +320,161 @@ BOOL tls_disconnect(rdpTls* tls)
if (!tls)
return FALSE;
if (tls->ssl)
if (!tls->ssl)
return TRUE;
if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY)
{
if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY)
{
/**
* OpenSSL doesn't really expose an API for sending a TLS alert manually.
*
* The following code disables the sending of the default "close notify"
* and then proceeds to force sending a custom TLS alert before shutting down.
*
* Manually sending a TLS alert is necessary in certain cases,
* like when server-side NLA results in an authentication failure.
*/
/**
* OpenSSL doesn't really expose an API for sending a TLS alert manually.
*
* The following code disables the sending of the default "close notify"
* and then proceeds to force sending a custom TLS alert before shutting down.
*
* Manually sending a TLS alert is necessary in certain cases,
* like when server-side NLA results in an authentication failure.
*/
SSL_set_quiet_shutdown(tls->ssl, 1);
SSL_set_quiet_shutdown(tls->ssl, 1);
if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session))
SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session);
if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session))
SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session);
tls->ssl->s3->alert_dispatch = 1;
tls->ssl->s3->send_alert[0] = tls->alertLevel;
tls->ssl->s3->send_alert[1] = tls->alertDescription;
tls->ssl->s3->alert_dispatch = 1;
tls->ssl->s3->send_alert[0] = tls->alertLevel;
tls->ssl->s3->send_alert[1] = tls->alertDescription;
if (tls->ssl->s3->wbuf.left == 0)
tls->ssl->method->ssl_dispatch_alert(tls->ssl);
if (tls->ssl->s3->wbuf.left == 0)
tls->ssl->method->ssl_dispatch_alert(tls->ssl);
SSL_shutdown(tls->ssl);
}
else
{
SSL_shutdown(tls->ssl);
}
SSL_shutdown(tls->ssl);
}
else
{
SSL_shutdown(tls->ssl);
}
return TRUE;
}
int tls_read(rdpTls* tls, BYTE* data, int length)
BIO *findBufferedBio(BIO *front)
{
int error;
int status;
BIO *ret = front;
if (!tls)
return -1;
if (!tls->ssl)
return -1;
status = SSL_read(tls->ssl, data, length);
if (status == 0)
while (ret)
{
return -1; /* peer disconnected */
if (BIO_method_type(ret) == BIO_TYPE_BUFFERED)
return ret;
ret = ret->next_bio;
}
if (status <= 0)
{
error = SSL_get_error(tls->ssl, status);
//fprintf(stderr, "tls_read: length: %d status: %d error: 0x%08X\n",
// length, status, error);
switch (error)
{
case SSL_ERROR_NONE:
break;
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
status = 0;
break;
case SSL_ERROR_SYSCALL:
#ifdef _WIN32
if (WSAGetLastError() == WSAEWOULDBLOCK)
#else
if ((errno == EAGAIN) || (errno == 0))
#endif
{
status = 0;
}
else
{
if (tls_print_error("SSL_read", tls->ssl, status))
{
status = -1;
}
else
{
status = 0;
}
}
break;
default:
if (tls_print_error("SSL_read", tls->ssl, status))
{
status = -1;
}
else
{
status = 0;
}
break;
}
}
#ifdef HAVE_VALGRIND_MEMCHECK_H
VALGRIND_MAKE_MEM_DEFINED(data, status);
#endif
return status;
return ret;
}
int tls_write(rdpTls* tls, BYTE* data, int length)
int tls_write_all(rdpTls* tls, const BYTE* data, int length)
{
int error;
int status;
int status, nchunks, commitedBytes;
rdpTcp *tcp;
fd_set rset, wset;
fd_set *rsetPtr, *wsetPtr;
struct timeval tv;
BIO *bio = tls->bio;
DataChunk chunks[2];
if (!tls)
return -1;
if (!tls->ssl)
return -1;
status = SSL_write(tls->ssl, data, length);
if (status == 0)
BIO *bufferedBio = findBufferedBio(bio);
if (!bufferedBio)
{
return -1; /* peer disconnected */
fprintf(stderr, "%s: error unable to retrieve the bufferedBio in the BIO chain\n", __FUNCTION__);
return -1;
}
if (status < 0)
{
error = SSL_get_error(tls->ssl, status);
//fprintf(stderr, "tls_write: length: %d status: %d error: 0x%08X\n", length, status, error);
switch (error)
{
case SSL_ERROR_NONE:
break;
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
status = 0;
break;
case SSL_ERROR_SYSCALL:
if (errno == EAGAIN)
{
status = 0;
}
else
{
tls_print_error("SSL_write", tls->ssl, status);
status = -1;
}
break;
default:
tls_print_error("SSL_write", tls->ssl, status);
status = -1;
break;
}
}
return status;
}
int tls_write_all(rdpTls* tls, BYTE* data, int length)
{
int status;
int sent = 0;
tcp = (rdpTcp *)bufferedBio->ptr;
do
{
status = tls_write(tls, &data[sent], length - sent);
status = BIO_write(bio, data, length);
/*fprintf(stderr, "%s: BIO_write(len=%d) = %d (retry=%d)\n", __FUNCTION__, length, status, BIO_should_retry(bio));*/
if (status > 0)
sent += status;
else if (status == 0)
tls_wait_write(tls);
if (sent >= length)
break;
if (!BIO_should_retry(bio))
return -1;
/* we try to handle SSL want_read and want_write nicely */
rsetPtr = wsetPtr = 0;
if (tcp->writeBlocked)
{
wsetPtr = &wset;
FD_ZERO(&wset);
FD_SET(tcp->sockfd, &wset);
}
else if (tcp->readBlocked)
{
rsetPtr = &rset;
FD_ZERO(&rset);
FD_SET(tcp->sockfd, &rset);
}
else
{
fprintf(stderr, "%s: weird we're blocked but the underlying is not read or write blocked !\n", __FUNCTION__);
USleep(10);
continue;
}
tv.tv_sec = 0;
tv.tv_usec = 100 * 1000;
status = select(tcp->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv);
if (status < 0)
return -1;
}
while (status >= 0);
while (TRUE);
if (status > 0)
return length;
else
return status;
}
int tls_wait_read(rdpTls* tls)
{
return freerdp_tcp_wait_read(tls->sockfd);
}
int tls_wait_write(rdpTls* tls)
{
return freerdp_tcp_wait_write(tls->sockfd);
}
static void tls_errors(const char *prefix)
{
unsigned long error;
while ((error = ERR_get_error()) != 0)
fprintf(stderr, "%s: %s\n", prefix, ERR_error_string(error, NULL));
}
BOOL tls_print_error(char* func, SSL* connection, int value)
{
switch (SSL_get_error(connection, value))
/* make sure the output buffer is empty */
commitedBytes = 0;
while ((nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer))))
{
case SSL_ERROR_ZERO_RETURN:
fprintf(stderr, "%s: Server closed TLS connection\n", func);
return TRUE;
int i;
case SSL_ERROR_WANT_READ:
fprintf(stderr, "%s: SSL_ERROR_WANT_READ\n", func);
return FALSE;
for (i = 0; i < nchunks; i++)
{
while (chunks[i].size)
{
status = BIO_write(tcp->socketBio, chunks[i].data, chunks[i].size);
if (status > 0)
{
chunks[i].size -= status;
chunks[i].data += status;
commitedBytes += status;
continue;
}
case SSL_ERROR_WANT_WRITE:
fprintf(stderr, "%s: SSL_ERROR_WANT_WRITE\n", func);
return FALSE;
if (!BIO_should_retry(tcp->socketBio))
goto out_fail;
FD_ZERO(&rset);
FD_SET(tcp->sockfd, &rset);
tv.tv_sec = 0;
tv.tv_usec = 100 * 1000;
case SSL_ERROR_SYSCALL:
#ifdef _WIN32
fprintf(stderr, "%s: I/O error: %d\n", func, WSAGetLastError());
#else
fprintf(stderr, "%s: I/O error: %s (%d)\n", func, strerror(errno), errno);
#endif
tls_errors(func);
return TRUE;
status = select(tcp->sockfd + 1, &rset, NULL, NULL, &tv);
if (status < 0)
goto out_fail;
}
case SSL_ERROR_SSL:
fprintf(stderr, "%s: Failure in SSL library (protocol error?)\n", func);
tls_errors(func);
return TRUE;
default:
fprintf(stderr, "%s: Unknown error\n", func);
tls_errors(func);
return TRUE;
}
}
ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes);
return length;
out_fail:
ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes);
return -1;
}
int tls_set_alert_code(rdpTls* tls, int level, int description)
{
tls->alertLevel = level;
@ -672,7 +535,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
if (!bio)
{
fprintf(stderr, "tls_verify_certificate: BIO_new() failure\n");
fprintf(stderr, "%s: BIO_new() failure\n", __FUNCTION__);
return -1;
}
@ -680,7 +543,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
if (status < 0)
{
fprintf(stderr, "tls_verify_certificate: PEM_write_bio_X509 failure: %d\n", status);
fprintf(stderr, "%s: PEM_write_bio_X509 failure: %d\n", __FUNCTION__, status);
return -1;
}
@ -692,7 +555,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
if (status < 0)
{
fprintf(stderr, "tls_verify_certificate: failed to read certificate\n");
fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__);
return -1;
}
@ -713,7 +576,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
if (status < 0)
{
fprintf(stderr, "tls_verify_certificate: failed to read certificate\n");
fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__);
return -1;
}
@ -727,8 +590,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
status = instance->VerifyX509Certificate(instance, pemCert, length, hostname, port, 0);
}
fprintf(stderr, "VerifyX509Certificate: (length = %d) status: %d\n%s\n",
length, status, pemCert);
fprintf(stderr, "%s: (length = %d) status: %d\n%s\n", __FUNCTION__, length, status, pemCert);
free(pemCert);
BIO_free(bio);
@ -932,57 +794,53 @@ rdpTls* tls_new(rdpSettings* settings)
{
rdpTls* tls;
tls = (rdpTls*) malloc(sizeof(rdpTls));
tls = (rdpTls *)calloc(1, sizeof(rdpTls));
if (!tls)
return NULL;
if (tls)
{
ZeroMemory(tls, sizeof(rdpTls));
SSL_load_error_strings();
SSL_library_init();
SSL_load_error_strings();
SSL_library_init();
tls->settings = settings;
tls->certificate_store = certificate_store_new(settings);
tls->alertLevel = TLS_ALERT_LEVEL_WARNING;
tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY;
}
tls->settings = settings;
tls->certificate_store = certificate_store_new(settings);
if (!tls->certificate_store)
goto out_free;
tls->alertLevel = TLS_ALERT_LEVEL_WARNING;
tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY;
return tls;
out_free:
free(tls);
return NULL;
}
void tls_free(rdpTls* tls)
{
if (tls)
if (!tls)
return;
if (tls->ctx)
{
if (tls->ssl)
{
SSL_free(tls->ssl);
tls->ssl = NULL;
}
if (tls->ctx)
{
SSL_CTX_free(tls->ctx);
tls->ctx = NULL;
}
if (tls->PublicKey)
{
free(tls->PublicKey);
tls->PublicKey = NULL;
}
if (tls->Bindings)
{
free(tls->Bindings->Bindings);
free(tls->Bindings);
tls->Bindings = NULL;
}
certificate_store_free(tls->certificate_store);
tls->certificate_store = NULL;
free(tls);
SSL_CTX_free(tls->ctx);
tls->ctx = NULL;
}
if (tls->PublicKey)
{
free(tls->PublicKey);
tls->PublicKey = NULL;
}
if (tls->Bindings)
{
free(tls->Bindings->Bindings);
free(tls->Bindings);
tls->Bindings = NULL;
}
certificate_store_free(tls->certificate_store);
tls->certificate_store = NULL;
free(tls);
}