Merge branch 'non_blocking_writes' of github.com:hardening/FreeRDP into non_blocking_writes

This commit is contained in:
Marc-André Moreau 2014-05-22 14:01:44 -04:00
commit af4a413287
30 changed files with 2155 additions and 1133 deletions

View File

@ -70,7 +70,6 @@ struct rdp_tls
SSL* ssl;
BIO* bio;
void* tsg;
int sockfd;
SSL_CTX* ctx;
BYTE* PublicKey;
BIO_METHOD* methods;
@ -84,17 +83,11 @@ struct rdp_tls
int alertDescription;
};
FREERDP_API int tls_connect(rdpTls* tls);
FREERDP_API BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file);
FREERDP_API int tls_connect(rdpTls* tls, BIO *underlying);
FREERDP_API BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file);
FREERDP_API BOOL tls_disconnect(rdpTls* tls);
FREERDP_API int tls_read(rdpTls* tls, BYTE* data, int length);
FREERDP_API int tls_write(rdpTls* tls, BYTE* data, int length);
FREERDP_API int tls_write_all(rdpTls* tls, BYTE* data, int length);
FREERDP_API int tls_wait_read(rdpTls* tls);
FREERDP_API int tls_wait_write(rdpTls* tls);
FREERDP_API int tls_write_all(rdpTls* tls, const BYTE* data, int length);
FREERDP_API int tls_set_alert_code(rdpTls* tls, int level, int description);

View File

@ -34,7 +34,10 @@ typedef void (*psPeerContextFree)(freerdp_peer* client, rdpContext* context);
typedef BOOL (*psPeerInitialize)(freerdp_peer* client);
typedef BOOL (*psPeerGetFileDescriptor)(freerdp_peer* client, void** rfds, int* rcount);
typedef HANDLE (*psPeerGetEventHandle)(freerdp_peer* client);
typedef HANDLE (*psPeerGetReceiveEventHandle)(freerdp_peer* client);
typedef BOOL (*psPeerCheckFileDescriptor)(freerdp_peer* client);
typedef BOOL (*psPeerIsWriteBlocked)(freerdp_peer* client);
typedef int (*psPeerDrainOutputBuffer)(freerdp_peer* client);
typedef BOOL (*psPeerClose)(freerdp_peer* client);
typedef void (*psPeerDisconnect)(freerdp_peer* client);
typedef BOOL (*psPeerCapabilities)(freerdp_peer* client);
@ -62,6 +65,7 @@ struct rdp_freerdp_peer
psPeerInitialize Initialize;
psPeerGetFileDescriptor GetFileDescriptor;
psPeerGetEventHandle GetEventHandle;
psPeerGetReceiveEventHandle GetReceiveEventHandle;
psPeerCheckFileDescriptor CheckFileDescriptor;
psPeerClose Close;
psPeerDisconnect Disconnect;
@ -81,6 +85,9 @@ struct rdp_freerdp_peer
BOOL activated;
BOOL authenticated;
SEC_WINNT_AUTH_IDENTITY identity;
psPeerIsWriteBlocked IsWriteBlocked;
psPeerDrainOutputBuffer DrainOutputBuffer;
};
#ifdef __cplusplus

View File

@ -799,7 +799,8 @@ struct rdp_settings
ALIGN64 char* Password; /* 22 */
ALIGN64 char* Domain; /* 23 */
ALIGN64 char* PasswordHash; /* 24 */
UINT64 padding0064[64 - 25]; /* 25 */
ALIGN64 BOOL WaitForOutputBufferFlush; /* 25 */
UINT64 padding0064[64 - 26]; /* 26 */
UINT64 padding0128[128 - 64]; /* 64 */
/**

View File

@ -0,0 +1,132 @@
/**
* Copyright © 2014 Thincast Technologies GmbH
* Copyright © 2014 Hardening <contact@hardening-consulting.com>
*
* Permission to use, copy, modify, distribute, and sell this software and
* its documentation for any purpose is hereby granted without fee, provided
* that the above copyright notice appear in all copies and that both that
* copyright notice and this permission notice appear in supporting
* documentation, and that the name of the copyright holders not be used in
* advertising or publicity pertaining to distribution of the software
* without specific, written prior permission. The copyright holders make
* no representations about the suitability of this software for any
* purpose. It is provided "as is" without express or implied warranty.
*
* THE COPYRIGHT HOLDERS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS
* SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS, IN NO EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY
* SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER
* RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF
* CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
* CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
#ifndef __RINGBUFFER_H___
#define __RINGBUFFER_H___
#include <winpr/wtypes.h>
#include <freerdp/api.h>
/** @brief ring buffer meta data */
struct _RingBuffer {
size_t initialSize;
size_t freeSize;
size_t size;
size_t readPtr;
size_t writePtr;
BYTE *buffer;
};
typedef struct _RingBuffer RingBuffer;
/** @brief a piece of data in the ring buffer, exactly like a glibc iovec */
struct _DataChunk {
size_t size;
const BYTE *data;
};
typedef struct _DataChunk DataChunk;
#ifdef __cplusplus
extern "C" {
#endif
/** initialise a ringbuffer
* @param initialSize the initial capacity of the ringBuffer
* @return if the initialisation was successful
*/
FREERDP_API BOOL ringbuffer_init(RingBuffer *rb, size_t initialSize);
/** destroys internal data used by this ringbuffer
* @param ringbuffer
*/
FREERDP_API void ringbuffer_destroy(RingBuffer *ringbuffer);
/** computes the space used in this ringbuffer
* @param ringbuffer
* @return the number of bytes stored in that ringbuffer
*/
FREERDP_API size_t ringbuffer_used(const RingBuffer *ringbuffer);
/** returns the capacity of the ring buffer
* @param ringbuffer
* @return the capacity of this ring buffer
*/
FREERDP_API size_t ringbuffer_capacity(const RingBuffer *ringbuffer);
/** writes some bytes in the ringbuffer, if the data doesn't fit, the ringbuffer
* is resized automatically
*
* @param rb the ringbuffer
* @param ptr a pointer on the data to add
* @param sz the size of the data to add
* @return if the operation was successful, it could fail in case of OOM during realloc()
*/
FREERDP_API BOOL ringbuffer_write(RingBuffer *rb, const BYTE *ptr, size_t sz);
/** ensures that we have sz bytes available at the write head, and return a pointer
* on the write head
*
* @param rb the ring buffer
* @param sz the size to ensure
* @return a pointer on the write head, or NULL in case of OOM
*/
FREERDP_API BYTE *ringbuffer_ensure_linear_write(RingBuffer *rb, size_t sz);
/** move ahead the write head in case some byte were written directly by using
* a pointer retrieved via ringbuffer_ensure_linear_write(). This function is
* used to commit the written bytes. The provided size should not exceed the
* size ensured by ringbuffer_ensure_linear_write()
*
* @param rb the ring buffer
* @param sz the number of bytes that have been written
* @return if the operation was successful, FALSE is sz is too big
*/
FREERDP_API BOOL ringbuffer_commit_written_bytes(RingBuffer *rb, size_t sz);
/** peeks the buffer chunks for sz bytes and returns how many chunks are filled.
* Note that the sum of the resulting chunks may be smaller than sz.
*
* @param rb the ringbuffer
* @param chunks an array of data chunks that will contain data / size of chunks
* @param sz the requested size
* @return the number of chunks used for reading sz bytes
*/
FREERDP_API int ringbuffer_peek(const RingBuffer *rb, DataChunk chunks[2], size_t sz);
/** move ahead the read head in case some byte were read using ringbuffer_peek()
* This function is used to commit the bytes that were effectively consumed.
*
* @param rb the ring buffer
* @param sz the
*/
FREERDP_API void ringbuffer_commit_read_bytes(RingBuffer *rb, size_t sz);
#ifdef __cplusplus
}
#endif
#endif /* __RINGBUFFER_H___ */

View File

@ -26,6 +26,10 @@
#include <winpr/stream.h>
#include <winpr/string.h>
#ifdef HAVE_VALGRIND_MEMCHECK_H
#include <valgrind/memcheck.h>
#endif
#include "http.h"
HttpContext* http_context_new()
@ -472,7 +476,7 @@ HttpResponse* http_response_recv(rdpTls* tls)
nbytes = 0;
length = 10000;
content = NULL;
buffer = malloc(length);
buffer = calloc(length, 1);
if (!buffer)
return NULL;
@ -487,14 +491,20 @@ HttpResponse* http_response_recv(rdpTls* tls)
{
while (nbytes < 5)
{
status = tls_read(tls, p, length - nbytes);
status = BIO_read(tls->bio, p, length - nbytes);
if (status < 0)
goto out_error;
if (status <= 0)
{
if (!BIO_should_retry(tls->bio))
goto out_error;
if (!status)
USleep(100);
continue;
}
#ifdef HAVE_VALGRIND_MEMCHECK_H
VALGRIND_MAKE_MEM_DEFINED(p, status);
#endif
nbytes += status;
p = (BYTE*) &buffer[nbytes];
}
@ -503,7 +513,7 @@ HttpResponse* http_response_recv(rdpTls* tls)
if (!header_end)
{
fprintf(stderr, "http_response_recv: invalid response:\n");
fprintf(stderr, "%s: invalid response:\n", __FUNCTION__);
winpr_HexDump(buffer, status);
goto out_error;
}
@ -517,7 +527,7 @@ HttpResponse* http_response_recv(rdpTls* tls)
header_end[0] = '\0';
header_end[1] = '\0';
content = &header_end[2];
content = header_end + 2;
count = 0;
line = (char*) buffer;
@ -552,11 +562,14 @@ HttpResponse* http_response_recv(rdpTls* tls)
if (!http_response_parse_header(http_response))
goto out_error;
if (http_response->ContentLength > 0)
http_response->bodyLen = nbytes - (content - (char *)buffer);
if (http_response->bodyLen > 0)
{
http_response->Content = _strdup(content);
if (!http_response->Content)
http_response->BodyContent = (BYTE *)malloc(http_response->bodyLen);
if (!http_response->BodyContent)
goto out_error;
CopyMemory(http_response->BodyContent, content, http_response->bodyLen);
}
break;
@ -627,7 +640,7 @@ void http_response_free(HttpResponse* http_response)
ListDictionary_Free(http_response->Authenticates);
if (http_response->ContentLength > 0)
free(http_response->Content);
free(http_response->BodyContent);
free(http_response);
}

View File

@ -84,7 +84,8 @@ struct _http_response
wListDictionary *Authenticates;
int ContentLength;
char* Content;
BYTE *BodyContent;
int bodyLen;
};
void http_response_print(HttpResponse* http_response);

View File

@ -98,6 +98,8 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc)
rdpNtlm* ntlm = rpc->NtlmHttpIn->ntlm;
http_response = http_response_recv(rpc->TlsIn);
if (!http_response)
return -1;
if (ListDictionary_Contains(http_response->Authenticates, "NTLM"))
{
@ -105,14 +107,12 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc)
if (!token64)
goto out;
ntlm_token_data = NULL;
crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length);
}
out:
ntlm->inputBuffer[0].pvBuffer = ntlm_token_data;
ntlm->inputBuffer[0].cbBuffer = ntlm_token_length;
out:
http_response_free(http_response);
return 0;
@ -123,25 +123,19 @@ int rpc_ncacn_http_ntlm_init(rdpRpc* rpc, TSG_CHANNEL channel)
rdpNtlm* ntlm = NULL;
rdpSettings* settings = rpc->settings;
freerdp* instance = (freerdp*) rpc->settings->instance;
BOOL promptPassword = FALSE;
if (channel == TSG_CHANNEL_IN)
ntlm = rpc->NtlmHttpIn->ntlm;
else if (channel == TSG_CHANNEL_OUT)
ntlm = rpc->NtlmHttpOut->ntlm;
if ((!settings->GatewayPassword) || (!settings->GatewayUsername)
|| (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername)))
{
promptPassword = TRUE;
}
if (promptPassword)
if (!settings->GatewayPassword || !settings->GatewayUsername ||
!strlen(settings->GatewayPassword) || !strlen(settings->GatewayUsername))
{
if (instance->GatewayAuthenticate)
{
BOOL proceed = instance->GatewayAuthenticate(instance,
&settings->GatewayUsername, &settings->GatewayPassword, &settings->GatewayDomain);
BOOL proceed = instance->GatewayAuthenticate(instance, &settings->GatewayUsername,
&settings->GatewayPassword, &settings->GatewayDomain);
if (!proceed)
{
@ -240,12 +234,10 @@ int rpc_ncacn_http_recv_out_channel_response(rdpRpc* rpc)
char *token64 = ListDictionary_GetItemValue(http_response->Authenticates, "NTLM");
crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length);
}
ntlm->inputBuffer[0].pvBuffer = ntlm_token_data;
ntlm->inputBuffer[0].cbBuffer = ntlm_token_length;
http_response_free(http_response);
return 0;
}
@ -259,15 +251,12 @@ BOOL rpc_ntlm_http_out_connect(rdpRpc* rpc)
success = TRUE;
/* Send OUT Channel Request */
rpc_ncacn_http_send_out_channel_request(rpc);
/* Receive OUT Channel Response */
rpc_ncacn_http_recv_out_channel_response(rpc);
/* Send OUT Channel Request */
rpc_ncacn_http_send_out_channel_request(rpc);
ntlm_client_uninit(ntlm);
@ -296,13 +285,11 @@ void rpc_ntlm_http_init_channel(rdpRpc* rpc, rdpNtlmHttp* ntlm_http, TSG_CHANNEL
if (channel == TSG_CHANNEL_IN)
{
http_context_set_pragma(ntlm_http->context,
"ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729");
http_context_set_pragma(ntlm_http->context, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729");
}
else if (channel == TSG_CHANNEL_OUT)
{
http_context_set_pragma(ntlm_http->context,
"ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729" ", "
http_context_set_pragma(ntlm_http->context, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729, "
"SessionId=fbd9c34f-397d-471d-a109-1b08cc554624");
}
}

View File

@ -33,6 +33,11 @@
#include <winpr/dsparse.h>
#include <openssl/rand.h>
#include <openssl/bio.h>
#ifdef HAVE_VALGRIND_MEMCHECK_H
#include <valgrind/memcheck.h>
#endif
#include "http.h"
#include "ntlm.h"
@ -235,80 +240,77 @@ BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* buffer, UINT32* offset, UINT32* l
{
UINT32 alloc_hint = 0;
rpcconn_hdr_t* header;
UINT32 frag_length;
UINT32 auth_length;
UINT32 auth_pad_length;
UINT32 sec_trailer_offset;
rpc_sec_trailer* sec_trailer;
*offset = RPC_COMMON_FIELDS_LENGTH;
header = ((rpcconn_hdr_t*) buffer);
if (header->common.ptype == PTYPE_RESPONSE)
switch (header->common.ptype)
{
*offset += 8;
rpc_offset_align(offset, 8);
alloc_hint = header->response.alloc_hint;
}
else if (header->common.ptype == PTYPE_REQUEST)
{
*offset += 4;
rpc_offset_align(offset, 8);
alloc_hint = header->request.alloc_hint;
}
else if (header->common.ptype == PTYPE_RTS)
{
*offset += 4;
}
else
{
return FALSE;
case PTYPE_RESPONSE:
*offset += 8;
rpc_offset_align(offset, 8);
alloc_hint = header->response.alloc_hint;
break;
case PTYPE_REQUEST:
*offset += 4;
rpc_offset_align(offset, 8);
alloc_hint = header->request.alloc_hint;
break;
case PTYPE_RTS:
*offset += 4;
break;
default:
fprintf(stderr, "%s: unknown ptype=0x%x\n", __FUNCTION__, header->common.ptype);
return FALSE;
}
if (length)
if (!length)
return TRUE;
if (header->common.ptype == PTYPE_REQUEST)
{
if (header->common.ptype == PTYPE_REQUEST)
{
UINT32 sec_trailer_offset;
UINT32 sec_trailer_offset;
sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8;
*length = sec_trailer_offset - *offset;
}
else
{
UINT32 frag_length;
UINT32 auth_length;
UINT32 auth_pad_length;
UINT32 sec_trailer_offset;
rpc_sec_trailer* sec_trailer;
sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8;
*length = sec_trailer_offset - *offset;
return TRUE;
}
frag_length = header->common.frag_length;
auth_length = header->common.auth_length;
sec_trailer_offset = frag_length - auth_length - 8;
sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset];
auth_pad_length = sec_trailer->auth_pad_length;
frag_length = header->common.frag_length;
auth_length = header->common.auth_length;
sec_trailer_offset = frag_length - auth_length - 8;
sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset];
auth_pad_length = sec_trailer->auth_pad_length;
#if 0
fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n",
sec_trailer->auth_type,
sec_trailer->auth_level,
sec_trailer->auth_pad_length,
sec_trailer->auth_reserved,
sec_trailer->auth_context_id);
fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n",
sec_trailer->auth_type,
sec_trailer->auth_level,
sec_trailer->auth_pad_length,
sec_trailer->auth_reserved,
sec_trailer->auth_context_id);
#endif
/**
* According to [MS-RPCE], auth_pad_length is the number of padding
* octets used to 4-byte align the security trailer, but in practice
* we get values up to 15, which indicates 16-byte alignment.
*/
/**
* According to [MS-RPCE], auth_pad_length is the number of padding
* octets used to 4-byte align the security trailer, but in practice
* we get values up to 15, which indicates 16-byte alignment.
*/
if ((frag_length - (sec_trailer_offset + 8)) != auth_length)
{
fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length,
(frag_length - (sec_trailer_offset + 8)));
}
*length = frag_length - auth_length - 24 - 8 - auth_pad_length;
}
if ((frag_length - (sec_trailer_offset + 8)) != auth_length)
{
fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length,
(frag_length - (sec_trailer_offset + 8)));
}
*length = frag_length - auth_length - 24 - 8 - auth_pad_length;
return TRUE;
}
@ -316,12 +318,23 @@ int rpc_out_read(rdpRpc* rpc, BYTE* data, int length)
{
int status;
status = tls_read(rpc->TlsOut, data, length);
status = BIO_read(rpc->TlsOut->bio, data, length);
/* fprintf(stderr, "%s: length=%d => status=%d shouldRetry=%d\n", __FUNCTION__, length,
* status, BIO_should_retry(rpc->TlsOut->bio)); */
if (status > 0) {
#ifdef HAVE_VALGRIND_MEMCHECK_H
VALGRIND_MAKE_MEM_DEFINED(data, status);
#endif
return status;
}
return status;
if (BIO_should_retry(rpc->TlsOut->bio))
return 0;
return -1;
}
int rpc_out_write(rdpRpc* rpc, BYTE* data, int length)
int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length)
{
int status;
@ -330,7 +343,7 @@ int rpc_out_write(rdpRpc* rpc, BYTE* data, int length)
return status;
}
int rpc_in_write(rdpRpc* rpc, BYTE* data, int length)
int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length)
{
int status;
@ -360,20 +373,21 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
ntlm = rpc->ntlm;
if ((!ntlm) || (!ntlm->table))
if (!ntlm || !ntlm->table)
{
fprintf(stderr, "rpc_write: invalid ntlm context\n");
fprintf(stderr, "%s: invalid ntlm context\n", __FUNCTION__);
return -1;
}
if (ntlm->table->QueryContextAttributes(&ntlm->context, SECPKG_ATTR_SIZES, &ntlm->ContextSizes) != SEC_E_OK)
{
fprintf(stderr, "QueryContextAttributes SECPKG_ATTR_SIZES failure\n");
fprintf(stderr, "%s: QueryContextAttributes SECPKG_ATTR_SIZES failure\n", __FUNCTION__);
return -1;
}
request_pdu = (rpcconn_request_hdr_t*) malloc(sizeof(rpcconn_request_hdr_t));
ZeroMemory(request_pdu, sizeof(rpcconn_request_hdr_t));
request_pdu = (rpcconn_request_hdr_t*) calloc(1, sizeof(rpcconn_request_hdr_t));
if (!request_pdu)
return -1;
rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) request_pdu);
@ -386,7 +400,11 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
request_pdu->opnum = opnum;
clientCall = rpc_client_call_new(request_pdu->call_id, request_pdu->opnum);
ArrayList_Add(rpc->client->ClientCallList, clientCall);
if (!clientCall)
goto out_free_pdu;
if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0)
goto out_free_clientCall;
if (request_pdu->opnum == TsProxySetupReceivePipeOpnum)
rpc->PipeCallId = request_pdu->call_id;
@ -407,8 +425,9 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
request_pdu->frag_length = offset;
buffer = (BYTE*) malloc(request_pdu->frag_length);
buffer = (BYTE*) calloc(1, request_pdu->frag_length);
if (!buffer)
goto out_free_pdu;
CopyMemory(buffer, request_pdu, 24);
offset = 24;
@ -427,15 +446,15 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
Buffers[0].cbBuffer = offset;
Buffers[1].cbBuffer = ntlm->ContextSizes.cbMaxSignature;
Buffers[1].pvBuffer = malloc(Buffers[1].cbBuffer);
ZeroMemory(Buffers[1].pvBuffer, Buffers[1].cbBuffer);
Buffers[1].pvBuffer = calloc(1, Buffers[1].cbBuffer);
if (!Buffers[1].pvBuffer)
return -1;
Message.cBuffers = 2;
Message.ulVersion = SECBUFFER_VERSION;
Message.pBuffers = (PSecBuffer) &Buffers;
encrypt_status = ntlm->table->EncryptMessage(&ntlm->context, 0, &Message, rpc->SendSeqNum++);
if (encrypt_status != SEC_E_OK)
{
fprintf(stderr, "EncryptMessage status: 0x%08X\n", encrypt_status);
@ -447,12 +466,18 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
offset += Buffers[1].cbBuffer;
free(Buffers[1].pvBuffer);
if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) != 0)
if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) < 0)
length = -1;
free(request_pdu);
return length;
out_free_clientCall:
rpc_client_call_free(clientCall);
out_free_pdu:
free(request_pdu);
return -1;
}
BOOL rpc_connect(rdpRpc* rpc)
@ -592,13 +617,17 @@ rdpRpc* rpc_new(rdpTransport* transport)
rpc->CallId = 2;
rpc_client_new(rpc);
if (rpc_client_new(rpc) < 0)
goto out_free_virtualConnectionCookieTable;
rpc->client->SynchronousSend = TRUE;
rpc->client->SynchronousReceive = TRUE;
return rpc;
out_free_virtualConnectionCookieTable:
rpc_client_free(rpc);
ArrayList_Free(rpc->VirtualConnectionCookieTable);
out_free_virtual_connection:
rpc_client_virtual_connection_free(rpc->VirtualConnection);
out_free_ntlm_http_out:

View File

@ -772,8 +772,8 @@ UINT32 rpc_offset_pad(UINT32* offset, UINT32 pad);
int rpc_out_read(rdpRpc* rpc, BYTE* data, int length);
int rpc_out_write(rdpRpc* rpc, BYTE* data, int length);
int rpc_in_write(rdpRpc* rpc, BYTE* data, int length);
int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length);
int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length);
BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* header, UINT32* offset, UINT32* length);

View File

@ -103,6 +103,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
DEBUG_RPC("Sending bind PDU");
rpc->ntlm = ntlm_new();
if (!rpc->ntlm)
return -1;
if ((!settings->GatewayPassword) || (!settings->GatewayUsername)
|| (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername)))
@ -129,17 +131,22 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
settings->Username = _strdup(settings->GatewayUsername);
settings->Domain = _strdup(settings->GatewayDomain);
settings->Password = _strdup(settings->GatewayPassword);
if (!settings->Username || !settings->Domain || settings->Password)
return -1;
}
}
}
ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL);
ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname);
if (!ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL) ||
!ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname) ||
!ntlm_authenticate(rpc->ntlm)
)
return -1;
ntlm_authenticate(rpc->ntlm);
bind_pdu = (rpcconn_bind_hdr_t*) malloc(sizeof(rpcconn_bind_hdr_t));
ZeroMemory(bind_pdu, sizeof(rpcconn_bind_hdr_t));
bind_pdu = (rpcconn_bind_hdr_t*) calloc(1, sizeof(rpcconn_bind_hdr_t));
if (!bind_pdu)
return -1;
rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) bind_pdu);
@ -159,6 +166,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
bind_pdu->p_context_elem.reserved2 = 0;
bind_pdu->p_context_elem.p_cont_elem = malloc(sizeof(p_cont_elem_t) * bind_pdu->p_context_elem.n_context_elem);
if (!bind_pdu->p_context_elem.p_cont_elem)
return -1;
p_cont_elem = &bind_pdu->p_context_elem.p_cont_elem[0];
@ -196,6 +205,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
bind_pdu->frag_length = offset;
buffer = (BYTE*) malloc(bind_pdu->frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, bind_pdu, 24);
CopyMemory(&buffer[24], &bind_pdu->p_context_elem, 4);
@ -214,7 +225,10 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
length = bind_pdu->frag_length;
clientCall = rpc_client_call_new(bind_pdu->call_id, 0);
ArrayList_Add(rpc->client->ClientCallList, clientCall);
if (!clientCall)
return -1;
if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0)
return -1;
if (rpc_send_enqueue_pdu(rpc, buffer, length) != 0)
length = -1;

View File

@ -34,9 +34,7 @@
#include <winpr/stream.h>
#include "rpc_fault.h"
#include "rpc_client.h"
#include "../rdp.h"
#define SYNCHRONOUS_TIMEOUT 5000
@ -69,8 +67,15 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc)
if (!pdu)
{
pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU));
pdu = (RPC_PDU *)malloc(sizeof(RPC_PDU));
if (!pdu)
return NULL;
pdu->s = Stream_New(NULL, rpc->max_recv_frag);
if (!pdu->s)
{
free(pdu);
return NULL;
}
}
pdu->CallId = 0;
@ -84,8 +89,7 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc)
int rpc_client_receive_pool_return(rdpRpc* rpc, RPC_PDU* pdu)
{
Queue_Enqueue(rpc->client->ReceivePool, pdu);
return 0;
return Queue_Enqueue(rpc->client->ReceivePool, pdu) == TRUE ? 0 : -1;
}
int rpc_client_on_fragment_received_event(rdpRpc* rpc)
@ -97,7 +101,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
rpcconn_hdr_t* header;
freerdp* instance;
instance = (freerdp*) rpc->transport->settings->instance;
instance = (freerdp *)rpc->transport->settings->instance;
if (!rpc->client->pdu)
rpc->client->pdu = rpc_client_receive_pool_take(rpc);
@ -125,34 +129,29 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
return 0;
}
if (header->common.ptype == PTYPE_RTS)
switch (header->common.ptype)
{
if (rpc->VirtualConnection->State >= VIRTUAL_CONNECTION_STATE_OPENED)
{
//fprintf(stderr, "Receiving Out-of-Sequence RTS PDU\n");
case PTYPE_RTS:
if (rpc->VirtualConnection->State < VIRTUAL_CONNECTION_STATE_OPENED)
{
fprintf(stderr, "%s: warning: unhandled RTS PDU\n", __FUNCTION__);
return 0;
}
fprintf(stderr, "%s: Receiving Out-of-Sequence RTS PDU\n", __FUNCTION__);
rts_recv_out_of_sequence_pdu(rpc, buffer, header->common.frag_length);
rpc_client_fragment_pool_return(rpc, fragment);
}
else
{
fprintf(stderr, "warning: unhandled RTS PDU\n");
}
return 0;
return 0;
}
else if (header->common.ptype == PTYPE_FAULT)
{
rpc_recv_fault_pdu(header);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1;
}
if (header->common.ptype != PTYPE_RESPONSE)
{
fprintf(stderr, "Unexpected RPC PDU type: %d\n", header->common.ptype);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1;
case PTYPE_FAULT:
rpc_recv_fault_pdu(header);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1;
case PTYPE_RESPONSE:
break;
default:
fprintf(stderr, "%s: unexpected RPC PDU type %d\n", __FUNCTION__, header->common.ptype);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1;
}
rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header->common.frag_length;
@ -160,7 +159,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
if (!rpc_get_stub_data_info(rpc, buffer, &StubOffset, &StubLength))
{
fprintf(stderr, "rpc_recv_pdu_fragment: expected stub\n");
fprintf(stderr, "%s: expected stub\n", __FUNCTION__);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1;
}
@ -196,7 +195,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
if (rpc->StubCallId != header->common.call_id)
{
fprintf(stderr, "invalid call_id: actual: %d, expected: %d, frag_count: %d\n",
fprintf(stderr, "%s: invalid call_id: actual: %d, expected: %d, frag_count: %d\n", __FUNCTION__,
rpc->StubCallId, header->common.call_id, rpc->StubFragCount);
}
@ -243,27 +242,34 @@ int rpc_client_on_read_event(rdpRpc* rpc)
int status = -1;
rpcconn_common_hdr_t* header;
if (!rpc->client->RecvFrag)
rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
position = Stream_GetPosition(rpc->client->RecvFrag);
if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
while (1)
{
status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag));
if (!rpc->client->RecvFrag)
rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
if (status < 0)
position = Stream_GetPosition(rpc->client->RecvFrag);
while (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
{
fprintf(stderr, "rpc_client_frag_read: error reading header\n");
return -1;
status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag));
if (status < 0)
{
fprintf(stderr, "rpc_client_frag_read: error reading header\n");
return -1;
}
if (!status)
return 0;
Stream_Seek(rpc->client->RecvFrag, status);
}
Stream_Seek(rpc->client->RecvFrag, status);
}
if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
return status;
if (Stream_GetPosition(rpc->client->RecvFrag) >= RPC_COMMON_FIELDS_LENGTH)
{
header = (rpcconn_common_hdr_t*) Stream_Buffer(rpc->client->RecvFrag);
if (header->frag_length > rpc->max_recv_frag)
@ -274,45 +280,44 @@ int rpc_client_on_read_event(rdpRpc* rpc)
return -1;
}
if (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length)
while (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length)
{
status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
header->frag_length - Stream_GetPosition(rpc->client->RecvFrag));
if (status < 0)
{
fprintf(stderr, "rpc_client_frag_read: error reading fragment body\n");
fprintf(stderr, "%s: error reading fragment body\n", __FUNCTION__);
return -1;
}
if (!status)
return 0;
Stream_Seek(rpc->client->RecvFrag, status);
}
}
else
{
return status;
}
if (status < 0)
return -1;
status = Stream_GetPosition(rpc->client->RecvFrag) - position;
if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length)
{
/* complete fragment received */
Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag);
Stream_SetPosition(rpc->client->RecvFrag, 0);
Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag);
rpc->client->RecvFrag = NULL;
if (rpc_client_on_fragment_received_event(rpc) < 0)
if (status < 0)
return -1;
status = Stream_GetPosition(rpc->client->RecvFrag) - position;
if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length)
{
/* complete fragment received */
Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag);
Stream_SetPosition(rpc->client->RecvFrag, 0);
Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag);
rpc->client->RecvFrag = NULL;
if (rpc_client_on_fragment_received_event(rpc) < 0)
return -1;
}
}
return status;
return 0;
}
/**
@ -349,13 +354,12 @@ RpcClientCall* rpc_client_call_new(UINT32 CallId, UINT32 OpNum)
RpcClientCall* clientCall;
clientCall = (RpcClientCall*) malloc(sizeof(RpcClientCall));
if (!clientCall)
return NULL;
if (clientCall)
{
clientCall->CallId = CallId;
clientCall->OpNum = OpNum;
clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS;
}
clientCall->CallId = CallId;
clientCall->OpNum = OpNum;
clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS;
return clientCall;
}
@ -371,16 +375,22 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
int status;
pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU));
pdu->s = Stream_New(buffer, length);
if (!pdu)
return -1;
Queue_Enqueue(rpc->client->SendQueue, pdu);
pdu->s = Stream_New(buffer, length);
if (!pdu->s)
goto out_free;
if (!Queue_Enqueue(rpc->client->SendQueue, pdu))
goto out_free_stream;
if (rpc->client->SynchronousSend)
{
status = WaitForSingleObject(rpc->client->PduSentEvent, SYNCHRONOUS_TIMEOUT);
if (status == WAIT_TIMEOUT)
{
fprintf(stderr, "rpc_send_enqueue_pdu: timed out waiting for pdu sent event\n");
fprintf(stderr, "%s: timed out waiting for pdu sent event %p\n", __FUNCTION__, rpc->client->PduSentEvent);
return -1;
}
@ -388,6 +398,12 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
}
return 0;
out_free_stream:
Stream_Free(pdu->s, TRUE);
out_free:
free(pdu);
return -1;
}
int rpc_send_dequeue_pdu(rdpRpc* rpc)
@ -396,13 +412,14 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
RPC_PDU* pdu;
RpcClientCall* clientCall;
rpcconn_common_hdr_t* header;
RpcInChannel *inChannel;
pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->SendQueue);
if (!pdu)
return 0;
WaitForSingleObject(rpc->VirtualConnection->DefaultInChannel->Mutex, INFINITE);
inChannel = rpc->VirtualConnection->DefaultInChannel;
WaitForSingleObject(inChannel->Mutex, INFINITE);
status = rpc_in_write(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s));
@ -410,7 +427,7 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
clientCall = rpc_client_call_find_by_id(rpc, header->call_id);
clientCall->State = RPC_CLIENT_CALL_STATE_DISPATCHED;
ReleaseMutex(rpc->VirtualConnection->DefaultInChannel->Mutex);
ReleaseMutex(inChannel->Mutex);
/*
* This protocol specifies that only RPC PDUs are subject to the flow control abstract
@ -421,8 +438,8 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
if (header->ptype == PTYPE_REQUEST)
{
rpc->VirtualConnection->DefaultInChannel->BytesSent += status;
rpc->VirtualConnection->DefaultInChannel->SenderAvailableWindow -= status;
inChannel->BytesSent += status;
inChannel->SenderAvailableWindow -= status;
}
Stream_Free(pdu->s, TRUE);
@ -440,57 +457,48 @@ RPC_PDU* rpc_recv_dequeue_pdu(rdpRpc* rpc)
DWORD dwMilliseconds;
DWORD result;
pdu = NULL;
dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0;
dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT * 4 : 0;
result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds);
if (result == WAIT_TIMEOUT)
{
fprintf(stderr, "rpc_recv_dequeue_pdu: timed out waiting for receive event\n");
fprintf(stderr, "%s: timed out waiting for receive event\n", __FUNCTION__);
return NULL;
}
if (result == WAIT_OBJECT_0)
{
pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->ReceiveQueue);
if (result != WAIT_OBJECT_0)
return NULL;
pdu = (RPC_PDU *)Queue_Dequeue(rpc->client->ReceiveQueue);
#ifdef WITH_DEBUG_TSG
if (pdu)
{
fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId);
winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s));
fprintf(stderr, "\n");
}
#endif
return pdu;
if (pdu)
{
fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId);
winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s));
fprintf(stderr, "\n");
}
else
{
fprintf(stderr, "Receiving a NULL PDU\n");
}
#endif
return pdu;
}
RPC_PDU* rpc_recv_peek_pdu(rdpRpc* rpc)
{
RPC_PDU* pdu;
DWORD dwMilliseconds;
DWORD result;
pdu = NULL;
dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0;
result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds);
if (result == WAIT_TIMEOUT)
{
if (result != WAIT_OBJECT_0)
return NULL;
}
if (result == WAIT_OBJECT_0)
{
pdu = (RPC_PDU*) Queue_Peek(rpc->client->ReceiveQueue);
return pdu;
}
return pdu;
return (RPC_PDU *)Queue_Peek(rpc->client->ReceiveQueue);
}
static void* rpc_client_thread(void* arg)
@ -500,40 +508,52 @@ static void* rpc_client_thread(void* arg)
DWORD nCount;
HANDLE events[3];
HANDLE ReadEvent;
int fd;
rpc = (rdpRpc*) arg;
fd = BIO_get_fd(rpc->TlsOut->bio, NULL);
ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, rpc->TlsOut->sockfd);
ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, fd);
nCount = 0;
events[nCount++] = rpc->client->StopEvent;
events[nCount++] = Queue_Event(rpc->client->SendQueue);
events[nCount++] = ReadEvent;
/* Do a first free run in case some bytes were set from the HTTP headers.
* We also have to do it because most of the time the underlying socket has notified,
* and the ssl layer has eaten all bytes, so we won't be notified any more even if the
* bytes are buffered locally
*/
if (rpc_client_on_read_event(rpc) < 0)
{
fprintf(stderr, "%s: an error occured when treating first packet\n", __FUNCTION__);
goto out;
}
while (rpc->transport->layer != TRANSPORT_LAYER_CLOSED)
{
status = WaitForMultipleObjects(nCount, events, FALSE, 100);
if (status != WAIT_TIMEOUT)
if (status == WAIT_TIMEOUT)
continue;
if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0)
break;
if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0)
{
if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0)
{
if (rpc_client_on_read_event(rpc) < 0)
break;
}
}
if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0)
{
if (rpc_client_on_read_event(rpc) < 0)
break;
}
if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0)
{
rpc_send_dequeue_pdu(rpc);
}
if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0)
{
rpc_send_dequeue_pdu(rpc);
}
}
out:
CloseHandle(ReadEvent);
return NULL;
@ -541,6 +561,9 @@ static void* rpc_client_thread(void* arg)
static void rpc_pdu_free(RPC_PDU* pdu)
{
if (!pdu)
return;
Stream_Free(pdu->s, TRUE);
free(pdu);
}
@ -554,35 +577,55 @@ int rpc_client_new(rdpRpc* rpc)
{
RpcClient* client = NULL;
client = (RpcClient*) calloc(1, sizeof(RpcClient));
if (client)
{
client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
client->SendQueue = Queue_New(TRUE, -1, -1);
Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->pdu = NULL;
client->ReceivePool = Queue_New(TRUE, -1, -1);
client->ReceiveQueue = Queue_New(TRUE, -1, -1);
Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->RecvFrag = NULL;
client->FragmentPool = Queue_New(TRUE, -1, -1);
client->FragmentQueue = Queue_New(TRUE, -1, -1);
Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
client->ClientCallList = ArrayList_New(TRUE);
ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free;
}
client = (RpcClient *)calloc(1, sizeof(RpcClient));
rpc->client = client;
if (!client)
return -1;
client->Thread = CreateThread(NULL, 0,
(LPTHREAD_START_ROUTINE) rpc_client_thread,
rpc, CREATE_SUSPENDED, NULL);
if (!client->Thread)
return -1;
client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
if (!client->StopEvent)
return -1;
client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
if (!client->PduSentEvent)
return -1;
client->SendQueue = Queue_New(TRUE, -1, -1);
if (!client->SendQueue)
return -1;
Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->pdu = NULL;
client->ReceivePool = Queue_New(TRUE, -1, -1);
if (!client->ReceivePool)
return -1;
Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->ReceiveQueue = Queue_New(TRUE, -1, -1);
if (!client->ReceiveQueue)
return -1;
Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->RecvFrag = NULL;
client->FragmentPool = Queue_New(TRUE, -1, -1);
if (!client->FragmentPool)
return -1;
Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
client->FragmentQueue = Queue_New(TRUE, -1, -1);
if (!client->FragmentQueue)
return -1;
Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
client->ClientCallList = ArrayList_New(TRUE);
if (!client->ClientCallList)
return -1;
ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free;
return 0;
}
@ -604,9 +647,7 @@ int rpc_client_stop(rdpRpc* rpc)
rpc->client->Thread = NULL;
}
rpc_client_free(rpc);
return 0;
return rpc_client_free(rpc);
}
int rpc_client_free(rdpRpc* rpc)
@ -615,31 +656,39 @@ int rpc_client_free(rdpRpc* rpc)
client = rpc->client;
if (client)
{
if (!client)
return 0;
if (client->SendQueue)
Queue_Free(client->SendQueue);
if (client->RecvFrag)
rpc_fragment_free(client->RecvFrag);
if (client->RecvFrag)
rpc_fragment_free(client->RecvFrag);
if (client->FragmentPool)
Queue_Free(client->FragmentPool);
if (client->FragmentQueue)
Queue_Free(client->FragmentQueue);
if (client->pdu)
rpc_pdu_free(client->pdu);
if (client->pdu)
rpc_pdu_free(client->pdu);
if (client->ReceivePool)
Queue_Free(client->ReceivePool);
if (client->ReceiveQueue)
Queue_Free(client->ReceiveQueue);
if (client->ClientCallList)
ArrayList_Free(client->ClientCallList);
if (client->StopEvent)
CloseHandle(client->StopEvent);
if (client->PduSentEvent)
CloseHandle(client->PduSentEvent);
if (client->Thread)
CloseHandle(client->Thread);
free(client);
}
free(client);
return 0;
}

View File

@ -93,25 +93,25 @@ BOOL rts_connect(rdpRpc* rpc)
if (!rpc_ntlm_http_out_connect(rpc))
{
fprintf(stderr, "rpc_out_connect_http error!\n");
fprintf(stderr, "%s: rpc_out_connect_http error!\n", __FUNCTION__);
return FALSE;
}
if (rts_send_CONN_A1_pdu(rpc) != 0)
{
fprintf(stderr, "rpc_send_CONN_A1_pdu error!\n");
fprintf(stderr, "%s: rpc_send_CONN_A1_pdu error!\n", __FUNCTION__);
return FALSE;
}
if (!rpc_ntlm_http_in_connect(rpc))
{
fprintf(stderr, "rpc_in_connect_http error!\n");
fprintf(stderr, "%s: rpc_in_connect_http error!\n", __FUNCTION__);
return FALSE;
}
if (rts_send_CONN_B1_pdu(rpc) != 0)
if (rts_send_CONN_B1_pdu(rpc) < 0)
{
fprintf(stderr, "rpc_send_CONN_B1_pdu error!\n");
fprintf(stderr, "%s: rpc_send_CONN_B1_pdu error!\n", __FUNCTION__);
return FALSE;
}
@ -147,10 +147,15 @@ BOOL rts_connect(rdpRpc* rpc)
*/
http_response = http_response_recv(rpc->TlsOut);
if (!http_response)
{
fprintf(stderr, "%s: unable to retrieve OUT Channel Response!\n", __FUNCTION__);
return FALSE;
}
if (http_response->StatusCode != HTTP_STATUS_OK)
{
fprintf(stderr, "rts_connect error! Status Code: %d\n", http_response->StatusCode);
fprintf(stderr, "%s: error! Status Code: %d\n", __FUNCTION__, http_response->StatusCode);
http_response_print(http_response);
http_response_free(http_response);
@ -170,6 +175,14 @@ BOOL rts_connect(rdpRpc* rpc)
return FALSE;
}
if (http_response->bodyLen)
{
/* inject bytes we have read in the body as a received packet for the RPC client */
rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
Stream_EnsureCapacity(rpc->client->RecvFrag, http_response->bodyLen);
CopyMemory(rpc->client->RecvFrag, http_response->BodyContent, http_response->bodyLen);
}
//http_response_print(http_response);
http_response_free(http_response);
@ -195,7 +208,6 @@ BOOL rts_connect(rdpRpc* rpc)
rpc_client_start(rpc);
pdu = rpc_recv_dequeue_pdu(rpc);
if (!pdu)
return FALSE;
@ -203,7 +215,7 @@ BOOL rts_connect(rdpRpc* rpc)
if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_A3_SIGNATURE, rts))
{
fprintf(stderr, "Unexpected RTS PDU: Expected CONN/A3\n");
fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/A3\n", __FUNCTION__);
return FALSE;
}
@ -236,7 +248,6 @@ BOOL rts_connect(rdpRpc* rpc)
*/
pdu = rpc_recv_dequeue_pdu(rpc);
if (!pdu)
return FALSE;
@ -244,7 +255,7 @@ BOOL rts_connect(rdpRpc* rpc)
if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_C2_SIGNATURE, rts))
{
fprintf(stderr, "Unexpected RTS PDU: Expected CONN/C2\n");
fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/C2\n", __FUNCTION__);
return FALSE;
}
@ -261,7 +272,7 @@ BOOL rts_connect(rdpRpc* rpc)
return TRUE;
}
#if defined WITH_DEBUG_RTS && 0
#ifdef WITH_DEBUG_RTS
static const char* const RTS_CMD_STRINGS[] =
{
@ -317,6 +328,7 @@ static const char* const RTS_CMD_STRINGS[] =
void rts_pdu_header_init(rpcconn_rts_hdr_t* header)
{
ZeroMemory(header, sizeof(*header));
header->rpc_vers = 5;
header->rpc_vers_minor = 0;
header->ptype = PTYPE_RTS;
@ -681,6 +693,8 @@ int rts_send_CONN_A1_pdu(rdpRpc* rpc)
ReceiveWindowSize = rpc->VirtualConnection->DefaultOutChannel->ReceiveWindow;
buffer = (BYTE*) malloc(header.frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_version_command_write(&buffer[20]); /* Version (8 bytes) */
@ -718,6 +732,7 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc)
BYTE* INChannelCookie;
BYTE* AssociationGroupId;
BYTE* VirtualConnectionCookie;
int status;
rts_pdu_header_init(&header);
header.frag_length = 104;
@ -734,6 +749,8 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc)
AssociationGroupId = (BYTE*) &(rpc->VirtualConnection->AssociationGroupId);
buffer = (BYTE*) malloc(header.frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_version_command_write(&buffer[20]); /* Version (8 bytes) */
@ -745,11 +762,11 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc)
length = header.frag_length;
rpc_in_write(rpc, buffer, length);
status = rpc_in_write(rpc, buffer, length);
free(buffer);
return 0;
return status;
}
/* CONN/C Sequence */
@ -795,12 +812,15 @@ int rts_send_keep_alive_pdu(rdpRpc* rpc)
DEBUG_RPC("Sending Keep-Alive RTS PDU");
buffer = (BYTE*) malloc(header.frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_client_keepalive_command_write(&buffer[20], rpc->CurrentKeepAliveInterval); /* ClientKeepAlive (8 bytes) */
length = header.frag_length;
rpc_in_write(rpc, buffer, length);
if (rpc_in_write(rpc, buffer, length) < 0)
return -1;
free(buffer);
return length;
@ -830,6 +850,8 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc)
rpc->VirtualConnection->DefaultOutChannel->AvailableWindowAdvertised;
buffer = (BYTE*) malloc(header.frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_destination_command_write(&buffer[20], FDOutProxy); /* Destination Command (8 bytes) */
@ -839,7 +861,8 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc)
length = header.frag_length;
rpc_in_write(rpc, buffer, length);
if (rpc_in_write(rpc, buffer, length) < 0)
return -1;
free(buffer);
return 0;
@ -923,12 +946,15 @@ int rts_send_ping_pdu(rdpRpc* rpc)
DEBUG_RPC("Sending Ping RTS PDU");
buffer = (BYTE*) malloc(header.frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
length = header.frag_length;
rpc_in_write(rpc, buffer, length);
if (rpc_in_write(rpc, buffer, length) < 0)
return -1;
free(buffer);
return length;
@ -1020,22 +1046,18 @@ int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
rts_extract_pdu_signature(rpc, &signature, rts);
SignatureId = rts_identify_pdu_signature(rpc, &signature, NULL);
if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK)
switch (SignatureId)
{
return rts_recv_flow_control_ack_pdu(rpc, buffer, length);
}
else if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION)
{
return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length);
}
else if (SignatureId == RTS_PDU_PING)
{
rts_send_ping_pdu(rpc);
}
else
{
fprintf(stderr, "Unimplemented signature id: 0x%08X\n", SignatureId);
rts_print_pdu_signature(rpc, &signature);
case RTS_PDU_FLOW_CONTROL_ACK:
return rts_recv_flow_control_ack_pdu(rpc, buffer, length);
case RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION:
return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length);
case RTS_PDU_PING:
return rts_send_ping_pdu(rpc);
default:
fprintf(stderr, "%s: unimplemented signature id: 0x%08X\n", __FUNCTION__, SignatureId);
rts_print_pdu_signature(rpc, &signature);
break;
}
return 0;

View File

@ -234,7 +234,6 @@ BOOL rts_match_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_rt
return FALSE;
status = rts_command_length(rpc, CommandType, &buffer[offset], length);
if (status < 0)
return FALSE;
@ -272,7 +271,6 @@ int rts_extract_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_r
signature->CommandTypes[i] = CommandType;
status = rts_command_length(rpc, CommandType, &buffer[offset], length);
if (status < 0)
return FALSE;
@ -294,22 +292,22 @@ UINT32 rts_identify_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, RTS_P
{
pSignature = RTS_PDU_SIGNATURE_TABLE[i].Signature;
if (signature->Flags == pSignature->Flags)
if (signature->Flags != pSignature->Flags)
continue;
if (signature->NumberOfCommands != pSignature->NumberOfCommands)
continue;
for (j = 0; j < signature->NumberOfCommands; j++)
{
if (signature->NumberOfCommands == pSignature->NumberOfCommands)
{
for (j = 0; j < signature->NumberOfCommands; j++)
{
if (signature->CommandTypes[j] != pSignature->CommandTypes[j])
continue;
}
if (entry)
*entry = &RTS_PDU_SIGNATURE_TABLE[i];
return RTS_PDU_SIGNATURE_TABLE[i].SignatureId;
}
if (signature->CommandTypes[j] != pSignature->CommandTypes[j])
continue;
}
if (entry)
*entry = &RTS_PDU_SIGNATURE_TABLE[i];
return RTS_PDU_SIGNATURE_TABLE[i].SignatureId;
}
return 0;

View File

@ -33,9 +33,9 @@
#include <winpr/stream.h>
#include "rpc_client.h"
#include "tsg.h"
/**
* RPC Functions: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378623/
* Remote Procedure Call: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378651/
@ -96,7 +96,9 @@ DWORD TsProxySendToServer(handle_t IDL_handle, byte pRpcMessage[], UINT32 count,
}
length = 28 + totalDataBytes;
buffer = (BYTE*) malloc(length);
buffer = (BYTE*) calloc(1, length);
if (!buffer)
return -1;
s = Stream_New(buffer, length);
@ -228,8 +230,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
if (!(pdu->Flags & RPC_PDU_FLAG_STUB))
buffer = &buffer[24];
packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET));
ZeroMemory(packet, sizeof(TSG_PACKET));
packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET));
if (!packet)
return FALSE;
offset = 4; // Skip Packet Pointer
packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */
@ -237,8 +240,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
if ((packet->packetId == TSG_PACKET_TYPE_CAPS_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_CAPS_RESPONSE))
{
packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) malloc(sizeof(TSG_PACKET_CAPS_RESPONSE));
ZeroMemory(packetCapsResponse, sizeof(TSG_PACKET_CAPS_RESPONSE));
packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) calloc(1, sizeof(TSG_PACKET_CAPS_RESPONSE));
if (!packetCapsResponse) // TODO: correct cleanup
return FALSE;
packet->tsgPacket.packetCapsResponse = packetCapsResponse;
/* PacketQuarResponsePtr (4 bytes) */
@ -258,8 +262,7 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
IsMessagePresent = *((UINT32*) &buffer[offset]);
offset += 4;
MessageSwitchValue = *((UINT32*) &buffer[offset]);
DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d",
IsMessagePresent, MessageSwitchValue);
DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d", IsMessagePresent, MessageSwitchValue);
offset += 4;
}
@ -289,8 +292,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
offset += 4;
}
versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS));
ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS));
versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS));
if (!versionCaps) // TODO: correct cleanup
return FALSE;
packetCapsResponse->pktQuarEncResponse.versionCaps = versionCaps;
versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */
@ -317,8 +321,10 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
/* 4-byte alignment */
rpc_offset_align(&offset, 4);
tsgCaps = (PTSG_PACKET_CAPABILITIES) malloc(sizeof(TSG_PACKET_CAPABILITIES));
ZeroMemory(tsgCaps, sizeof(TSG_PACKET_CAPABILITIES));
tsgCaps = (PTSG_PACKET_CAPABILITIES) calloc(1, sizeof(TSG_PACKET_CAPABILITIES));
if (!tsgCaps)
return FALSE;
versionCaps->tsgCaps = tsgCaps;
offset += 4; /* MaxCount (4 bytes) */
@ -406,8 +412,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
}
else if ((packet->packetId == TSG_PACKET_TYPE_QUARENC_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_QUARENC_RESPONSE))
{
packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) malloc(sizeof(TSG_PACKET_QUARENC_RESPONSE));
ZeroMemory(packetQuarEncResponse, sizeof(TSG_PACKET_QUARENC_RESPONSE));
packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) calloc(1, sizeof(TSG_PACKET_QUARENC_RESPONSE));
if (!packetQuarEncResponse) // TODO: handle cleanup
return FALSE;
packet->tsgPacket.packetQuarEncResponse = packetQuarEncResponse;
/* PacketQuarResponsePtr (4 bytes) */
@ -443,8 +450,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
offset += 4;
}
versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS));
ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS));
versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS));
if (!versionCaps) // TODO: handle cleanup
return FALSE;
packetQuarEncResponse->versionCaps = versionCaps;
versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */
@ -779,8 +787,9 @@ BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
if (!(pdu->Flags & RPC_PDU_FLAG_STUB))
buffer = &buffer[24];
packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET));
ZeroMemory(packet, sizeof(TSG_PACKET));
packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET));
if (!packet)
return FALSE;
offset = 4;
packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */
@ -923,6 +932,8 @@ BOOL TsProxyCreateChannelWriteRequest(rdpTsg* tsg, PTUNNEL_CONTEXT_HANDLE_NOSERI
length = 60 + (count * 2);
buffer = (BYTE*) malloc(length);
if (!buffer)
return FALSE;
/* TunnelContext */
handle = (CONTEXT_HANDLE*) tunnelContext;
@ -1526,48 +1537,53 @@ int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length)
return CopyLength;
}
else
tsg->pdu = rpc_recv_peek_pdu(rpc);
if (!tsg->pdu)
{
tsg->pdu = rpc_recv_peek_pdu(rpc);
if (!tsg->rpc->client->SynchronousReceive)
return 0;
if (!tsg->pdu)
{
if (tsg->rpc->client->SynchronousReceive)
return tsg_read(tsg, data, length);
else
return 0;
}
tsg->PendingPdu = TRUE;
tsg->BytesAvailable = Stream_Length(tsg->pdu->s);
tsg->BytesRead = 0;
CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable;
CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength);
tsg->BytesAvailable -= CopyLength;
tsg->BytesRead += CopyLength;
if (tsg->BytesAvailable < 1)
{
tsg->PendingPdu = FALSE;
rpc_recv_dequeue_pdu(rpc);
rpc_client_receive_pool_return(rpc, tsg->pdu);
}
return CopyLength;
// weird !!!!
return tsg_read(tsg, data, length);
}
tsg->PendingPdu = TRUE;
tsg->BytesAvailable = Stream_Length(tsg->pdu->s);
tsg->BytesRead = 0;
CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable;
CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength);
tsg->BytesAvailable -= CopyLength;
tsg->BytesRead += CopyLength;
if (tsg->BytesAvailable < 1)
{
tsg->PendingPdu = FALSE;
rpc_recv_dequeue_pdu(rpc);
rpc_client_receive_pool_return(rpc, tsg->pdu);
}
return CopyLength;
}
int tsg_write(rdpTsg* tsg, BYTE* data, UINT32 length)
{
int status;
if (tsg->rpc->transport->layer == TRANSPORT_LAYER_CLOSED)
{
fprintf(stderr, "tsg_write error: connection lost\n");
fprintf(stderr, "%s: error, connection lost\n", __FUNCTION__);
return -1;
}
return TsProxySendToServer((handle_t) tsg, data, 1, &length);
status = TsProxySendToServer((handle_t) tsg, data, 1, &length);
if (status < 0)
return -1;
return length;
}
BOOL tsg_set_blocking_mode(rdpTsg* tsg, BOOL blocking)
@ -1584,18 +1600,21 @@ rdpTsg* tsg_new(rdpTransport* transport)
{
rdpTsg* tsg;
tsg = (rdpTsg*) malloc(sizeof(rdpTsg));
ZeroMemory(tsg, sizeof(rdpTsg));
if (tsg != NULL)
{
tsg->transport = transport;
tsg->settings = transport->settings;
tsg->rpc = rpc_new(tsg->transport);
tsg->PendingPdu = FALSE;
}
tsg = (rdpTsg*) calloc(1, sizeof(rdpTsg));
if (!tsg)
return NULL;
tsg->transport = transport;
tsg->settings = transport->settings;
tsg->rpc = rpc_new(tsg->transport);
if (!tsg->rpc)
goto out_free;
tsg->PendingPdu = FALSE;
return tsg;
out_free:
free(tsg);
return NULL;
}
void tsg_free(rdpTsg* tsg)

View File

@ -241,7 +241,7 @@ int license_recv(rdpLicense* license, wStream* s)
if (!rdp_read_header(license->rdp, s, &length, &channelId))
{
fprintf(stderr, "Incorrect RDP header.\n");
fprintf(stderr, "%s: Incorrect RDP header.\n", __FUNCTION__);
return -1;
}
@ -252,7 +252,7 @@ int license_recv(rdpLicense* license, wStream* s)
{
if (!rdp_decrypt(license->rdp, s, length - 4, securityFlags))
{
fprintf(stderr, "rdp_decrypt failed\n");
fprintf(stderr, "%s: rdp_decrypt failed\n", __FUNCTION__);
return -1;
}
}
@ -268,7 +268,7 @@ int license_recv(rdpLicense* license, wStream* s)
if (status < 0)
{
fprintf(stderr, "Unexpected license packet.\n");
fprintf(stderr, "%s: unexpected license packet.\n", __FUNCTION__);
return status;
}
@ -308,7 +308,7 @@ int license_recv(rdpLicense* license, wStream* s)
break;
default:
fprintf(stderr, "invalid bMsgType:%d\n", bMsgType);
fprintf(stderr, "%s: invalid bMsgType:%d\n", __FUNCTION__, bMsgType);
return FALSE;
}

View File

@ -1056,26 +1056,29 @@ rdpMcs* mcs_new(rdpTransport* transport)
{
rdpMcs* mcs;
mcs = (rdpMcs*) malloc(sizeof(rdpMcs));
mcs = (rdpMcs *)calloc(1, sizeof(rdpMcs));
if (!mcs)
return NULL;
if (mcs)
{
ZeroMemory(mcs, sizeof(rdpMcs));
mcs->transport = transport;
mcs->settings = transport->settings;
mcs->transport = transport;
mcs->settings = transport->settings;
mcs_init_domain_parameters(&mcs->targetParameters, 34, 2, 0, 0xFFFF);
mcs_init_domain_parameters(&mcs->minimumParameters, 1, 1, 1, 0x420);
mcs_init_domain_parameters(&mcs->maximumParameters, 0xFFFF, 0xFC17, 0xFFFF, 0xFFFF);
mcs_init_domain_parameters(&mcs->domainParameters, 0, 0, 0, 0xFFFF);
mcs_init_domain_parameters(&mcs->targetParameters, 34, 2, 0, 0xFFFF);
mcs_init_domain_parameters(&mcs->minimumParameters, 1, 1, 1, 0x420);
mcs_init_domain_parameters(&mcs->maximumParameters, 0xFFFF, 0xFC17, 0xFFFF, 0xFFFF);
mcs_init_domain_parameters(&mcs->domainParameters, 0, 0, 0, 0xFFFF);
mcs->channelCount = 0;
mcs->channelMaxCount = CHANNEL_MAX_COUNT;
mcs->channels = (rdpMcsChannel*) calloc(mcs->channelMaxCount, sizeof(rdpMcsChannel));
}
mcs->channelCount = 0;
mcs->channelMaxCount = CHANNEL_MAX_COUNT;
mcs->channels = (rdpMcsChannel *)calloc(mcs->channelMaxCount, sizeof(rdpMcsChannel));
if (!mcs->channels)
goto out_free;
return mcs;
out_free:
free(mcs);
return NULL;
}
/**

View File

@ -52,13 +52,13 @@ static BOOL freerdp_peer_initialize(freerdp_peer* client)
fprintf(stderr, "%s: inavlid RDP key file %s\n", __FUNCTION__, settings->RdpKeyFile);
return FALSE;
}
if (settings->RdpServerRsaKey->ModulusLength > 256)
{
fprintf(stderr, "%s: Key sizes > 2048 are currently not supported for RDP security.\n", __FUNCTION__);
fprintf(stderr, "%s: Set a different key file than %s\n", __FUNCTION__, settings->RdpKeyFile);
exit(1);
}
}
return TRUE;
@ -77,12 +77,13 @@ static HANDLE freerdp_peer_get_event_handle(freerdp_peer* client)
return client->context->rdp->transport->TcpIn->event;
}
static BOOL freerdp_peer_check_fds(freerdp_peer* client)
static BOOL freerdp_peer_check_fds(freerdp_peer* peer)
{
int status;
rdpRdp* rdp;
rdp = client->context->rdp;
rdp = peer->context->rdp;
status = rdp_check_fds(rdp);
@ -413,6 +414,19 @@ static int freerdp_peer_send_channel_data(freerdp_peer* client, UINT16 channelId
return rdp_send_channel_data(client->context->rdp, channelId, data, size);
}
static BOOL freerdp_peer_is_write_blocked(freerdp_peer* peer)
{
return tranport_is_write_blocked(peer->context->rdp->transport);
}
static int freerdp_peer_drain_output_buffer(freerdp_peer* peer)
{
rdpTransport *transport = peer->context->rdp->transport;
return tranport_drain_output_buffer(transport);
}
void freerdp_peer_context_new(freerdp_peer* client)
{
rdpRdp* rdp;
@ -445,6 +459,9 @@ void freerdp_peer_context_new(freerdp_peer* client)
rdp->transport->ReceiveExtra = client;
transport_set_blocking_mode(rdp->transport, FALSE);
client->IsWriteBlocked = freerdp_peer_is_write_blocked;
client->DrainOutputBuffer = freerdp_peer_drain_output_buffer;
IFCALL(client->ContextNew, client, client->context);
}
@ -473,6 +490,8 @@ freerdp_peer* freerdp_peer_new(int sockfd)
client->Close = freerdp_peer_close;
client->Disconnect = freerdp_peer_disconnect;
client->SendChannelData = freerdp_peer_send_channel_data;
client->IsWriteBlocked = freerdp_peer_is_write_blocked;
client->DrainOutputBuffer = freerdp_peer_drain_output_buffer;
}
return client;
@ -480,10 +499,10 @@ freerdp_peer* freerdp_peer_new(int sockfd)
void freerdp_peer_free(freerdp_peer* client)
{
if (client)
{
rdp_free(client->context->rdp);
free(client->context);
free(client);
}
if (!client)
return;
rdp_free(client->context->rdp);
free(client->context);
free(client);
}

View File

@ -209,6 +209,7 @@ rdpSettings* freerdp_settings_new(DWORD flags)
ZeroMemory(settings, sizeof(rdpSettings));
settings->ServerMode = (flags & FREERDP_SETTINGS_SERVER_MODE) ? TRUE : FALSE;
settings->WaitForOutputBufferFlush = TRUE;
settings->DesktopWidth = 1024;
settings->DesktopHeight = 768;
@ -582,6 +583,7 @@ rdpSettings* freerdp_settings_clone(rdpSettings* settings)
/* BOOL values */
_settings->ServerMode = settings->ServerMode; /* 16 */
_settings->WaitForOutputBufferFlush = settings->WaitForOutputBufferFlush; /* 25 */
_settings->NetworkAutoDetect = settings->NetworkAutoDetect; /* 137 */
_settings->SupportAsymetricKeys = settings->SupportAsymetricKeys; /* 138 */
_settings->SupportErrorInfoPdu = settings->SupportErrorInfoPdu; /* 139 */

View File

@ -66,6 +66,165 @@
#include "tcp.h"
long transport_bio_buffered_callback(BIO* bio, int mode, const char* argp, int argi, long argl, long ret)
{
return 1;
}
static int transport_bio_buffered_write(BIO* bio, const char* buf, int num)
{
int status, ret;
rdpTcp *tcp = (rdpTcp *)bio->ptr;
int nchunks, committedBytes, i;
DataChunk chunks[2];
ret = num;
BIO_clear_retry_flags(bio);
tcp->writeBlocked = FALSE;
/* we directly append extra bytes in the xmit buffer, this could be prevented
* but for now it makes the code more simple.
*/
if (buf && num && !ringbuffer_write(&tcp->xmitBuffer, (const BYTE *)buf, num))
{
fprintf(stderr, "%s: an error occured when writing(toWrite=%d)\n", __FUNCTION__, num);
return -1;
}
committedBytes = 0;
nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer));
for (i = 0; i < nchunks; i++)
{
while (chunks[i].size)
{
status = BIO_write(bio->next_bio, chunks[i].data, chunks[i].size);
/*fprintf(stderr, "%s: i=%d/%d size=%d/%d status=%d retry=%d\n", __FUNCTION__, i, nchunks,
chunks[i].size, ringbuffer_used(&tcp->xmitBuffer), status,
BIO_should_retry(bio->next_bio)
);*/
if (status <= 0)
{
if (BIO_should_retry(bio->next_bio))
{
tcp->writeBlocked = TRUE;
goto out; /* EWOULDBLOCK */
}
/* any other is an error, but we still have to commit written bytes */
ret = -1;
goto out;
}
committedBytes += status;
chunks[i].size -= status;
chunks[i].data += status;
}
}
out:
ringbuffer_commit_read_bytes(&tcp->xmitBuffer, committedBytes);
return ret;
}
static int transport_bio_buffered_read(BIO* bio, char* buf, int size)
{
int status;
rdpTcp *tcp = (rdpTcp *)bio->ptr;
tcp->readBlocked = FALSE;
BIO_clear_retry_flags(bio);
status = BIO_read(bio->next_bio, buf, size);
/*fprintf(stderr, "%s: size=%d status=%d shouldRetry=%d\n", __FUNCTION__, size, status, BIO_should_retry(bio->next_bio)); */
if (status <= 0 && BIO_should_retry(bio->next_bio))
{
BIO_set_retry_read(bio);
tcp->readBlocked = TRUE;
}
return status;
}
static int transport_bio_buffered_puts(BIO* bio, const char* str)
{
return 1;
}
static int transport_bio_buffered_gets(BIO* bio, char* str, int size)
{
return 1;
}
static long transport_bio_buffered_ctrl(BIO* bio, int cmd, long arg1, void* arg2)
{
rdpTcp *tcp = (rdpTcp *)bio->ptr;
switch (cmd)
{
case BIO_CTRL_FLUSH:
return 1;
case BIO_CTRL_WPENDING:
return ringbuffer_used(&tcp->xmitBuffer);
case BIO_CTRL_PENDING:
return 0;
default:
/*fprintf(stderr, "%s: passing to next BIO, bio=%p cmd=%d arg1=%d arg2=%p\n", __FUNCTION__, bio, cmd, arg1, arg2); */
return BIO_ctrl(bio->next_bio, cmd, arg1, arg2);
}
return 0;
}
static int transport_bio_buffered_new(BIO* bio)
{
bio->init = 1;
bio->num = 0;
bio->ptr = NULL;
bio->flags = 0;
return 1;
}
static int transport_bio_buffered_free(BIO* bio)
{
return 1;
}
static BIO_METHOD transport_bio_buffered_socket_methods =
{
BIO_TYPE_BUFFERED,
"BufferedSocket",
transport_bio_buffered_write,
transport_bio_buffered_read,
transport_bio_buffered_puts,
transport_bio_buffered_gets,
transport_bio_buffered_ctrl,
transport_bio_buffered_new,
transport_bio_buffered_free,
NULL,
};
BIO_METHOD* BIO_s_buffered_socket(void)
{
return &transport_bio_buffered_socket_methods;
}
BOOL transport_bio_buffered_drain(BIO *bio)
{
rdpTcp *tcp = (rdpTcp *)bio->ptr;
int status;
if (!ringbuffer_used(&tcp->xmitBuffer))
return 1;
status = transport_bio_buffered_write(bio, NULL, 0);
return status >= 0;
}
void tcp_get_ip_address(rdpTcp* tcp)
{
BYTE* ip;
@ -136,62 +295,65 @@ BOOL tcp_connect(rdpTcp* tcp, const char* hostname, int port)
if (hostname[0] == '/')
{
tcp->sockfd = freerdp_uds_connect(hostname);
if (tcp->sockfd < 0)
return FALSE;
tcp->socketBio = BIO_new_fd(tcp->sockfd, 1);
if (!tcp->socketBio)
return FALSE;
}
else
{
tcp->sockfd = freerdp_tcp_connect(hostname, port);
if (tcp->sockfd < 0)
tcp->socketBio = BIO_new(BIO_s_connect());
if (!tcp->socketBio)
return FALSE;
SetEventFileDescriptor(tcp->event, tcp->sockfd);
if (BIO_set_conn_hostname(tcp->socketBio, hostname) < 0 || BIO_set_conn_int_port(tcp->socketBio, &port) < 0)
return FALSE;
tcp_get_ip_address(tcp);
tcp_get_mac_address(tcp);
if (BIO_do_connect(tcp->socketBio) <= 0)
return FALSE;
option_value = 1;
option_len = sizeof(option_value);
setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len);
/* receive buffer must be a least 32 K */
if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0)
{
if (option_value < (1024 * 32))
{
option_value = 1024 * 32;
option_len = sizeof(option_value);
setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len);
}
}
tcp_set_keep_alive_mode(tcp);
tcp->sockfd = BIO_get_fd(tcp->socketBio, NULL);
}
SetEventFileDescriptor(tcp->event, tcp->sockfd);
tcp_get_ip_address(tcp);
tcp_get_mac_address(tcp);
option_value = 1;
option_len = sizeof(option_value);
if (setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len) < 0)
fprintf(stderr, "%s: unable to set TCP_NODELAY\n", __FUNCTION__);
/* receive buffer must be a least 32 K */
if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0)
{
if (option_value < (1024 * 32))
{
option_value = 1024 * 32;
option_len = sizeof(option_value);
if (setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len) < 0)
{
fprintf(stderr, "%s: unable to set receive buffer len\n", __FUNCTION__);
return FALSE;
}
}
}
if (!tcp_set_keep_alive_mode(tcp))
return FALSE;
tcp->bufferedBio = BIO_new(BIO_s_buffered_socket());
if (!tcp->bufferedBio)
return FALSE;
tcp->bufferedBio->ptr = tcp;
tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio);
return TRUE;
}
int tcp_read(rdpTcp* tcp, BYTE* data, int length)
{
return freerdp_tcp_read(tcp->sockfd, data, length);
}
int tcp_write(rdpTcp* tcp, BYTE* data, int length)
{
return freerdp_tcp_write(tcp->sockfd, data, length);
}
int tcp_wait_read(rdpTcp* tcp)
{
return freerdp_tcp_wait_read(tcp->sockfd);
}
int tcp_wait_write(rdpTcp* tcp)
{
return freerdp_tcp_wait_write(tcp->sockfd);
}
BOOL tcp_disconnect(rdpTcp* tcp)
{
@ -209,7 +371,7 @@ BOOL tcp_set_blocking_mode(rdpTcp* tcp, BOOL blocking)
if (flags == -1)
{
fprintf(stderr, "tcp_set_blocking_mode: fcntl failed.\n");
fprintf(stderr, "%s: fcntl failed, %s.\n", __FUNCTION__, strerror(errno));
return FALSE;
}
@ -297,6 +459,31 @@ int tcp_attach(rdpTcp* tcp, int sockfd)
{
tcp->sockfd = sockfd;
SetEventFileDescriptor(tcp->event, tcp->sockfd);
ringbuffer_commit_read_bytes(&tcp->xmitBuffer, ringbuffer_used(&tcp->xmitBuffer));
if (tcp->socketBio)
{
if (BIO_set_fd(tcp->socketBio, sockfd, 1) < 0)
return -1;
}
else
{
tcp->socketBio = BIO_new_socket(sockfd, 1);
if (!tcp->socketBio)
return -1;
}
if (!tcp->bufferedBio)
{
tcp->bufferedBio = BIO_new(BIO_s_buffered_socket());
if (!tcp->bufferedBio)
return FALSE;
tcp->bufferedBio->ptr = tcp;
tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio);
}
return 0;
}
@ -316,25 +503,34 @@ rdpTcp* tcp_new(rdpSettings* settings)
{
rdpTcp* tcp;
tcp = (rdpTcp*) malloc(sizeof(rdpTcp));
tcp = (rdpTcp *)calloc(1, sizeof(rdpTcp));
if (!tcp)
return NULL;
if (tcp)
{
ZeroMemory(tcp, sizeof(rdpTcp));
if (!ringbuffer_init(&tcp->xmitBuffer, 0x10000))
goto out_free;
tcp->sockfd = -1;
tcp->settings = settings;
tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd);
}
tcp->sockfd = -1;
tcp->settings = settings;
tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd);
if (!tcp->event || tcp->event == INVALID_HANDLE_VALUE)
goto out_ringbuffer;
return tcp;
out_ringbuffer:
ringbuffer_destroy(&tcp->xmitBuffer);
out_free:
free(tcp);
return NULL;
}
void tcp_free(rdpTcp* tcp)
{
if (tcp)
{
CloseHandle(tcp->event);
free(tcp);
}
if (!tcp)
return;
ringbuffer_destroy(&tcp->xmitBuffer);
CloseHandle(tcp->event);
free(tcp);
}

View File

@ -31,10 +31,15 @@
#include <winpr/stream.h>
#include <winpr/winsock.h>
#include <freerdp/utils/ringbuffer.h>
#include <openssl/bio.h>
#ifndef MSG_NOSIGNAL
#define MSG_NOSIGNAL 0
#endif
#define BIO_TYPE_BUFFERED 66
typedef struct rdp_tcp rdpTcp;
struct rdp_tcp
@ -46,6 +51,12 @@ struct rdp_tcp
#ifdef _WIN32
WSAEVENT wsa_event;
#endif
BIO *socketBio;
BIO *bufferedBio;
RingBuffer xmitBuffer;
BOOL writeBlocked;
BOOL readBlocked;
HANDLE event;
};

View File

@ -33,7 +33,9 @@
#include <freerdp/error.h>
#include <freerdp/utils/tcp.h>
#include <freerdp/utils/ringbuffer.h>
#include <openssl/bio.h>
#include <time.h>
#include <errno.h>
#include <fcntl.h>
@ -41,6 +43,12 @@
#ifndef _WIN32
#include <netdb.h>
#include <sys/socket.h>
#include <sys/select.h>
#include <sys/time.h>
#endif
#ifdef HAVE_VALGRIND_MEMCHECK_H
#include <valgrind/memcheck.h>
#endif
#include "tpkt.h"
@ -48,6 +56,7 @@
#include "transport.h"
#include "rdp.h"
#define BUFFER_SIZE 16384
static void* transport_client_thread(void* arg);
@ -69,6 +78,7 @@ void transport_attach(rdpTransport* transport, int sockfd)
tcp_attach(transport->TcpIn, sockfd);
transport->SplitInputOutput = FALSE;
transport->TcpOut = transport->TcpIn;
transport->frontBio = transport->TcpIn->bufferedBio;
}
void transport_stop(rdpTransport* transport)
@ -98,18 +108,9 @@ BOOL transport_disconnect(rdpTransport* transport)
transport_stop(transport);
if (transport->layer == TRANSPORT_LAYER_TLS)
status &= tls_disconnect(transport->TlsIn);
if ((transport->layer == TRANSPORT_LAYER_TSG) || (transport->layer == TRANSPORT_LAYER_TSG_TLS))
{
status &= tsg_disconnect(transport->tsg);
}
else
{
status &= tcp_disconnect(transport->TcpIn);
}
BIO_free_all(transport->frontBio);
transport->frontBio = 0;
return status;
}
@ -131,16 +132,16 @@ static int transport_bio_tsg_write(BIO* bio, const char* buf, int num)
rdpTsg* tsg;
tsg = (rdpTsg*) bio->ptr;
status = tsg_write(tsg, (BYTE*) buf, num);
BIO_clear_retry_flags(bio);
status = tsg_write(tsg, (BYTE*) buf, num);
if (status > 0)
return status;
if (status == 0)
{
BIO_set_retry_write(bio);
}
return status < 0 ? 0 : num;
return -1;
}
static int transport_bio_tsg_read(BIO* bio, char* buf, int size)
@ -222,8 +223,13 @@ BIO_METHOD* BIO_s_tsg(void)
return &transport_bio_tsg_methods;
}
BOOL transport_connect_tls(rdpTransport* transport)
{
rdpSettings *settings = transport->settings;
rdpTls *targetTls;
BIO *targetBio;
int tls_status;
freerdp* instance;
rdpContext* context;
@ -234,61 +240,33 @@ BOOL transport_connect_tls(rdpTransport* transport)
if (transport->layer == TRANSPORT_LAYER_TSG)
{
transport->TsgTls = tls_new(transport->settings);
transport->TsgTls->methods = BIO_s_tsg();
transport->TsgTls->tsg = (void*) transport->tsg;
transport->layer = TRANSPORT_LAYER_TSG_TLS;
transport->TsgTls->hostname = transport->settings->ServerHostname;
transport->TsgTls->port = transport->settings->ServerPort;
targetTls = transport->TsgTls;
targetBio = transport->frontBio;
}
else
{
if (!transport->TlsIn)
transport->TlsIn = tls_new(settings);
if (transport->TsgTls->port == 0)
transport->TsgTls->port = 3389;
if (!transport->TlsOut)
transport->TlsOut = transport->TlsIn;
tls_status = tls_connect(transport->TsgTls);
targetTls = transport->TlsIn;
targetBio = transport->TcpIn->bufferedBio;
if (tls_status < 1)
{
if (tls_status < 0)
{
if (!connectErrorCode)
connectErrorCode = TLSCONNECTERROR;
if (!freerdp_get_last_error(context))
freerdp_set_last_error(context, FREERDP_ERROR_TLS_CONNECT_FAILED);
}
else
{
if (!freerdp_get_last_error(context))
freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED);
}
tls_free(transport->TsgTls);
transport->TsgTls = NULL;
return FALSE;
}
return TRUE;
transport->layer = TRANSPORT_LAYER_TLS;
}
if (!transport->TlsIn)
transport->TlsIn = tls_new(transport->settings);
if (!transport->TlsOut)
transport->TlsOut = transport->TlsIn;
targetTls->hostname = settings->ServerHostname;
targetTls->port = settings->ServerPort;
transport->layer = TRANSPORT_LAYER_TLS;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (targetTls->port == 0)
targetTls->port = 3389;
transport->TlsIn->hostname = transport->settings->ServerHostname;
transport->TlsIn->port = transport->settings->ServerPort;
if (transport->TlsIn->port == 0)
transport->TlsIn->port = 3389;
tls_status = tls_connect(transport->TlsIn);
tls_status = tls_connect(targetTls, targetBio);
if (tls_status < 1)
{
@ -306,13 +284,13 @@ BOOL transport_connect_tls(rdpTransport* transport)
freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED);
}
tls_free(transport->TlsIn);
if (transport->TlsIn == transport->TlsOut)
transport->TlsIn = transport->TlsOut = NULL;
else
transport->TlsIn = NULL;
return FALSE;
}
transport->frontBio = targetTls->bio;
if (!transport->frontBio)
{
fprintf(stderr, "%s: unable to prepend a filtering TLS bio", __FUNCTION__);
return FALSE;
}
@ -323,6 +301,7 @@ BOOL transport_connect_nla(rdpTransport* transport)
{
freerdp* instance;
rdpSettings* settings;
rdpCredssp *credSsp;
settings = transport->settings;
instance = (freerdp*) settings->instance;
@ -338,16 +317,22 @@ BOOL transport_connect_nla(rdpTransport* transport)
if (!transport->credssp)
{
transport->credssp = credssp_new(instance, transport, settings);
if (!transport->credssp)
return FALSE;
transport_set_nla_mode(transport, TRUE);
if (settings->AuthenticationServiceClass)
{
transport->credssp->ServicePrincipalName =
credssp_make_spn(settings->AuthenticationServiceClass, settings->ServerHostname);
if (!transport->credssp->ServicePrincipalName)
return FALSE;
}
}
if (credssp_authenticate(transport->credssp) < 0)
credSsp = transport->credssp;
if (credssp_authenticate(credSsp) < 0)
{
if (!connectErrorCode)
connectErrorCode = AUTHENTICATIONERROR;
@ -361,14 +346,14 @@ BOOL transport_connect_nla(rdpTransport* transport)
"If credentials are valid, the NTLMSSP implementation may be to blame.\n");
transport_set_nla_mode(transport, FALSE);
credssp_free(transport->credssp);
credssp_free(credSsp);
transport->credssp = NULL;
return FALSE;
}
transport_set_nla_mode(transport, FALSE);
credssp_free(transport->credssp);
credssp_free(credSsp);
transport->credssp = NULL;
return TRUE;
@ -380,38 +365,41 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
int tls_status;
freerdp* instance;
rdpContext* context;
rdpSettings *settings = transport->settings;
instance = (freerdp*) transport->settings->instance;
context = instance->context;
tsg = tsg_new(transport);
if (!tsg)
return FALSE;
tsg->transport = transport;
transport->tsg = tsg;
transport->SplitInputOutput = TRUE;
if (!transport->TlsIn)
transport->TlsIn = tls_new(transport->settings);
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
transport->TlsIn->hostname = transport->settings->GatewayHostname;
transport->TlsIn->port = transport->settings->GatewayPort;
if (transport->TlsIn->port == 0)
transport->TlsIn->port = 443;
{
transport->TlsIn = tls_new(settings);
if (!transport->TlsIn)
return FALSE;
}
if (!transport->TlsOut)
transport->TlsOut = tls_new(transport->settings);
{
transport->TlsOut = tls_new(settings);
if (!transport->TlsOut)
return FALSE;
}
transport->TlsOut->sockfd = transport->TcpOut->sockfd;
transport->TlsOut->hostname = transport->settings->GatewayHostname;
transport->TlsOut->port = transport->settings->GatewayPort;
/* put a decent default value for gateway port */
if (!settings->GatewayPort)
settings->GatewayPort = 443;
if (transport->TlsOut->port == 0)
transport->TlsOut->port = 443;
transport->TlsIn->hostname = transport->TlsOut->hostname = settings->GatewayHostname;
transport->TlsIn->port = transport->TlsOut->port = settings->GatewayPort;
tls_status = tls_connect(transport->TlsIn);
tls_status = tls_connect(transport->TlsIn, transport->TcpIn->bufferedBio);
if (tls_status < 1)
{
if (tls_status < 0)
@ -428,8 +416,7 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
return FALSE;
}
tls_status = tls_connect(transport->TlsOut);
tls_status = tls_connect(transport->TlsOut, transport->TcpOut->bufferedBio);
if (tls_status < 1)
{
if (tls_status < 0)
@ -449,6 +436,8 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
if (!tsg_connect(tsg, hostname, port))
return FALSE;
transport->frontBio = BIO_new(BIO_s_tsg());
transport->frontBio->ptr = tsg;
return TRUE;
}
@ -462,15 +451,20 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por
if (transport->GatewayEnabled)
{
transport->layer = TRANSPORT_LAYER_TSG;
transport->SplitInputOutput = TRUE;
transport->TcpOut = tcp_new(settings);
status = tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort);
if (!tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort) ||
!tcp_set_blocking_mode(transport->TcpIn, FALSE))
return FALSE;
if (status)
status = tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort);
if (!tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort) ||
!tcp_set_blocking_mode(transport->TcpOut, FALSE))
return FALSE;
if (status)
status = transport_tsg_connect(transport, hostname, port);
if (!transport_tsg_connect(transport, hostname, port))
return FALSE;
status = TRUE;
}
else
{
@ -478,6 +472,7 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por
transport->SplitInputOutput = FALSE;
transport->TcpOut = transport->TcpIn;
transport->frontBio = transport->TcpIn->bufferedBio;
}
if (status)
@ -510,11 +505,11 @@ BOOL transport_accept_tls(rdpTransport* transport)
transport->TlsOut = transport->TlsIn;
transport->layer = TRANSPORT_LAYER_TLS;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile))
if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, transport->settings->CertificateFile, transport->settings->PrivateKeyFile))
return FALSE;
transport->frontBio = transport->TlsIn->bio;
return TRUE;
}
@ -533,10 +528,10 @@ BOOL transport_accept_nla(rdpTransport* transport)
transport->TlsOut = transport->TlsIn;
transport->layer = TRANSPORT_LAYER_TLS;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile))
if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, settings->CertificateFile, settings->PrivateKeyFile))
return FALSE;
transport->frontBio = transport->TlsIn->bio;
/* Network Level Authentication */
@ -630,56 +625,131 @@ UINT32 nla_header_length(wStream* s)
return length;
}
static int transport_wait_for_read(rdpTransport* transport)
{
struct timeval tv;
fd_set rset, wset;
fd_set *rsetPtr = NULL, *wsetPtr = NULL;
rdpTcp *tcpIn;
tcpIn = transport->TcpIn;
if (tcpIn->readBlocked)
{
rsetPtr = &rset;
FD_ZERO(rsetPtr);
FD_SET(tcpIn->sockfd, rsetPtr);
}
else if (tcpIn->writeBlocked)
{
wsetPtr = &wset;
FD_ZERO(wsetPtr);
FD_SET(tcpIn->sockfd, wsetPtr);
}
if (!wsetPtr && !rsetPtr)
{
USleep(1000);
return 0;
}
tv.tv_sec = 0;
tv.tv_usec = 1000;
return select(tcpIn->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv);
}
static int transport_wait_for_write(rdpTransport* transport)
{
struct timeval tv;
fd_set rset, wset;
fd_set *rsetPtr = NULL, *wsetPtr = NULL;
rdpTcp *tcpOut;
tcpOut = transport->SplitInputOutput ? transport->TcpOut : transport->TcpIn;
if (tcpOut->writeBlocked)
{
wsetPtr = &wset;
FD_ZERO(wsetPtr);
FD_SET(tcpOut->sockfd, wsetPtr);
}
else if (tcpOut->readBlocked)
{
rsetPtr = &rset;
FD_ZERO(rsetPtr);
FD_SET(tcpOut->sockfd, rsetPtr);
}
if (!wsetPtr && !rsetPtr)
{
USleep(1000);
return 0;
}
tv.tv_sec = 0;
tv.tv_usec = 1000;
return select(tcpOut->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv);
}
int transport_read_layer(rdpTransport* transport, BYTE* data, int bytes)
{
int read = 0;
int status = -1;
while (read < bytes)
{
if (transport->layer == TRANSPORT_LAYER_TLS)
status = tls_read(transport->TlsIn, data + read, bytes - read);
else if (transport->layer == TRANSPORT_LAYER_TCP)
status = tcp_read(transport->TcpIn, data + read, bytes - read);
else if (transport->layer == TRANSPORT_LAYER_TSG)
status = tsg_read(transport->tsg, data + read, bytes - read);
else if (transport->layer == TRANSPORT_LAYER_TSG_TLS) {
status = tls_read(transport->TsgTls, data + read, bytes - read);
status = BIO_read(transport->frontBio, data + read, bytes - read);
if (!status)
{
transport->layer = TRANSPORT_LAYER_CLOSED;
return -1;
}
/* blocking means that we can't continue until this is read */
if (!transport->blocking)
return status;
if (status < 0)
{
/* A read error indicates that the peer has dropped the connection */
transport->layer = TRANSPORT_LAYER_CLOSED;
return status;
if (!BIO_should_retry(transport->frontBio))
{
/* something unexpected happened, let's close */
transport->layer = TRANSPORT_LAYER_CLOSED;
return -1;
}
/* non blocking will survive a partial read */
if (!transport->blocking)
return read;
/* blocking means that we can't continue until we have read the number of
* requested bytes */
if (transport_wait_for_read(transport) < 0)
{
fprintf(stderr, "%s: error when selecting for read\n", __FUNCTION__);
return -1;
}
continue;
}
#ifdef HAVE_VALGRIND_MEMCHECK_H
VALGRIND_MAKE_MEM_DEFINED(data + read, bytes - read);
#endif
read += status;
if (status == 0)
{
/*
* instead of sleeping, we should wait timeout on the
* socket but this only happens on initial connection
*/
USleep(transport->SleepInterval);
}
}
return read;
}
int transport_read(rdpTransport* transport, wStream* s)
{
int status;
int position;
int pduLength;
BYTE header[4];
BYTE *header;
int transport_status;
position = 0;
@ -710,7 +780,7 @@ int transport_read(rdpTransport* transport, wStream* s)
position += status;
}
CopyMemory(header, Stream_Buffer(s), 4); /* peek at first 4 bytes */
header = Stream_Buffer(s);
/* if header is present, read exactly one PDU */
@ -802,6 +872,8 @@ static int transport_read_nonblocking(rdpTransport* transport)
return status;
}
BOOL transport_bio_buffered_drain(BIO *bio);
int transport_write(rdpTransport* transport, wStream* s)
{
int length;
@ -827,36 +899,48 @@ int transport_write(rdpTransport* transport, wStream* s)
while (length > 0)
{
if (transport->layer == TRANSPORT_LAYER_TLS)
status = tls_write(transport->TlsOut, Stream_Pointer(s), length);
else if (transport->layer == TRANSPORT_LAYER_TCP)
status = tcp_write(transport->TcpOut, Stream_Pointer(s), length);
else if (transport->layer == TRANSPORT_LAYER_TSG)
status = tsg_write(transport->tsg, Stream_Pointer(s), length);
else if (transport->layer == TRANSPORT_LAYER_TSG_TLS)
status = tls_write(transport->TsgTls, Stream_Pointer(s), length);
status = BIO_write(transport->frontBio, Stream_Pointer(s), length);
if (status < 0)
break; /* error occurred */
if (status == 0)
if (status <= 0)
{
/* when sending is blocked in nonblocking mode, the receiving buffer should be checked */
if (!transport->blocking)
{
/* and in case we do have buffered some data, we set the event so next loop will get it */
if (transport_read_nonblocking(transport) > 0)
SetEvent(transport->ReceiveEvent);
}
/* the buffered BIO that is at the end of the chain always says OK for writing,
* so a retry means that for any reason we need to read. The most probable
* is a SSL or TSG BIO in the chain.
*/
if (!BIO_should_retry(transport->frontBio))
return status;
if (transport->layer == TRANSPORT_LAYER_TLS)
tls_wait_write(transport->TlsOut);
else if (transport->layer == TRANSPORT_LAYER_TCP)
tcp_wait_write(transport->TcpOut);
else if (transport->layer == TRANSPORT_LAYER_TSG_TLS)
tls_wait_write(transport->TsgTls);
else
USleep(transport->SleepInterval);
/* non-blocking can live with blocked IOs */
if (!transport->blocking)
return status;
if (transport_wait_for_write(transport) < 0)
{
fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__);
return -1;
}
continue;
}
if (transport->blocking || transport->settings->WaitForOutputBufferFlush)
{
/* blocking transport, we must ensure the write buffer is really empty */
rdpTcp *out = transport->TcpOut;
while (out->writeBlocked)
{
if (transport_wait_for_write(transport) < 0)
{
fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__);
return -1;
}
if (!transport_bio_buffered_drain(out->bufferedBio))
{
fprintf(stderr, "%s: error when draining outputBuffer\n", __FUNCTION__);
return -1;
}
}
}
length -= status;
@ -945,6 +1029,38 @@ void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD*
}
}
BOOL tranport_is_write_blocked(rdpTransport* transport)
{
if (transport->TcpIn->writeBlocked)
return TRUE;
return transport->SplitInputOutput &&
transport->TcpOut &&
transport->TcpOut->writeBlocked;
}
int tranport_drain_output_buffer(rdpTransport* transport)
{
BOOL ret = FALSE;
/* First try to send some accumulated bytes in the send buffer */
if (transport->TcpIn->writeBlocked)
{
if (!transport_bio_buffered_drain(transport->TcpIn->bufferedBio))
return -1;
ret |= transport->TcpIn->writeBlocked;
}
if (transport->SplitInputOutput && transport->TcpOut && transport->TcpOut->writeBlocked)
{
if (!transport_bio_buffered_drain(transport->TcpOut->bufferedBio))
return -1;
ret |= transport->TcpOut->writeBlocked;
}
return ret;
}
int transport_check_fds(rdpTransport* transport)
{
int pos;
@ -1079,15 +1195,14 @@ int transport_check_fds(rdpTransport* transport)
recv_status = transport->ReceiveCallback(transport, received, transport->ReceiveExtra);
Stream_Release(received);
if (recv_status < 0)
return -1;
if (recv_status == 1)
{
return 1; /* session redirection */
}
Stream_Release(received);
if (recv_status < 0)
return -1;
}
return 0;
@ -1198,80 +1313,107 @@ rdpTransport* transport_new(rdpSettings* settings)
{
rdpTransport* transport;
transport = (rdpTransport*) malloc(sizeof(rdpTransport));
transport = (rdpTransport *)calloc(1, sizeof(rdpTransport));
if (!transport)
return NULL;
if (transport)
{
ZeroMemory(transport, sizeof(rdpTransport));
WLog_Init();
transport->log = WLog_Get("com.freerdp.core.transport");
if (!transport->log)
goto out_free;
WLog_Init();
transport->log = WLog_Get("com.freerdp.core.transport");
transport->TcpIn = tcp_new(settings);
if (!transport->TcpIn)
goto out_free;
transport->TcpIn = tcp_new(settings);
transport->settings = settings;
transport->settings = settings;
/* a small 0.1ms delay when transport is blocking. */
transport->SleepInterval = 100;
/* a small 0.1ms delay when transport is blocking. */
transport->SleepInterval = 100;
transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE);
if (!transport->ReceivePool)
goto out_free_tcpin;
transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE);
/* receive buffer for non-blocking read. */
transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0);
if (!transport->ReceiveBuffer)
goto out_free_receivepool;
/* receive buffer for non-blocking read. */
transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0);
transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
if (!transport->ReceiveEvent || transport->ReceiveEvent == INVALID_HANDLE_VALUE)
goto out_free_receivebuffer;
transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
if (!transport->connectedEvent || transport->connectedEvent == INVALID_HANDLE_VALUE)
goto out_free_receiveEvent;
transport->blocking = TRUE;
transport->GatewayEnabled = FALSE;
transport->blocking = TRUE;
transport->GatewayEnabled = FALSE;
transport->layer = TRANSPORT_LAYER_TCP;
InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000);
InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000);
transport->layer = TRANSPORT_LAYER_TCP;
}
if (!InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000))
goto out_free_connectedEvent;
if (!InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000))
goto out_free_readlock;
return transport;
out_free_readlock:
DeleteCriticalSection(&(transport->ReadLock));
out_free_connectedEvent:
CloseHandle(transport->connectedEvent);
out_free_receiveEvent:
CloseHandle(transport->ReceiveEvent);
out_free_receivebuffer:
StreamPool_Return(transport->ReceivePool, transport->ReceiveBuffer);
out_free_receivepool:
StreamPool_Free(transport->ReceivePool);
out_free_tcpin:
tcp_free(transport->TcpIn);
out_free:
free(transport);
return NULL;
}
void transport_free(rdpTransport* transport)
{
if (transport)
{
transport_stop(transport);
if (!transport)
return;
if (transport->ReceiveBuffer)
Stream_Release(transport->ReceiveBuffer);
transport_stop(transport);
StreamPool_Free(transport->ReceivePool);
if (transport->ReceiveBuffer)
Stream_Release(transport->ReceiveBuffer);
CloseHandle(transport->ReceiveEvent);
CloseHandle(transport->connectedEvent);
StreamPool_Free(transport->ReceivePool);
if (transport->TlsIn)
tls_free(transport->TlsIn);
CloseHandle(transport->ReceiveEvent);
CloseHandle(transport->connectedEvent);
if (transport->TlsOut != transport->TlsIn)
tls_free(transport->TlsOut);
if (transport->TlsIn)
tls_free(transport->TlsIn);
transport->TlsIn = NULL;
transport->TlsOut = NULL;
if (transport->TlsOut != transport->TlsIn)
tls_free(transport->TlsOut);
if (transport->TcpIn)
tcp_free(transport->TcpIn);
transport->TlsIn = NULL;
transport->TlsOut = NULL;
if (transport->TcpOut != transport->TcpIn)
tcp_free(transport->TcpOut);
if (transport->TcpIn)
tcp_free(transport->TcpIn);
transport->TcpIn = NULL;
transport->TcpOut = NULL;
if (transport->TcpOut != transport->TcpIn)
tcp_free(transport->TcpOut);
tsg_free(transport->tsg);
transport->tsg = NULL;
transport->TcpIn = NULL;
transport->TcpOut = NULL;
DeleteCriticalSection(&(transport->ReadLock));
DeleteCriticalSection(&(transport->WriteLock));
tsg_free(transport->tsg);
transport->tsg = NULL;
free(transport);
}
DeleteCriticalSection(&(transport->ReadLock));
DeleteCriticalSection(&(transport->WriteLock));
free(transport);
}

View File

@ -49,11 +49,13 @@ typedef struct rdp_transport rdpTransport;
#include <freerdp/types.h>
#include <freerdp/settings.h>
typedef int (*TransportRecv) (rdpTransport* transport, wStream* stream, void* extra);
struct rdp_transport
{
TRANSPORT_LAYER layer;
BIO *frontBio;
rdpTsg* tsg;
rdpTcp* TcpIn;
rdpTcp* TcpOut;
@ -102,6 +104,8 @@ BOOL transport_set_blocking_mode(rdpTransport* transport, BOOL blocking);
void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled);
void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode);
void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD* count);
BOOL tranport_is_write_blocked(rdpTransport* transport);
int tranport_drain_output_buffer(rdpTransport* transport);
wStream* transport_receive_pool_take(rdpTransport* transport);
int transport_receive_pool_return(rdpTransport* transport, wStream* pdu);

View File

@ -544,7 +544,7 @@ static void update_end_paint(rdpContext* context)
if (update->numberOrders > 0)
{
printf("Sending %d orders\n", update->numberOrders);
fprintf(stderr, "%s: sending %d orders\n", __FUNCTION__, update->numberOrders);
fastpath_send_update_pdu(context->rdp->fastpath, FASTPATH_UPDATETYPE_ORDERS, s);
}

1
libfreerdp/crypto/test/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
TestFreeRDPCrypto.c

View File

@ -28,34 +28,35 @@
#include <winpr/stream.h>
#include <freerdp/utils/tcp.h>
#include <freerdp/utils/ringbuffer.h>
#include <freerdp/crypto/tls.h>
#ifdef HAVE_VALGRIND_MEMCHECK_H
#include <valgrind/memcheck.h>
#endif
#include "../core/tcp.h"
static CryptoCert tls_get_certificate(rdpTls* tls, BOOL peer)
{
CryptoCert cert;
X509* server_cert;
X509* remote_cert;
if (peer)
server_cert = SSL_get_peer_certificate(tls->ssl);
remote_cert = SSL_get_peer_certificate(tls->ssl);
else
server_cert = SSL_get_certificate(tls->ssl);
remote_cert = SSL_get_certificate(tls->ssl);
if (!server_cert)
if (!remote_cert)
{
fprintf(stderr, "tls_get_certificate: failed to get the server TLS certificate\n");
cert = NULL;
}
else
{
cert = malloc(sizeof(*cert));
cert->px509 = server_cert;
fprintf(stderr, "%s: failed to get the server TLS certificate\n", __FUNCTION__);
return NULL;
}
cert = malloc(sizeof(*cert));
if (!cert)
{
X509_free(remote_cert);
return NULL;
}
cert->px509 = remote_cert;
return cert;
}
@ -83,12 +84,14 @@ SecPkgContext_Bindings* tls_get_channel_bindings(X509* cert)
PrefixLength = strlen(TLS_SERVER_END_POINT);
ChannelBindingTokenLength = PrefixLength + CertificateHashLength;
ContextBindings = (SecPkgContext_Bindings*) malloc(sizeof(SecPkgContext_Bindings));
ZeroMemory(ContextBindings, sizeof(SecPkgContext_Bindings));
ContextBindings = (SecPkgContext_Bindings*) calloc(1, sizeof(SecPkgContext_Bindings));
if (!ContextBindings)
return NULL;
ContextBindings->BindingsLength = sizeof(SEC_CHANNEL_BINDINGS) + ChannelBindingTokenLength;
ChannelBindings = (SEC_CHANNEL_BINDINGS*) malloc(ContextBindings->BindingsLength);
ZeroMemory(ChannelBindings, ContextBindings->BindingsLength);
ChannelBindings = (SEC_CHANNEL_BINDINGS*) calloc(1, ContextBindings->BindingsLength);
if (!ChannelBindings)
goto out_free;
ContextBindings->Bindings = ChannelBindings;
ChannelBindings->cbApplicationDataLength = ChannelBindingTokenLength;
@ -99,32 +102,121 @@ SecPkgContext_Bindings* tls_get_channel_bindings(X509* cert)
CopyMemory(&ChannelBindingToken[PrefixLength], CertificateHash, CertificateHashLength);
return ContextBindings;
out_free:
free(ContextBindings);
return NULL;
}
static void tls_ssl_info_callback(const SSL* ssl, int type, int val)
BOOL tls_prepare(rdpTls* tls, BIO *underlying, const SSL_METHOD *method, int options, BOOL clientMode)
{
if (type & SSL_CB_HANDSHAKE_START)
{
}
}
int tls_connect(rdpTls* tls)
{
CryptoCert cert;
long options = 0;
int verify_status;
int connection_status;
tls->ctx = SSL_CTX_new(TLSv1_client_method());
tls->ctx = SSL_CTX_new(method);
if (!tls->ctx)
{
fprintf(stderr, "SSL_CTX_new failed\n");
fprintf(stderr, "%s: SSL_CTX_new failed\n", __FUNCTION__);
return FALSE;
}
SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE);
SSL_CTX_set_options(tls->ctx, options);
SSL_CTX_set_read_ahead(tls->ctx, 1);
tls->bio = BIO_new_ssl(tls->ctx, clientMode);
if (BIO_get_ssl(tls->bio, &tls->ssl) < 0)
{
fprintf(stderr, "%s: unable to retrieve the SSL of the connection\n", __FUNCTION__);
return FALSE;
}
BIO_push(tls->bio, underlying);
return TRUE;
}
int tls_do_handshake(rdpTls* tls, BOOL clientMode)
{
CryptoCert cert;
int verify_status, status;
do
{
struct timeval tv;
fd_set rset;
int fd;
status = BIO_do_handshake(tls->bio);
if (status == 1)
break;
if (!BIO_should_retry(tls->bio))
return -1;
/* we select() only for read even if we should test both read and write
* depending of what have blocked */
FD_ZERO(&rset);
fd = BIO_get_fd(tls->bio, NULL);
if (fd < 0)
{
fprintf(stderr, "%s: unable to retrieve BIO fd\n", __FUNCTION__);
return -1;
}
FD_SET(fd, &rset);
tv.tv_sec = 0;
tv.tv_usec = 10 * 1000; /* 10ms */
status = select(fd + 1, &rset, NULL, NULL, &tv);
if (status < 0)
{
fprintf(stderr, "%s: error during select()\n", __FUNCTION__);
return -1;
}
}
while (TRUE);
if (!clientMode)
return 1;
cert = tls_get_certificate(tls, clientMode);
if (!cert)
{
fprintf(stderr, "%s: tls_get_certificate failed to return the server certificate.\n", __FUNCTION__);
return -1;
}
//SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE);
tls->Bindings = tls_get_channel_bindings(cert->px509);
if (!tls->Bindings)
{
fprintf(stderr, "%s: unable to retrieve bindings\n", __FUNCTION__);
return -1;
}
if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength))
{
fprintf(stderr, "%s: crypto_cert_get_public_key failed to return the server public key.\n", __FUNCTION__);
tls_free_certificate(cert);
return -1;
}
verify_status = tls_verify_certificate(tls, cert, tls->hostname, tls->port);
if (verify_status < 1)
{
fprintf(stderr, "%s: certificate not trusted, aborting.\n", __FUNCTION__);
tls_disconnect(tls);
tls_free_certificate(cert);
return 0;
}
tls_free_certificate(cert);
return verify_status;
}
int tls_connect(rdpTls* tls, BIO *underlying)
{
int options = 0;
/**
* SSL_OP_NO_COMPRESSION:
@ -155,96 +247,19 @@ int tls_connect(rdpTls* tls)
*/
options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
SSL_CTX_set_options(tls->ctx, options);
if (!tls_prepare(tls, underlying, TLSv1_client_method(), options, TRUE))
return FALSE;
tls->ssl = SSL_new(tls->ctx);
if (!tls->ssl)
{
fprintf(stderr, "SSL_new failed\n");
return -1;
}
if (tls->tsg)
{
tls->bio = BIO_new(tls->methods);
if (!tls->bio)
{
fprintf(stderr, "BIO_new failed\n");
return -1;
}
tls->bio->ptr = tls->tsg;
SSL_set_bio(tls->ssl, tls->bio, tls->bio);
SSL_CTX_set_info_callback(tls->ctx, tls_ssl_info_callback);
}
else
{
if (SSL_set_fd(tls->ssl, tls->sockfd) < 1)
{
fprintf(stderr, "SSL_set_fd failed\n");
return -1;
}
}
connection_status = SSL_connect(tls->ssl);
if (connection_status <= 0)
{
if (tls_print_error("SSL_connect", tls->ssl, connection_status))
{
return -1;
}
}
cert = tls_get_certificate(tls, TRUE);
if (!cert)
{
fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n");
return -1;
}
tls->Bindings = tls_get_channel_bindings(cert->px509);
if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength))
{
fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n");
tls_free_certificate(cert);
return -1;
}
verify_status = tls_verify_certificate(tls, cert, tls->hostname, tls->port);
if (verify_status < 1)
{
fprintf(stderr, "tls_connect: certificate not trusted, aborting.\n");
tls_disconnect(tls);
}
tls_free_certificate(cert);
return verify_status;
return tls_do_handshake(tls, TRUE);
}
BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file)
BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file)
{
CryptoCert cert;
long options = 0;
int connection_status;
tls->ctx = SSL_CTX_new(SSLv23_server_method());
if (tls->ctx == NULL)
{
fprintf(stderr, "SSL_CTX_new failed\n");
return FALSE;
}
/*
/**
* SSL_OP_NO_SSLv2:
*
* We only want SSLv3 and TLSv1, so disable SSLv2.
@ -281,80 +296,23 @@ BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file)
*/
options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
SSL_CTX_set_options(tls->ctx, options);
if (SSL_CTX_use_RSAPrivateKey_file(tls->ctx, privatekey_file, SSL_FILETYPE_PEM) <= 0)
{
fprintf(stderr, "SSL_CTX_use_RSAPrivateKey_file failed\n");
fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file);
if (!tls_prepare(tls, underlying, SSLv23_server_method(), options, FALSE))
return FALSE;
}
tls->ssl = SSL_new(tls->ctx);
if (!tls->ssl)
if (SSL_use_RSAPrivateKey_file(tls->ssl, privatekey_file, SSL_FILETYPE_PEM) <= 0)
{
fprintf(stderr, "SSL_new failed\n");
fprintf(stderr, "%s: SSL_CTX_use_RSAPrivateKey_file failed\n", __FUNCTION__);
fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file);
return FALSE;
}
if (SSL_use_certificate_file(tls->ssl, cert_file, SSL_FILETYPE_PEM) <= 0)
{
fprintf(stderr, "SSL_use_certificate_file failed\n");
fprintf(stderr, "%s: SSL_use_certificate_file failed\n", __FUNCTION__);
return FALSE;
}
if (SSL_set_fd(tls->ssl, tls->sockfd) < 1)
{
fprintf(stderr, "SSL_set_fd failed\n");
return FALSE;
}
while (1)
{
connection_status = SSL_accept(tls->ssl);
if (connection_status <= 0)
{
switch (SSL_get_error(tls->ssl, connection_status))
{
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
break;
default:
if (tls_print_error("SSL_accept", tls->ssl, connection_status))
return FALSE;
break;
}
}
else
{
break;
}
}
cert = tls_get_certificate(tls, FALSE);
if (!cert)
{
fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n");
return FALSE;
}
if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength))
{
fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n");
tls_free_certificate(cert);
return FALSE;
}
free(cert);
fprintf(stderr, "TLS connection accepted\n");
return TRUE;
return tls_do_handshake(tls, FALSE) > 0;
}
BOOL tls_disconnect(rdpTls* tls)
@ -362,256 +320,161 @@ BOOL tls_disconnect(rdpTls* tls)
if (!tls)
return FALSE;
if (tls->ssl)
if (!tls->ssl)
return TRUE;
if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY)
{
if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY)
{
/**
* OpenSSL doesn't really expose an API for sending a TLS alert manually.
*
* The following code disables the sending of the default "close notify"
* and then proceeds to force sending a custom TLS alert before shutting down.
*
* Manually sending a TLS alert is necessary in certain cases,
* like when server-side NLA results in an authentication failure.
*/
/**
* OpenSSL doesn't really expose an API for sending a TLS alert manually.
*
* The following code disables the sending of the default "close notify"
* and then proceeds to force sending a custom TLS alert before shutting down.
*
* Manually sending a TLS alert is necessary in certain cases,
* like when server-side NLA results in an authentication failure.
*/
SSL_set_quiet_shutdown(tls->ssl, 1);
SSL_set_quiet_shutdown(tls->ssl, 1);
if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session))
SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session);
if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session))
SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session);
tls->ssl->s3->alert_dispatch = 1;
tls->ssl->s3->send_alert[0] = tls->alertLevel;
tls->ssl->s3->send_alert[1] = tls->alertDescription;
tls->ssl->s3->alert_dispatch = 1;
tls->ssl->s3->send_alert[0] = tls->alertLevel;
tls->ssl->s3->send_alert[1] = tls->alertDescription;
if (tls->ssl->s3->wbuf.left == 0)
tls->ssl->method->ssl_dispatch_alert(tls->ssl);
if (tls->ssl->s3->wbuf.left == 0)
tls->ssl->method->ssl_dispatch_alert(tls->ssl);
SSL_shutdown(tls->ssl);
}
else
{
SSL_shutdown(tls->ssl);
}
SSL_shutdown(tls->ssl);
}
else
{
SSL_shutdown(tls->ssl);
}
return TRUE;
}
int tls_read(rdpTls* tls, BYTE* data, int length)
BIO *findBufferedBio(BIO *front)
{
int error;
int status;
BIO *ret = front;
if (!tls)
return -1;
if (!tls->ssl)
return -1;
status = SSL_read(tls->ssl, data, length);
if (status == 0)
while (ret)
{
return -1; /* peer disconnected */
if (BIO_method_type(ret) == BIO_TYPE_BUFFERED)
return ret;
ret = ret->next_bio;
}
if (status <= 0)
{
error = SSL_get_error(tls->ssl, status);
//fprintf(stderr, "tls_read: length: %d status: %d error: 0x%08X\n",
// length, status, error);
switch (error)
{
case SSL_ERROR_NONE:
break;
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
status = 0;
break;
case SSL_ERROR_SYSCALL:
#ifdef _WIN32
if (WSAGetLastError() == WSAEWOULDBLOCK)
#else
if ((errno == EAGAIN) || (errno == 0))
#endif
{
status = 0;
}
else
{
if (tls_print_error("SSL_read", tls->ssl, status))
{
status = -1;
}
else
{
status = 0;
}
}
break;
default:
if (tls_print_error("SSL_read", tls->ssl, status))
{
status = -1;
}
else
{
status = 0;
}
break;
}
}
#ifdef HAVE_VALGRIND_MEMCHECK_H
VALGRIND_MAKE_MEM_DEFINED(data, status);
#endif
return status;
return ret;
}
int tls_write(rdpTls* tls, BYTE* data, int length)
int tls_write_all(rdpTls* tls, const BYTE* data, int length)
{
int error;
int status;
int status, nchunks, commitedBytes;
rdpTcp *tcp;
fd_set rset, wset;
fd_set *rsetPtr, *wsetPtr;
struct timeval tv;
BIO *bio = tls->bio;
DataChunk chunks[2];
if (!tls)
return -1;
if (!tls->ssl)
return -1;
status = SSL_write(tls->ssl, data, length);
if (status == 0)
BIO *bufferedBio = findBufferedBio(bio);
if (!bufferedBio)
{
return -1; /* peer disconnected */
fprintf(stderr, "%s: error unable to retrieve the bufferedBio in the BIO chain\n", __FUNCTION__);
return -1;
}
if (status < 0)
{
error = SSL_get_error(tls->ssl, status);
//fprintf(stderr, "tls_write: length: %d status: %d error: 0x%08X\n", length, status, error);
switch (error)
{
case SSL_ERROR_NONE:
break;
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
status = 0;
break;
case SSL_ERROR_SYSCALL:
if (errno == EAGAIN)
{
status = 0;
}
else
{
tls_print_error("SSL_write", tls->ssl, status);
status = -1;
}
break;
default:
tls_print_error("SSL_write", tls->ssl, status);
status = -1;
break;
}
}
return status;
}
int tls_write_all(rdpTls* tls, BYTE* data, int length)
{
int status;
int sent = 0;
tcp = (rdpTcp *)bufferedBio->ptr;
do
{
status = tls_write(tls, &data[sent], length - sent);
status = BIO_write(bio, data, length);
/*fprintf(stderr, "%s: BIO_write(len=%d) = %d (retry=%d)\n", __FUNCTION__, length, status, BIO_should_retry(bio));*/
if (status > 0)
sent += status;
else if (status == 0)
tls_wait_write(tls);
if (sent >= length)
break;
if (!BIO_should_retry(bio))
return -1;
/* we try to handle SSL want_read and want_write nicely */
rsetPtr = wsetPtr = 0;
if (tcp->writeBlocked)
{
wsetPtr = &wset;
FD_ZERO(&wset);
FD_SET(tcp->sockfd, &wset);
}
else if (tcp->readBlocked)
{
rsetPtr = &rset;
FD_ZERO(&rset);
FD_SET(tcp->sockfd, &rset);
}
else
{
fprintf(stderr, "%s: weird we're blocked but the underlying is not read or write blocked !\n", __FUNCTION__);
USleep(10);
continue;
}
tv.tv_sec = 0;
tv.tv_usec = 100 * 1000;
status = select(tcp->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv);
if (status < 0)
return -1;
}
while (status >= 0);
while (TRUE);
if (status > 0)
return length;
else
return status;
}
int tls_wait_read(rdpTls* tls)
{
return freerdp_tcp_wait_read(tls->sockfd);
}
int tls_wait_write(rdpTls* tls)
{
return freerdp_tcp_wait_write(tls->sockfd);
}
static void tls_errors(const char *prefix)
{
unsigned long error;
while ((error = ERR_get_error()) != 0)
fprintf(stderr, "%s: %s\n", prefix, ERR_error_string(error, NULL));
}
BOOL tls_print_error(char* func, SSL* connection, int value)
{
switch (SSL_get_error(connection, value))
/* make sure the output buffer is empty */
commitedBytes = 0;
while ((nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer))))
{
case SSL_ERROR_ZERO_RETURN:
fprintf(stderr, "%s: Server closed TLS connection\n", func);
return TRUE;
int i;
case SSL_ERROR_WANT_READ:
fprintf(stderr, "%s: SSL_ERROR_WANT_READ\n", func);
return FALSE;
for (i = 0; i < nchunks; i++)
{
while (chunks[i].size)
{
status = BIO_write(tcp->socketBio, chunks[i].data, chunks[i].size);
if (status > 0)
{
chunks[i].size -= status;
chunks[i].data += status;
commitedBytes += status;
continue;
}
case SSL_ERROR_WANT_WRITE:
fprintf(stderr, "%s: SSL_ERROR_WANT_WRITE\n", func);
return FALSE;
if (!BIO_should_retry(tcp->socketBio))
goto out_fail;
FD_ZERO(&rset);
FD_SET(tcp->sockfd, &rset);
tv.tv_sec = 0;
tv.tv_usec = 100 * 1000;
case SSL_ERROR_SYSCALL:
#ifdef _WIN32
fprintf(stderr, "%s: I/O error: %d\n", func, WSAGetLastError());
#else
fprintf(stderr, "%s: I/O error: %s (%d)\n", func, strerror(errno), errno);
#endif
tls_errors(func);
return TRUE;
status = select(tcp->sockfd + 1, &rset, NULL, NULL, &tv);
if (status < 0)
goto out_fail;
}
case SSL_ERROR_SSL:
fprintf(stderr, "%s: Failure in SSL library (protocol error?)\n", func);
tls_errors(func);
return TRUE;
default:
fprintf(stderr, "%s: Unknown error\n", func);
tls_errors(func);
return TRUE;
}
}
ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes);
return length;
out_fail:
ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes);
return -1;
}
int tls_set_alert_code(rdpTls* tls, int level, int description)
{
tls->alertLevel = level;
@ -672,7 +535,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
if (!bio)
{
fprintf(stderr, "tls_verify_certificate: BIO_new() failure\n");
fprintf(stderr, "%s: BIO_new() failure\n", __FUNCTION__);
return -1;
}
@ -680,7 +543,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
if (status < 0)
{
fprintf(stderr, "tls_verify_certificate: PEM_write_bio_X509 failure: %d\n", status);
fprintf(stderr, "%s: PEM_write_bio_X509 failure: %d\n", __FUNCTION__, status);
return -1;
}
@ -692,7 +555,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
if (status < 0)
{
fprintf(stderr, "tls_verify_certificate: failed to read certificate\n");
fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__);
return -1;
}
@ -713,7 +576,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
if (status < 0)
{
fprintf(stderr, "tls_verify_certificate: failed to read certificate\n");
fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__);
return -1;
}
@ -727,8 +590,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
status = instance->VerifyX509Certificate(instance, pemCert, length, hostname, port, 0);
}
fprintf(stderr, "VerifyX509Certificate: (length = %d) status: %d\n%s\n",
length, status, pemCert);
fprintf(stderr, "%s: (length = %d) status: %d\n%s\n", __FUNCTION__, length, status, pemCert);
free(pemCert);
BIO_free(bio);
@ -932,57 +794,53 @@ rdpTls* tls_new(rdpSettings* settings)
{
rdpTls* tls;
tls = (rdpTls*) malloc(sizeof(rdpTls));
tls = (rdpTls *)calloc(1, sizeof(rdpTls));
if (!tls)
return NULL;
if (tls)
{
ZeroMemory(tls, sizeof(rdpTls));
SSL_load_error_strings();
SSL_library_init();
SSL_load_error_strings();
SSL_library_init();
tls->settings = settings;
tls->certificate_store = certificate_store_new(settings);
tls->alertLevel = TLS_ALERT_LEVEL_WARNING;
tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY;
}
tls->settings = settings;
tls->certificate_store = certificate_store_new(settings);
if (!tls->certificate_store)
goto out_free;
tls->alertLevel = TLS_ALERT_LEVEL_WARNING;
tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY;
return tls;
out_free:
free(tls);
return NULL;
}
void tls_free(rdpTls* tls)
{
if (tls)
if (!tls)
return;
if (tls->ctx)
{
if (tls->ssl)
{
SSL_free(tls->ssl);
tls->ssl = NULL;
}
if (tls->ctx)
{
SSL_CTX_free(tls->ctx);
tls->ctx = NULL;
}
if (tls->PublicKey)
{
free(tls->PublicKey);
tls->PublicKey = NULL;
}
if (tls->Bindings)
{
free(tls->Bindings->Bindings);
free(tls->Bindings);
tls->Bindings = NULL;
}
certificate_store_free(tls->certificate_store);
tls->certificate_store = NULL;
free(tls);
SSL_CTX_free(tls->ctx);
tls->ctx = NULL;
}
if (tls->PublicKey)
{
free(tls->PublicKey);
tls->PublicKey = NULL;
}
if (tls->Bindings)
{
free(tls->Bindings->Bindings);
free(tls->Bindings);
tls->Bindings = NULL;
}
certificate_store_free(tls->certificate_store);
tls->certificate_store = NULL;
free(tls);
}

View File

@ -25,6 +25,7 @@ set(${MODULE_PREFIX}_SRCS
pcap.c
profiler.c
rail.c
ringbuffer.c
signal.c
stopwatch.c
svc_plugin.c
@ -68,3 +69,9 @@ else()
endif()
set_property(TARGET ${MODULE_NAME} PROPERTY FOLDER "FreeRDP/libfreerdp")
if(BUILD_TESTING)
add_subdirectory(test)
endif()

View File

@ -0,0 +1,251 @@
/**
* FreeRDP: A Remote Desktop Protocol Implementation
*
* Copyright © 2014 Thincast Technologies GmbH
* Copyright © 2014 Hardening <contact@hardening-consulting.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <freerdp/utils/ringbuffer.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
BOOL ringbuffer_init(RingBuffer *rb, size_t initialSize)
{
rb->buffer = malloc(initialSize);
if (!rb->buffer)
return FALSE;
rb->readPtr = rb->writePtr = 0;
rb->initialSize = rb->size = rb->freeSize = initialSize;
return TRUE;
}
size_t ringbuffer_used(const RingBuffer *ringbuffer)
{
return ringbuffer->size - ringbuffer->freeSize;
}
size_t ringbuffer_capacity(const RingBuffer *ringbuffer)
{
return ringbuffer->size;
}
void ringbuffer_destroy(RingBuffer *ringbuffer)
{
free(ringbuffer->buffer);
ringbuffer->buffer = 0;
}
static BOOL ringbuffer_realloc(RingBuffer *rb, size_t targetSize)
{
BYTE *newData;
if (rb->writePtr == rb->readPtr)
{
/* when no size is used we can realloc() and set the heads at the
* beginning of the buffer
*/
newData = (BYTE *)realloc(rb->buffer, targetSize);
if (!newData)
return FALSE;
rb->readPtr = rb->writePtr = 0;
rb->buffer = newData;
}
else if ((rb->writePtr >= rb->readPtr) && (rb->writePtr < targetSize))
{
/* we reallocate only if we're in that case, realloc don't touch read
* and write heads
*
* readPtr writePtr
* | |
* v v
* [............|XXXXXXXXXXXXXX|..........]
*/
newData = (BYTE *)realloc(rb->buffer, targetSize);
if (!newData)
return FALSE;
rb->buffer = newData;
}
else
{
/* in case of malloc the read head is moved at the beginning of the new buffer
* and the write head is set accordingly
*/
newData = (BYTE *)malloc(targetSize);
if (!newData)
return FALSE;
if (rb->readPtr < rb->writePtr)
{
/* readPtr writePtr
* | |
* v v
* [............|XXXXXXXXXXXXXX|..........]
*/
memcpy(newData, rb->buffer + rb->readPtr, ringbuffer_used(rb));
}
else
{
/* writePtr readPtr
* | |
* v v
* [XXXXXXXXXXXX|..............|XXXXXXXXXX]
*/
BYTE *dst = newData;
memcpy(dst, rb->buffer + rb->readPtr, rb->size - rb->readPtr);
dst += (rb->size - rb->readPtr);
if (rb->writePtr)
memcpy(dst, rb->buffer, rb->writePtr);
}
rb->writePtr = rb->size - rb->freeSize;
rb->readPtr = 0;
free(rb->buffer);
rb->buffer = newData;
}
rb->freeSize += (targetSize - rb->size);
rb->size = targetSize;
return TRUE;
}
/**
*
* @param rb
* @param ptr
* @param sz
* @return
*/
BOOL ringbuffer_write(RingBuffer *rb, const BYTE *ptr, size_t sz)
{
size_t toWrite;
size_t remaining;
if ((rb->freeSize <= sz) && !ringbuffer_realloc(rb, rb->size + sz))
return FALSE;
/* the write could be split in two
* readHead writeHead
* | |
* v v
* [ ################ ]
*/
toWrite = sz;
remaining = sz;
if (rb->size - rb->writePtr < sz)
toWrite = rb->size - rb->writePtr;
if (toWrite)
{
memcpy(rb->buffer + rb->writePtr, ptr, toWrite);
remaining -= toWrite;
ptr += toWrite;
}
if (remaining)
memcpy(rb->buffer, ptr, remaining);
rb->writePtr = (rb->writePtr + sz) % rb->size;
rb->freeSize -= sz;
return TRUE;
}
BYTE *ringbuffer_ensure_linear_write(RingBuffer *rb, size_t sz)
{
if (rb->freeSize < sz)
{
if (!ringbuffer_realloc(rb, rb->size + sz - rb->freeSize + 32))
return NULL;
}
if (rb->writePtr == rb->readPtr)
{
rb->writePtr = rb->readPtr = 0;
}
if (rb->writePtr + sz < rb->size)
return rb->buffer + rb->writePtr;
/*
* to add: .......
* [ XXXXXXXXX ]
*
* result:
* [XXXXXXXXX....... ]
*/
memmove(rb->buffer, rb->buffer + rb->readPtr, rb->writePtr - rb->readPtr);
rb->readPtr = 0;
rb->writePtr = rb->size - rb->freeSize;
return rb->buffer + rb->writePtr;
}
BOOL ringbuffer_commit_written_bytes(RingBuffer *rb, size_t sz)
{
if (rb->writePtr + sz > rb->size)
return FALSE;
rb->writePtr = (rb->writePtr + sz) % rb->size;
rb->freeSize -= sz;
return TRUE;
}
int ringbuffer_peek(const RingBuffer *rb, DataChunk chunks[2], size_t sz)
{
size_t remaining = sz;
size_t toRead;
int chunkIndex = 0;
int ret = 0;
if (rb->size - rb->freeSize < sz)
remaining = rb->size - rb->freeSize;
toRead = remaining;
if (rb->readPtr + remaining > rb->size)
toRead = rb->size - rb->readPtr;
if (toRead)
{
chunks[0].data = rb->buffer + rb->readPtr;
chunks[0].size = toRead;
remaining -= toRead;
chunkIndex++;
ret++;
}
if (remaining)
{
chunks[chunkIndex].data = rb->buffer;
chunks[chunkIndex].size = remaining;
ret++;
}
return ret;
}
void ringbuffer_commit_read_bytes(RingBuffer *rb, size_t sz)
{
assert(rb->size - rb->freeSize >= sz);
rb->readPtr = (rb->readPtr + sz) % rb->size;
rb->freeSize += sz;
/* when we reach a reasonable free size, we can go back to the original size */
if ((rb->size != rb->initialSize) && (ringbuffer_used(rb) < rb->initialSize / 2))
ringbuffer_realloc(rb, rb->initialSize);
}

1
libfreerdp/utils/test/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
TestFreeRDPutils.c

View File

@ -0,0 +1,34 @@
set(MODULE_NAME "TestFreeRDPutils")
set(MODULE_PREFIX "TEST_FREERDP_UTILS")
set(${MODULE_PREFIX}_DRIVER ${MODULE_NAME}.c)
set(${MODULE_PREFIX}_TESTS
TestRingBuffer.c
)
create_test_sourcelist(${MODULE_PREFIX}_SRCS
${${MODULE_PREFIX}_DRIVER}
${${MODULE_PREFIX}_TESTS}
)
add_executable(${MODULE_NAME} ${${MODULE_PREFIX}_SRCS})
set_complex_link_libraries(VARIABLE ${MODULE_PREFIX}_LIBS
MONOLITHIC ${MONOLITHIC_BUILD}
MODULE winpr
MODULES winpr-thread winpr-synch winpr-file winpr-utils winpr-crt freerdp-utils
)
target_link_libraries(${MODULE_NAME} ${${MODULE_PREFIX}_LIBS})
set_target_properties(${MODULE_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${TESTING_OUTPUT_DIRECTORY}")
foreach(test ${${MODULE_PREFIX}_TESTS})
get_filename_component(TestName ${test} NAME_WE)
add_test(${TestName} ${TESTING_OUTPUT_DIRECTORY}/${MODULE_NAME} ${TestName})
endforeach()
set_property(TARGET ${MODULE_NAME} PROPERTY FOLDER "FreeRDP/Test")

View File

@ -0,0 +1,228 @@
/**
* FreeRDP: A Remote Desktop Protocol Implementation
*
* Copyright © 2014 Thincast Technologies GmbH
* Copyright © 2014 Hardening <contact@hardening-consulting.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include <string.h>
#include <freerdp/utils/ringbuffer.h>
BOOL test_overlaps(void)
{
RingBuffer rb;
DataChunk chunks[2];
BYTE bytes[200];
int nchunks, i, j, k, counter = 0;
for (i = 0; i < sizeof(bytes); i++)
bytes[i] = (BYTE)i;
ringbuffer_init(&rb, 5);
if (!ringbuffer_write(&rb, bytes, 4)) /* [0123.] */
goto error;
counter += 4;
ringbuffer_commit_read_bytes(&rb, 2); /* [..23.] */
if (!ringbuffer_write(&rb, &bytes[counter], 2)) /* [5.234] */
goto error;
counter += 2;
nchunks = ringbuffer_peek(&rb, chunks, 4);
if (nchunks != 2 || chunks[0].size != 3 || chunks[1].size != 1)
goto error;
for (i = 0, j = 2; i < nchunks; i++)
{
for (k = 0; k < chunks[i].size; k++, j++)
{
if (chunks[i].data[k] != (BYTE)j)
goto error;
}
}
ringbuffer_commit_read_bytes(&rb, 3); /* [5....] */
if (ringbuffer_used(&rb) != 1)
goto error;
if (!ringbuffer_write(&rb, &bytes[counter], 6)) /* [56789ab....] */
goto error;
counter += 6;
ringbuffer_commit_read_bytes(&rb, 6); /* [......b....] */
nchunks = ringbuffer_peek(&rb, chunks, 10);
if (nchunks != 1 || chunks[0].size != 1 || (*chunks[0].data != 0xb))
goto error;
if (ringbuffer_capacity(&rb) != 5)
goto error;
ringbuffer_destroy(&rb);
return TRUE;
error:
ringbuffer_destroy(&rb);
return FALSE;
}
int TestRingBuffer(int argc, char* argv[])
{
RingBuffer ringBuffer;
int testNo = 0;
BYTE *tmpBuf;
BYTE *rb_ptr;
int i/*, chunkNb, counter*/;
DataChunk chunks[2];
if (!ringbuffer_init(&ringBuffer, 10))
{
fprintf(stderr, "unable to initialize ringbuffer\n");
return -1;
}
tmpBuf = (BYTE *)malloc(50);
if (!tmpBuf)
return -1;
for (i = 0; i < 50; i++)
tmpBuf[i] = (char)i;
fprintf(stderr, "%d: basic tests...", ++testNo);
if (!ringbuffer_write(&ringBuffer, tmpBuf, 5) || !ringbuffer_write(&ringBuffer, tmpBuf, 5) ||
!ringbuffer_write(&ringBuffer, tmpBuf, 5))
{
fprintf(stderr, "error when writing bytes\n");
return -1;
}
if (ringbuffer_used(&ringBuffer) != 15)
{
fprintf(stderr, "invalid used size got %d when i would expect 15\n", ringbuffer_used(&ringBuffer));
return -1;
}
if (ringbuffer_peek(&ringBuffer, chunks, 10) != 1 || chunks[0].size != 10)
{
fprintf(stderr, "error when reading bytes\n");
return -1;
}
ringbuffer_commit_read_bytes(&ringBuffer, chunks[0].size);
/* check retrieved bytes */
for (i = 0; i < chunks[0].size; i++)
{
if (chunks[0].data[i] != i % 5)
{
fprintf(stderr, "invalid byte at %d, got %d instead of %d\n", i, chunks[0].data[i], i % 5);
return -1;
}
}
if (ringbuffer_used(&ringBuffer) != 5)
{
fprintf(stderr, "invalid used size after read got %d when i would expect 5\n", ringbuffer_used(&ringBuffer));
return -1;
}
/* write some more bytes to have writePtr < readPtr and data splitted in 2 chunks */
if (!ringbuffer_write(&ringBuffer, tmpBuf, 6) ||
ringbuffer_peek(&ringBuffer, chunks, 11) != 2 ||
chunks[0].size != 10 ||
chunks[1].size != 1)
{
fprintf(stderr, "invalid read of splitted data\n");
return -1;
}
ringbuffer_commit_read_bytes(&ringBuffer, 11);
fprintf(stderr, "ok\n");
fprintf(stderr, "%d: peek with nothing to read...", ++testNo);
if (ringbuffer_peek(&ringBuffer, chunks, 10))
{
fprintf(stderr, "peek returns some chunks\n");
return -1;
}
fprintf(stderr, "ok\n");
fprintf(stderr, "%d: ensure_linear_write / read() shouldn't grow...", ++testNo);
for (i = 0; i < 1000; i++)
{
rb_ptr = ringbuffer_ensure_linear_write(&ringBuffer, 50);
if (!rb_ptr)
{
fprintf(stderr, "ringbuffer_ensure_linear_write() error\n");
return -1;
}
memcpy(rb_ptr, tmpBuf, 50);
if (!ringbuffer_commit_written_bytes(&ringBuffer, 50))
{
fprintf(stderr, "ringbuffer_commit_written_bytes() error, i=%d\n", i);
return -1;
}
//ringbuffer_commit_read_bytes(&ringBuffer, 25);
}
for (i = 0; i < 1000; i++)
ringbuffer_commit_read_bytes(&ringBuffer, 25);
for (i = 0; i < 1000; i++)
ringbuffer_commit_read_bytes(&ringBuffer, 25);
if (ringbuffer_capacity(&ringBuffer) != 10)
{
fprintf(stderr, "not the expected capacity, have %d and expects 10\n", ringbuffer_capacity(&ringBuffer));
return -1;
}
fprintf(stderr, "ok\n");
fprintf(stderr, "%d: free size is correctly computed...", ++testNo);
for (i = 0; i < 1000; i++)
{
ringbuffer_ensure_linear_write(&ringBuffer, 50);
if (!ringbuffer_commit_written_bytes(&ringBuffer, 50))
{
fprintf(stderr, "ringbuffer_commit_written_bytes() error, i=%d\n", i);
return -1;
}
}
ringbuffer_commit_read_bytes(&ringBuffer, 50 * 1000);
fprintf(stderr, "ok\n");
ringbuffer_destroy(&ringBuffer);
fprintf(stderr, "%d: specific overlaps test...", ++testNo);
if (!test_overlaps())
{
fprintf(stderr, "ko\n", i);
return -1;
}
fprintf(stderr, "ok\n");
ringbuffer_destroy(&ringBuffer);
free(tmpBuf);
return 0;
}