libfreerdp-core: simplify TS Gateway RPC fragment receiving

This commit is contained in:
Marc-André Moreau 2015-01-31 16:56:25 -05:00
parent 5e53063d55
commit 85191391d5
8 changed files with 81 additions and 112 deletions

View File

@ -253,6 +253,7 @@ BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* buffer, UINT32* offset, UINT32* l
UINT32 auth_pad_length;
UINT32 sec_trailer_offset;
rpc_sec_trailer* sec_trailer;
*offset = RPC_COMMON_FIELDS_LENGTH;
header = ((rpcconn_hdr_t*) buffer);
@ -309,7 +310,7 @@ BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* buffer, UINT32* offset, UINT32* l
if ((frag_length - (sec_trailer_offset + 8)) != auth_length)
{
WLog_ERR(TAG, "invalid auth_length: actual: %d, expected: %d", auth_length,
WLog_ERR(TAG, "invalid auth_length: actual: %d, expected: %d", auth_length,
(frag_length - (sec_trailer_offset + 8)));
}
@ -445,7 +446,7 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
if (encrypt_status != SEC_E_OK)
{
WLog_ERR(TAG, "EncryptMessage status: 0x%08X", encrypt_status);
WLog_ERR(TAG, "EncryptMessage status: 0x%08X", encrypt_status);
goto out_free_pdu;
}
@ -475,7 +476,7 @@ BOOL rpc_connect(rdpRpc* rpc)
if (!rts_connect(rpc))
{
WLog_ERR(TAG, "rts_connect error!");
WLog_ERR(TAG, "rts_connect error!");
return FALSE;
}
@ -483,7 +484,7 @@ BOOL rpc_connect(rdpRpc* rpc)
if (rpc_secure_bind(rpc) != 0)
{
WLog_ERR(TAG, "rpc_secure_bind error!");
WLog_ERR(TAG, "rpc_secure_bind error!");
return FALSE;
}

View File

@ -30,6 +30,8 @@
typedef struct rdp_rpc rdpRpc;
#pragma pack(push, 1)
#define DEFINE_RPC_COMMON_FIELDS() \
BYTE rpc_vers; \
BYTE rpc_vers_minor; \
@ -57,10 +59,13 @@ typedef struct
typedef struct _RPC_PDU
{
wStream* s;
DWORD Flags;
DWORD CallId;
UINT32 Type;
UINT32 Flags;
UINT32 CallId;
} RPC_PDU, *PRPC_PDU;
#pragma pack(pop)
#include "../tcp.h"
#include "../transport.h"
@ -131,6 +136,8 @@ typedef struct _RPC_PDU
*/
#define RPC_PDU_HEADER_MAX_LENGTH 32
#pragma pack(push, 1)
typedef struct
{
DEFINE_RPC_COMMON_FIELDS();
@ -532,6 +539,8 @@ typedef union
rpcconn_rts_hdr_t rts;
} rpcconn_hdr_t;
#pragma pack(pop)
struct _RPC_SECURITY_PROVIDER_INFO
{
UINT32 Id;
@ -706,8 +715,6 @@ struct rpc_client
wQueue* ReceiveQueue;
wStream* RecvFrag;
wQueue* FragmentPool;
wQueue* FragmentQueue;
wArrayList* ClientCallList;

View File

@ -412,7 +412,7 @@ int rpc_secure_bind(rdpRpc* rpc)
if (status <= 0)
{
WLog_ERR(TAG, "rpc_secure_bind: error sending bind pdu!");
WLog_ERR(TAG, "rpc_secure_bind: error sending bind pdu!");
return -1;
}
@ -424,13 +424,13 @@ int rpc_secure_bind(rdpRpc* rpc)
if (!pdu)
{
WLog_ERR(TAG, "rpc_secure_bind: error receiving bind ack pdu!");
WLog_ERR(TAG, "rpc_secure_bind: error receiving bind ack pdu!");
return -1;
}
if (rpc_recv_bind_ack_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s)) <= 0)
{
WLog_ERR(TAG, "rpc_secure_bind: error receiving bind ack pdu!");
WLog_ERR(TAG, "rpc_secure_bind: error receiving bind ack pdu!");
return -1;
}
@ -438,7 +438,7 @@ int rpc_secure_bind(rdpRpc* rpc)
if (rpc_send_rpc_auth_3_pdu(rpc) <= 0)
{
WLog_ERR(TAG, "rpc_secure_bind: error sending rpc_auth_3 pdu!");
WLog_ERR(TAG, "rpc_secure_bind: error sending rpc_auth_3 pdu!");
return -1;
}
@ -446,7 +446,7 @@ int rpc_secure_bind(rdpRpc* rpc)
}
else
{
WLog_ERR(TAG, "rpc_secure_bind: invalid state: %d", rpc->State);
WLog_ERR(TAG, "rpc_secure_bind: invalid state: %d", rpc->State);
return -1;
}
}

View File

@ -34,27 +34,9 @@
#include "../rdp.h"
#define TAG FREERDP_TAG("core.gateway")
#define SYNCHRONOUS_TIMEOUT 5000
wStream* rpc_client_fragment_pool_take(rdpRpc* rpc)
{
wStream* fragment = NULL;
if (WaitForSingleObject(Queue_Event(rpc->client->FragmentPool), 0) == WAIT_OBJECT_0)
fragment = Queue_Dequeue(rpc->client->FragmentPool);
if (!fragment)
fragment = Stream_New(NULL, rpc->max_recv_frag);
return fragment;
}
int rpc_client_fragment_pool_return(rdpRpc* rpc, wStream* fragment)
{
Queue_Enqueue(rpc->client->FragmentPool, fragment);
return 0;
}
RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc)
{
RPC_PDU* pdu = NULL;
@ -90,32 +72,35 @@ int rpc_client_receive_pool_return(rdpRpc* rpc, RPC_PDU* pdu)
return Queue_Enqueue(rpc->client->ReceivePool, pdu) == TRUE ? 0 : -1;
}
int rpc_client_on_fragment_received_event(rdpRpc* rpc)
int rpc_client_on_pdu_received_event(rdpRpc* rpc)
{
Queue_Enqueue(rpc->client->ReceiveQueue, rpc->client->pdu);
rpc->client->pdu = NULL;
return 1;
}
int rpc_client_on_fragment_received_event(rdpRpc* rpc, wStream* fragment)
{
BYTE* buffer;
UINT32 StubOffset;
UINT32 StubLength;
wStream* fragment;
rpcconn_hdr_t* header;
if (!rpc->client->pdu)
rpc->client->pdu = rpc_client_receive_pool_take(rpc);
fragment = Queue_Dequeue(rpc->client->FragmentQueue);
buffer = (BYTE*) Stream_Buffer(fragment);
header = (rpcconn_hdr_t*) Stream_Buffer(fragment);
if (rpc->State < RPC_CLIENT_STATE_CONTEXT_NEGOTIATED)
{
rpc->client->pdu->Flags = 0;
rpc->client->pdu->Type = header->common.ptype;
rpc->client->pdu->CallId = header->common.call_id;
Stream_EnsureCapacity(rpc->client->pdu->s, Stream_Length(fragment));
Stream_Write(rpc->client->pdu->s, buffer, Stream_Length(fragment));
Stream_Length(rpc->client->pdu->s) = Stream_GetPosition(rpc->client->pdu->s);
rpc_client_fragment_pool_return(rpc, fragment);
Queue_Enqueue(rpc->client->ReceiveQueue, rpc->client->pdu);
SetEvent(rpc->transport->ReceiveEvent);
rpc->client->pdu = NULL;
Stream_SealLength(rpc->client->pdu->s);
rpc_client_on_pdu_received_event(rpc);
return 0;
}
@ -131,12 +116,10 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
WLog_DBG(TAG, "Receiving Out-of-Sequence RTS PDU");
rts_recv_out_of_sequence_pdu(rpc, buffer, header->common.frag_length);
rpc_client_fragment_pool_return(rpc, fragment);
return 0;
case PTYPE_FAULT:
rpc_recv_fault_pdu(header);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1;
case PTYPE_RESPONSE:
@ -144,17 +127,17 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
default:
WLog_ERR(TAG, "unexpected RPC PDU type %d", header->common.ptype);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1;
}
/* PTYPE_RESPONSE */
rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header->common.frag_length;
rpc->VirtualConnection->DefaultOutChannel->ReceiverAvailableWindow -= header->common.frag_length;
if (!rpc_get_stub_data_info(rpc, buffer, &StubOffset, &StubLength))
{
WLog_ERR(TAG, "expected stub");
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1;
}
@ -174,7 +157,6 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
PubSub_OnTerminate(rpc->context->pubSub, rpc->context, &e);
}
rpc_client_fragment_pool_return(rpc, fragment);
return 0;
}
@ -193,7 +175,6 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
Stream_Write(rpc->client->pdu->s, &buffer[StubOffset], StubLength);
rpc->StubFragCount++;
rpc_client_fragment_pool_return(rpc, fragment);
if (rpc->VirtualConnection->DefaultOutChannel->ReceiverAvailableWindow < (rpc->ReceiveWindow / 2))
{
@ -209,12 +190,12 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
if (header->response.alloc_hint == StubLength)
{
rpc->client->pdu->Flags = RPC_PDU_FLAG_STUB;
rpc->client->pdu->Type = PTYPE_RESPONSE;
rpc->client->pdu->CallId = rpc->StubCallId;
Stream_Length(rpc->client->pdu->s) = Stream_GetPosition(rpc->client->pdu->s);
Stream_SealLength(rpc->client->pdu->s);
rpc->StubFragCount = 0;
rpc->StubCallId = 0;
Queue_Enqueue(rpc->client->ReceiveQueue, rpc->client->pdu);
rpc->client->pdu = NULL;
rpc_client_on_pdu_received_event(rpc);
return 0;
}
@ -229,9 +210,6 @@ int rpc_client_on_read_event(rdpRpc* rpc)
while (1)
{
if (!rpc->client->RecvFrag)
rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
position = Stream_GetPosition(rpc->client->RecvFrag);
while (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
@ -286,13 +264,16 @@ int rpc_client_on_read_event(rdpRpc* rpc)
if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length)
{
/* complete fragment received */
Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag);
Stream_SetPosition(rpc->client->RecvFrag, 0);
Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag);
rpc->client->RecvFrag = NULL;
if (rpc_client_on_fragment_received_event(rpc) < 0)
return -1;
Stream_SealLength(rpc->client->RecvFrag);
Stream_SetPosition(rpc->client->RecvFrag, 0);
status = rpc_client_on_fragment_received_event(rpc, rpc->client->RecvFrag);
if (status < 0)
return status;
Stream_SetPosition(rpc->client->RecvFrag, 0);
}
if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0)
@ -440,12 +421,12 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
RPC_PDU* rpc_recv_dequeue_pdu(rdpRpc* rpc)
{
RPC_PDU* pdu;
DWORD timeout;
DWORD waitStatus;
DWORD dwMilliseconds;
dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT * 4 : 0;
timeout = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT * 4 : 0;
waitStatus = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds);
waitStatus = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), timeout);
if (waitStatus == WAIT_TIMEOUT)
{
@ -463,17 +444,20 @@ RPC_PDU* rpc_recv_dequeue_pdu(rdpRpc* rpc)
RPC_PDU* rpc_recv_peek_pdu(rdpRpc* rpc)
{
RPC_PDU* pdu;
DWORD timeout;
DWORD waitStatus;
DWORD dwMilliseconds;
dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0;
timeout = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0;
waitStatus = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds);
waitStatus = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), timeout);
if (waitStatus != WAIT_OBJECT_0)
return NULL;
return (RPC_PDU*) Queue_Peek(rpc->client->ReceiveQueue);
pdu = (RPC_PDU*) Queue_Peek(rpc->client->ReceiveQueue);
return pdu;
}
static void* rpc_client_thread(void* arg)
@ -525,11 +509,6 @@ static void rpc_pdu_free(RPC_PDU* pdu)
free(pdu);
}
static void rpc_fragment_free(wStream* fragment)
{
Stream_Free(fragment, TRUE);
}
int rpc_client_new(rdpRpc* rpc)
{
RpcClient* client;
@ -557,32 +536,25 @@ int rpc_client_new(rdpRpc* rpc)
return -1;
Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->pdu = NULL;
client->ReceivePool = Queue_New(TRUE, -1, -1);
if (!client->ReceivePool)
return -1;
Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->ReceiveQueue = Queue_New(TRUE, -1, -1);
if (!client->ReceiveQueue)
return -1;
Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->RecvFrag = NULL;
client->FragmentPool = Queue_New(TRUE, -1, -1);
if (!client->FragmentPool)
return -1;
client->RecvFrag = Stream_New(NULL, rpc->max_recv_frag);
Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
client->FragmentQueue = Queue_New(TRUE, -1, -1);
if (!client->FragmentQueue)
return -1;
Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
client->ClientCallList = ArrayList_New(TRUE);
if (!client->ClientCallList)
@ -625,13 +597,7 @@ int rpc_client_free(rdpRpc* rpc)
Queue_Free(client->SendQueue);
if (client->RecvFrag)
rpc_fragment_free(client->RecvFrag);
if (client->FragmentPool)
Queue_Free(client->FragmentPool);
if (client->FragmentQueue)
Queue_Free(client->FragmentQueue);
Stream_Free(client->RecvFrag, TRUE);
if (client->pdu)
rpc_pdu_free(client->pdu);

View File

@ -64,7 +64,7 @@ BOOL rts_connect(rdpRpc* rpc)
{
RPC_PDU* pdu;
rpcconn_rts_hdr_t* rts;
HttpResponse* http_response;
HttpResponse* httpResponse;
freerdp* instance = (freerdp*) rpc->settings->instance;
rdpContext* context = instance->context;
@ -130,7 +130,7 @@ BOOL rts_connect(rdpRpc* rpc)
* A client implementation MUST NOT accept the OUT channel HTTP response in any state other than
* Out Channel Wait. If received in any other state, this HTTP response is a protocol error. Therefore,
* the client MUST consider the virtual connection opening a failure and indicate this to higher layers
* in an implementation-specific way. The Microsoft Windows® implementation returns
* in an implementation-specific way. The Microsoft Windows implementation returns
* RPC_S_PROTOCOL_ERROR, as specified in [MS-ERREF], to higher-layer protocols.
*
* If this HTTP response is received in Out Channel Wait state, the client MUST process the fields of
@ -152,20 +152,21 @@ BOOL rts_connect(rdpRpc* rpc)
*
*/
http_response = http_response_recv(rpc->TlsOut);
if (!http_response)
httpResponse = http_response_recv(rpc->TlsOut);
if (!httpResponse)
{
WLog_ERR(TAG, "unable to retrieve OUT Channel Response!");
return FALSE;
}
if (http_response->StatusCode != HTTP_STATUS_OK)
if (httpResponse->StatusCode != HTTP_STATUS_OK)
{
WLog_ERR(TAG, "error! Status Code: %d", http_response->StatusCode);
http_response_print(http_response);
http_response_free(http_response);
WLog_ERR(TAG, "error! Status Code: %d", httpResponse->StatusCode);
http_response_print(httpResponse);
http_response_free(httpResponse);
if (http_response->StatusCode == HTTP_STATUS_DENIED)
if (httpResponse->StatusCode == HTTP_STATUS_DENIED)
{
if (!connectErrorCode)
{
@ -181,16 +182,7 @@ BOOL rts_connect(rdpRpc* rpc)
return FALSE;
}
if (http_response->bodyLen)
{
/* inject bytes we have read in the body as a received packet for the RPC client */
rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
Stream_EnsureCapacity(rpc->client->RecvFrag, http_response->bodyLen);
CopyMemory(rpc->client->RecvFrag, http_response->BodyContent, http_response->bodyLen);
}
//http_response_print(http_response);
http_response_free(http_response);
http_response_free(httpResponse);
rpc->VirtualConnection->State = VIRTUAL_CONNECTION_STATE_WAIT_A3W;
WLog_DBG(TAG, "VIRTUAL_CONNECTION_STATE_WAIT_A3W");
@ -255,6 +247,7 @@ BOOL rts_connect(rdpRpc* rpc)
*/
pdu = rpc_recv_dequeue_pdu(rpc);
if (!pdu)
return FALSE;

View File

@ -1512,7 +1512,7 @@ int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length)
{
tsg->pdu = rpc_recv_peek_pdu(rpc);
/* there is a pdu to process - move on*/
/* there is a pdu to process - move on */
if (tsg->pdu)
break;

View File

@ -155,7 +155,7 @@ VOID EnterCriticalSection(LPCRITICAL_SECTION lpCriticalSection)
#endif
/* First try the fastest posssible path to get the lock. */
/* First try the fastest possible path to get the lock. */
if (InterlockedIncrement(&lpCriticalSection->LockCount))
{
/* Section is already locked. Check if it is owned by the current thread. */

View File

@ -257,7 +257,7 @@ BOOL ListDictionary_Contains(wListDictionary* listDictionary, void* key)
OBJECT_EQUALS_FN keyEquals;
if (listDictionary->synchronized)
EnterCriticalSection(&listDictionary->lock);
EnterCriticalSection(&(listDictionary->lock));
keyEquals = listDictionary->objectKey.fnObjectEquals;
item = listDictionary->head;
@ -271,7 +271,7 @@ BOOL ListDictionary_Contains(wListDictionary* listDictionary, void* key)
}
if (listDictionary->synchronized)
LeaveCriticalSection(&listDictionary->lock);
LeaveCriticalSection(&(listDictionary->lock));
return (item) ? TRUE : FALSE;
}
@ -434,12 +434,13 @@ wListDictionary* ListDictionary_New(BOOL synchronized)
wListDictionary* listDictionary = NULL;
listDictionary = (wListDictionary*) calloc(1, sizeof(wListDictionary));
if (!listDictionary)
return NULL;
listDictionary->synchronized = synchronized;
if (!InitializeCriticalSectionAndSpinCount(&listDictionary->lock, 4000))
if (!InitializeCriticalSectionAndSpinCount(&(listDictionary->lock), 4000))
{
free(listDictionary);
return NULL;
@ -447,6 +448,7 @@ wListDictionary* ListDictionary_New(BOOL synchronized)
listDictionary->objectKey.fnObjectEquals = default_equal_function;
listDictionary->objectValue.fnObjectEquals = default_equal_function;
return listDictionary;
}