libfreerdp-core: start handling client call state

This commit is contained in:
Marc-André Moreau 2012-12-07 21:09:55 -05:00
parent be98cffbd2
commit 731e606c15
5 changed files with 127 additions and 16 deletions

View File

@ -403,7 +403,7 @@ int rpc_recv_pdu_fragment(rdpRpc* rpc)
headerLength = status;
header = (rpcconn_hdr_t*) rpc->FragBuffer;
bytesRead += status;
if (header->common.frag_length > rpc->FragBufferSize)
{
rpc->FragBufferSize = header->common.frag_length;
@ -426,6 +426,11 @@ int rpc_recv_pdu_fragment(rdpRpc* rpc)
ReleaseMutex(rpc->VirtualConnection->DefaultInChannel->Mutex);
if (header->common.call_id == rpc->PipeCallId)
{
/* TsProxySetupReceivePipe response! */
}
if (header->common.ptype == PTYPE_RTS) /* RTS PDU */
{
if (rpc->VirtualConnection->State < VIRTUAL_CONNECTION_STATE_OPENED)
@ -477,18 +482,22 @@ RPC_PDU* rpc_recv_pdu(rdpRpc* rpc)
status = rpc_recv_pdu_fragment(rpc);
if (status <= 0)
return NULL;
header = (rpcconn_hdr_t*) rpc->FragBuffer;
if (rpc->State < RPC_CLIENT_STATE_CONTEXT_NEGOTIATED)
{
rpc->pdu->Flags = 0;
rpc->pdu->Buffer = rpc->FragBuffer;
rpc->pdu->Size = rpc->FragBufferSize;
rpc->pdu->Length = status;
rpc->pdu->CallId = header->common.call_id;
return rpc->pdu;
}
header = (rpcconn_hdr_t*) rpc->FragBuffer;
if (header->common.ptype != PTYPE_RESPONSE)
{
printf("rpc_recv_pdu: unexpected ptype 0x%02X\n", header->common.ptype);
@ -507,6 +516,20 @@ RPC_PDU* rpc_recv_pdu(rdpRpc* rpc)
rpc->StubBuffer = (BYTE*) realloc(rpc->StubBuffer, rpc->StubBufferSize);
}
if (rpc->StubFragCount == 0)
rpc->StubCallId = header->common.call_id;
if (rpc->StubCallId == rpc->PipeCallId)
{
/* TsProxySetupReceivePipe response! */
}
if (rpc->StubCallId != header->common.call_id)
{
printf("invalid call_id: actual: %d, expected: %d, frag_count: %d\n",
rpc->StubCallId, header->common.call_id, rpc->StubFragCount);
}
CopyMemory(&rpc->StubBuffer[rpc->StubOffset], &rpc->FragBuffer[StubOffset], StubLength);
rpc->StubOffset += StubLength;
rpc->StubFragCount++;
@ -525,9 +548,12 @@ RPC_PDU* rpc_recv_pdu(rdpRpc* rpc)
//printf("Reassembled PDU (%d):\n", rpc->StubOffset);
//freerdp_hexdump(rpc->StubBuffer, rpc->StubOffset);
rpc->pdu->CallId = rpc->StubCallId;
rpc->StubLength = rpc->StubOffset;
rpc->StubOffset = 0;
rpc->StubFragCount = 0;
rpc->StubCallId = 0;
rpc->pdu->Flags = RPC_PDU_FLAG_STUB;
rpc->pdu->Buffer = rpc->StubBuffer;
@ -548,6 +574,7 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
UINT32 stub_data_pad;
SecBuffer Buffers[2];
SecBufferDesc Message;
RpcClientCall* client_call;
SECURITY_STATUS encrypt_status;
rpcconn_request_hdr_t* request_pdu;
@ -567,17 +594,17 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
request_pdu->ptype = PTYPE_REQUEST;
request_pdu->pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG;
request_pdu->auth_length = ntlm->ContextSizes.cbMaxSignature;
request_pdu->call_id = ++rpc->call_id;
/* opnum 8 is TsProxySetupReceivePipe, save call_id for checking pipe responses */
if (opnum == 8)
rpc->pipe_call_id = rpc->call_id;
request_pdu->call_id = rpc->call_id++;
request_pdu->alloc_hint = length;
request_pdu->p_cont_id = 0x0000;
request_pdu->opnum = opnum;
client_call = rpc_client_call_new(request_pdu->call_id, request_pdu->opnum);
ArrayList_Add(rpc->ClientCalls, client_call);
if (request_pdu->opnum == TsProxySetupReceivePipeOpnum)
rpc->PipeCallId = request_pdu->call_id;
request_pdu->stub_data = data;
offset = 24;
@ -768,11 +795,14 @@ rdpRpc* rpc_new(rdpTransport* transport)
rpc->FragBufferSize = 20;
rpc->FragBuffer = (BYTE*) malloc(rpc->FragBufferSize);
rpc->PipeCallId = 0;
rpc->StubOffset = 0;
rpc->StubBufferSize = 20;
rpc->StubLength = 0;
rpc->StubFragCount = 0;
rpc->StubBuffer = (BYTE*) malloc(rpc->FragBufferSize);
rpc->StubCallId = 0;
rpc->rpc_vers = 5;
rpc->rpc_vers_minor = 0;
@ -791,6 +821,8 @@ rdpRpc* rpc_new(rdpTransport* transport)
rpc->SendQueue = Queue_New(TRUE, -1, -1);
rpc->ReceiveQueue = Queue_New(TRUE, -1, -1);
rpc->ClientCalls = ArrayList_New(TRUE);
rpc->ReceiveWindow = 0x00010000;
rpc->ChannelLifetime = 0x40000000;
@ -803,7 +835,7 @@ rdpRpc* rpc_new(rdpTransport* transport)
rpc->VirtualConnection = rpc_client_virtual_connection_new(rpc);
rpc->VirtualConnectionCookieTable = rpc_virtual_connection_cookie_table_new(rpc);
rpc->call_id = 1;
rpc->call_id = 2;
rpc_client_new(rpc);
@ -829,6 +861,9 @@ void rpc_free(rdpRpc* rpc)
Queue_Clear(rpc->ReceiveQueue);
Queue_Free(rpc->ReceiveQueue);
ArrayList_Clear(rpc->ClientCalls);
ArrayList_Free(rpc->ClientCalls);
rpc_client_virtual_connection_free(rpc->VirtualConnection);
rpc_virtual_connection_cookie_table_free(rpc->VirtualConnectionCookieTable);

View File

@ -58,6 +58,7 @@ typedef struct _RPC_PDU
UINT32 Size;
UINT32 Length;
DWORD Flags;
DWORD CallId;
} RPC_PDU, *PRPC_PDU;
#include "tcp.h"
@ -580,6 +581,14 @@ enum _RPC_CLIENT_CALL_STATE
};
typedef enum _RPC_CLIENT_CALL_STATE RPC_CLIENT_CALL_STATE;
struct rpc_client_call
{
UINT32 CallId;
UINT32 OpNum;
RPC_CLIENT_CALL_STATE State;
};
typedef struct rpc_client_call RpcClientCall;
enum _TSG_CHANNEL
{
TSG_CHANNEL_IN,
@ -743,7 +752,7 @@ struct rdp_rpc
rdpTransport* transport;
UINT32 call_id;
UINT32 pipe_call_id;
UINT32 PipeCallId;
RPC_PDU* pdu;
@ -755,6 +764,7 @@ struct rdp_rpc
UINT32 StubLength;
UINT32 StubOffset;
UINT32 StubFragCount;
UINT32 StubCallId;
BYTE rpc_vers;
BYTE rpc_vers_minor;
@ -766,6 +776,8 @@ struct rdp_rpc
wQueue* SendQueue;
wQueue* ReceiveQueue;
wArrayList* ClientCalls;
UINT32 ReceiveWindow;
UINT32 ChannelLifetime;

View File

@ -31,6 +31,56 @@
#include "rpc_client.h"
/**
* [MS-RPCE] Client Call:
* http://msdn.microsoft.com/en-us/library/gg593159/
*/
RpcClientCall* rpc_client_call_find_by_id(rdpRpc* rpc, UINT32 CallId)
{
int index;
int count;
RpcClientCall* client_call;
ArrayList_Lock(rpc->ClientCalls);
client_call = NULL;
count = ArrayList_Count(rpc->ClientCalls);
for (index = 0; index < count; index++)
{
client_call = (RpcClientCall*) ArrayList_GetItem(rpc->ClientCalls, index);
if (client_call->CallId == CallId)
break;
}
ArrayList_Unlock(rpc->ClientCalls);
return client_call;
}
RpcClientCall* rpc_client_call_new(UINT32 CallId, UINT32 OpNum)
{
RpcClientCall* client_call;
client_call = (RpcClientCall*) malloc(sizeof(RpcClientCall));
if (client_call)
{
client_call->CallId = CallId;
client_call->OpNum = OpNum;
client_call->State = RPC_CLIENT_CALL_STATE_SEND_PDUS;
}
return client_call;
}
void rpc_client_call_free(RpcClientCall* client_call)
{
free(client_call);
}
int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
{
RPC_PDU* pdu;
@ -55,6 +105,8 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
{
int status;
RPC_PDU* pdu;
RpcClientCall* client_call;
rpcconn_common_hdr_t* header;
pdu = (RPC_PDU*) Queue_Dequeue(rpc->SendQueue);
@ -65,6 +117,10 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
status = rpc_in_write(rpc, pdu->Buffer, pdu->Length);
header = (rpcconn_common_hdr_t*) pdu->Buffer;
client_call = rpc_client_call_find_by_id(rpc, header->call_id);
client_call->State = RPC_CLIENT_CALL_STATE_DISPATCHED;
ReleaseMutex(rpc->VirtualConnection->DefaultInChannel->Mutex);
/*
@ -92,10 +148,7 @@ int rpc_recv_enqueue_pdu(rdpRpc* rpc)
pdu = rpc_recv_pdu(rpc);
if (!pdu)
{
printf("rpc_recv_enqueue_pdu error\n");
return -1;
}
return 0;
rpc->pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU));

View File

@ -24,6 +24,11 @@
#include <winpr/interlocked.h>
RpcClientCall* rpc_client_call_find_by_id(rdpRpc* rpc, UINT32 CallId);
RpcClientCall* rpc_client_call_new(UINT32 CallId, UINT32 OpNum);
void rpc_client_call_free(RpcClientCall* client_call);
int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length);
int rpc_send_dequeue_pdu(rdpRpc* rpc);

View File

@ -192,6 +192,12 @@ int ArrayList_Add(wArrayList* arrayList, void* obj)
if (arrayList->synchronized)
WaitForSingleObject(arrayList->mutex, INFINITE);
if (arrayList->size + 1 > arrayList->capacity)
{
arrayList->capacity *= arrayList->growthFactor;
arrayList->array = (void**) realloc(arrayList->array, sizeof(void*) * arrayList->capacity);
}
arrayList->array[arrayList->size++] = obj;
index = arrayList->size;