Merge branch 'master' of github.com:awakecoding/FreeRDP into xcrush

This commit is contained in:
Marc-André Moreau 2014-05-22 15:22:42 -04:00
commit 6cd6d63e42
35 changed files with 2193 additions and 1155 deletions

View File

@ -70,7 +70,6 @@ struct rdp_tls
SSL* ssl; SSL* ssl;
BIO* bio; BIO* bio;
void* tsg; void* tsg;
int sockfd;
SSL_CTX* ctx; SSL_CTX* ctx;
BYTE* PublicKey; BYTE* PublicKey;
BIO_METHOD* methods; BIO_METHOD* methods;
@ -84,17 +83,11 @@ struct rdp_tls
int alertDescription; int alertDescription;
}; };
FREERDP_API int tls_connect(rdpTls* tls); FREERDP_API int tls_connect(rdpTls* tls, BIO *underlying);
FREERDP_API BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file); FREERDP_API BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file);
FREERDP_API BOOL tls_disconnect(rdpTls* tls); FREERDP_API BOOL tls_disconnect(rdpTls* tls);
FREERDP_API int tls_read(rdpTls* tls, BYTE* data, int length); FREERDP_API int tls_write_all(rdpTls* tls, const BYTE* data, int length);
FREERDP_API int tls_write(rdpTls* tls, BYTE* data, int length);
FREERDP_API int tls_write_all(rdpTls* tls, BYTE* data, int length);
FREERDP_API int tls_wait_read(rdpTls* tls);
FREERDP_API int tls_wait_write(rdpTls* tls);
FREERDP_API int tls_set_alert_code(rdpTls* tls, int level, int description); FREERDP_API int tls_set_alert_code(rdpTls* tls, int level, int description);

View File

@ -34,7 +34,10 @@ typedef void (*psPeerContextFree)(freerdp_peer* client, rdpContext* context);
typedef BOOL (*psPeerInitialize)(freerdp_peer* client); typedef BOOL (*psPeerInitialize)(freerdp_peer* client);
typedef BOOL (*psPeerGetFileDescriptor)(freerdp_peer* client, void** rfds, int* rcount); typedef BOOL (*psPeerGetFileDescriptor)(freerdp_peer* client, void** rfds, int* rcount);
typedef HANDLE (*psPeerGetEventHandle)(freerdp_peer* client); typedef HANDLE (*psPeerGetEventHandle)(freerdp_peer* client);
typedef HANDLE (*psPeerGetReceiveEventHandle)(freerdp_peer* client);
typedef BOOL (*psPeerCheckFileDescriptor)(freerdp_peer* client); typedef BOOL (*psPeerCheckFileDescriptor)(freerdp_peer* client);
typedef BOOL (*psPeerIsWriteBlocked)(freerdp_peer* client);
typedef int (*psPeerDrainOutputBuffer)(freerdp_peer* client);
typedef BOOL (*psPeerClose)(freerdp_peer* client); typedef BOOL (*psPeerClose)(freerdp_peer* client);
typedef void (*psPeerDisconnect)(freerdp_peer* client); typedef void (*psPeerDisconnect)(freerdp_peer* client);
typedef BOOL (*psPeerCapabilities)(freerdp_peer* client); typedef BOOL (*psPeerCapabilities)(freerdp_peer* client);
@ -62,6 +65,7 @@ struct rdp_freerdp_peer
psPeerInitialize Initialize; psPeerInitialize Initialize;
psPeerGetFileDescriptor GetFileDescriptor; psPeerGetFileDescriptor GetFileDescriptor;
psPeerGetEventHandle GetEventHandle; psPeerGetEventHandle GetEventHandle;
psPeerGetReceiveEventHandle GetReceiveEventHandle;
psPeerCheckFileDescriptor CheckFileDescriptor; psPeerCheckFileDescriptor CheckFileDescriptor;
psPeerClose Close; psPeerClose Close;
psPeerDisconnect Disconnect; psPeerDisconnect Disconnect;
@ -81,6 +85,9 @@ struct rdp_freerdp_peer
BOOL activated; BOOL activated;
BOOL authenticated; BOOL authenticated;
SEC_WINNT_AUTH_IDENTITY identity; SEC_WINNT_AUTH_IDENTITY identity;
psPeerIsWriteBlocked IsWriteBlocked;
psPeerDrainOutputBuffer DrainOutputBuffer;
}; };
#ifdef __cplusplus #ifdef __cplusplus

View File

@ -799,7 +799,8 @@ struct rdp_settings
ALIGN64 char* Password; /* 22 */ ALIGN64 char* Password; /* 22 */
ALIGN64 char* Domain; /* 23 */ ALIGN64 char* Domain; /* 23 */
ALIGN64 char* PasswordHash; /* 24 */ ALIGN64 char* PasswordHash; /* 24 */
UINT64 padding0064[64 - 25]; /* 25 */ ALIGN64 BOOL WaitForOutputBufferFlush; /* 25 */
UINT64 padding0064[64 - 26]; /* 26 */
UINT64 padding0128[128 - 64]; /* 64 */ UINT64 padding0128[128 - 64]; /* 64 */
/** /**

View File

@ -0,0 +1,132 @@
/**
* Copyright © 2014 Thincast Technologies GmbH
* Copyright © 2014 Hardening <contact@hardening-consulting.com>
*
* Permission to use, copy, modify, distribute, and sell this software and
* its documentation for any purpose is hereby granted without fee, provided
* that the above copyright notice appear in all copies and that both that
* copyright notice and this permission notice appear in supporting
* documentation, and that the name of the copyright holders not be used in
* advertising or publicity pertaining to distribution of the software
* without specific, written prior permission. The copyright holders make
* no representations about the suitability of this software for any
* purpose. It is provided "as is" without express or implied warranty.
*
* THE COPYRIGHT HOLDERS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS
* SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS, IN NO EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY
* SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER
* RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF
* CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
* CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
#ifndef __RINGBUFFER_H___
#define __RINGBUFFER_H___
#include <winpr/wtypes.h>
#include <freerdp/api.h>
/** @brief ring buffer meta data */
struct _RingBuffer {
size_t initialSize;
size_t freeSize;
size_t size;
size_t readPtr;
size_t writePtr;
BYTE *buffer;
};
typedef struct _RingBuffer RingBuffer;
/** @brief a piece of data in the ring buffer, exactly like a glibc iovec */
struct _DataChunk {
size_t size;
const BYTE *data;
};
typedef struct _DataChunk DataChunk;
#ifdef __cplusplus
extern "C" {
#endif
/** initialise a ringbuffer
* @param initialSize the initial capacity of the ringBuffer
* @return if the initialisation was successful
*/
FREERDP_API BOOL ringbuffer_init(RingBuffer *rb, size_t initialSize);
/** destroys internal data used by this ringbuffer
* @param ringbuffer
*/
FREERDP_API void ringbuffer_destroy(RingBuffer *ringbuffer);
/** computes the space used in this ringbuffer
* @param ringbuffer
* @return the number of bytes stored in that ringbuffer
*/
FREERDP_API size_t ringbuffer_used(const RingBuffer *ringbuffer);
/** returns the capacity of the ring buffer
* @param ringbuffer
* @return the capacity of this ring buffer
*/
FREERDP_API size_t ringbuffer_capacity(const RingBuffer *ringbuffer);
/** writes some bytes in the ringbuffer, if the data doesn't fit, the ringbuffer
* is resized automatically
*
* @param rb the ringbuffer
* @param ptr a pointer on the data to add
* @param sz the size of the data to add
* @return if the operation was successful, it could fail in case of OOM during realloc()
*/
FREERDP_API BOOL ringbuffer_write(RingBuffer *rb, const BYTE *ptr, size_t sz);
/** ensures that we have sz bytes available at the write head, and return a pointer
* on the write head
*
* @param rb the ring buffer
* @param sz the size to ensure
* @return a pointer on the write head, or NULL in case of OOM
*/
FREERDP_API BYTE *ringbuffer_ensure_linear_write(RingBuffer *rb, size_t sz);
/** move ahead the write head in case some byte were written directly by using
* a pointer retrieved via ringbuffer_ensure_linear_write(). This function is
* used to commit the written bytes. The provided size should not exceed the
* size ensured by ringbuffer_ensure_linear_write()
*
* @param rb the ring buffer
* @param sz the number of bytes that have been written
* @return if the operation was successful, FALSE is sz is too big
*/
FREERDP_API BOOL ringbuffer_commit_written_bytes(RingBuffer *rb, size_t sz);
/** peeks the buffer chunks for sz bytes and returns how many chunks are filled.
* Note that the sum of the resulting chunks may be smaller than sz.
*
* @param rb the ringbuffer
* @param chunks an array of data chunks that will contain data / size of chunks
* @param sz the requested size
* @return the number of chunks used for reading sz bytes
*/
FREERDP_API int ringbuffer_peek(const RingBuffer *rb, DataChunk chunks[2], size_t sz);
/** move ahead the read head in case some byte were read using ringbuffer_peek()
* This function is used to commit the bytes that were effectively consumed.
*
* @param rb the ring buffer
* @param sz the
*/
FREERDP_API void ringbuffer_commit_read_bytes(RingBuffer *rb, size_t sz);
#ifdef __cplusplus
}
#endif
#endif /* __RINGBUFFER_H___ */

View File

@ -26,6 +26,10 @@
#include <winpr/stream.h> #include <winpr/stream.h>
#include <winpr/string.h> #include <winpr/string.h>
#ifdef HAVE_VALGRIND_MEMCHECK_H
#include <valgrind/memcheck.h>
#endif
#include "http.h" #include "http.h"
HttpContext* http_context_new() HttpContext* http_context_new()
@ -472,7 +476,7 @@ HttpResponse* http_response_recv(rdpTls* tls)
nbytes = 0; nbytes = 0;
length = 10000; length = 10000;
content = NULL; content = NULL;
buffer = malloc(length); buffer = calloc(length, 1);
if (!buffer) if (!buffer)
return NULL; return NULL;
@ -487,14 +491,20 @@ HttpResponse* http_response_recv(rdpTls* tls)
{ {
while (nbytes < 5) while (nbytes < 5)
{ {
status = tls_read(tls, p, length - nbytes); status = BIO_read(tls->bio, p, length - nbytes);
if (status < 0) if (status <= 0)
goto out_error; {
if (!BIO_should_retry(tls->bio))
goto out_error;
if (!status) USleep(100);
continue; continue;
}
#ifdef HAVE_VALGRIND_MEMCHECK_H
VALGRIND_MAKE_MEM_DEFINED(p, status);
#endif
nbytes += status; nbytes += status;
p = (BYTE*) &buffer[nbytes]; p = (BYTE*) &buffer[nbytes];
} }
@ -503,7 +513,7 @@ HttpResponse* http_response_recv(rdpTls* tls)
if (!header_end) if (!header_end)
{ {
fprintf(stderr, "http_response_recv: invalid response:\n"); fprintf(stderr, "%s: invalid response:\n", __FUNCTION__);
winpr_HexDump(buffer, status); winpr_HexDump(buffer, status);
goto out_error; goto out_error;
} }
@ -517,7 +527,7 @@ HttpResponse* http_response_recv(rdpTls* tls)
header_end[0] = '\0'; header_end[0] = '\0';
header_end[1] = '\0'; header_end[1] = '\0';
content = &header_end[2]; content = header_end + 2;
count = 0; count = 0;
line = (char*) buffer; line = (char*) buffer;
@ -552,11 +562,14 @@ HttpResponse* http_response_recv(rdpTls* tls)
if (!http_response_parse_header(http_response)) if (!http_response_parse_header(http_response))
goto out_error; goto out_error;
if (http_response->ContentLength > 0) http_response->bodyLen = nbytes - (content - (char *)buffer);
if (http_response->bodyLen > 0)
{ {
http_response->Content = _strdup(content); http_response->BodyContent = (BYTE *)malloc(http_response->bodyLen);
if (!http_response->Content) if (!http_response->BodyContent)
goto out_error; goto out_error;
CopyMemory(http_response->BodyContent, content, http_response->bodyLen);
} }
break; break;
@ -627,7 +640,7 @@ void http_response_free(HttpResponse* http_response)
ListDictionary_Free(http_response->Authenticates); ListDictionary_Free(http_response->Authenticates);
if (http_response->ContentLength > 0) if (http_response->ContentLength > 0)
free(http_response->Content); free(http_response->BodyContent);
free(http_response); free(http_response);
} }

View File

@ -84,7 +84,8 @@ struct _http_response
wListDictionary *Authenticates; wListDictionary *Authenticates;
int ContentLength; int ContentLength;
char* Content; BYTE *BodyContent;
int bodyLen;
}; };
void http_response_print(HttpResponse* http_response); void http_response_print(HttpResponse* http_response);

View File

@ -98,6 +98,8 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc)
rdpNtlm* ntlm = rpc->NtlmHttpIn->ntlm; rdpNtlm* ntlm = rpc->NtlmHttpIn->ntlm;
http_response = http_response_recv(rpc->TlsIn); http_response = http_response_recv(rpc->TlsIn);
if (!http_response)
return -1;
if (ListDictionary_Contains(http_response->Authenticates, "NTLM")) if (ListDictionary_Contains(http_response->Authenticates, "NTLM"))
{ {
@ -105,14 +107,12 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc)
if (!token64) if (!token64)
goto out; goto out;
ntlm_token_data = NULL;
crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length); crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length);
} }
out:
ntlm->inputBuffer[0].pvBuffer = ntlm_token_data; ntlm->inputBuffer[0].pvBuffer = ntlm_token_data;
ntlm->inputBuffer[0].cbBuffer = ntlm_token_length; ntlm->inputBuffer[0].cbBuffer = ntlm_token_length;
out:
http_response_free(http_response); http_response_free(http_response);
return 0; return 0;
@ -123,25 +123,19 @@ int rpc_ncacn_http_ntlm_init(rdpRpc* rpc, TSG_CHANNEL channel)
rdpNtlm* ntlm = NULL; rdpNtlm* ntlm = NULL;
rdpSettings* settings = rpc->settings; rdpSettings* settings = rpc->settings;
freerdp* instance = (freerdp*) rpc->settings->instance; freerdp* instance = (freerdp*) rpc->settings->instance;
BOOL promptPassword = FALSE;
if (channel == TSG_CHANNEL_IN) if (channel == TSG_CHANNEL_IN)
ntlm = rpc->NtlmHttpIn->ntlm; ntlm = rpc->NtlmHttpIn->ntlm;
else if (channel == TSG_CHANNEL_OUT) else if (channel == TSG_CHANNEL_OUT)
ntlm = rpc->NtlmHttpOut->ntlm; ntlm = rpc->NtlmHttpOut->ntlm;
if ((!settings->GatewayPassword) || (!settings->GatewayUsername) if (!settings->GatewayPassword || !settings->GatewayUsername ||
|| (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername))) !strlen(settings->GatewayPassword) || !strlen(settings->GatewayUsername))
{
promptPassword = TRUE;
}
if (promptPassword)
{ {
if (instance->GatewayAuthenticate) if (instance->GatewayAuthenticate)
{ {
BOOL proceed = instance->GatewayAuthenticate(instance, BOOL proceed = instance->GatewayAuthenticate(instance, &settings->GatewayUsername,
&settings->GatewayUsername, &settings->GatewayPassword, &settings->GatewayDomain); &settings->GatewayPassword, &settings->GatewayDomain);
if (!proceed) if (!proceed)
{ {
@ -240,12 +234,10 @@ int rpc_ncacn_http_recv_out_channel_response(rdpRpc* rpc)
char *token64 = ListDictionary_GetItemValue(http_response->Authenticates, "NTLM"); char *token64 = ListDictionary_GetItemValue(http_response->Authenticates, "NTLM");
crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length); crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length);
} }
ntlm->inputBuffer[0].pvBuffer = ntlm_token_data; ntlm->inputBuffer[0].pvBuffer = ntlm_token_data;
ntlm->inputBuffer[0].cbBuffer = ntlm_token_length; ntlm->inputBuffer[0].cbBuffer = ntlm_token_length;
http_response_free(http_response); http_response_free(http_response);
return 0; return 0;
} }
@ -259,15 +251,12 @@ BOOL rpc_ntlm_http_out_connect(rdpRpc* rpc)
success = TRUE; success = TRUE;
/* Send OUT Channel Request */ /* Send OUT Channel Request */
rpc_ncacn_http_send_out_channel_request(rpc); rpc_ncacn_http_send_out_channel_request(rpc);
/* Receive OUT Channel Response */ /* Receive OUT Channel Response */
rpc_ncacn_http_recv_out_channel_response(rpc); rpc_ncacn_http_recv_out_channel_response(rpc);
/* Send OUT Channel Request */ /* Send OUT Channel Request */
rpc_ncacn_http_send_out_channel_request(rpc); rpc_ncacn_http_send_out_channel_request(rpc);
ntlm_client_uninit(ntlm); ntlm_client_uninit(ntlm);
@ -296,13 +285,11 @@ void rpc_ntlm_http_init_channel(rdpRpc* rpc, rdpNtlmHttp* ntlm_http, TSG_CHANNEL
if (channel == TSG_CHANNEL_IN) if (channel == TSG_CHANNEL_IN)
{ {
http_context_set_pragma(ntlm_http->context, http_context_set_pragma(ntlm_http->context, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729");
"ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729");
} }
else if (channel == TSG_CHANNEL_OUT) else if (channel == TSG_CHANNEL_OUT)
{ {
http_context_set_pragma(ntlm_http->context, http_context_set_pragma(ntlm_http->context, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729, "
"ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729" ", "
"SessionId=fbd9c34f-397d-471d-a109-1b08cc554624"); "SessionId=fbd9c34f-397d-471d-a109-1b08cc554624");
} }
} }

View File

@ -33,6 +33,11 @@
#include <winpr/dsparse.h> #include <winpr/dsparse.h>
#include <openssl/rand.h> #include <openssl/rand.h>
#include <openssl/bio.h>
#ifdef HAVE_VALGRIND_MEMCHECK_H
#include <valgrind/memcheck.h>
#endif
#include "http.h" #include "http.h"
#include "ntlm.h" #include "ntlm.h"
@ -235,80 +240,77 @@ BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* buffer, UINT32* offset, UINT32* l
{ {
UINT32 alloc_hint = 0; UINT32 alloc_hint = 0;
rpcconn_hdr_t* header; rpcconn_hdr_t* header;
UINT32 frag_length;
UINT32 auth_length;
UINT32 auth_pad_length;
UINT32 sec_trailer_offset;
rpc_sec_trailer* sec_trailer;
*offset = RPC_COMMON_FIELDS_LENGTH; *offset = RPC_COMMON_FIELDS_LENGTH;
header = ((rpcconn_hdr_t*) buffer); header = ((rpcconn_hdr_t*) buffer);
if (header->common.ptype == PTYPE_RESPONSE) switch (header->common.ptype)
{ {
*offset += 8; case PTYPE_RESPONSE:
rpc_offset_align(offset, 8); *offset += 8;
alloc_hint = header->response.alloc_hint; rpc_offset_align(offset, 8);
} alloc_hint = header->response.alloc_hint;
else if (header->common.ptype == PTYPE_REQUEST) break;
{ case PTYPE_REQUEST:
*offset += 4; *offset += 4;
rpc_offset_align(offset, 8); rpc_offset_align(offset, 8);
alloc_hint = header->request.alloc_hint; alloc_hint = header->request.alloc_hint;
} break;
else if (header->common.ptype == PTYPE_RTS) case PTYPE_RTS:
{ *offset += 4;
*offset += 4; break;
} default:
else fprintf(stderr, "%s: unknown ptype=0x%x\n", __FUNCTION__, header->common.ptype);
{ return FALSE;
return FALSE;
} }
if (length) if (!length)
return TRUE;
if (header->common.ptype == PTYPE_REQUEST)
{ {
if (header->common.ptype == PTYPE_REQUEST) UINT32 sec_trailer_offset;
{
UINT32 sec_trailer_offset;
sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8; sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8;
*length = sec_trailer_offset - *offset; *length = sec_trailer_offset - *offset;
} return TRUE;
else }
{
UINT32 frag_length;
UINT32 auth_length;
UINT32 auth_pad_length;
UINT32 sec_trailer_offset;
rpc_sec_trailer* sec_trailer;
frag_length = header->common.frag_length;
auth_length = header->common.auth_length;
sec_trailer_offset = frag_length - auth_length - 8; frag_length = header->common.frag_length;
sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset]; auth_length = header->common.auth_length;
auth_pad_length = sec_trailer->auth_pad_length;
sec_trailer_offset = frag_length - auth_length - 8;
sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset];
auth_pad_length = sec_trailer->auth_pad_length;
#if 0 #if 0
fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n", fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n",
sec_trailer->auth_type, sec_trailer->auth_type,
sec_trailer->auth_level, sec_trailer->auth_level,
sec_trailer->auth_pad_length, sec_trailer->auth_pad_length,
sec_trailer->auth_reserved, sec_trailer->auth_reserved,
sec_trailer->auth_context_id); sec_trailer->auth_context_id);
#endif #endif
/** /**
* According to [MS-RPCE], auth_pad_length is the number of padding * According to [MS-RPCE], auth_pad_length is the number of padding
* octets used to 4-byte align the security trailer, but in practice * octets used to 4-byte align the security trailer, but in practice
* we get values up to 15, which indicates 16-byte alignment. * we get values up to 15, which indicates 16-byte alignment.
*/ */
if ((frag_length - (sec_trailer_offset + 8)) != auth_length) if ((frag_length - (sec_trailer_offset + 8)) != auth_length)
{ {
fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length, fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length,
(frag_length - (sec_trailer_offset + 8))); (frag_length - (sec_trailer_offset + 8)));
}
*length = frag_length - auth_length - 24 - 8 - auth_pad_length;
}
} }
*length = frag_length - auth_length - 24 - 8 - auth_pad_length;
return TRUE; return TRUE;
} }
@ -316,12 +318,23 @@ int rpc_out_read(rdpRpc* rpc, BYTE* data, int length)
{ {
int status; int status;
status = tls_read(rpc->TlsOut, data, length); status = BIO_read(rpc->TlsOut->bio, data, length);
/* fprintf(stderr, "%s: length=%d => status=%d shouldRetry=%d\n", __FUNCTION__, length,
* status, BIO_should_retry(rpc->TlsOut->bio)); */
if (status > 0) {
#ifdef HAVE_VALGRIND_MEMCHECK_H
VALGRIND_MAKE_MEM_DEFINED(data, status);
#endif
return status;
}
return status; if (BIO_should_retry(rpc->TlsOut->bio))
return 0;
return -1;
} }
int rpc_out_write(rdpRpc* rpc, BYTE* data, int length) int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length)
{ {
int status; int status;
@ -330,7 +343,7 @@ int rpc_out_write(rdpRpc* rpc, BYTE* data, int length)
return status; return status;
} }
int rpc_in_write(rdpRpc* rpc, BYTE* data, int length) int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length)
{ {
int status; int status;
@ -360,20 +373,21 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
ntlm = rpc->ntlm; ntlm = rpc->ntlm;
if ((!ntlm) || (!ntlm->table)) if (!ntlm || !ntlm->table)
{ {
fprintf(stderr, "rpc_write: invalid ntlm context\n"); fprintf(stderr, "%s: invalid ntlm context\n", __FUNCTION__);
return -1; return -1;
} }
if (ntlm->table->QueryContextAttributes(&ntlm->context, SECPKG_ATTR_SIZES, &ntlm->ContextSizes) != SEC_E_OK) if (ntlm->table->QueryContextAttributes(&ntlm->context, SECPKG_ATTR_SIZES, &ntlm->ContextSizes) != SEC_E_OK)
{ {
fprintf(stderr, "QueryContextAttributes SECPKG_ATTR_SIZES failure\n"); fprintf(stderr, "%s: QueryContextAttributes SECPKG_ATTR_SIZES failure\n", __FUNCTION__);
return -1; return -1;
} }
request_pdu = (rpcconn_request_hdr_t*) malloc(sizeof(rpcconn_request_hdr_t)); request_pdu = (rpcconn_request_hdr_t*) calloc(1, sizeof(rpcconn_request_hdr_t));
ZeroMemory(request_pdu, sizeof(rpcconn_request_hdr_t)); if (!request_pdu)
return -1;
rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) request_pdu); rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) request_pdu);
@ -386,7 +400,11 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
request_pdu->opnum = opnum; request_pdu->opnum = opnum;
clientCall = rpc_client_call_new(request_pdu->call_id, request_pdu->opnum); clientCall = rpc_client_call_new(request_pdu->call_id, request_pdu->opnum);
ArrayList_Add(rpc->client->ClientCallList, clientCall); if (!clientCall)
goto out_free_pdu;
if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0)
goto out_free_clientCall;
if (request_pdu->opnum == TsProxySetupReceivePipeOpnum) if (request_pdu->opnum == TsProxySetupReceivePipeOpnum)
rpc->PipeCallId = request_pdu->call_id; rpc->PipeCallId = request_pdu->call_id;
@ -407,8 +425,9 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
request_pdu->frag_length = offset; request_pdu->frag_length = offset;
buffer = (BYTE*) malloc(request_pdu->frag_length); buffer = (BYTE*) calloc(1, request_pdu->frag_length);
if (!buffer)
goto out_free_pdu;
CopyMemory(buffer, request_pdu, 24); CopyMemory(buffer, request_pdu, 24);
offset = 24; offset = 24;
@ -427,15 +446,15 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
Buffers[0].cbBuffer = offset; Buffers[0].cbBuffer = offset;
Buffers[1].cbBuffer = ntlm->ContextSizes.cbMaxSignature; Buffers[1].cbBuffer = ntlm->ContextSizes.cbMaxSignature;
Buffers[1].pvBuffer = malloc(Buffers[1].cbBuffer); Buffers[1].pvBuffer = calloc(1, Buffers[1].cbBuffer);
ZeroMemory(Buffers[1].pvBuffer, Buffers[1].cbBuffer); if (!Buffers[1].pvBuffer)
return -1;
Message.cBuffers = 2; Message.cBuffers = 2;
Message.ulVersion = SECBUFFER_VERSION; Message.ulVersion = SECBUFFER_VERSION;
Message.pBuffers = (PSecBuffer) &Buffers; Message.pBuffers = (PSecBuffer) &Buffers;
encrypt_status = ntlm->table->EncryptMessage(&ntlm->context, 0, &Message, rpc->SendSeqNum++); encrypt_status = ntlm->table->EncryptMessage(&ntlm->context, 0, &Message, rpc->SendSeqNum++);
if (encrypt_status != SEC_E_OK) if (encrypt_status != SEC_E_OK)
{ {
fprintf(stderr, "EncryptMessage status: 0x%08X\n", encrypt_status); fprintf(stderr, "EncryptMessage status: 0x%08X\n", encrypt_status);
@ -447,12 +466,18 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
offset += Buffers[1].cbBuffer; offset += Buffers[1].cbBuffer;
free(Buffers[1].pvBuffer); free(Buffers[1].pvBuffer);
if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) != 0) if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) < 0)
length = -1; length = -1;
free(request_pdu); free(request_pdu);
return length; return length;
out_free_clientCall:
rpc_client_call_free(clientCall);
out_free_pdu:
free(request_pdu);
return -1;
} }
BOOL rpc_connect(rdpRpc* rpc) BOOL rpc_connect(rdpRpc* rpc)
@ -592,13 +617,17 @@ rdpRpc* rpc_new(rdpTransport* transport)
rpc->CallId = 2; rpc->CallId = 2;
rpc_client_new(rpc); if (rpc_client_new(rpc) < 0)
goto out_free_virtualConnectionCookieTable;
rpc->client->SynchronousSend = TRUE; rpc->client->SynchronousSend = TRUE;
rpc->client->SynchronousReceive = TRUE; rpc->client->SynchronousReceive = TRUE;
return rpc; return rpc;
out_free_virtualConnectionCookieTable:
rpc_client_free(rpc);
ArrayList_Free(rpc->VirtualConnectionCookieTable);
out_free_virtual_connection: out_free_virtual_connection:
rpc_client_virtual_connection_free(rpc->VirtualConnection); rpc_client_virtual_connection_free(rpc->VirtualConnection);
out_free_ntlm_http_out: out_free_ntlm_http_out:

View File

@ -772,8 +772,8 @@ UINT32 rpc_offset_pad(UINT32* offset, UINT32 pad);
int rpc_out_read(rdpRpc* rpc, BYTE* data, int length); int rpc_out_read(rdpRpc* rpc, BYTE* data, int length);
int rpc_out_write(rdpRpc* rpc, BYTE* data, int length); int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length);
int rpc_in_write(rdpRpc* rpc, BYTE* data, int length); int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length);
BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* header, UINT32* offset, UINT32* length); BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* header, UINT32* offset, UINT32* length);

View File

@ -103,6 +103,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
DEBUG_RPC("Sending bind PDU"); DEBUG_RPC("Sending bind PDU");
rpc->ntlm = ntlm_new(); rpc->ntlm = ntlm_new();
if (!rpc->ntlm)
return -1;
if ((!settings->GatewayPassword) || (!settings->GatewayUsername) if ((!settings->GatewayPassword) || (!settings->GatewayUsername)
|| (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername))) || (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername)))
@ -129,17 +131,22 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
settings->Username = _strdup(settings->GatewayUsername); settings->Username = _strdup(settings->GatewayUsername);
settings->Domain = _strdup(settings->GatewayDomain); settings->Domain = _strdup(settings->GatewayDomain);
settings->Password = _strdup(settings->GatewayPassword); settings->Password = _strdup(settings->GatewayPassword);
if (!settings->Username || !settings->Domain || settings->Password)
return -1;
} }
} }
} }
ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL); if (!ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL) ||
ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname); !ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname) ||
!ntlm_authenticate(rpc->ntlm)
)
return -1;
ntlm_authenticate(rpc->ntlm); bind_pdu = (rpcconn_bind_hdr_t*) calloc(1, sizeof(rpcconn_bind_hdr_t));
if (!bind_pdu)
bind_pdu = (rpcconn_bind_hdr_t*) malloc(sizeof(rpcconn_bind_hdr_t)); return -1;
ZeroMemory(bind_pdu, sizeof(rpcconn_bind_hdr_t));
rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) bind_pdu); rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) bind_pdu);
@ -159,6 +166,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
bind_pdu->p_context_elem.reserved2 = 0; bind_pdu->p_context_elem.reserved2 = 0;
bind_pdu->p_context_elem.p_cont_elem = malloc(sizeof(p_cont_elem_t) * bind_pdu->p_context_elem.n_context_elem); bind_pdu->p_context_elem.p_cont_elem = malloc(sizeof(p_cont_elem_t) * bind_pdu->p_context_elem.n_context_elem);
if (!bind_pdu->p_context_elem.p_cont_elem)
return -1;
p_cont_elem = &bind_pdu->p_context_elem.p_cont_elem[0]; p_cont_elem = &bind_pdu->p_context_elem.p_cont_elem[0];
@ -196,6 +205,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
bind_pdu->frag_length = offset; bind_pdu->frag_length = offset;
buffer = (BYTE*) malloc(bind_pdu->frag_length); buffer = (BYTE*) malloc(bind_pdu->frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, bind_pdu, 24); CopyMemory(buffer, bind_pdu, 24);
CopyMemory(&buffer[24], &bind_pdu->p_context_elem, 4); CopyMemory(&buffer[24], &bind_pdu->p_context_elem, 4);
@ -214,7 +225,10 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
length = bind_pdu->frag_length; length = bind_pdu->frag_length;
clientCall = rpc_client_call_new(bind_pdu->call_id, 0); clientCall = rpc_client_call_new(bind_pdu->call_id, 0);
ArrayList_Add(rpc->client->ClientCallList, clientCall); if (!clientCall)
return -1;
if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0)
return -1;
if (rpc_send_enqueue_pdu(rpc, buffer, length) != 0) if (rpc_send_enqueue_pdu(rpc, buffer, length) != 0)
length = -1; length = -1;

View File

@ -34,9 +34,7 @@
#include <winpr/stream.h> #include <winpr/stream.h>
#include "rpc_fault.h" #include "rpc_fault.h"
#include "rpc_client.h" #include "rpc_client.h"
#include "../rdp.h" #include "../rdp.h"
#define SYNCHRONOUS_TIMEOUT 5000 #define SYNCHRONOUS_TIMEOUT 5000
@ -69,8 +67,15 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc)
if (!pdu) if (!pdu)
{ {
pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU)); pdu = (RPC_PDU *)malloc(sizeof(RPC_PDU));
if (!pdu)
return NULL;
pdu->s = Stream_New(NULL, rpc->max_recv_frag); pdu->s = Stream_New(NULL, rpc->max_recv_frag);
if (!pdu->s)
{
free(pdu);
return NULL;
}
} }
pdu->CallId = 0; pdu->CallId = 0;
@ -84,8 +89,7 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc)
int rpc_client_receive_pool_return(rdpRpc* rpc, RPC_PDU* pdu) int rpc_client_receive_pool_return(rdpRpc* rpc, RPC_PDU* pdu)
{ {
Queue_Enqueue(rpc->client->ReceivePool, pdu); return Queue_Enqueue(rpc->client->ReceivePool, pdu) == TRUE ? 0 : -1;
return 0;
} }
int rpc_client_on_fragment_received_event(rdpRpc* rpc) int rpc_client_on_fragment_received_event(rdpRpc* rpc)
@ -97,7 +101,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
rpcconn_hdr_t* header; rpcconn_hdr_t* header;
freerdp* instance; freerdp* instance;
instance = (freerdp*) rpc->transport->settings->instance; instance = (freerdp *)rpc->transport->settings->instance;
if (!rpc->client->pdu) if (!rpc->client->pdu)
rpc->client->pdu = rpc_client_receive_pool_take(rpc); rpc->client->pdu = rpc_client_receive_pool_take(rpc);
@ -125,34 +129,29 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
return 0; return 0;
} }
if (header->common.ptype == PTYPE_RTS) switch (header->common.ptype)
{ {
if (rpc->VirtualConnection->State >= VIRTUAL_CONNECTION_STATE_OPENED) case PTYPE_RTS:
{ if (rpc->VirtualConnection->State < VIRTUAL_CONNECTION_STATE_OPENED)
//fprintf(stderr, "Receiving Out-of-Sequence RTS PDU\n"); {
fprintf(stderr, "%s: warning: unhandled RTS PDU\n", __FUNCTION__);
return 0;
}
fprintf(stderr, "%s: Receiving Out-of-Sequence RTS PDU\n", __FUNCTION__);
rts_recv_out_of_sequence_pdu(rpc, buffer, header->common.frag_length); rts_recv_out_of_sequence_pdu(rpc, buffer, header->common.frag_length);
rpc_client_fragment_pool_return(rpc, fragment); rpc_client_fragment_pool_return(rpc, fragment);
} return 0;
else
{
fprintf(stderr, "warning: unhandled RTS PDU\n");
}
return 0; case PTYPE_FAULT:
} rpc_recv_fault_pdu(header);
else if (header->common.ptype == PTYPE_FAULT) Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
{ return -1;
rpc_recv_fault_pdu(header); case PTYPE_RESPONSE:
Queue_Enqueue(rpc->client->ReceiveQueue, NULL); break;
return -1; default:
} fprintf(stderr, "%s: unexpected RPC PDU type %d\n", __FUNCTION__, header->common.ptype);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
if (header->common.ptype != PTYPE_RESPONSE) return -1;
{
fprintf(stderr, "Unexpected RPC PDU type: %d\n", header->common.ptype);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1;
} }
rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header->common.frag_length; rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header->common.frag_length;
@ -160,7 +159,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
if (!rpc_get_stub_data_info(rpc, buffer, &StubOffset, &StubLength)) if (!rpc_get_stub_data_info(rpc, buffer, &StubOffset, &StubLength))
{ {
fprintf(stderr, "rpc_recv_pdu_fragment: expected stub\n"); fprintf(stderr, "%s: expected stub\n", __FUNCTION__);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL); Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1; return -1;
} }
@ -196,7 +195,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
if (rpc->StubCallId != header->common.call_id) if (rpc->StubCallId != header->common.call_id)
{ {
fprintf(stderr, "invalid call_id: actual: %d, expected: %d, frag_count: %d\n", fprintf(stderr, "%s: invalid call_id: actual: %d, expected: %d, frag_count: %d\n", __FUNCTION__,
rpc->StubCallId, header->common.call_id, rpc->StubFragCount); rpc->StubCallId, header->common.call_id, rpc->StubFragCount);
} }
@ -243,27 +242,34 @@ int rpc_client_on_read_event(rdpRpc* rpc)
int status = -1; int status = -1;
rpcconn_common_hdr_t* header; rpcconn_common_hdr_t* header;
if (!rpc->client->RecvFrag) while (1)
rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
position = Stream_GetPosition(rpc->client->RecvFrag);
if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
{ {
status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag), if (!rpc->client->RecvFrag)
RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag)); rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
if (status < 0) position = Stream_GetPosition(rpc->client->RecvFrag);
while (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
{ {
fprintf(stderr, "rpc_client_frag_read: error reading header\n"); status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
return -1; RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag));
if (status < 0)
{
fprintf(stderr, "rpc_client_frag_read: error reading header\n");
return -1;
}
if (!status)
return 0;
Stream_Seek(rpc->client->RecvFrag, status);
} }
Stream_Seek(rpc->client->RecvFrag, status); if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
} return status;
if (Stream_GetPosition(rpc->client->RecvFrag) >= RPC_COMMON_FIELDS_LENGTH)
{
header = (rpcconn_common_hdr_t*) Stream_Buffer(rpc->client->RecvFrag); header = (rpcconn_common_hdr_t*) Stream_Buffer(rpc->client->RecvFrag);
if (header->frag_length > rpc->max_recv_frag) if (header->frag_length > rpc->max_recv_frag)
@ -274,45 +280,44 @@ int rpc_client_on_read_event(rdpRpc* rpc)
return -1; return -1;
} }
if (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length) while (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length)
{ {
status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag), status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
header->frag_length - Stream_GetPosition(rpc->client->RecvFrag)); header->frag_length - Stream_GetPosition(rpc->client->RecvFrag));
if (status < 0) if (status < 0)
{ {
fprintf(stderr, "rpc_client_frag_read: error reading fragment body\n"); fprintf(stderr, "%s: error reading fragment body\n", __FUNCTION__);
return -1; return -1;
} }
if (!status)
return 0;
Stream_Seek(rpc->client->RecvFrag, status); Stream_Seek(rpc->client->RecvFrag, status);
} }
}
else
{
return status;
}
if (status < 0) if (status < 0)
return -1;
status = Stream_GetPosition(rpc->client->RecvFrag) - position;
if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length)
{
/* complete fragment received */
Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag);
Stream_SetPosition(rpc->client->RecvFrag, 0);
Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag);
rpc->client->RecvFrag = NULL;
if (rpc_client_on_fragment_received_event(rpc) < 0)
return -1; return -1;
status = Stream_GetPosition(rpc->client->RecvFrag) - position;
if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length)
{
/* complete fragment received */
Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag);
Stream_SetPosition(rpc->client->RecvFrag, 0);
Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag);
rpc->client->RecvFrag = NULL;
if (rpc_client_on_fragment_received_event(rpc) < 0)
return -1;
}
} }
return status; return 0;
} }
/** /**
@ -349,13 +354,12 @@ RpcClientCall* rpc_client_call_new(UINT32 CallId, UINT32 OpNum)
RpcClientCall* clientCall; RpcClientCall* clientCall;
clientCall = (RpcClientCall*) malloc(sizeof(RpcClientCall)); clientCall = (RpcClientCall*) malloc(sizeof(RpcClientCall));
if (!clientCall)
return NULL;
if (clientCall) clientCall->CallId = CallId;
{ clientCall->OpNum = OpNum;
clientCall->CallId = CallId; clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS;
clientCall->OpNum = OpNum;
clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS;
}
return clientCall; return clientCall;
} }
@ -371,16 +375,22 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
int status; int status;
pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU)); pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU));
pdu->s = Stream_New(buffer, length); if (!pdu)
return -1;
Queue_Enqueue(rpc->client->SendQueue, pdu); pdu->s = Stream_New(buffer, length);
if (!pdu->s)
goto out_free;
if (!Queue_Enqueue(rpc->client->SendQueue, pdu))
goto out_free_stream;
if (rpc->client->SynchronousSend) if (rpc->client->SynchronousSend)
{ {
status = WaitForSingleObject(rpc->client->PduSentEvent, SYNCHRONOUS_TIMEOUT); status = WaitForSingleObject(rpc->client->PduSentEvent, SYNCHRONOUS_TIMEOUT);
if (status == WAIT_TIMEOUT) if (status == WAIT_TIMEOUT)
{ {
fprintf(stderr, "rpc_send_enqueue_pdu: timed out waiting for pdu sent event\n"); fprintf(stderr, "%s: timed out waiting for pdu sent event %p\n", __FUNCTION__, rpc->client->PduSentEvent);
return -1; return -1;
} }
@ -388,6 +398,12 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
} }
return 0; return 0;
out_free_stream:
Stream_Free(pdu->s, TRUE);
out_free:
free(pdu);
return -1;
} }
int rpc_send_dequeue_pdu(rdpRpc* rpc) int rpc_send_dequeue_pdu(rdpRpc* rpc)
@ -396,13 +412,14 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
RPC_PDU* pdu; RPC_PDU* pdu;
RpcClientCall* clientCall; RpcClientCall* clientCall;
rpcconn_common_hdr_t* header; rpcconn_common_hdr_t* header;
RpcInChannel *inChannel;
pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->SendQueue); pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->SendQueue);
if (!pdu) if (!pdu)
return 0; return 0;
WaitForSingleObject(rpc->VirtualConnection->DefaultInChannel->Mutex, INFINITE); inChannel = rpc->VirtualConnection->DefaultInChannel;
WaitForSingleObject(inChannel->Mutex, INFINITE);
status = rpc_in_write(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s)); status = rpc_in_write(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s));
@ -410,7 +427,7 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
clientCall = rpc_client_call_find_by_id(rpc, header->call_id); clientCall = rpc_client_call_find_by_id(rpc, header->call_id);
clientCall->State = RPC_CLIENT_CALL_STATE_DISPATCHED; clientCall->State = RPC_CLIENT_CALL_STATE_DISPATCHED;
ReleaseMutex(rpc->VirtualConnection->DefaultInChannel->Mutex); ReleaseMutex(inChannel->Mutex);
/* /*
* This protocol specifies that only RPC PDUs are subject to the flow control abstract * This protocol specifies that only RPC PDUs are subject to the flow control abstract
@ -421,8 +438,8 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
if (header->ptype == PTYPE_REQUEST) if (header->ptype == PTYPE_REQUEST)
{ {
rpc->VirtualConnection->DefaultInChannel->BytesSent += status; inChannel->BytesSent += status;
rpc->VirtualConnection->DefaultInChannel->SenderAvailableWindow -= status; inChannel->SenderAvailableWindow -= status;
} }
Stream_Free(pdu->s, TRUE); Stream_Free(pdu->s, TRUE);
@ -440,57 +457,48 @@ RPC_PDU* rpc_recv_dequeue_pdu(rdpRpc* rpc)
DWORD dwMilliseconds; DWORD dwMilliseconds;
DWORD result; DWORD result;
pdu = NULL; dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT * 4 : 0;
dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0;
result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds); result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds);
if (result == WAIT_TIMEOUT) if (result == WAIT_TIMEOUT)
{ {
fprintf(stderr, "rpc_recv_dequeue_pdu: timed out waiting for receive event\n"); fprintf(stderr, "%s: timed out waiting for receive event\n", __FUNCTION__);
return NULL; return NULL;
} }
if (result == WAIT_OBJECT_0) if (result != WAIT_OBJECT_0)
{ return NULL;
pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->ReceiveQueue);
pdu = (RPC_PDU *)Queue_Dequeue(rpc->client->ReceiveQueue);
#ifdef WITH_DEBUG_TSG #ifdef WITH_DEBUG_TSG
if (pdu) if (pdu)
{ {
fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId); fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId);
winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s)); winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s));
fprintf(stderr, "\n"); fprintf(stderr, "\n");
}
#endif
return pdu;
} }
else
{
fprintf(stderr, "Receiving a NULL PDU\n");
}
#endif
return pdu; return pdu;
} }
RPC_PDU* rpc_recv_peek_pdu(rdpRpc* rpc) RPC_PDU* rpc_recv_peek_pdu(rdpRpc* rpc)
{ {
RPC_PDU* pdu;
DWORD dwMilliseconds; DWORD dwMilliseconds;
DWORD result; DWORD result;
pdu = NULL;
dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0; dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0;
result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds); result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds);
if (result == WAIT_TIMEOUT) if (result != WAIT_OBJECT_0)
{
return NULL; return NULL;
}
if (result == WAIT_OBJECT_0) return (RPC_PDU *)Queue_Peek(rpc->client->ReceiveQueue);
{
pdu = (RPC_PDU*) Queue_Peek(rpc->client->ReceiveQueue);
return pdu;
}
return pdu;
} }
static void* rpc_client_thread(void* arg) static void* rpc_client_thread(void* arg)
@ -500,40 +508,52 @@ static void* rpc_client_thread(void* arg)
DWORD nCount; DWORD nCount;
HANDLE events[3]; HANDLE events[3];
HANDLE ReadEvent; HANDLE ReadEvent;
int fd;
rpc = (rdpRpc*) arg; rpc = (rdpRpc*) arg;
fd = BIO_get_fd(rpc->TlsOut->bio, NULL);
ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, rpc->TlsOut->sockfd); ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, fd);
nCount = 0; nCount = 0;
events[nCount++] = rpc->client->StopEvent; events[nCount++] = rpc->client->StopEvent;
events[nCount++] = Queue_Event(rpc->client->SendQueue); events[nCount++] = Queue_Event(rpc->client->SendQueue);
events[nCount++] = ReadEvent; events[nCount++] = ReadEvent;
/* Do a first free run in case some bytes were set from the HTTP headers.
* We also have to do it because most of the time the underlying socket has notified,
* and the ssl layer has eaten all bytes, so we won't be notified any more even if the
* bytes are buffered locally
*/
if (rpc_client_on_read_event(rpc) < 0)
{
fprintf(stderr, "%s: an error occured when treating first packet\n", __FUNCTION__);
goto out;
}
while (rpc->transport->layer != TRANSPORT_LAYER_CLOSED) while (rpc->transport->layer != TRANSPORT_LAYER_CLOSED)
{ {
status = WaitForMultipleObjects(nCount, events, FALSE, 100); status = WaitForMultipleObjects(nCount, events, FALSE, 100);
if (status != WAIT_TIMEOUT) if (status == WAIT_TIMEOUT)
continue;
if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0)
break;
if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0)
{ {
if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0) if (rpc_client_on_read_event(rpc) < 0)
{
break; break;
} }
if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0) if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0)
{ {
if (rpc_client_on_read_event(rpc) < 0) rpc_send_dequeue_pdu(rpc);
break;
}
if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0)
{
rpc_send_dequeue_pdu(rpc);
}
} }
} }
out:
CloseHandle(ReadEvent); CloseHandle(ReadEvent);
return NULL; return NULL;
@ -541,6 +561,9 @@ static void* rpc_client_thread(void* arg)
static void rpc_pdu_free(RPC_PDU* pdu) static void rpc_pdu_free(RPC_PDU* pdu)
{ {
if (!pdu)
return;
Stream_Free(pdu->s, TRUE); Stream_Free(pdu->s, TRUE);
free(pdu); free(pdu);
} }
@ -554,35 +577,55 @@ int rpc_client_new(rdpRpc* rpc)
{ {
RpcClient* client = NULL; RpcClient* client = NULL;
client = (RpcClient*) calloc(1, sizeof(RpcClient)); client = (RpcClient *)calloc(1, sizeof(RpcClient));
if (client)
{
client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
client->SendQueue = Queue_New(TRUE, -1, -1);
Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->pdu = NULL;
client->ReceivePool = Queue_New(TRUE, -1, -1);
client->ReceiveQueue = Queue_New(TRUE, -1, -1);
Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->RecvFrag = NULL;
client->FragmentPool = Queue_New(TRUE, -1, -1);
client->FragmentQueue = Queue_New(TRUE, -1, -1);
Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
client->ClientCallList = ArrayList_New(TRUE);
ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free;
}
rpc->client = client; rpc->client = client;
if (!client)
return -1;
client->Thread = CreateThread(NULL, 0,
(LPTHREAD_START_ROUTINE) rpc_client_thread,
rpc, CREATE_SUSPENDED, NULL);
if (!client->Thread)
return -1;
client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
if (!client->StopEvent)
return -1;
client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
if (!client->PduSentEvent)
return -1;
client->SendQueue = Queue_New(TRUE, -1, -1);
if (!client->SendQueue)
return -1;
Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->pdu = NULL;
client->ReceivePool = Queue_New(TRUE, -1, -1);
if (!client->ReceivePool)
return -1;
Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->ReceiveQueue = Queue_New(TRUE, -1, -1);
if (!client->ReceiveQueue)
return -1;
Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->RecvFrag = NULL;
client->FragmentPool = Queue_New(TRUE, -1, -1);
if (!client->FragmentPool)
return -1;
Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
client->FragmentQueue = Queue_New(TRUE, -1, -1);
if (!client->FragmentQueue)
return -1;
Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
client->ClientCallList = ArrayList_New(TRUE);
if (!client->ClientCallList)
return -1;
ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free;
return 0; return 0;
} }
@ -604,9 +647,7 @@ int rpc_client_stop(rdpRpc* rpc)
rpc->client->Thread = NULL; rpc->client->Thread = NULL;
} }
rpc_client_free(rpc); return rpc_client_free(rpc);
return 0;
} }
int rpc_client_free(rdpRpc* rpc) int rpc_client_free(rdpRpc* rpc)
@ -615,31 +656,39 @@ int rpc_client_free(rdpRpc* rpc)
client = rpc->client; client = rpc->client;
if (client) if (!client)
{ return 0;
if (client->SendQueue)
Queue_Free(client->SendQueue); Queue_Free(client->SendQueue);
if (client->RecvFrag) if (client->RecvFrag)
rpc_fragment_free(client->RecvFrag); rpc_fragment_free(client->RecvFrag);
if (client->FragmentPool)
Queue_Free(client->FragmentPool); Queue_Free(client->FragmentPool);
if (client->FragmentQueue)
Queue_Free(client->FragmentQueue); Queue_Free(client->FragmentQueue);
if (client->pdu) if (client->pdu)
rpc_pdu_free(client->pdu); rpc_pdu_free(client->pdu);
if (client->ReceivePool)
Queue_Free(client->ReceivePool); Queue_Free(client->ReceivePool);
if (client->ReceiveQueue)
Queue_Free(client->ReceiveQueue); Queue_Free(client->ReceiveQueue);
if (client->ClientCallList)
ArrayList_Free(client->ClientCallList); ArrayList_Free(client->ClientCallList);
if (client->StopEvent)
CloseHandle(client->StopEvent); CloseHandle(client->StopEvent);
if (client->PduSentEvent)
CloseHandle(client->PduSentEvent); CloseHandle(client->PduSentEvent);
if (client->Thread)
CloseHandle(client->Thread); CloseHandle(client->Thread);
free(client); free(client);
}
return 0; return 0;
} }

View File

@ -93,25 +93,25 @@ BOOL rts_connect(rdpRpc* rpc)
if (!rpc_ntlm_http_out_connect(rpc)) if (!rpc_ntlm_http_out_connect(rpc))
{ {
fprintf(stderr, "rpc_out_connect_http error!\n"); fprintf(stderr, "%s: rpc_out_connect_http error!\n", __FUNCTION__);
return FALSE; return FALSE;
} }
if (rts_send_CONN_A1_pdu(rpc) != 0) if (rts_send_CONN_A1_pdu(rpc) != 0)
{ {
fprintf(stderr, "rpc_send_CONN_A1_pdu error!\n"); fprintf(stderr, "%s: rpc_send_CONN_A1_pdu error!\n", __FUNCTION__);
return FALSE; return FALSE;
} }
if (!rpc_ntlm_http_in_connect(rpc)) if (!rpc_ntlm_http_in_connect(rpc))
{ {
fprintf(stderr, "rpc_in_connect_http error!\n"); fprintf(stderr, "%s: rpc_in_connect_http error!\n", __FUNCTION__);
return FALSE; return FALSE;
} }
if (rts_send_CONN_B1_pdu(rpc) != 0) if (rts_send_CONN_B1_pdu(rpc) < 0)
{ {
fprintf(stderr, "rpc_send_CONN_B1_pdu error!\n"); fprintf(stderr, "%s: rpc_send_CONN_B1_pdu error!\n", __FUNCTION__);
return FALSE; return FALSE;
} }
@ -147,10 +147,15 @@ BOOL rts_connect(rdpRpc* rpc)
*/ */
http_response = http_response_recv(rpc->TlsOut); http_response = http_response_recv(rpc->TlsOut);
if (!http_response)
{
fprintf(stderr, "%s: unable to retrieve OUT Channel Response!\n", __FUNCTION__);
return FALSE;
}
if (http_response->StatusCode != HTTP_STATUS_OK) if (http_response->StatusCode != HTTP_STATUS_OK)
{ {
fprintf(stderr, "rts_connect error! Status Code: %d\n", http_response->StatusCode); fprintf(stderr, "%s: error! Status Code: %d\n", __FUNCTION__, http_response->StatusCode);
http_response_print(http_response); http_response_print(http_response);
http_response_free(http_response); http_response_free(http_response);
@ -170,6 +175,14 @@ BOOL rts_connect(rdpRpc* rpc)
return FALSE; return FALSE;
} }
if (http_response->bodyLen)
{
/* inject bytes we have read in the body as a received packet for the RPC client */
rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
Stream_EnsureCapacity(rpc->client->RecvFrag, http_response->bodyLen);
CopyMemory(rpc->client->RecvFrag, http_response->BodyContent, http_response->bodyLen);
}
//http_response_print(http_response); //http_response_print(http_response);
http_response_free(http_response); http_response_free(http_response);
@ -195,7 +208,6 @@ BOOL rts_connect(rdpRpc* rpc)
rpc_client_start(rpc); rpc_client_start(rpc);
pdu = rpc_recv_dequeue_pdu(rpc); pdu = rpc_recv_dequeue_pdu(rpc);
if (!pdu) if (!pdu)
return FALSE; return FALSE;
@ -203,7 +215,7 @@ BOOL rts_connect(rdpRpc* rpc)
if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_A3_SIGNATURE, rts)) if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_A3_SIGNATURE, rts))
{ {
fprintf(stderr, "Unexpected RTS PDU: Expected CONN/A3\n"); fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/A3\n", __FUNCTION__);
return FALSE; return FALSE;
} }
@ -236,7 +248,6 @@ BOOL rts_connect(rdpRpc* rpc)
*/ */
pdu = rpc_recv_dequeue_pdu(rpc); pdu = rpc_recv_dequeue_pdu(rpc);
if (!pdu) if (!pdu)
return FALSE; return FALSE;
@ -244,7 +255,7 @@ BOOL rts_connect(rdpRpc* rpc)
if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_C2_SIGNATURE, rts)) if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_C2_SIGNATURE, rts))
{ {
fprintf(stderr, "Unexpected RTS PDU: Expected CONN/C2\n"); fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/C2\n", __FUNCTION__);
return FALSE; return FALSE;
} }
@ -261,7 +272,7 @@ BOOL rts_connect(rdpRpc* rpc)
return TRUE; return TRUE;
} }
#if defined WITH_DEBUG_RTS && 0 #ifdef WITH_DEBUG_RTS
static const char* const RTS_CMD_STRINGS[] = static const char* const RTS_CMD_STRINGS[] =
{ {
@ -317,6 +328,7 @@ static const char* const RTS_CMD_STRINGS[] =
void rts_pdu_header_init(rpcconn_rts_hdr_t* header) void rts_pdu_header_init(rpcconn_rts_hdr_t* header)
{ {
ZeroMemory(header, sizeof(*header));
header->rpc_vers = 5; header->rpc_vers = 5;
header->rpc_vers_minor = 0; header->rpc_vers_minor = 0;
header->ptype = PTYPE_RTS; header->ptype = PTYPE_RTS;
@ -681,6 +693,8 @@ int rts_send_CONN_A1_pdu(rdpRpc* rpc)
ReceiveWindowSize = rpc->VirtualConnection->DefaultOutChannel->ReceiveWindow; ReceiveWindowSize = rpc->VirtualConnection->DefaultOutChannel->ReceiveWindow;
buffer = (BYTE*) malloc(header.frag_length); buffer = (BYTE*) malloc(header.frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_version_command_write(&buffer[20]); /* Version (8 bytes) */ rts_version_command_write(&buffer[20]); /* Version (8 bytes) */
@ -718,6 +732,7 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc)
BYTE* INChannelCookie; BYTE* INChannelCookie;
BYTE* AssociationGroupId; BYTE* AssociationGroupId;
BYTE* VirtualConnectionCookie; BYTE* VirtualConnectionCookie;
int status;
rts_pdu_header_init(&header); rts_pdu_header_init(&header);
header.frag_length = 104; header.frag_length = 104;
@ -734,6 +749,8 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc)
AssociationGroupId = (BYTE*) &(rpc->VirtualConnection->AssociationGroupId); AssociationGroupId = (BYTE*) &(rpc->VirtualConnection->AssociationGroupId);
buffer = (BYTE*) malloc(header.frag_length); buffer = (BYTE*) malloc(header.frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_version_command_write(&buffer[20]); /* Version (8 bytes) */ rts_version_command_write(&buffer[20]); /* Version (8 bytes) */
@ -745,11 +762,11 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc)
length = header.frag_length; length = header.frag_length;
rpc_in_write(rpc, buffer, length); status = rpc_in_write(rpc, buffer, length);
free(buffer); free(buffer);
return 0; return status;
} }
/* CONN/C Sequence */ /* CONN/C Sequence */
@ -795,12 +812,15 @@ int rts_send_keep_alive_pdu(rdpRpc* rpc)
DEBUG_RPC("Sending Keep-Alive RTS PDU"); DEBUG_RPC("Sending Keep-Alive RTS PDU");
buffer = (BYTE*) malloc(header.frag_length); buffer = (BYTE*) malloc(header.frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_client_keepalive_command_write(&buffer[20], rpc->CurrentKeepAliveInterval); /* ClientKeepAlive (8 bytes) */ rts_client_keepalive_command_write(&buffer[20], rpc->CurrentKeepAliveInterval); /* ClientKeepAlive (8 bytes) */
length = header.frag_length; length = header.frag_length;
rpc_in_write(rpc, buffer, length); if (rpc_in_write(rpc, buffer, length) < 0)
return -1;
free(buffer); free(buffer);
return length; return length;
@ -830,6 +850,8 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc)
rpc->VirtualConnection->DefaultOutChannel->AvailableWindowAdvertised; rpc->VirtualConnection->DefaultOutChannel->AvailableWindowAdvertised;
buffer = (BYTE*) malloc(header.frag_length); buffer = (BYTE*) malloc(header.frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_destination_command_write(&buffer[20], FDOutProxy); /* Destination Command (8 bytes) */ rts_destination_command_write(&buffer[20], FDOutProxy); /* Destination Command (8 bytes) */
@ -839,7 +861,8 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc)
length = header.frag_length; length = header.frag_length;
rpc_in_write(rpc, buffer, length); if (rpc_in_write(rpc, buffer, length) < 0)
return -1;
free(buffer); free(buffer);
return 0; return 0;
@ -923,12 +946,15 @@ int rts_send_ping_pdu(rdpRpc* rpc)
DEBUG_RPC("Sending Ping RTS PDU"); DEBUG_RPC("Sending Ping RTS PDU");
buffer = (BYTE*) malloc(header.frag_length); buffer = (BYTE*) malloc(header.frag_length);
if (!buffer)
return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
length = header.frag_length; length = header.frag_length;
rpc_in_write(rpc, buffer, length); if (rpc_in_write(rpc, buffer, length) < 0)
return -1;
free(buffer); free(buffer);
return length; return length;
@ -1020,22 +1046,18 @@ int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
rts_extract_pdu_signature(rpc, &signature, rts); rts_extract_pdu_signature(rpc, &signature, rts);
SignatureId = rts_identify_pdu_signature(rpc, &signature, NULL); SignatureId = rts_identify_pdu_signature(rpc, &signature, NULL);
if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK) switch (SignatureId)
{ {
return rts_recv_flow_control_ack_pdu(rpc, buffer, length); case RTS_PDU_FLOW_CONTROL_ACK:
} return rts_recv_flow_control_ack_pdu(rpc, buffer, length);
else if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION) case RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION:
{ return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length);
return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length); case RTS_PDU_PING:
} return rts_send_ping_pdu(rpc);
else if (SignatureId == RTS_PDU_PING) default:
{ fprintf(stderr, "%s: unimplemented signature id: 0x%08X\n", __FUNCTION__, SignatureId);
rts_send_ping_pdu(rpc); rts_print_pdu_signature(rpc, &signature);
} break;
else
{
fprintf(stderr, "Unimplemented signature id: 0x%08X\n", SignatureId);
rts_print_pdu_signature(rpc, &signature);
} }
return 0; return 0;

View File

@ -234,7 +234,6 @@ BOOL rts_match_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_rt
return FALSE; return FALSE;
status = rts_command_length(rpc, CommandType, &buffer[offset], length); status = rts_command_length(rpc, CommandType, &buffer[offset], length);
if (status < 0) if (status < 0)
return FALSE; return FALSE;
@ -272,7 +271,6 @@ int rts_extract_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_r
signature->CommandTypes[i] = CommandType; signature->CommandTypes[i] = CommandType;
status = rts_command_length(rpc, CommandType, &buffer[offset], length); status = rts_command_length(rpc, CommandType, &buffer[offset], length);
if (status < 0) if (status < 0)
return FALSE; return FALSE;
@ -294,22 +292,22 @@ UINT32 rts_identify_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, RTS_P
{ {
pSignature = RTS_PDU_SIGNATURE_TABLE[i].Signature; pSignature = RTS_PDU_SIGNATURE_TABLE[i].Signature;
if (signature->Flags == pSignature->Flags) if (signature->Flags != pSignature->Flags)
continue;
if (signature->NumberOfCommands != pSignature->NumberOfCommands)
continue;
for (j = 0; j < signature->NumberOfCommands; j++)
{ {
if (signature->NumberOfCommands == pSignature->NumberOfCommands) if (signature->CommandTypes[j] != pSignature->CommandTypes[j])
{ continue;
for (j = 0; j < signature->NumberOfCommands; j++)
{
if (signature->CommandTypes[j] != pSignature->CommandTypes[j])
continue;
}
if (entry)
*entry = &RTS_PDU_SIGNATURE_TABLE[i];
return RTS_PDU_SIGNATURE_TABLE[i].SignatureId;
}
} }
if (entry)
*entry = &RTS_PDU_SIGNATURE_TABLE[i];
return RTS_PDU_SIGNATURE_TABLE[i].SignatureId;
} }
return 0; return 0;

View File

@ -33,9 +33,9 @@
#include <winpr/stream.h> #include <winpr/stream.h>
#include "rpc_client.h" #include "rpc_client.h"
#include "tsg.h" #include "tsg.h"
/** /**
* RPC Functions: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378623/ * RPC Functions: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378623/
* Remote Procedure Call: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378651/ * Remote Procedure Call: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378651/
@ -96,7 +96,9 @@ DWORD TsProxySendToServer(handle_t IDL_handle, byte pRpcMessage[], UINT32 count,
} }
length = 28 + totalDataBytes; length = 28 + totalDataBytes;
buffer = (BYTE*) malloc(length); buffer = (BYTE*) calloc(1, length);
if (!buffer)
return -1;
s = Stream_New(buffer, length); s = Stream_New(buffer, length);
@ -228,8 +230,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) if (!(pdu->Flags & RPC_PDU_FLAG_STUB))
buffer = &buffer[24]; buffer = &buffer[24];
packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET)); packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET));
ZeroMemory(packet, sizeof(TSG_PACKET)); if (!packet)
return FALSE;
offset = 4; // Skip Packet Pointer offset = 4; // Skip Packet Pointer
packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */ packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */
@ -237,8 +240,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
if ((packet->packetId == TSG_PACKET_TYPE_CAPS_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_CAPS_RESPONSE)) if ((packet->packetId == TSG_PACKET_TYPE_CAPS_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_CAPS_RESPONSE))
{ {
packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) malloc(sizeof(TSG_PACKET_CAPS_RESPONSE)); packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) calloc(1, sizeof(TSG_PACKET_CAPS_RESPONSE));
ZeroMemory(packetCapsResponse, sizeof(TSG_PACKET_CAPS_RESPONSE)); if (!packetCapsResponse) // TODO: correct cleanup
return FALSE;
packet->tsgPacket.packetCapsResponse = packetCapsResponse; packet->tsgPacket.packetCapsResponse = packetCapsResponse;
/* PacketQuarResponsePtr (4 bytes) */ /* PacketQuarResponsePtr (4 bytes) */
@ -258,8 +262,7 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
IsMessagePresent = *((UINT32*) &buffer[offset]); IsMessagePresent = *((UINT32*) &buffer[offset]);
offset += 4; offset += 4;
MessageSwitchValue = *((UINT32*) &buffer[offset]); MessageSwitchValue = *((UINT32*) &buffer[offset]);
DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d", DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d", IsMessagePresent, MessageSwitchValue);
IsMessagePresent, MessageSwitchValue);
offset += 4; offset += 4;
} }
@ -289,8 +292,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
offset += 4; offset += 4;
} }
versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS)); versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS));
ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS)); if (!versionCaps) // TODO: correct cleanup
return FALSE;
packetCapsResponse->pktQuarEncResponse.versionCaps = versionCaps; packetCapsResponse->pktQuarEncResponse.versionCaps = versionCaps;
versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */ versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */
@ -317,8 +321,10 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
/* 4-byte alignment */ /* 4-byte alignment */
rpc_offset_align(&offset, 4); rpc_offset_align(&offset, 4);
tsgCaps = (PTSG_PACKET_CAPABILITIES) malloc(sizeof(TSG_PACKET_CAPABILITIES)); tsgCaps = (PTSG_PACKET_CAPABILITIES) calloc(1, sizeof(TSG_PACKET_CAPABILITIES));
ZeroMemory(tsgCaps, sizeof(TSG_PACKET_CAPABILITIES)); if (!tsgCaps)
return FALSE;
versionCaps->tsgCaps = tsgCaps; versionCaps->tsgCaps = tsgCaps;
offset += 4; /* MaxCount (4 bytes) */ offset += 4; /* MaxCount (4 bytes) */
@ -406,8 +412,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
} }
else if ((packet->packetId == TSG_PACKET_TYPE_QUARENC_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_QUARENC_RESPONSE)) else if ((packet->packetId == TSG_PACKET_TYPE_QUARENC_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_QUARENC_RESPONSE))
{ {
packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) malloc(sizeof(TSG_PACKET_QUARENC_RESPONSE)); packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) calloc(1, sizeof(TSG_PACKET_QUARENC_RESPONSE));
ZeroMemory(packetQuarEncResponse, sizeof(TSG_PACKET_QUARENC_RESPONSE)); if (!packetQuarEncResponse) // TODO: handle cleanup
return FALSE;
packet->tsgPacket.packetQuarEncResponse = packetQuarEncResponse; packet->tsgPacket.packetQuarEncResponse = packetQuarEncResponse;
/* PacketQuarResponsePtr (4 bytes) */ /* PacketQuarResponsePtr (4 bytes) */
@ -443,8 +450,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
offset += 4; offset += 4;
} }
versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS)); versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS));
ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS)); if (!versionCaps) // TODO: handle cleanup
return FALSE;
packetQuarEncResponse->versionCaps = versionCaps; packetQuarEncResponse->versionCaps = versionCaps;
versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */ versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */
@ -779,8 +787,9 @@ BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) if (!(pdu->Flags & RPC_PDU_FLAG_STUB))
buffer = &buffer[24]; buffer = &buffer[24];
packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET)); packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET));
ZeroMemory(packet, sizeof(TSG_PACKET)); if (!packet)
return FALSE;
offset = 4; offset = 4;
packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */ packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */
@ -923,6 +932,8 @@ BOOL TsProxyCreateChannelWriteRequest(rdpTsg* tsg, PTUNNEL_CONTEXT_HANDLE_NOSERI
length = 60 + (count * 2); length = 60 + (count * 2);
buffer = (BYTE*) malloc(length); buffer = (BYTE*) malloc(length);
if (!buffer)
return FALSE;
/* TunnelContext */ /* TunnelContext */
handle = (CONTEXT_HANDLE*) tunnelContext; handle = (CONTEXT_HANDLE*) tunnelContext;
@ -1526,48 +1537,53 @@ int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length)
return CopyLength; return CopyLength;
} }
else
tsg->pdu = rpc_recv_peek_pdu(rpc);
if (!tsg->pdu)
{ {
tsg->pdu = rpc_recv_peek_pdu(rpc); if (!tsg->rpc->client->SynchronousReceive)
return 0;
if (!tsg->pdu) // weird !!!!
{ return tsg_read(tsg, data, length);
if (tsg->rpc->client->SynchronousReceive)
return tsg_read(tsg, data, length);
else
return 0;
}
tsg->PendingPdu = TRUE;
tsg->BytesAvailable = Stream_Length(tsg->pdu->s);
tsg->BytesRead = 0;
CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable;
CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength);
tsg->BytesAvailable -= CopyLength;
tsg->BytesRead += CopyLength;
if (tsg->BytesAvailable < 1)
{
tsg->PendingPdu = FALSE;
rpc_recv_dequeue_pdu(rpc);
rpc_client_receive_pool_return(rpc, tsg->pdu);
}
return CopyLength;
} }
tsg->PendingPdu = TRUE;
tsg->BytesAvailable = Stream_Length(tsg->pdu->s);
tsg->BytesRead = 0;
CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable;
CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength);
tsg->BytesAvailable -= CopyLength;
tsg->BytesRead += CopyLength;
if (tsg->BytesAvailable < 1)
{
tsg->PendingPdu = FALSE;
rpc_recv_dequeue_pdu(rpc);
rpc_client_receive_pool_return(rpc, tsg->pdu);
}
return CopyLength;
} }
int tsg_write(rdpTsg* tsg, BYTE* data, UINT32 length) int tsg_write(rdpTsg* tsg, BYTE* data, UINT32 length)
{ {
int status;
if (tsg->rpc->transport->layer == TRANSPORT_LAYER_CLOSED) if (tsg->rpc->transport->layer == TRANSPORT_LAYER_CLOSED)
{ {
fprintf(stderr, "tsg_write error: connection lost\n"); fprintf(stderr, "%s: error, connection lost\n", __FUNCTION__);
return -1; return -1;
} }
return TsProxySendToServer((handle_t) tsg, data, 1, &length); status = TsProxySendToServer((handle_t) tsg, data, 1, &length);
if (status < 0)
return -1;
return length;
} }
BOOL tsg_set_blocking_mode(rdpTsg* tsg, BOOL blocking) BOOL tsg_set_blocking_mode(rdpTsg* tsg, BOOL blocking)
@ -1584,18 +1600,21 @@ rdpTsg* tsg_new(rdpTransport* transport)
{ {
rdpTsg* tsg; rdpTsg* tsg;
tsg = (rdpTsg*) malloc(sizeof(rdpTsg)); tsg = (rdpTsg*) calloc(1, sizeof(rdpTsg));
ZeroMemory(tsg, sizeof(rdpTsg)); if (!tsg)
return NULL;
if (tsg != NULL)
{
tsg->transport = transport;
tsg->settings = transport->settings;
tsg->rpc = rpc_new(tsg->transport);
tsg->PendingPdu = FALSE;
}
tsg->transport = transport;
tsg->settings = transport->settings;
tsg->rpc = rpc_new(tsg->transport);
if (!tsg->rpc)
goto out_free;
tsg->PendingPdu = FALSE;
return tsg; return tsg;
out_free:
free(tsg);
return NULL;
} }
void tsg_free(rdpTsg* tsg) void tsg_free(rdpTsg* tsg)

View File

@ -241,7 +241,7 @@ int license_recv(rdpLicense* license, wStream* s)
if (!rdp_read_header(license->rdp, s, &length, &channelId)) if (!rdp_read_header(license->rdp, s, &length, &channelId))
{ {
fprintf(stderr, "Incorrect RDP header.\n"); fprintf(stderr, "%s: Incorrect RDP header.\n", __FUNCTION__);
return -1; return -1;
} }
@ -252,7 +252,7 @@ int license_recv(rdpLicense* license, wStream* s)
{ {
if (!rdp_decrypt(license->rdp, s, length - 4, securityFlags)) if (!rdp_decrypt(license->rdp, s, length - 4, securityFlags))
{ {
fprintf(stderr, "rdp_decrypt failed\n"); fprintf(stderr, "%s: rdp_decrypt failed\n", __FUNCTION__);
return -1; return -1;
} }
} }
@ -268,7 +268,7 @@ int license_recv(rdpLicense* license, wStream* s)
if (status < 0) if (status < 0)
{ {
fprintf(stderr, "Unexpected license packet.\n"); fprintf(stderr, "%s: unexpected license packet.\n", __FUNCTION__);
return status; return status;
} }
@ -308,7 +308,7 @@ int license_recv(rdpLicense* license, wStream* s)
break; break;
default: default:
fprintf(stderr, "invalid bMsgType:%d\n", bMsgType); fprintf(stderr, "%s: invalid bMsgType:%d\n", __FUNCTION__, bMsgType);
return FALSE; return FALSE;
} }

View File

@ -1056,26 +1056,29 @@ rdpMcs* mcs_new(rdpTransport* transport)
{ {
rdpMcs* mcs; rdpMcs* mcs;
mcs = (rdpMcs*) malloc(sizeof(rdpMcs)); mcs = (rdpMcs *)calloc(1, sizeof(rdpMcs));
if (!mcs)
return NULL;
if (mcs) mcs->transport = transport;
{ mcs->settings = transport->settings;
ZeroMemory(mcs, sizeof(rdpMcs));
mcs->transport = transport; mcs_init_domain_parameters(&mcs->targetParameters, 34, 2, 0, 0xFFFF);
mcs->settings = transport->settings; mcs_init_domain_parameters(&mcs->minimumParameters, 1, 1, 1, 0x420);
mcs_init_domain_parameters(&mcs->maximumParameters, 0xFFFF, 0xFC17, 0xFFFF, 0xFFFF);
mcs_init_domain_parameters(&mcs->domainParameters, 0, 0, 0, 0xFFFF);
mcs_init_domain_parameters(&mcs->targetParameters, 34, 2, 0, 0xFFFF); mcs->channelCount = 0;
mcs_init_domain_parameters(&mcs->minimumParameters, 1, 1, 1, 0x420); mcs->channelMaxCount = CHANNEL_MAX_COUNT;
mcs_init_domain_parameters(&mcs->maximumParameters, 0xFFFF, 0xFC17, 0xFFFF, 0xFFFF); mcs->channels = (rdpMcsChannel *)calloc(mcs->channelMaxCount, sizeof(rdpMcsChannel));
mcs_init_domain_parameters(&mcs->domainParameters, 0, 0, 0, 0xFFFF); if (!mcs->channels)
goto out_free;
mcs->channelCount = 0;
mcs->channelMaxCount = CHANNEL_MAX_COUNT;
mcs->channels = (rdpMcsChannel*) calloc(mcs->channelMaxCount, sizeof(rdpMcsChannel));
}
return mcs; return mcs;
out_free:
free(mcs);
return NULL;
} }
/** /**

View File

@ -52,13 +52,13 @@ static BOOL freerdp_peer_initialize(freerdp_peer* client)
fprintf(stderr, "%s: inavlid RDP key file %s\n", __FUNCTION__, settings->RdpKeyFile); fprintf(stderr, "%s: inavlid RDP key file %s\n", __FUNCTION__, settings->RdpKeyFile);
return FALSE; return FALSE;
} }
if (settings->RdpServerRsaKey->ModulusLength > 256) if (settings->RdpServerRsaKey->ModulusLength > 256)
{ {
fprintf(stderr, "%s: Key sizes > 2048 are currently not supported for RDP security.\n", __FUNCTION__); fprintf(stderr, "%s: Key sizes > 2048 are currently not supported for RDP security.\n", __FUNCTION__);
fprintf(stderr, "%s: Set a different key file than %s\n", __FUNCTION__, settings->RdpKeyFile); fprintf(stderr, "%s: Set a different key file than %s\n", __FUNCTION__, settings->RdpKeyFile);
exit(1); exit(1);
} }
} }
return TRUE; return TRUE;
@ -77,12 +77,13 @@ static HANDLE freerdp_peer_get_event_handle(freerdp_peer* client)
return client->context->rdp->transport->TcpIn->event; return client->context->rdp->transport->TcpIn->event;
} }
static BOOL freerdp_peer_check_fds(freerdp_peer* client)
static BOOL freerdp_peer_check_fds(freerdp_peer* peer)
{ {
int status; int status;
rdpRdp* rdp; rdpRdp* rdp;
rdp = client->context->rdp; rdp = peer->context->rdp;
status = rdp_check_fds(rdp); status = rdp_check_fds(rdp);
@ -413,6 +414,19 @@ static int freerdp_peer_send_channel_data(freerdp_peer* client, UINT16 channelId
return rdp_send_channel_data(client->context->rdp, channelId, data, size); return rdp_send_channel_data(client->context->rdp, channelId, data, size);
} }
static BOOL freerdp_peer_is_write_blocked(freerdp_peer* peer)
{
return tranport_is_write_blocked(peer->context->rdp->transport);
}
static int freerdp_peer_drain_output_buffer(freerdp_peer* peer)
{
rdpTransport *transport = peer->context->rdp->transport;
return tranport_drain_output_buffer(transport);
}
void freerdp_peer_context_new(freerdp_peer* client) void freerdp_peer_context_new(freerdp_peer* client)
{ {
rdpRdp* rdp; rdpRdp* rdp;
@ -445,6 +459,9 @@ void freerdp_peer_context_new(freerdp_peer* client)
rdp->transport->ReceiveExtra = client; rdp->transport->ReceiveExtra = client;
transport_set_blocking_mode(rdp->transport, FALSE); transport_set_blocking_mode(rdp->transport, FALSE);
client->IsWriteBlocked = freerdp_peer_is_write_blocked;
client->DrainOutputBuffer = freerdp_peer_drain_output_buffer;
IFCALL(client->ContextNew, client, client->context); IFCALL(client->ContextNew, client, client->context);
} }
@ -473,6 +490,8 @@ freerdp_peer* freerdp_peer_new(int sockfd)
client->Close = freerdp_peer_close; client->Close = freerdp_peer_close;
client->Disconnect = freerdp_peer_disconnect; client->Disconnect = freerdp_peer_disconnect;
client->SendChannelData = freerdp_peer_send_channel_data; client->SendChannelData = freerdp_peer_send_channel_data;
client->IsWriteBlocked = freerdp_peer_is_write_blocked;
client->DrainOutputBuffer = freerdp_peer_drain_output_buffer;
} }
return client; return client;
@ -480,10 +499,10 @@ freerdp_peer* freerdp_peer_new(int sockfd)
void freerdp_peer_free(freerdp_peer* client) void freerdp_peer_free(freerdp_peer* client)
{ {
if (client) if (!client)
{ return;
rdp_free(client->context->rdp);
free(client->context); rdp_free(client->context->rdp);
free(client); free(client->context);
} free(client);
} }

View File

@ -209,6 +209,7 @@ rdpSettings* freerdp_settings_new(DWORD flags)
ZeroMemory(settings, sizeof(rdpSettings)); ZeroMemory(settings, sizeof(rdpSettings));
settings->ServerMode = (flags & FREERDP_SETTINGS_SERVER_MODE) ? TRUE : FALSE; settings->ServerMode = (flags & FREERDP_SETTINGS_SERVER_MODE) ? TRUE : FALSE;
settings->WaitForOutputBufferFlush = TRUE;
settings->DesktopWidth = 1024; settings->DesktopWidth = 1024;
settings->DesktopHeight = 768; settings->DesktopHeight = 768;
@ -582,6 +583,7 @@ rdpSettings* freerdp_settings_clone(rdpSettings* settings)
/* BOOL values */ /* BOOL values */
_settings->ServerMode = settings->ServerMode; /* 16 */ _settings->ServerMode = settings->ServerMode; /* 16 */
_settings->WaitForOutputBufferFlush = settings->WaitForOutputBufferFlush; /* 25 */
_settings->NetworkAutoDetect = settings->NetworkAutoDetect; /* 137 */ _settings->NetworkAutoDetect = settings->NetworkAutoDetect; /* 137 */
_settings->SupportAsymetricKeys = settings->SupportAsymetricKeys; /* 138 */ _settings->SupportAsymetricKeys = settings->SupportAsymetricKeys; /* 138 */
_settings->SupportErrorInfoPdu = settings->SupportErrorInfoPdu; /* 139 */ _settings->SupportErrorInfoPdu = settings->SupportErrorInfoPdu; /* 139 */

View File

@ -66,6 +66,165 @@
#include "tcp.h" #include "tcp.h"
long transport_bio_buffered_callback(BIO* bio, int mode, const char* argp, int argi, long argl, long ret)
{
return 1;
}
static int transport_bio_buffered_write(BIO* bio, const char* buf, int num)
{
int status, ret;
rdpTcp *tcp = (rdpTcp *)bio->ptr;
int nchunks, committedBytes, i;
DataChunk chunks[2];
ret = num;
BIO_clear_retry_flags(bio);
tcp->writeBlocked = FALSE;
/* we directly append extra bytes in the xmit buffer, this could be prevented
* but for now it makes the code more simple.
*/
if (buf && num && !ringbuffer_write(&tcp->xmitBuffer, (const BYTE *)buf, num))
{
fprintf(stderr, "%s: an error occured when writing(toWrite=%d)\n", __FUNCTION__, num);
return -1;
}
committedBytes = 0;
nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer));
for (i = 0; i < nchunks; i++)
{
while (chunks[i].size)
{
status = BIO_write(bio->next_bio, chunks[i].data, chunks[i].size);
/*fprintf(stderr, "%s: i=%d/%d size=%d/%d status=%d retry=%d\n", __FUNCTION__, i, nchunks,
chunks[i].size, ringbuffer_used(&tcp->xmitBuffer), status,
BIO_should_retry(bio->next_bio)
);*/
if (status <= 0)
{
if (BIO_should_retry(bio->next_bio))
{
tcp->writeBlocked = TRUE;
goto out; /* EWOULDBLOCK */
}
/* any other is an error, but we still have to commit written bytes */
ret = -1;
goto out;
}
committedBytes += status;
chunks[i].size -= status;
chunks[i].data += status;
}
}
out:
ringbuffer_commit_read_bytes(&tcp->xmitBuffer, committedBytes);
return ret;
}
static int transport_bio_buffered_read(BIO* bio, char* buf, int size)
{
int status;
rdpTcp *tcp = (rdpTcp *)bio->ptr;
tcp->readBlocked = FALSE;
BIO_clear_retry_flags(bio);
status = BIO_read(bio->next_bio, buf, size);
/*fprintf(stderr, "%s: size=%d status=%d shouldRetry=%d\n", __FUNCTION__, size, status, BIO_should_retry(bio->next_bio)); */
if (status <= 0 && BIO_should_retry(bio->next_bio))
{
BIO_set_retry_read(bio);
tcp->readBlocked = TRUE;
}
return status;
}
static int transport_bio_buffered_puts(BIO* bio, const char* str)
{
return 1;
}
static int transport_bio_buffered_gets(BIO* bio, char* str, int size)
{
return 1;
}
static long transport_bio_buffered_ctrl(BIO* bio, int cmd, long arg1, void* arg2)
{
rdpTcp *tcp = (rdpTcp *)bio->ptr;
switch (cmd)
{
case BIO_CTRL_FLUSH:
return 1;
case BIO_CTRL_WPENDING:
return ringbuffer_used(&tcp->xmitBuffer);
case BIO_CTRL_PENDING:
return 0;
default:
/*fprintf(stderr, "%s: passing to next BIO, bio=%p cmd=%d arg1=%d arg2=%p\n", __FUNCTION__, bio, cmd, arg1, arg2); */
return BIO_ctrl(bio->next_bio, cmd, arg1, arg2);
}
return 0;
}
static int transport_bio_buffered_new(BIO* bio)
{
bio->init = 1;
bio->num = 0;
bio->ptr = NULL;
bio->flags = 0;
return 1;
}
static int transport_bio_buffered_free(BIO* bio)
{
return 1;
}
static BIO_METHOD transport_bio_buffered_socket_methods =
{
BIO_TYPE_BUFFERED,
"BufferedSocket",
transport_bio_buffered_write,
transport_bio_buffered_read,
transport_bio_buffered_puts,
transport_bio_buffered_gets,
transport_bio_buffered_ctrl,
transport_bio_buffered_new,
transport_bio_buffered_free,
NULL,
};
BIO_METHOD* BIO_s_buffered_socket(void)
{
return &transport_bio_buffered_socket_methods;
}
BOOL transport_bio_buffered_drain(BIO *bio)
{
rdpTcp *tcp = (rdpTcp *)bio->ptr;
int status;
if (!ringbuffer_used(&tcp->xmitBuffer))
return 1;
status = transport_bio_buffered_write(bio, NULL, 0);
return status >= 0;
}
void tcp_get_ip_address(rdpTcp* tcp) void tcp_get_ip_address(rdpTcp* tcp)
{ {
BYTE* ip; BYTE* ip;
@ -136,62 +295,65 @@ BOOL tcp_connect(rdpTcp* tcp, const char* hostname, int port)
if (hostname[0] == '/') if (hostname[0] == '/')
{ {
tcp->sockfd = freerdp_uds_connect(hostname); tcp->sockfd = freerdp_uds_connect(hostname);
if (tcp->sockfd < 0) if (tcp->sockfd < 0)
return FALSE; return FALSE;
tcp->socketBio = BIO_new_fd(tcp->sockfd, 1);
if (!tcp->socketBio)
return FALSE;
} }
else else
{ {
tcp->sockfd = freerdp_tcp_connect(hostname, port); tcp->socketBio = BIO_new(BIO_s_connect());
if (!tcp->socketBio)
if (tcp->sockfd < 0)
return FALSE; return FALSE;
SetEventFileDescriptor(tcp->event, tcp->sockfd); if (BIO_set_conn_hostname(tcp->socketBio, hostname) < 0 || BIO_set_conn_int_port(tcp->socketBio, &port) < 0)
return FALSE;
tcp_get_ip_address(tcp); if (BIO_do_connect(tcp->socketBio) <= 0)
tcp_get_mac_address(tcp); return FALSE;
option_value = 1; tcp->sockfd = BIO_get_fd(tcp->socketBio, NULL);
option_len = sizeof(option_value);
setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len);
/* receive buffer must be a least 32 K */
if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0)
{
if (option_value < (1024 * 32))
{
option_value = 1024 * 32;
option_len = sizeof(option_value);
setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len);
}
}
tcp_set_keep_alive_mode(tcp);
} }
SetEventFileDescriptor(tcp->event, tcp->sockfd);
tcp_get_ip_address(tcp);
tcp_get_mac_address(tcp);
option_value = 1;
option_len = sizeof(option_value);
if (setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len) < 0)
fprintf(stderr, "%s: unable to set TCP_NODELAY\n", __FUNCTION__);
/* receive buffer must be a least 32 K */
if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0)
{
if (option_value < (1024 * 32))
{
option_value = 1024 * 32;
option_len = sizeof(option_value);
if (setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len) < 0)
{
fprintf(stderr, "%s: unable to set receive buffer len\n", __FUNCTION__);
return FALSE;
}
}
}
if (!tcp_set_keep_alive_mode(tcp))
return FALSE;
tcp->bufferedBio = BIO_new(BIO_s_buffered_socket());
if (!tcp->bufferedBio)
return FALSE;
tcp->bufferedBio->ptr = tcp;
tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio);
return TRUE; return TRUE;
} }
int tcp_read(rdpTcp* tcp, BYTE* data, int length)
{
return freerdp_tcp_read(tcp->sockfd, data, length);
}
int tcp_write(rdpTcp* tcp, BYTE* data, int length)
{
return freerdp_tcp_write(tcp->sockfd, data, length);
}
int tcp_wait_read(rdpTcp* tcp)
{
return freerdp_tcp_wait_read(tcp->sockfd);
}
int tcp_wait_write(rdpTcp* tcp)
{
return freerdp_tcp_wait_write(tcp->sockfd);
}
BOOL tcp_disconnect(rdpTcp* tcp) BOOL tcp_disconnect(rdpTcp* tcp)
{ {
@ -209,7 +371,7 @@ BOOL tcp_set_blocking_mode(rdpTcp* tcp, BOOL blocking)
if (flags == -1) if (flags == -1)
{ {
fprintf(stderr, "tcp_set_blocking_mode: fcntl failed.\n"); fprintf(stderr, "%s: fcntl failed, %s.\n", __FUNCTION__, strerror(errno));
return FALSE; return FALSE;
} }
@ -297,6 +459,31 @@ int tcp_attach(rdpTcp* tcp, int sockfd)
{ {
tcp->sockfd = sockfd; tcp->sockfd = sockfd;
SetEventFileDescriptor(tcp->event, tcp->sockfd); SetEventFileDescriptor(tcp->event, tcp->sockfd);
ringbuffer_commit_read_bytes(&tcp->xmitBuffer, ringbuffer_used(&tcp->xmitBuffer));
if (tcp->socketBio)
{
if (BIO_set_fd(tcp->socketBio, sockfd, 1) < 0)
return -1;
}
else
{
tcp->socketBio = BIO_new_socket(sockfd, 1);
if (!tcp->socketBio)
return -1;
}
if (!tcp->bufferedBio)
{
tcp->bufferedBio = BIO_new(BIO_s_buffered_socket());
if (!tcp->bufferedBio)
return FALSE;
tcp->bufferedBio->ptr = tcp;
tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio);
}
return 0; return 0;
} }
@ -316,25 +503,36 @@ rdpTcp* tcp_new(rdpSettings* settings)
{ {
rdpTcp* tcp; rdpTcp* tcp;
tcp = (rdpTcp*) malloc(sizeof(rdpTcp)); tcp = (rdpTcp *)calloc(1, sizeof(rdpTcp));
if (!tcp)
return NULL;
if (tcp) if (!ringbuffer_init(&tcp->xmitBuffer, 0x10000))
{ goto out_free;
ZeroMemory(tcp, sizeof(rdpTcp));
tcp->sockfd = -1; tcp->sockfd = -1;
tcp->settings = settings; tcp->settings = settings;
tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd);
} #ifndef _WIN32
tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd);
if (!tcp->event || tcp->event == INVALID_HANDLE_VALUE)
goto out_ringbuffer;
#endif
return tcp; return tcp;
out_ringbuffer:
ringbuffer_destroy(&tcp->xmitBuffer);
out_free:
free(tcp);
return NULL;
} }
void tcp_free(rdpTcp* tcp) void tcp_free(rdpTcp* tcp)
{ {
if (tcp) if (!tcp)
{ return;
CloseHandle(tcp->event);
free(tcp); ringbuffer_destroy(&tcp->xmitBuffer);
} CloseHandle(tcp->event);
free(tcp);
} }

View File

@ -31,10 +31,15 @@
#include <winpr/stream.h> #include <winpr/stream.h>
#include <winpr/winsock.h> #include <winpr/winsock.h>
#include <freerdp/utils/ringbuffer.h>
#include <openssl/bio.h>
#ifndef MSG_NOSIGNAL #ifndef MSG_NOSIGNAL
#define MSG_NOSIGNAL 0 #define MSG_NOSIGNAL 0
#endif #endif
#define BIO_TYPE_BUFFERED 66
typedef struct rdp_tcp rdpTcp; typedef struct rdp_tcp rdpTcp;
struct rdp_tcp struct rdp_tcp
@ -46,6 +51,12 @@ struct rdp_tcp
#ifdef _WIN32 #ifdef _WIN32
WSAEVENT wsa_event; WSAEVENT wsa_event;
#endif #endif
BIO *socketBio;
BIO *bufferedBio;
RingBuffer xmitBuffer;
BOOL writeBlocked;
BOOL readBlocked;
HANDLE event; HANDLE event;
}; };

View File

@ -33,7 +33,9 @@
#include <freerdp/error.h> #include <freerdp/error.h>
#include <freerdp/utils/tcp.h> #include <freerdp/utils/tcp.h>
#include <freerdp/utils/ringbuffer.h>
#include <openssl/bio.h>
#include <time.h> #include <time.h>
#include <errno.h> #include <errno.h>
#include <fcntl.h> #include <fcntl.h>
@ -41,6 +43,12 @@
#ifndef _WIN32 #ifndef _WIN32
#include <netdb.h> #include <netdb.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/select.h>
#include <sys/time.h>
#endif
#ifdef HAVE_VALGRIND_MEMCHECK_H
#include <valgrind/memcheck.h>
#endif #endif
#include "tpkt.h" #include "tpkt.h"
@ -48,6 +56,7 @@
#include "transport.h" #include "transport.h"
#include "rdp.h" #include "rdp.h"
#define BUFFER_SIZE 16384 #define BUFFER_SIZE 16384
static void* transport_client_thread(void* arg); static void* transport_client_thread(void* arg);
@ -69,6 +78,7 @@ void transport_attach(rdpTransport* transport, int sockfd)
tcp_attach(transport->TcpIn, sockfd); tcp_attach(transport->TcpIn, sockfd);
transport->SplitInputOutput = FALSE; transport->SplitInputOutput = FALSE;
transport->TcpOut = transport->TcpIn; transport->TcpOut = transport->TcpIn;
transport->frontBio = transport->TcpIn->bufferedBio;
} }
void transport_stop(rdpTransport* transport) void transport_stop(rdpTransport* transport)
@ -98,18 +108,9 @@ BOOL transport_disconnect(rdpTransport* transport)
transport_stop(transport); transport_stop(transport);
if (transport->layer == TRANSPORT_LAYER_TLS) BIO_free_all(transport->frontBio);
status &= tls_disconnect(transport->TlsIn);
if ((transport->layer == TRANSPORT_LAYER_TSG) || (transport->layer == TRANSPORT_LAYER_TSG_TLS))
{
status &= tsg_disconnect(transport->tsg);
}
else
{
status &= tcp_disconnect(transport->TcpIn);
}
transport->frontBio = 0;
return status; return status;
} }
@ -131,16 +132,16 @@ static int transport_bio_tsg_write(BIO* bio, const char* buf, int num)
rdpTsg* tsg; rdpTsg* tsg;
tsg = (rdpTsg*) bio->ptr; tsg = (rdpTsg*) bio->ptr;
status = tsg_write(tsg, (BYTE*) buf, num);
BIO_clear_retry_flags(bio); BIO_clear_retry_flags(bio);
status = tsg_write(tsg, (BYTE*) buf, num);
if (status > 0)
return status;
if (status == 0) if (status == 0)
{
BIO_set_retry_write(bio); BIO_set_retry_write(bio);
}
return status < 0 ? 0 : num; return -1;
} }
static int transport_bio_tsg_read(BIO* bio, char* buf, int size) static int transport_bio_tsg_read(BIO* bio, char* buf, int size)
@ -222,8 +223,13 @@ BIO_METHOD* BIO_s_tsg(void)
return &transport_bio_tsg_methods; return &transport_bio_tsg_methods;
} }
BOOL transport_connect_tls(rdpTransport* transport) BOOL transport_connect_tls(rdpTransport* transport)
{ {
rdpSettings *settings = transport->settings;
rdpTls *targetTls;
BIO *targetBio;
int tls_status; int tls_status;
freerdp* instance; freerdp* instance;
rdpContext* context; rdpContext* context;
@ -234,61 +240,33 @@ BOOL transport_connect_tls(rdpTransport* transport)
if (transport->layer == TRANSPORT_LAYER_TSG) if (transport->layer == TRANSPORT_LAYER_TSG)
{ {
transport->TsgTls = tls_new(transport->settings); transport->TsgTls = tls_new(transport->settings);
transport->TsgTls->methods = BIO_s_tsg();
transport->TsgTls->tsg = (void*) transport->tsg;
transport->layer = TRANSPORT_LAYER_TSG_TLS; transport->layer = TRANSPORT_LAYER_TSG_TLS;
transport->TsgTls->hostname = transport->settings->ServerHostname; targetTls = transport->TsgTls;
transport->TsgTls->port = transport->settings->ServerPort; targetBio = transport->frontBio;
}
else
{
if (!transport->TlsIn)
transport->TlsIn = tls_new(settings);
if (transport->TsgTls->port == 0) if (!transport->TlsOut)
transport->TsgTls->port = 3389; transport->TlsOut = transport->TlsIn;
tls_status = tls_connect(transport->TsgTls); targetTls = transport->TlsIn;
targetBio = transport->TcpIn->bufferedBio;
if (tls_status < 1) transport->layer = TRANSPORT_LAYER_TLS;
{
if (tls_status < 0)
{
if (!connectErrorCode)
connectErrorCode = TLSCONNECTERROR;
if (!freerdp_get_last_error(context))
freerdp_set_last_error(context, FREERDP_ERROR_TLS_CONNECT_FAILED);
}
else
{
if (!freerdp_get_last_error(context))
freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED);
}
tls_free(transport->TsgTls);
transport->TsgTls = NULL;
return FALSE;
}
return TRUE;
} }
if (!transport->TlsIn)
transport->TlsIn = tls_new(transport->settings);
if (!transport->TlsOut) targetTls->hostname = settings->ServerHostname;
transport->TlsOut = transport->TlsIn; targetTls->port = settings->ServerPort;
transport->layer = TRANSPORT_LAYER_TLS; if (targetTls->port == 0)
transport->TlsIn->sockfd = transport->TcpIn->sockfd; targetTls->port = 3389;
transport->TlsIn->hostname = transport->settings->ServerHostname; tls_status = tls_connect(targetTls, targetBio);
transport->TlsIn->port = transport->settings->ServerPort;
if (transport->TlsIn->port == 0)
transport->TlsIn->port = 3389;
tls_status = tls_connect(transport->TlsIn);
if (tls_status < 1) if (tls_status < 1)
{ {
@ -306,13 +284,13 @@ BOOL transport_connect_tls(rdpTransport* transport)
freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED); freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED);
} }
tls_free(transport->TlsIn); return FALSE;
}
if (transport->TlsIn == transport->TlsOut)
transport->TlsIn = transport->TlsOut = NULL;
else
transport->TlsIn = NULL;
transport->frontBio = targetTls->bio;
if (!transport->frontBio)
{
fprintf(stderr, "%s: unable to prepend a filtering TLS bio", __FUNCTION__);
return FALSE; return FALSE;
} }
@ -323,6 +301,7 @@ BOOL transport_connect_nla(rdpTransport* transport)
{ {
freerdp* instance; freerdp* instance;
rdpSettings* settings; rdpSettings* settings;
rdpCredssp *credSsp;
settings = transport->settings; settings = transport->settings;
instance = (freerdp*) settings->instance; instance = (freerdp*) settings->instance;
@ -338,16 +317,22 @@ BOOL transport_connect_nla(rdpTransport* transport)
if (!transport->credssp) if (!transport->credssp)
{ {
transport->credssp = credssp_new(instance, transport, settings); transport->credssp = credssp_new(instance, transport, settings);
if (!transport->credssp)
return FALSE;
transport_set_nla_mode(transport, TRUE); transport_set_nla_mode(transport, TRUE);
if (settings->AuthenticationServiceClass) if (settings->AuthenticationServiceClass)
{ {
transport->credssp->ServicePrincipalName = transport->credssp->ServicePrincipalName =
credssp_make_spn(settings->AuthenticationServiceClass, settings->ServerHostname); credssp_make_spn(settings->AuthenticationServiceClass, settings->ServerHostname);
if (!transport->credssp->ServicePrincipalName)
return FALSE;
} }
} }
if (credssp_authenticate(transport->credssp) < 0) credSsp = transport->credssp;
if (credssp_authenticate(credSsp) < 0)
{ {
if (!connectErrorCode) if (!connectErrorCode)
connectErrorCode = AUTHENTICATIONERROR; connectErrorCode = AUTHENTICATIONERROR;
@ -361,14 +346,14 @@ BOOL transport_connect_nla(rdpTransport* transport)
"If credentials are valid, the NTLMSSP implementation may be to blame.\n"); "If credentials are valid, the NTLMSSP implementation may be to blame.\n");
transport_set_nla_mode(transport, FALSE); transport_set_nla_mode(transport, FALSE);
credssp_free(transport->credssp); credssp_free(credSsp);
transport->credssp = NULL; transport->credssp = NULL;
return FALSE; return FALSE;
} }
transport_set_nla_mode(transport, FALSE); transport_set_nla_mode(transport, FALSE);
credssp_free(transport->credssp); credssp_free(credSsp);
transport->credssp = NULL; transport->credssp = NULL;
return TRUE; return TRUE;
@ -380,38 +365,41 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
int tls_status; int tls_status;
freerdp* instance; freerdp* instance;
rdpContext* context; rdpContext* context;
rdpSettings *settings = transport->settings;
instance = (freerdp*) transport->settings->instance; instance = (freerdp*) transport->settings->instance;
context = instance->context; context = instance->context;
tsg = tsg_new(transport); tsg = tsg_new(transport);
if (!tsg)
return FALSE;
tsg->transport = transport; tsg->transport = transport;
transport->tsg = tsg; transport->tsg = tsg;
transport->SplitInputOutput = TRUE; transport->SplitInputOutput = TRUE;
if (!transport->TlsIn) if (!transport->TlsIn)
transport->TlsIn = tls_new(transport->settings); {
transport->TlsIn = tls_new(settings);
transport->TlsIn->sockfd = transport->TcpIn->sockfd; if (!transport->TlsIn)
transport->TlsIn->hostname = transport->settings->GatewayHostname; return FALSE;
transport->TlsIn->port = transport->settings->GatewayPort; }
if (transport->TlsIn->port == 0)
transport->TlsIn->port = 443;
if (!transport->TlsOut) if (!transport->TlsOut)
transport->TlsOut = tls_new(transport->settings); {
transport->TlsOut = tls_new(settings);
if (!transport->TlsOut)
return FALSE;
}
transport->TlsOut->sockfd = transport->TcpOut->sockfd; /* put a decent default value for gateway port */
transport->TlsOut->hostname = transport->settings->GatewayHostname; if (!settings->GatewayPort)
transport->TlsOut->port = transport->settings->GatewayPort; settings->GatewayPort = 443;
if (transport->TlsOut->port == 0) transport->TlsIn->hostname = transport->TlsOut->hostname = settings->GatewayHostname;
transport->TlsOut->port = 443; transport->TlsIn->port = transport->TlsOut->port = settings->GatewayPort;
tls_status = tls_connect(transport->TlsIn);
tls_status = tls_connect(transport->TlsIn, transport->TcpIn->bufferedBio);
if (tls_status < 1) if (tls_status < 1)
{ {
if (tls_status < 0) if (tls_status < 0)
@ -428,8 +416,7 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
return FALSE; return FALSE;
} }
tls_status = tls_connect(transport->TlsOut); tls_status = tls_connect(transport->TlsOut, transport->TcpOut->bufferedBio);
if (tls_status < 1) if (tls_status < 1)
{ {
if (tls_status < 0) if (tls_status < 0)
@ -449,6 +436,8 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
if (!tsg_connect(tsg, hostname, port)) if (!tsg_connect(tsg, hostname, port))
return FALSE; return FALSE;
transport->frontBio = BIO_new(BIO_s_tsg());
transport->frontBio->ptr = tsg;
return TRUE; return TRUE;
} }
@ -462,15 +451,20 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por
if (transport->GatewayEnabled) if (transport->GatewayEnabled)
{ {
transport->layer = TRANSPORT_LAYER_TSG; transport->layer = TRANSPORT_LAYER_TSG;
transport->SplitInputOutput = TRUE;
transport->TcpOut = tcp_new(settings); transport->TcpOut = tcp_new(settings);
status = tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort); if (!tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort) ||
!tcp_set_blocking_mode(transport->TcpIn, FALSE))
return FALSE;
if (status) if (!tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort) ||
status = tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort); !tcp_set_blocking_mode(transport->TcpOut, FALSE))
return FALSE;
if (status) if (!transport_tsg_connect(transport, hostname, port))
status = transport_tsg_connect(transport, hostname, port); return FALSE;
status = TRUE;
} }
else else
{ {
@ -478,6 +472,7 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por
transport->SplitInputOutput = FALSE; transport->SplitInputOutput = FALSE;
transport->TcpOut = transport->TcpIn; transport->TcpOut = transport->TcpIn;
transport->frontBio = transport->TcpIn->bufferedBio;
} }
if (status) if (status)
@ -510,11 +505,11 @@ BOOL transport_accept_tls(rdpTransport* transport)
transport->TlsOut = transport->TlsIn; transport->TlsOut = transport->TlsIn;
transport->layer = TRANSPORT_LAYER_TLS; transport->layer = TRANSPORT_LAYER_TLS;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile)) if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, transport->settings->CertificateFile, transport->settings->PrivateKeyFile))
return FALSE; return FALSE;
transport->frontBio = transport->TlsIn->bio;
return TRUE; return TRUE;
} }
@ -533,10 +528,10 @@ BOOL transport_accept_nla(rdpTransport* transport)
transport->TlsOut = transport->TlsIn; transport->TlsOut = transport->TlsIn;
transport->layer = TRANSPORT_LAYER_TLS; transport->layer = TRANSPORT_LAYER_TLS;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile)) if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, settings->CertificateFile, settings->PrivateKeyFile))
return FALSE; return FALSE;
transport->frontBio = transport->TlsIn->bio;
/* Network Level Authentication */ /* Network Level Authentication */
@ -630,56 +625,131 @@ UINT32 nla_header_length(wStream* s)
return length; return length;
} }
static int transport_wait_for_read(rdpTransport* transport)
{
struct timeval tv;
fd_set rset, wset;
fd_set *rsetPtr = NULL, *wsetPtr = NULL;
rdpTcp *tcpIn;
tcpIn = transport->TcpIn;
if (tcpIn->readBlocked)
{
rsetPtr = &rset;
FD_ZERO(rsetPtr);
FD_SET(tcpIn->sockfd, rsetPtr);
}
else if (tcpIn->writeBlocked)
{
wsetPtr = &wset;
FD_ZERO(wsetPtr);
FD_SET(tcpIn->sockfd, wsetPtr);
}
if (!wsetPtr && !rsetPtr)
{
USleep(1000);
return 0;
}
tv.tv_sec = 0;
tv.tv_usec = 1000;
return select(tcpIn->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv);
}
static int transport_wait_for_write(rdpTransport* transport)
{
struct timeval tv;
fd_set rset, wset;
fd_set *rsetPtr = NULL, *wsetPtr = NULL;
rdpTcp *tcpOut;
tcpOut = transport->SplitInputOutput ? transport->TcpOut : transport->TcpIn;
if (tcpOut->writeBlocked)
{
wsetPtr = &wset;
FD_ZERO(wsetPtr);
FD_SET(tcpOut->sockfd, wsetPtr);
}
else if (tcpOut->readBlocked)
{
rsetPtr = &rset;
FD_ZERO(rsetPtr);
FD_SET(tcpOut->sockfd, rsetPtr);
}
if (!wsetPtr && !rsetPtr)
{
USleep(1000);
return 0;
}
tv.tv_sec = 0;
tv.tv_usec = 1000;
return select(tcpOut->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv);
}
int transport_read_layer(rdpTransport* transport, BYTE* data, int bytes) int transport_read_layer(rdpTransport* transport, BYTE* data, int bytes)
{ {
int read = 0; int read = 0;
int status = -1; int status = -1;
while (read < bytes) while (read < bytes)
{ {
if (transport->layer == TRANSPORT_LAYER_TLS) status = BIO_read(transport->frontBio, data + read, bytes - read);
status = tls_read(transport->TlsIn, data + read, bytes - read);
else if (transport->layer == TRANSPORT_LAYER_TCP) if (!status)
status = tcp_read(transport->TcpIn, data + read, bytes - read); {
else if (transport->layer == TRANSPORT_LAYER_TSG) transport->layer = TRANSPORT_LAYER_CLOSED;
status = tsg_read(transport->tsg, data + read, bytes - read); return -1;
else if (transport->layer == TRANSPORT_LAYER_TSG_TLS) {
status = tls_read(transport->TsgTls, data + read, bytes - read);
} }
/* blocking means that we can't continue until this is read */
if (!transport->blocking)
return status;
if (status < 0) if (status < 0)
{ {
/* A read error indicates that the peer has dropped the connection */ if (!BIO_should_retry(transport->frontBio))
transport->layer = TRANSPORT_LAYER_CLOSED; {
return status; /* something unexpected happened, let's close */
transport->layer = TRANSPORT_LAYER_CLOSED;
return -1;
}
/* non blocking will survive a partial read */
if (!transport->blocking)
return read;
/* blocking means that we can't continue until we have read the number of
* requested bytes */
if (transport_wait_for_read(transport) < 0)
{
fprintf(stderr, "%s: error when selecting for read\n", __FUNCTION__);
return -1;
}
continue;
} }
#ifdef HAVE_VALGRIND_MEMCHECK_H
VALGRIND_MAKE_MEM_DEFINED(data + read, bytes - read);
#endif
read += status; read += status;
if (status == 0)
{
/*
* instead of sleeping, we should wait timeout on the
* socket but this only happens on initial connection
*/
USleep(transport->SleepInterval);
}
} }
return read; return read;
} }
int transport_read(rdpTransport* transport, wStream* s) int transport_read(rdpTransport* transport, wStream* s)
{ {
int status; int status;
int position; int position;
int pduLength; int pduLength;
BYTE header[4]; BYTE *header;
int transport_status; int transport_status;
position = 0; position = 0;
@ -710,7 +780,7 @@ int transport_read(rdpTransport* transport, wStream* s)
position += status; position += status;
} }
CopyMemory(header, Stream_Buffer(s), 4); /* peek at first 4 bytes */ header = Stream_Buffer(s);
/* if header is present, read exactly one PDU */ /* if header is present, read exactly one PDU */
@ -802,6 +872,8 @@ static int transport_read_nonblocking(rdpTransport* transport)
return status; return status;
} }
BOOL transport_bio_buffered_drain(BIO *bio);
int transport_write(rdpTransport* transport, wStream* s) int transport_write(rdpTransport* transport, wStream* s)
{ {
int length; int length;
@ -827,36 +899,48 @@ int transport_write(rdpTransport* transport, wStream* s)
while (length > 0) while (length > 0)
{ {
if (transport->layer == TRANSPORT_LAYER_TLS) status = BIO_write(transport->frontBio, Stream_Pointer(s), length);
status = tls_write(transport->TlsOut, Stream_Pointer(s), length);
else if (transport->layer == TRANSPORT_LAYER_TCP)
status = tcp_write(transport->TcpOut, Stream_Pointer(s), length);
else if (transport->layer == TRANSPORT_LAYER_TSG)
status = tsg_write(transport->tsg, Stream_Pointer(s), length);
else if (transport->layer == TRANSPORT_LAYER_TSG_TLS)
status = tls_write(transport->TsgTls, Stream_Pointer(s), length);
if (status < 0) if (status <= 0)
break; /* error occurred */
if (status == 0)
{ {
/* when sending is blocked in nonblocking mode, the receiving buffer should be checked */ /* the buffered BIO that is at the end of the chain always says OK for writing,
if (!transport->blocking) * so a retry means that for any reason we need to read. The most probable
{ * is a SSL or TSG BIO in the chain.
/* and in case we do have buffered some data, we set the event so next loop will get it */ */
if (transport_read_nonblocking(transport) > 0) if (!BIO_should_retry(transport->frontBio))
SetEvent(transport->ReceiveEvent); return status;
}
if (transport->layer == TRANSPORT_LAYER_TLS) /* non-blocking can live with blocked IOs */
tls_wait_write(transport->TlsOut); if (!transport->blocking)
else if (transport->layer == TRANSPORT_LAYER_TCP) return status;
tcp_wait_write(transport->TcpOut);
else if (transport->layer == TRANSPORT_LAYER_TSG_TLS) if (transport_wait_for_write(transport) < 0)
tls_wait_write(transport->TsgTls); {
else fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__);
USleep(transport->SleepInterval); return -1;
}
continue;
}
if (transport->blocking || transport->settings->WaitForOutputBufferFlush)
{
/* blocking transport, we must ensure the write buffer is really empty */
rdpTcp *out = transport->TcpOut;
while (out->writeBlocked)
{
if (transport_wait_for_write(transport) < 0)
{
fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__);
return -1;
}
if (!transport_bio_buffered_drain(out->bufferedBio))
{
fprintf(stderr, "%s: error when draining outputBuffer\n", __FUNCTION__);
return -1;
}
}
} }
length -= status; length -= status;
@ -945,6 +1029,38 @@ void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD*
} }
} }
BOOL tranport_is_write_blocked(rdpTransport* transport)
{
if (transport->TcpIn->writeBlocked)
return TRUE;
return transport->SplitInputOutput &&
transport->TcpOut &&
transport->TcpOut->writeBlocked;
}
int tranport_drain_output_buffer(rdpTransport* transport)
{
BOOL ret = FALSE;
/* First try to send some accumulated bytes in the send buffer */
if (transport->TcpIn->writeBlocked)
{
if (!transport_bio_buffered_drain(transport->TcpIn->bufferedBio))
return -1;
ret |= transport->TcpIn->writeBlocked;
}
if (transport->SplitInputOutput && transport->TcpOut && transport->TcpOut->writeBlocked)
{
if (!transport_bio_buffered_drain(transport->TcpOut->bufferedBio))
return -1;
ret |= transport->TcpOut->writeBlocked;
}
return ret;
}
int transport_check_fds(rdpTransport* transport) int transport_check_fds(rdpTransport* transport)
{ {
int pos; int pos;
@ -1079,15 +1195,14 @@ int transport_check_fds(rdpTransport* transport)
recv_status = transport->ReceiveCallback(transport, received, transport->ReceiveExtra); recv_status = transport->ReceiveCallback(transport, received, transport->ReceiveExtra);
Stream_Release(received);
if (recv_status < 0)
return -1;
if (recv_status == 1) if (recv_status == 1)
{ {
return 1; /* session redirection */ return 1; /* session redirection */
} }
Stream_Release(received);
if (recv_status < 0)
return -1;
} }
return 0; return 0;
@ -1198,80 +1313,107 @@ rdpTransport* transport_new(rdpSettings* settings)
{ {
rdpTransport* transport; rdpTransport* transport;
transport = (rdpTransport*) malloc(sizeof(rdpTransport)); transport = (rdpTransport *)calloc(1, sizeof(rdpTransport));
if (!transport)
return NULL;
if (transport) WLog_Init();
{ transport->log = WLog_Get("com.freerdp.core.transport");
ZeroMemory(transport, sizeof(rdpTransport)); if (!transport->log)
goto out_free;
WLog_Init(); transport->TcpIn = tcp_new(settings);
transport->log = WLog_Get("com.freerdp.core.transport"); if (!transport->TcpIn)
goto out_free;
transport->TcpIn = tcp_new(settings); transport->settings = settings;
transport->settings = settings; /* a small 0.1ms delay when transport is blocking. */
transport->SleepInterval = 100;
/* a small 0.1ms delay when transport is blocking. */ transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE);
transport->SleepInterval = 100; if (!transport->ReceivePool)
goto out_free_tcpin;
transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE); /* receive buffer for non-blocking read. */
transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0);
if (!transport->ReceiveBuffer)
goto out_free_receivepool;
/* receive buffer for non-blocking read. */ transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0); if (!transport->ReceiveEvent || transport->ReceiveEvent == INVALID_HANDLE_VALUE)
transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL); goto out_free_receivebuffer;
transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL); transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
if (!transport->connectedEvent || transport->connectedEvent == INVALID_HANDLE_VALUE)
goto out_free_receiveEvent;
transport->blocking = TRUE; transport->blocking = TRUE;
transport->GatewayEnabled = FALSE; transport->GatewayEnabled = FALSE;
transport->layer = TRANSPORT_LAYER_TCP;
InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000); if (!InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000))
InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000); goto out_free_connectedEvent;
if (!InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000))
transport->layer = TRANSPORT_LAYER_TCP; goto out_free_readlock;
}
return transport; return transport;
out_free_readlock:
DeleteCriticalSection(&(transport->ReadLock));
out_free_connectedEvent:
CloseHandle(transport->connectedEvent);
out_free_receiveEvent:
CloseHandle(transport->ReceiveEvent);
out_free_receivebuffer:
StreamPool_Return(transport->ReceivePool, transport->ReceiveBuffer);
out_free_receivepool:
StreamPool_Free(transport->ReceivePool);
out_free_tcpin:
tcp_free(transport->TcpIn);
out_free:
free(transport);
return NULL;
} }
void transport_free(rdpTransport* transport) void transport_free(rdpTransport* transport)
{ {
if (transport) if (!transport)
{ return;
transport_stop(transport);
if (transport->ReceiveBuffer) transport_stop(transport);
Stream_Release(transport->ReceiveBuffer);
StreamPool_Free(transport->ReceivePool); if (transport->ReceiveBuffer)
Stream_Release(transport->ReceiveBuffer);
CloseHandle(transport->ReceiveEvent); StreamPool_Free(transport->ReceivePool);
CloseHandle(transport->connectedEvent);
if (transport->TlsIn) CloseHandle(transport->ReceiveEvent);
tls_free(transport->TlsIn); CloseHandle(transport->connectedEvent);
if (transport->TlsOut != transport->TlsIn) if (transport->TlsIn)
tls_free(transport->TlsOut); tls_free(transport->TlsIn);
transport->TlsIn = NULL; if (transport->TlsOut != transport->TlsIn)
transport->TlsOut = NULL; tls_free(transport->TlsOut);
if (transport->TcpIn) transport->TlsIn = NULL;
tcp_free(transport->TcpIn); transport->TlsOut = NULL;
if (transport->TcpOut != transport->TcpIn) if (transport->TcpIn)
tcp_free(transport->TcpOut); tcp_free(transport->TcpIn);
transport->TcpIn = NULL; if (transport->TcpOut != transport->TcpIn)
transport->TcpOut = NULL; tcp_free(transport->TcpOut);
tsg_free(transport->tsg); transport->TcpIn = NULL;
transport->tsg = NULL; transport->TcpOut = NULL;
DeleteCriticalSection(&(transport->ReadLock)); tsg_free(transport->tsg);
DeleteCriticalSection(&(transport->WriteLock)); transport->tsg = NULL;
free(transport); DeleteCriticalSection(&(transport->ReadLock));
} DeleteCriticalSection(&(transport->WriteLock));
free(transport);
} }

View File

@ -49,11 +49,13 @@ typedef struct rdp_transport rdpTransport;
#include <freerdp/types.h> #include <freerdp/types.h>
#include <freerdp/settings.h> #include <freerdp/settings.h>
typedef int (*TransportRecv) (rdpTransport* transport, wStream* stream, void* extra); typedef int (*TransportRecv) (rdpTransport* transport, wStream* stream, void* extra);
struct rdp_transport struct rdp_transport
{ {
TRANSPORT_LAYER layer; TRANSPORT_LAYER layer;
BIO *frontBio;
rdpTsg* tsg; rdpTsg* tsg;
rdpTcp* TcpIn; rdpTcp* TcpIn;
rdpTcp* TcpOut; rdpTcp* TcpOut;
@ -102,6 +104,8 @@ BOOL transport_set_blocking_mode(rdpTransport* transport, BOOL blocking);
void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled); void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled);
void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode); void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode);
void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD* count); void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD* count);
BOOL tranport_is_write_blocked(rdpTransport* transport);
int tranport_drain_output_buffer(rdpTransport* transport);
wStream* transport_receive_pool_take(rdpTransport* transport); wStream* transport_receive_pool_take(rdpTransport* transport);
int transport_receive_pool_return(rdpTransport* transport, wStream* pdu); int transport_receive_pool_return(rdpTransport* transport, wStream* pdu);

View File

@ -544,7 +544,7 @@ static void update_end_paint(rdpContext* context)
if (update->numberOrders > 0) if (update->numberOrders > 0)
{ {
printf("Sending %d orders\n", update->numberOrders); fprintf(stderr, "%s: sending %d orders\n", __FUNCTION__, update->numberOrders);
fastpath_send_update_pdu(context->rdp->fastpath, FASTPATH_UPDATETYPE_ORDERS, s); fastpath_send_update_pdu(context->rdp->fastpath, FASTPATH_UPDATETYPE_ORDERS, s);
} }

1
libfreerdp/crypto/test/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
TestFreeRDPCrypto.c

View File

@ -28,34 +28,35 @@
#include <winpr/stream.h> #include <winpr/stream.h>
#include <freerdp/utils/tcp.h> #include <freerdp/utils/tcp.h>
#include <freerdp/utils/ringbuffer.h>
#include <freerdp/crypto/tls.h> #include <freerdp/crypto/tls.h>
#include "../core/tcp.h"
#ifdef HAVE_VALGRIND_MEMCHECK_H
#include <valgrind/memcheck.h>
#endif
static CryptoCert tls_get_certificate(rdpTls* tls, BOOL peer) static CryptoCert tls_get_certificate(rdpTls* tls, BOOL peer)
{ {
CryptoCert cert; CryptoCert cert;
X509* server_cert; X509* remote_cert;
if (peer) if (peer)
server_cert = SSL_get_peer_certificate(tls->ssl); remote_cert = SSL_get_peer_certificate(tls->ssl);
else else
server_cert = SSL_get_certificate(tls->ssl); remote_cert = SSL_get_certificate(tls->ssl);
if (!server_cert) if (!remote_cert)
{ {
fprintf(stderr, "tls_get_certificate: failed to get the server TLS certificate\n"); fprintf(stderr, "%s: failed to get the server TLS certificate\n", __FUNCTION__);
cert = NULL; return NULL;
}
else
{
cert = malloc(sizeof(*cert));
cert->px509 = server_cert;
} }
cert = malloc(sizeof(*cert));
if (!cert)
{
X509_free(remote_cert);
return NULL;
}
cert->px509 = remote_cert;
return cert; return cert;
} }
@ -83,12 +84,14 @@ SecPkgContext_Bindings* tls_get_channel_bindings(X509* cert)
PrefixLength = strlen(TLS_SERVER_END_POINT); PrefixLength = strlen(TLS_SERVER_END_POINT);
ChannelBindingTokenLength = PrefixLength + CertificateHashLength; ChannelBindingTokenLength = PrefixLength + CertificateHashLength;
ContextBindings = (SecPkgContext_Bindings*) malloc(sizeof(SecPkgContext_Bindings)); ContextBindings = (SecPkgContext_Bindings*) calloc(1, sizeof(SecPkgContext_Bindings));
ZeroMemory(ContextBindings, sizeof(SecPkgContext_Bindings)); if (!ContextBindings)
return NULL;
ContextBindings->BindingsLength = sizeof(SEC_CHANNEL_BINDINGS) + ChannelBindingTokenLength; ContextBindings->BindingsLength = sizeof(SEC_CHANNEL_BINDINGS) + ChannelBindingTokenLength;
ChannelBindings = (SEC_CHANNEL_BINDINGS*) malloc(ContextBindings->BindingsLength); ChannelBindings = (SEC_CHANNEL_BINDINGS*) calloc(1, ContextBindings->BindingsLength);
ZeroMemory(ChannelBindings, ContextBindings->BindingsLength); if (!ChannelBindings)
goto out_free;
ContextBindings->Bindings = ChannelBindings; ContextBindings->Bindings = ChannelBindings;
ChannelBindings->cbApplicationDataLength = ChannelBindingTokenLength; ChannelBindings->cbApplicationDataLength = ChannelBindingTokenLength;
@ -99,32 +102,121 @@ SecPkgContext_Bindings* tls_get_channel_bindings(X509* cert)
CopyMemory(&ChannelBindingToken[PrefixLength], CertificateHash, CertificateHashLength); CopyMemory(&ChannelBindingToken[PrefixLength], CertificateHash, CertificateHashLength);
return ContextBindings; return ContextBindings;
out_free:
free(ContextBindings);
return NULL;
} }
static void tls_ssl_info_callback(const SSL* ssl, int type, int val)
BOOL tls_prepare(rdpTls* tls, BIO *underlying, const SSL_METHOD *method, int options, BOOL clientMode)
{ {
if (type & SSL_CB_HANDSHAKE_START) tls->ctx = SSL_CTX_new(method);
{
}
}
int tls_connect(rdpTls* tls)
{
CryptoCert cert;
long options = 0;
int verify_status;
int connection_status;
tls->ctx = SSL_CTX_new(TLSv1_client_method());
if (!tls->ctx) if (!tls->ctx)
{ {
fprintf(stderr, "SSL_CTX_new failed\n"); fprintf(stderr, "%s: SSL_CTX_new failed\n", __FUNCTION__);
return FALSE;
}
SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE);
SSL_CTX_set_options(tls->ctx, options);
SSL_CTX_set_read_ahead(tls->ctx, 1);
tls->bio = BIO_new_ssl(tls->ctx, clientMode);
if (BIO_get_ssl(tls->bio, &tls->ssl) < 0)
{
fprintf(stderr, "%s: unable to retrieve the SSL of the connection\n", __FUNCTION__);
return FALSE;
}
BIO_push(tls->bio, underlying);
return TRUE;
}
int tls_do_handshake(rdpTls* tls, BOOL clientMode)
{
CryptoCert cert;
int verify_status, status;
do
{
struct timeval tv;
fd_set rset;
int fd;
status = BIO_do_handshake(tls->bio);
if (status == 1)
break;
if (!BIO_should_retry(tls->bio))
return -1;
/* we select() only for read even if we should test both read and write
* depending of what have blocked */
FD_ZERO(&rset);
fd = BIO_get_fd(tls->bio, NULL);
if (fd < 0)
{
fprintf(stderr, "%s: unable to retrieve BIO fd\n", __FUNCTION__);
return -1;
}
FD_SET(fd, &rset);
tv.tv_sec = 0;
tv.tv_usec = 10 * 1000; /* 10ms */
status = select(fd + 1, &rset, NULL, NULL, &tv);
if (status < 0)
{
fprintf(stderr, "%s: error during select()\n", __FUNCTION__);
return -1;
}
}
while (TRUE);
if (!clientMode)
return 1;
cert = tls_get_certificate(tls, clientMode);
if (!cert)
{
fprintf(stderr, "%s: tls_get_certificate failed to return the server certificate.\n", __FUNCTION__);
return -1; return -1;
} }
//SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE); tls->Bindings = tls_get_channel_bindings(cert->px509);
if (!tls->Bindings)
{
fprintf(stderr, "%s: unable to retrieve bindings\n", __FUNCTION__);
return -1;
}
if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength))
{
fprintf(stderr, "%s: crypto_cert_get_public_key failed to return the server public key.\n", __FUNCTION__);
tls_free_certificate(cert);
return -1;
}
verify_status = tls_verify_certificate(tls, cert, tls->hostname, tls->port);
if (verify_status < 1)
{
fprintf(stderr, "%s: certificate not trusted, aborting.\n", __FUNCTION__);
tls_disconnect(tls);
tls_free_certificate(cert);
return 0;
}
tls_free_certificate(cert);
return verify_status;
}
int tls_connect(rdpTls* tls, BIO *underlying)
{
int options = 0;
/** /**
* SSL_OP_NO_COMPRESSION: * SSL_OP_NO_COMPRESSION:
@ -138,7 +230,7 @@ int tls_connect(rdpTls* tls)
#ifdef SSL_OP_NO_COMPRESSION #ifdef SSL_OP_NO_COMPRESSION
options |= SSL_OP_NO_COMPRESSION; options |= SSL_OP_NO_COMPRESSION;
#endif #endif
/** /**
* SSL_OP_TLS_BLOCK_PADDING_BUG: * SSL_OP_TLS_BLOCK_PADDING_BUG:
* *
@ -155,96 +247,19 @@ int tls_connect(rdpTls* tls)
*/ */
options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS; options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
SSL_CTX_set_options(tls->ctx, options); if (!tls_prepare(tls, underlying, TLSv1_client_method(), options, TRUE))
return FALSE;
tls->ssl = SSL_new(tls->ctx); return tls_do_handshake(tls, TRUE);
if (!tls->ssl)
{
fprintf(stderr, "SSL_new failed\n");
return -1;
}
if (tls->tsg)
{
tls->bio = BIO_new(tls->methods);
if (!tls->bio)
{
fprintf(stderr, "BIO_new failed\n");
return -1;
}
tls->bio->ptr = tls->tsg;
SSL_set_bio(tls->ssl, tls->bio, tls->bio);
SSL_CTX_set_info_callback(tls->ctx, tls_ssl_info_callback);
}
else
{
if (SSL_set_fd(tls->ssl, tls->sockfd) < 1)
{
fprintf(stderr, "SSL_set_fd failed\n");
return -1;
}
}
connection_status = SSL_connect(tls->ssl);
if (connection_status <= 0)
{
if (tls_print_error("SSL_connect", tls->ssl, connection_status))
{
return -1;
}
}
cert = tls_get_certificate(tls, TRUE);
if (!cert)
{
fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n");
return -1;
}
tls->Bindings = tls_get_channel_bindings(cert->px509);
if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength))
{
fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n");
tls_free_certificate(cert);
return -1;
}
verify_status = tls_verify_certificate(tls, cert, tls->hostname, tls->port);
if (verify_status < 1)
{
fprintf(stderr, "tls_connect: certificate not trusted, aborting.\n");
tls_disconnect(tls);
}
tls_free_certificate(cert);
return verify_status;
} }
BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file)
BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file)
{ {
CryptoCert cert;
long options = 0; long options = 0;
int connection_status;
tls->ctx = SSL_CTX_new(SSLv23_server_method()); /**
if (tls->ctx == NULL)
{
fprintf(stderr, "SSL_CTX_new failed\n");
return FALSE;
}
/*
* SSL_OP_NO_SSLv2: * SSL_OP_NO_SSLv2:
* *
* We only want SSLv3 and TLSv1, so disable SSLv2. * We only want SSLv3 and TLSv1, so disable SSLv2.
@ -281,80 +296,23 @@ BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file)
*/ */
options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS; options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
SSL_CTX_set_options(tls->ctx, options); if (!tls_prepare(tls, underlying, SSLv23_server_method(), options, FALSE))
if (SSL_CTX_use_RSAPrivateKey_file(tls->ctx, privatekey_file, SSL_FILETYPE_PEM) <= 0)
{
fprintf(stderr, "SSL_CTX_use_RSAPrivateKey_file failed\n");
fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file);
return FALSE; return FALSE;
}
tls->ssl = SSL_new(tls->ctx); if (SSL_use_RSAPrivateKey_file(tls->ssl, privatekey_file, SSL_FILETYPE_PEM) <= 0)
if (!tls->ssl)
{ {
fprintf(stderr, "SSL_new failed\n"); fprintf(stderr, "%s: SSL_CTX_use_RSAPrivateKey_file failed\n", __FUNCTION__);
fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file);
return FALSE; return FALSE;
} }
if (SSL_use_certificate_file(tls->ssl, cert_file, SSL_FILETYPE_PEM) <= 0) if (SSL_use_certificate_file(tls->ssl, cert_file, SSL_FILETYPE_PEM) <= 0)
{ {
fprintf(stderr, "SSL_use_certificate_file failed\n"); fprintf(stderr, "%s: SSL_use_certificate_file failed\n", __FUNCTION__);
return FALSE; return FALSE;
} }
if (SSL_set_fd(tls->ssl, tls->sockfd) < 1) return tls_do_handshake(tls, FALSE) > 0;
{
fprintf(stderr, "SSL_set_fd failed\n");
return FALSE;
}
while (1)
{
connection_status = SSL_accept(tls->ssl);
if (connection_status <= 0)
{
switch (SSL_get_error(tls->ssl, connection_status))
{
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
break;
default:
if (tls_print_error("SSL_accept", tls->ssl, connection_status))
return FALSE;
break;
}
}
else
{
break;
}
}
cert = tls_get_certificate(tls, FALSE);
if (!cert)
{
fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n");
return FALSE;
}
if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength))
{
fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n");
tls_free_certificate(cert);
return FALSE;
}
free(cert);
fprintf(stderr, "TLS connection accepted\n");
return TRUE;
} }
BOOL tls_disconnect(rdpTls* tls) BOOL tls_disconnect(rdpTls* tls)
@ -362,256 +320,161 @@ BOOL tls_disconnect(rdpTls* tls)
if (!tls) if (!tls)
return FALSE; return FALSE;
if (tls->ssl) if (!tls->ssl)
return TRUE;
if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY)
{ {
if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY) /**
{ * OpenSSL doesn't really expose an API for sending a TLS alert manually.
/** *
* OpenSSL doesn't really expose an API for sending a TLS alert manually. * The following code disables the sending of the default "close notify"
* * and then proceeds to force sending a custom TLS alert before shutting down.
* The following code disables the sending of the default "close notify" *
* and then proceeds to force sending a custom TLS alert before shutting down. * Manually sending a TLS alert is necessary in certain cases,
* * like when server-side NLA results in an authentication failure.
* Manually sending a TLS alert is necessary in certain cases, */
* like when server-side NLA results in an authentication failure.
*/
SSL_set_quiet_shutdown(tls->ssl, 1); SSL_set_quiet_shutdown(tls->ssl, 1);
if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session)) if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session))
SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session); SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session);
tls->ssl->s3->alert_dispatch = 1; tls->ssl->s3->alert_dispatch = 1;
tls->ssl->s3->send_alert[0] = tls->alertLevel; tls->ssl->s3->send_alert[0] = tls->alertLevel;
tls->ssl->s3->send_alert[1] = tls->alertDescription; tls->ssl->s3->send_alert[1] = tls->alertDescription;
if (tls->ssl->s3->wbuf.left == 0) if (tls->ssl->s3->wbuf.left == 0)
tls->ssl->method->ssl_dispatch_alert(tls->ssl); tls->ssl->method->ssl_dispatch_alert(tls->ssl);
SSL_shutdown(tls->ssl); SSL_shutdown(tls->ssl);
} }
else else
{ {
SSL_shutdown(tls->ssl); SSL_shutdown(tls->ssl);
}
} }
return TRUE; return TRUE;
} }
int tls_read(rdpTls* tls, BYTE* data, int length)
BIO *findBufferedBio(BIO *front)
{ {
int error; BIO *ret = front;
int status;
if (!tls) while (ret)
return -1;
if (!tls->ssl)
return -1;
status = SSL_read(tls->ssl, data, length);
if (status == 0)
{ {
return -1; /* peer disconnected */ if (BIO_method_type(ret) == BIO_TYPE_BUFFERED)
return ret;
ret = ret->next_bio;
} }
if (status <= 0) return ret;
{
error = SSL_get_error(tls->ssl, status);
//fprintf(stderr, "tls_read: length: %d status: %d error: 0x%08X\n",
// length, status, error);
switch (error)
{
case SSL_ERROR_NONE:
break;
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
status = 0;
break;
case SSL_ERROR_SYSCALL:
#ifdef _WIN32
if (WSAGetLastError() == WSAEWOULDBLOCK)
#else
if ((errno == EAGAIN) || (errno == 0))
#endif
{
status = 0;
}
else
{
if (tls_print_error("SSL_read", tls->ssl, status))
{
status = -1;
}
else
{
status = 0;
}
}
break;
default:
if (tls_print_error("SSL_read", tls->ssl, status))
{
status = -1;
}
else
{
status = 0;
}
break;
}
}
#ifdef HAVE_VALGRIND_MEMCHECK_H
VALGRIND_MAKE_MEM_DEFINED(data, status);
#endif
return status;
} }
int tls_write(rdpTls* tls, BYTE* data, int length) int tls_write_all(rdpTls* tls, const BYTE* data, int length)
{ {
int error; int status, nchunks, commitedBytes;
int status; rdpTcp *tcp;
fd_set rset, wset;
fd_set *rsetPtr, *wsetPtr;
struct timeval tv;
BIO *bio = tls->bio;
DataChunk chunks[2];
if (!tls) BIO *bufferedBio = findBufferedBio(bio);
return -1; if (!bufferedBio)
if (!tls->ssl)
return -1;
status = SSL_write(tls->ssl, data, length);
if (status == 0)
{ {
return -1; /* peer disconnected */ fprintf(stderr, "%s: error unable to retrieve the bufferedBio in the BIO chain\n", __FUNCTION__);
return -1;
} }
if (status < 0) tcp = (rdpTcp *)bufferedBio->ptr;
{
error = SSL_get_error(tls->ssl, status);
//fprintf(stderr, "tls_write: length: %d status: %d error: 0x%08X\n", length, status, error);
switch (error)
{
case SSL_ERROR_NONE:
break;
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
status = 0;
break;
case SSL_ERROR_SYSCALL:
if (errno == EAGAIN)
{
status = 0;
}
else
{
tls_print_error("SSL_write", tls->ssl, status);
status = -1;
}
break;
default:
tls_print_error("SSL_write", tls->ssl, status);
status = -1;
break;
}
}
return status;
}
int tls_write_all(rdpTls* tls, BYTE* data, int length)
{
int status;
int sent = 0;
do do
{ {
status = tls_write(tls, &data[sent], length - sent); status = BIO_write(bio, data, length);
/*fprintf(stderr, "%s: BIO_write(len=%d) = %d (retry=%d)\n", __FUNCTION__, length, status, BIO_should_retry(bio));*/
if (status > 0) if (status > 0)
sent += status;
else if (status == 0)
tls_wait_write(tls);
if (sent >= length)
break; break;
if (!BIO_should_retry(bio))
return -1;
/* we try to handle SSL want_read and want_write nicely */
rsetPtr = wsetPtr = 0;
if (tcp->writeBlocked)
{
wsetPtr = &wset;
FD_ZERO(&wset);
FD_SET(tcp->sockfd, &wset);
}
else if (tcp->readBlocked)
{
rsetPtr = &rset;
FD_ZERO(&rset);
FD_SET(tcp->sockfd, &rset);
}
else
{
fprintf(stderr, "%s: weird we're blocked but the underlying is not read or write blocked !\n", __FUNCTION__);
USleep(10);
continue;
}
tv.tv_sec = 0;
tv.tv_usec = 100 * 1000;
status = select(tcp->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv);
if (status < 0)
return -1;
} }
while (status >= 0); while (TRUE);
if (status > 0) /* make sure the output buffer is empty */
return length; commitedBytes = 0;
else while ((nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer))))
return status;
}
int tls_wait_read(rdpTls* tls)
{
return freerdp_tcp_wait_read(tls->sockfd);
}
int tls_wait_write(rdpTls* tls)
{
return freerdp_tcp_wait_write(tls->sockfd);
}
static void tls_errors(const char *prefix)
{
unsigned long error;
while ((error = ERR_get_error()) != 0)
fprintf(stderr, "%s: %s\n", prefix, ERR_error_string(error, NULL));
}
BOOL tls_print_error(char* func, SSL* connection, int value)
{
switch (SSL_get_error(connection, value))
{ {
case SSL_ERROR_ZERO_RETURN: int i;
fprintf(stderr, "%s: Server closed TLS connection\n", func);
return TRUE;
case SSL_ERROR_WANT_READ: for (i = 0; i < nchunks; i++)
fprintf(stderr, "%s: SSL_ERROR_WANT_READ\n", func); {
return FALSE; while (chunks[i].size)
{
status = BIO_write(tcp->socketBio, chunks[i].data, chunks[i].size);
if (status > 0)
{
chunks[i].size -= status;
chunks[i].data += status;
commitedBytes += status;
continue;
}
case SSL_ERROR_WANT_WRITE: if (!BIO_should_retry(tcp->socketBio))
fprintf(stderr, "%s: SSL_ERROR_WANT_WRITE\n", func); goto out_fail;
return FALSE; FD_ZERO(&rset);
FD_SET(tcp->sockfd, &rset);
tv.tv_sec = 0;
tv.tv_usec = 100 * 1000;
case SSL_ERROR_SYSCALL: status = select(tcp->sockfd + 1, &rset, NULL, NULL, &tv);
#ifdef _WIN32 if (status < 0)
fprintf(stderr, "%s: I/O error: %d\n", func, WSAGetLastError()); goto out_fail;
#else }
fprintf(stderr, "%s: I/O error: %s (%d)\n", func, strerror(errno), errno);
#endif
tls_errors(func);
return TRUE;
case SSL_ERROR_SSL: }
fprintf(stderr, "%s: Failure in SSL library (protocol error?)\n", func);
tls_errors(func);
return TRUE;
default:
fprintf(stderr, "%s: Unknown error\n", func);
tls_errors(func);
return TRUE;
} }
ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes);
return length;
out_fail:
ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes);
return -1;
} }
int tls_set_alert_code(rdpTls* tls, int level, int description) int tls_set_alert_code(rdpTls* tls, int level, int description)
{ {
tls->alertLevel = level; tls->alertLevel = level;
@ -672,7 +535,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
if (!bio) if (!bio)
{ {
fprintf(stderr, "tls_verify_certificate: BIO_new() failure\n"); fprintf(stderr, "%s: BIO_new() failure\n", __FUNCTION__);
return -1; return -1;
} }
@ -680,7 +543,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
if (status < 0) if (status < 0)
{ {
fprintf(stderr, "tls_verify_certificate: PEM_write_bio_X509 failure: %d\n", status); fprintf(stderr, "%s: PEM_write_bio_X509 failure: %d\n", __FUNCTION__, status);
return -1; return -1;
} }
@ -692,7 +555,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
if (status < 0) if (status < 0)
{ {
fprintf(stderr, "tls_verify_certificate: failed to read certificate\n"); fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__);
return -1; return -1;
} }
@ -713,7 +576,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
if (status < 0) if (status < 0)
{ {
fprintf(stderr, "tls_verify_certificate: failed to read certificate\n"); fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__);
return -1; return -1;
} }
@ -727,8 +590,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
status = instance->VerifyX509Certificate(instance, pemCert, length, hostname, port, 0); status = instance->VerifyX509Certificate(instance, pemCert, length, hostname, port, 0);
} }
fprintf(stderr, "VerifyX509Certificate: (length = %d) status: %d\n%s\n", fprintf(stderr, "%s: (length = %d) status: %d\n%s\n", __FUNCTION__, length, status, pemCert);
length, status, pemCert);
free(pemCert); free(pemCert);
BIO_free(bio); BIO_free(bio);
@ -932,57 +794,53 @@ rdpTls* tls_new(rdpSettings* settings)
{ {
rdpTls* tls; rdpTls* tls;
tls = (rdpTls*) malloc(sizeof(rdpTls)); tls = (rdpTls *)calloc(1, sizeof(rdpTls));
if (!tls)
return NULL;
if (tls) SSL_load_error_strings();
{ SSL_library_init();
ZeroMemory(tls, sizeof(rdpTls));
SSL_load_error_strings(); tls->settings = settings;
SSL_library_init(); tls->certificate_store = certificate_store_new(settings);
if (!tls->certificate_store)
tls->settings = settings; goto out_free;
tls->certificate_store = certificate_store_new(settings);
tls->alertLevel = TLS_ALERT_LEVEL_WARNING;
tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY;
}
tls->alertLevel = TLS_ALERT_LEVEL_WARNING;
tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY;
return tls; return tls;
out_free:
free(tls);
return NULL;
} }
void tls_free(rdpTls* tls) void tls_free(rdpTls* tls)
{ {
if (tls) if (!tls)
return;
if (tls->ctx)
{ {
if (tls->ssl) SSL_CTX_free(tls->ctx);
{ tls->ctx = NULL;
SSL_free(tls->ssl);
tls->ssl = NULL;
}
if (tls->ctx)
{
SSL_CTX_free(tls->ctx);
tls->ctx = NULL;
}
if (tls->PublicKey)
{
free(tls->PublicKey);
tls->PublicKey = NULL;
}
if (tls->Bindings)
{
free(tls->Bindings->Bindings);
free(tls->Bindings);
tls->Bindings = NULL;
}
certificate_store_free(tls->certificate_store);
tls->certificate_store = NULL;
free(tls);
} }
if (tls->PublicKey)
{
free(tls->PublicKey);
tls->PublicKey = NULL;
}
if (tls->Bindings)
{
free(tls->Bindings->Bindings);
free(tls->Bindings);
tls->Bindings = NULL;
}
certificate_store_free(tls->certificate_store);
tls->certificate_store = NULL;
free(tls);
} }

View File

@ -25,6 +25,7 @@ set(${MODULE_PREFIX}_SRCS
pcap.c pcap.c
profiler.c profiler.c
rail.c rail.c
ringbuffer.c
signal.c signal.c
stopwatch.c stopwatch.c
svc_plugin.c svc_plugin.c
@ -68,3 +69,9 @@ else()
endif() endif()
set_property(TARGET ${MODULE_NAME} PROPERTY FOLDER "FreeRDP/libfreerdp") set_property(TARGET ${MODULE_NAME} PROPERTY FOLDER "FreeRDP/libfreerdp")
if(BUILD_TESTING)
add_subdirectory(test)
endif()

View File

@ -0,0 +1,251 @@
/**
* FreeRDP: A Remote Desktop Protocol Implementation
*
* Copyright © 2014 Thincast Technologies GmbH
* Copyright © 2014 Hardening <contact@hardening-consulting.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <freerdp/utils/ringbuffer.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
BOOL ringbuffer_init(RingBuffer *rb, size_t initialSize)
{
rb->buffer = malloc(initialSize);
if (!rb->buffer)
return FALSE;
rb->readPtr = rb->writePtr = 0;
rb->initialSize = rb->size = rb->freeSize = initialSize;
return TRUE;
}
size_t ringbuffer_used(const RingBuffer *ringbuffer)
{
return ringbuffer->size - ringbuffer->freeSize;
}
size_t ringbuffer_capacity(const RingBuffer *ringbuffer)
{
return ringbuffer->size;
}
void ringbuffer_destroy(RingBuffer *ringbuffer)
{
free(ringbuffer->buffer);
ringbuffer->buffer = 0;
}
static BOOL ringbuffer_realloc(RingBuffer *rb, size_t targetSize)
{
BYTE *newData;
if (rb->writePtr == rb->readPtr)
{
/* when no size is used we can realloc() and set the heads at the
* beginning of the buffer
*/
newData = (BYTE *)realloc(rb->buffer, targetSize);
if (!newData)
return FALSE;
rb->readPtr = rb->writePtr = 0;
rb->buffer = newData;
}
else if ((rb->writePtr >= rb->readPtr) && (rb->writePtr < targetSize))
{
/* we reallocate only if we're in that case, realloc don't touch read
* and write heads
*
* readPtr writePtr
* | |
* v v
* [............|XXXXXXXXXXXXXX|..........]
*/
newData = (BYTE *)realloc(rb->buffer, targetSize);
if (!newData)
return FALSE;
rb->buffer = newData;
}
else
{
/* in case of malloc the read head is moved at the beginning of the new buffer
* and the write head is set accordingly
*/
newData = (BYTE *)malloc(targetSize);
if (!newData)
return FALSE;
if (rb->readPtr < rb->writePtr)
{
/* readPtr writePtr
* | |
* v v
* [............|XXXXXXXXXXXXXX|..........]
*/
memcpy(newData, rb->buffer + rb->readPtr, ringbuffer_used(rb));
}
else
{
/* writePtr readPtr
* | |
* v v
* [XXXXXXXXXXXX|..............|XXXXXXXXXX]
*/
BYTE *dst = newData;
memcpy(dst, rb->buffer + rb->readPtr, rb->size - rb->readPtr);
dst += (rb->size - rb->readPtr);
if (rb->writePtr)
memcpy(dst, rb->buffer, rb->writePtr);
}
rb->writePtr = rb->size - rb->freeSize;
rb->readPtr = 0;
free(rb->buffer);
rb->buffer = newData;
}
rb->freeSize += (targetSize - rb->size);
rb->size = targetSize;
return TRUE;
}
/**
*
* @param rb
* @param ptr
* @param sz
* @return
*/
BOOL ringbuffer_write(RingBuffer *rb, const BYTE *ptr, size_t sz)
{
size_t toWrite;
size_t remaining;
if ((rb->freeSize <= sz) && !ringbuffer_realloc(rb, rb->size + sz))
return FALSE;
/* the write could be split in two
* readHead writeHead
* | |
* v v
* [ ################ ]
*/
toWrite = sz;
remaining = sz;
if (rb->size - rb->writePtr < sz)
toWrite = rb->size - rb->writePtr;
if (toWrite)
{
memcpy(rb->buffer + rb->writePtr, ptr, toWrite);
remaining -= toWrite;
ptr += toWrite;
}
if (remaining)
memcpy(rb->buffer, ptr, remaining);
rb->writePtr = (rb->writePtr + sz) % rb->size;
rb->freeSize -= sz;
return TRUE;
}
BYTE *ringbuffer_ensure_linear_write(RingBuffer *rb, size_t sz)
{
if (rb->freeSize < sz)
{
if (!ringbuffer_realloc(rb, rb->size + sz - rb->freeSize + 32))
return NULL;
}
if (rb->writePtr == rb->readPtr)
{
rb->writePtr = rb->readPtr = 0;
}
if (rb->writePtr + sz < rb->size)
return rb->buffer + rb->writePtr;
/*
* to add: .......
* [ XXXXXXXXX ]
*
* result:
* [XXXXXXXXX....... ]
*/
memmove(rb->buffer, rb->buffer + rb->readPtr, rb->writePtr - rb->readPtr);
rb->readPtr = 0;
rb->writePtr = rb->size - rb->freeSize;
return rb->buffer + rb->writePtr;
}
BOOL ringbuffer_commit_written_bytes(RingBuffer *rb, size_t sz)
{
if (rb->writePtr + sz > rb->size)
return FALSE;
rb->writePtr = (rb->writePtr + sz) % rb->size;
rb->freeSize -= sz;
return TRUE;
}
int ringbuffer_peek(const RingBuffer *rb, DataChunk chunks[2], size_t sz)
{
size_t remaining = sz;
size_t toRead;
int chunkIndex = 0;
int ret = 0;
if (rb->size - rb->freeSize < sz)
remaining = rb->size - rb->freeSize;
toRead = remaining;
if (rb->readPtr + remaining > rb->size)
toRead = rb->size - rb->readPtr;
if (toRead)
{
chunks[0].data = rb->buffer + rb->readPtr;
chunks[0].size = toRead;
remaining -= toRead;
chunkIndex++;
ret++;
}
if (remaining)
{
chunks[chunkIndex].data = rb->buffer;
chunks[chunkIndex].size = remaining;
ret++;
}
return ret;
}
void ringbuffer_commit_read_bytes(RingBuffer *rb, size_t sz)
{
assert(rb->size - rb->freeSize >= sz);
rb->readPtr = (rb->readPtr + sz) % rb->size;
rb->freeSize += sz;
/* when we reach a reasonable free size, we can go back to the original size */
if ((rb->size != rb->initialSize) && (ringbuffer_used(rb) < rb->initialSize / 2))
ringbuffer_realloc(rb, rb->initialSize);
}

1
libfreerdp/utils/test/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
TestFreeRDPutils.c

View File

@ -0,0 +1,36 @@
set(MODULE_NAME "TestFreeRDPUtils")
set(MODULE_PREFIX "TEST_FREERDP_UTILS")
set(${MODULE_PREFIX}_DRIVER ${MODULE_NAME}.c)
set(${MODULE_PREFIX}_TESTS
TestRingBuffer.c)
create_test_sourcelist(${MODULE_PREFIX}_SRCS
${${MODULE_PREFIX}_DRIVER}
${${MODULE_PREFIX}_TESTS})
add_executable(${MODULE_NAME} ${${MODULE_PREFIX}_SRCS})
set_complex_link_libraries(VARIABLE ${MODULE_PREFIX}_LIBS
MONOLITHIC ${MONOLITHIC_BUILD}
MODULE winpr
MODULES winpr-thread winpr-synch winpr-file winpr-utils winpr-crt)
set_complex_link_libraries(VARIABLE ${MODULE_PREFIX}_LIBS
MONOLITHIC ${MONOLITHIC_BUILD}
MODULE freerdp
MODULES freerdp-utils)
target_link_libraries(${MODULE_NAME} ${${MODULE_PREFIX}_LIBS})
set_target_properties(${MODULE_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${TESTING_OUTPUT_DIRECTORY}")
foreach(test ${${MODULE_PREFIX}_TESTS})
get_filename_component(TestName ${test} NAME_WE)
add_test(${TestName} ${TESTING_OUTPUT_DIRECTORY}/${MODULE_NAME} ${TestName})
endforeach()
set_property(TARGET ${MODULE_NAME} PROPERTY FOLDER "FreeRDP/Test")

View File

@ -0,0 +1,228 @@
/**
* FreeRDP: A Remote Desktop Protocol Implementation
*
* Copyright © 2014 Thincast Technologies GmbH
* Copyright © 2014 Hardening <contact@hardening-consulting.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include <string.h>
#include <freerdp/utils/ringbuffer.h>
BOOL test_overlaps(void)
{
RingBuffer rb;
DataChunk chunks[2];
BYTE bytes[200];
int nchunks, i, j, k, counter = 0;
for (i = 0; i < sizeof(bytes); i++)
bytes[i] = (BYTE)i;
ringbuffer_init(&rb, 5);
if (!ringbuffer_write(&rb, bytes, 4)) /* [0123.] */
goto error;
counter += 4;
ringbuffer_commit_read_bytes(&rb, 2); /* [..23.] */
if (!ringbuffer_write(&rb, &bytes[counter], 2)) /* [5.234] */
goto error;
counter += 2;
nchunks = ringbuffer_peek(&rb, chunks, 4);
if (nchunks != 2 || chunks[0].size != 3 || chunks[1].size != 1)
goto error;
for (i = 0, j = 2; i < nchunks; i++)
{
for (k = 0; k < (int) chunks[i].size; k++, j++)
{
if (chunks[i].data[k] != (BYTE)j)
goto error;
}
}
ringbuffer_commit_read_bytes(&rb, 3); /* [5....] */
if (ringbuffer_used(&rb) != 1)
goto error;
if (!ringbuffer_write(&rb, &bytes[counter], 6)) /* [56789ab....] */
goto error;
counter += 6;
ringbuffer_commit_read_bytes(&rb, 6); /* [......b....] */
nchunks = ringbuffer_peek(&rb, chunks, 10);
if (nchunks != 1 || chunks[0].size != 1 || (*chunks[0].data != 0xb))
goto error;
if (ringbuffer_capacity(&rb) != 5)
goto error;
ringbuffer_destroy(&rb);
return TRUE;
error:
ringbuffer_destroy(&rb);
return FALSE;
}
int TestRingBuffer(int argc, char* argv[])
{
RingBuffer ringBuffer;
int testNo = 0;
BYTE *tmpBuf;
BYTE *rb_ptr;
int i/*, chunkNb, counter*/;
DataChunk chunks[2];
if (!ringbuffer_init(&ringBuffer, 10))
{
fprintf(stderr, "unable to initialize ringbuffer\n");
return -1;
}
tmpBuf = (BYTE *)malloc(50);
if (!tmpBuf)
return -1;
for (i = 0; i < 50; i++)
tmpBuf[i] = (char)i;
fprintf(stderr, "%d: basic tests...", ++testNo);
if (!ringbuffer_write(&ringBuffer, tmpBuf, 5) || !ringbuffer_write(&ringBuffer, tmpBuf, 5) ||
!ringbuffer_write(&ringBuffer, tmpBuf, 5))
{
fprintf(stderr, "error when writing bytes\n");
return -1;
}
if (ringbuffer_used(&ringBuffer) != 15)
{
fprintf(stderr, "invalid used size got %d when i would expect 15\n", ringbuffer_used(&ringBuffer));
return -1;
}
if (ringbuffer_peek(&ringBuffer, chunks, 10) != 1 || chunks[0].size != 10)
{
fprintf(stderr, "error when reading bytes\n");
return -1;
}
ringbuffer_commit_read_bytes(&ringBuffer, chunks[0].size);
/* check retrieved bytes */
for (i = 0; i < (int) chunks[0].size; i++)
{
if (chunks[0].data[i] != i % 5)
{
fprintf(stderr, "invalid byte at %d, got %d instead of %d\n", i, chunks[0].data[i], i % 5);
return -1;
}
}
if (ringbuffer_used(&ringBuffer) != 5)
{
fprintf(stderr, "invalid used size after read got %d when i would expect 5\n", ringbuffer_used(&ringBuffer));
return -1;
}
/* write some more bytes to have writePtr < readPtr and data splitted in 2 chunks */
if (!ringbuffer_write(&ringBuffer, tmpBuf, 6) ||
ringbuffer_peek(&ringBuffer, chunks, 11) != 2 ||
chunks[0].size != 10 ||
chunks[1].size != 1)
{
fprintf(stderr, "invalid read of splitted data\n");
return -1;
}
ringbuffer_commit_read_bytes(&ringBuffer, 11);
fprintf(stderr, "ok\n");
fprintf(stderr, "%d: peek with nothing to read...", ++testNo);
if (ringbuffer_peek(&ringBuffer, chunks, 10))
{
fprintf(stderr, "peek returns some chunks\n");
return -1;
}
fprintf(stderr, "ok\n");
fprintf(stderr, "%d: ensure_linear_write / read() shouldn't grow...", ++testNo);
for (i = 0; i < 1000; i++)
{
rb_ptr = ringbuffer_ensure_linear_write(&ringBuffer, 50);
if (!rb_ptr)
{
fprintf(stderr, "ringbuffer_ensure_linear_write() error\n");
return -1;
}
memcpy(rb_ptr, tmpBuf, 50);
if (!ringbuffer_commit_written_bytes(&ringBuffer, 50))
{
fprintf(stderr, "ringbuffer_commit_written_bytes() error, i=%d\n", i);
return -1;
}
//ringbuffer_commit_read_bytes(&ringBuffer, 25);
}
for (i = 0; i < 1000; i++)
ringbuffer_commit_read_bytes(&ringBuffer, 25);
for (i = 0; i < 1000; i++)
ringbuffer_commit_read_bytes(&ringBuffer, 25);
if (ringbuffer_capacity(&ringBuffer) != 10)
{
fprintf(stderr, "not the expected capacity, have %d and expects 10\n", ringbuffer_capacity(&ringBuffer));
return -1;
}
fprintf(stderr, "ok\n");
fprintf(stderr, "%d: free size is correctly computed...", ++testNo);
for (i = 0; i < 1000; i++)
{
ringbuffer_ensure_linear_write(&ringBuffer, 50);
if (!ringbuffer_commit_written_bytes(&ringBuffer, 50))
{
fprintf(stderr, "ringbuffer_commit_written_bytes() error, i=%d\n", i);
return -1;
}
}
ringbuffer_commit_read_bytes(&ringBuffer, 50 * 1000);
fprintf(stderr, "ok\n");
ringbuffer_destroy(&ringBuffer);
fprintf(stderr, "%d: specific overlaps test...", ++testNo);
if (!test_overlaps())
{
fprintf(stderr, "ko\n", i);
return -1;
}
fprintf(stderr, "ok\n");
ringbuffer_destroy(&ringBuffer);
free(tmpBuf);
return 0;
}

View File

@ -22,6 +22,9 @@
#include <winpr/winpr.h> #include <winpr/winpr.h>
#include <winpr/wtypes.h> #include <winpr/wtypes.h>
#ifdef __cplusplus
extern "C" {
#endif
typedef struct _MAKECERT_CONTEXT MAKECERT_CONTEXT; typedef struct _MAKECERT_CONTEXT MAKECERT_CONTEXT;
@ -34,4 +37,8 @@ WINPR_API int makecert_context_output_private_key_file(MAKECERT_CONTEXT* context
WINPR_API MAKECERT_CONTEXT* makecert_context_new(); WINPR_API MAKECERT_CONTEXT* makecert_context_new();
WINPR_API void makecert_context_free(MAKECERT_CONTEXT* context); WINPR_API void makecert_context_free(MAKECERT_CONTEXT* context);
#ifdef __cplusplus
}
#endif
#endif /* MAKECERT_TOOL_H */ #endif /* MAKECERT_TOOL_H */

View File

@ -152,7 +152,7 @@ static PVOID TestSynchCritical_Main(PVOID arg)
InitializeCriticalSection(&critical); InitializeCriticalSection(&critical);
for (i=0; i<1000; i++) for (i = 0; i < 1000; i++)
{ {
if (critical.RecursionCount != i) if (critical.RecursionCount != i)
{ {
@ -200,9 +200,9 @@ static PVOID TestSynchCritical_Main(PVOID arg)
dwThreadCount = sysinfo.dwNumberOfProcessors > 1 ? sysinfo.dwNumberOfProcessors : 2; dwThreadCount = sysinfo.dwNumberOfProcessors > 1 ? sysinfo.dwNumberOfProcessors : 2;
hThreads = (HANDLE*)calloc(dwThreadCount, sizeof(HANDLE)); hThreads = (HANDLE*) calloc(dwThreadCount, sizeof(HANDLE));
for (j=0; j < TEST_SYNC_CRITICAL_TEST1_RUNS; j++) for (j = 0; j < TEST_SYNC_CRITICAL_TEST1_RUNS; j++)
{ {
dwSpinCount = j * 1000; dwSpinCount = j * 1000;
InitializeCriticalSectionAndSpinCount(&critical, dwSpinCount); InitializeCriticalSectionAndSpinCount(&critical, dwSpinCount);
@ -212,14 +212,15 @@ static PVOID TestSynchCritical_Main(PVOID arg)
/* the TestSynchCritical_Test1 threads shall run until bTest1Running is FALSE */ /* the TestSynchCritical_Test1 threads shall run until bTest1Running is FALSE */
bTest1Running = TRUE; bTest1Running = TRUE;
for (i=0; i<dwThreadCount; i++) { for (i = 0; i < (int) dwThreadCount; i++) {
hThreads[i] = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) TestSynchCritical_Test1, &bTest1Running, 0, NULL); hThreads[i] = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) TestSynchCritical_Test1, &bTest1Running, 0, NULL);
} }
/* let it run for TEST_SYNC_CRITICAL_TEST1_RUNTIME_MS ... */ /* let it run for TEST_SYNC_CRITICAL_TEST1_RUNTIME_MS ... */
Sleep(TEST_SYNC_CRITICAL_TEST1_RUNTIME_MS); Sleep(TEST_SYNC_CRITICAL_TEST1_RUNTIME_MS);
bTest1Running = FALSE; bTest1Running = FALSE;
for (i=0; i<dwThreadCount; i++) for (i = 0; i < (int) dwThreadCount; i++)
{ {
if (WaitForSingleObject(hThreads[i], INFINITE) != WAIT_OBJECT_0) if (WaitForSingleObject(hThreads[i], INFINITE) != WAIT_OBJECT_0)
{ {
@ -288,13 +289,13 @@ int TestSynchCritical(int argc, char* argv[])
HANDLE hThread; HANDLE hThread;
DWORD dwThreadExitCode; DWORD dwThreadExitCode;
DWORD dwDeadLockDetectionTimeMs; DWORD dwDeadLockDetectionTimeMs;
int i; DWORD i;
dwDeadLockDetectionTimeMs = 2 * TEST_SYNC_CRITICAL_TEST1_RUNTIME_MS * TEST_SYNC_CRITICAL_TEST1_RUNS; dwDeadLockDetectionTimeMs = 2 * TEST_SYNC_CRITICAL_TEST1_RUNTIME_MS * TEST_SYNC_CRITICAL_TEST1_RUNS;
printf("Deadlock will be assumed after %u ms.\n", dwDeadLockDetectionTimeMs); printf("Deadlock will be assumed after %u ms.\n", dwDeadLockDetectionTimeMs);
hThread = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) TestSynchCritical_Main, &bThreadTerminated, 0, NULL); hThread = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) TestSynchCritical_Main, &bThreadTerminated, 0, NULL);
/** /**
* We have to be able to detect dead locks in this test. * We have to be able to detect dead locks in this test.
@ -303,10 +304,11 @@ int TestSynchCritical(int argc, char* argv[])
* Workaround checking the value of bThreadTerminated which is passed in the thread arg * Workaround checking the value of bThreadTerminated which is passed in the thread arg
*/ */
for (i=0; i<dwDeadLockDetectionTimeMs; i+=100) for (i = 0; i < dwDeadLockDetectionTimeMs; i += 100)
{ {
if (bThreadTerminated) if (bThreadTerminated)
break; break;
Sleep(100); Sleep(100);
} }
@ -319,7 +321,7 @@ int TestSynchCritical(int argc, char* argv[])
GetExitCodeThread(hThread, &dwThreadExitCode); GetExitCodeThread(hThread, &dwThreadExitCode);
CloseHandle(hThread); CloseHandle(hThread);
if(dwThreadExitCode != 0) if (dwThreadExitCode != 0)
{ {
return -1; return -1;
} }

View File

@ -1,6 +1,7 @@
#include <winpr/crt.h> #include <winpr/crt.h>
#include <winpr/wnd.h> #include <winpr/wnd.h>
#include <winpr/tchar.h>
#include <winpr/wtsapi.h> #include <winpr/wtsapi.h>
#include <winpr/library.h> #include <winpr/library.h>

View File

@ -1,6 +1,7 @@
#include <winpr/crt.h> #include <winpr/crt.h>
#include <winpr/wnd.h> #include <winpr/wnd.h>
#include <winpr/tchar.h>
#include <winpr/library.h> #include <winpr/library.h>
static LRESULT CALLBACK TestWndProc(HWND hwnd, UINT uMsg, WPARAM wParam, LPARAM lParam) static LRESULT CALLBACK TestWndProc(HWND hwnd, UINT uMsg, WPARAM wParam, LPARAM lParam)

View File

@ -31,8 +31,6 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
for (index = 0; index < count; index++) for (index = 0; index < count; index++)
{ {
pBuffer = NULL;
bytesReturned = 0;
char* Username; char* Username;
char* Domain; char* Domain;
char* ClientName; char* ClientName;
@ -44,6 +42,9 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
PWTS_CLIENT_ADDRESS ClientAddress; PWTS_CLIENT_ADDRESS ClientAddress;
WTS_CONNECTSTATE_CLASS ConnectState; WTS_CONNECTSTATE_CLASS ConnectState;
pBuffer = NULL;
bytesReturned = 0;
sessionId = pSessionInfo[index].SessionId; sessionId = pSessionInfo[index].SessionId;
printf("[%d] SessionId: %d State: %d\n", (int) index, printf("[%d] SessionId: %d State: %d\n", (int) index,
@ -52,7 +53,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
/* WTSUserName */ /* WTSUserName */
bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSUserName, &pBuffer, &bytesReturned); bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSUserName, &pBuffer, &bytesReturned);
if (!bSuccess) if (!bSuccess)
{ {
@ -65,7 +66,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
/* WTSDomainName */ /* WTSDomainName */
bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSDomainName, &pBuffer, &bytesReturned); bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSDomainName, &pBuffer, &bytesReturned);
if (!bSuccess) if (!bSuccess)
{ {
@ -78,7 +79,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
/* WTSConnectState */ /* WTSConnectState */
bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSConnectState, &pBuffer, &bytesReturned); bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSConnectState, &pBuffer, &bytesReturned);
if (!bSuccess) if (!bSuccess)
{ {
@ -91,7 +92,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
/* WTSClientBuildNumber */ /* WTSClientBuildNumber */
bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSClientBuildNumber, &pBuffer, &bytesReturned); bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSClientBuildNumber, &pBuffer, &bytesReturned);
if (!bSuccess) if (!bSuccess)
{ {
@ -104,7 +105,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
/* WTSClientName */ /* WTSClientName */
bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSClientName, &pBuffer, &bytesReturned); bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSClientName, &pBuffer, &bytesReturned);
if (!bSuccess) if (!bSuccess)
{ {
@ -117,7 +118,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
/* WTSClientProductId */ /* WTSClientProductId */
bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSClientProductId, &pBuffer, &bytesReturned); bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSClientProductId, &pBuffer, &bytesReturned);
if (!bSuccess) if (!bSuccess)
{ {
@ -130,7 +131,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
/* WTSClientHardwareId */ /* WTSClientHardwareId */
bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSClientHardwareId, &pBuffer, &bytesReturned); bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSClientHardwareId, &pBuffer, &bytesReturned);
if (!bSuccess) if (!bSuccess)
{ {
@ -143,7 +144,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
/* WTSClientAddress */ /* WTSClientAddress */
bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSClientAddress, &pBuffer, &bytesReturned); bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSClientAddress, &pBuffer, &bytesReturned);
if (!bSuccess) if (!bSuccess)
{ {
@ -157,7 +158,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
/* WTSClientDisplay */ /* WTSClientDisplay */
bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSClientDisplay, &pBuffer, &bytesReturned); bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSClientDisplay, &pBuffer, &bytesReturned);
if (!bSuccess) if (!bSuccess)
{ {
@ -172,7 +173,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
/* WTSClientProtocolType */ /* WTSClientProtocolType */
bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSClientProtocolType, &pBuffer, &bytesReturned); bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSClientProtocolType, &pBuffer, &bytesReturned);
if (!bSuccess) if (!bSuccess)
{ {