libfreerdp-core: refactoring of TSG transport, still need to handle fragmentation correctly

This commit is contained in:
Marc-André Moreau 2012-10-30 13:01:54 -04:00
parent 9ea88f13a9
commit fa1dc6d9e1
4 changed files with 209 additions and 188 deletions

View File

@ -36,6 +36,68 @@
#include "rpc.h"
static char* PTYPE_STRINGS[] =
{
"PTYPE_REQUEST",
"PTYPE_PING",
"PTYPE_RESPONSE",
"PTYPE_FAULT",
"PTYPE_WORKING",
"PTYPE_NOCALL",
"PTYPE_REJECT",
"PTYPE_ACK",
"PTYPE_CL_CANCEL",
"PTYPE_FACK",
"PTYPE_CANCEL_ACK",
"PTYPE_BIND",
"PTYPE_BIND_ACK",
"PTYPE_BIND_NAK",
"PTYPE_ALTER_CONTEXT",
"PTYPE_ALTER_CONTEXT_RESP",
"PTYPE_RPC_AUTH_3",
"PTYPE_SHUTDOWN",
"PTYPE_CO_CANCEL",
"PTYPE_ORPHANED",
"PTYPE_RTS",
""
};
void rpc_pdu_header_print(RPC_PDU_HEADER* header)
{
printf("rpc_vers: %d\n", header->rpc_vers);
printf("rpc_vers_minor: %d\n", header->rpc_vers_minor);
if (header->ptype > PTYPE_RTS)
printf("ptype: %s (%d)\n", "PTYPE_UNKNOWN", header->ptype);
else
printf("ptype: %s (%d)\n", PTYPE_STRINGS[header->ptype], header->ptype);
printf("pfc_flags (0x%02X) = {", header->pfc_flags);
if (header->pfc_flags & PFC_FIRST_FRAG)
printf(" PFC_FIRST_FRAG");
if (header->pfc_flags & PFC_LAST_FRAG)
printf(" PFC_LAST_FRAG");
if (header->pfc_flags & PFC_PENDING_CANCEL)
printf(" PFC_PENDING_CANCEL");
if (header->pfc_flags & PFC_RESERVED_1)
printf(" PFC_RESERVED_1");
if (header->pfc_flags & PFC_CONC_MPX)
printf(" PFC_CONC_MPX");
if (header->pfc_flags & PFC_DID_NOT_EXECUTE)
printf(" PFC_DID_NOT_EXECUTE");
if (header->pfc_flags & PFC_OBJECT_UUID)
printf(" PFC_OBJECT_UUID");
printf(" }\n");
printf("packed_drep[4]: %02X %02X %02X %02X\n",
header->packed_drep[0], header->packed_drep[1],
header->packed_drep[2], header->packed_drep[3]);
printf("frag_length: %d\n", header->frag_length);
printf("auth_length: %d\n", header->auth_length);
printf("call_id: %d\n", header->call_id);
}
/**
* The Security Support Provider Interface:
* http://technet.microsoft.com/en-us/library/bb742535/
@ -425,6 +487,8 @@ int rpc_out_write(rdpRpc* rpc, BYTE* data, int length)
{
int status;
rpc_pdu_header_print((RPC_PDU_HEADER*) data);
#ifdef WITH_DEBUG_RPC
printf("rpc_out_write(): length: %d\n", length);
freerdp_hexdump(data, length);
@ -440,6 +504,8 @@ int rpc_in_write(rdpRpc* rpc, BYTE* data, int length)
{
int status;
rpc_pdu_header_print((RPC_PDU_HEADER*) data);
#ifdef WITH_DEBUG_RPC
printf("rpc_in_write() length: %d\n", length);
freerdp_hexdump(data, length);
@ -607,43 +673,25 @@ BOOL rpc_send_bind_pdu(rdpRpc* rpc)
int rpc_recv_bind_ack_pdu(rdpRpc* rpc)
{
BYTE* p;
int status;
BYTE* pdu;
BYTE* auth_data;
RPC_PDU_HEADER header;
int pdu_length = 0x8FFF;
RPC_PDU_HEADER* header;
pdu = malloc(pdu_length);
if (pdu == NULL)
return -1;
status = rpc_out_read(rpc, pdu, pdu_length);
status = rpc_recv_pdu(rpc);
if (status > 0)
{
CopyMemory(&header, pdu, 20);
header = (RPC_PDU_HEADER*) rpc->buffer;
auth_data = malloc(header.auth_length);
rpc->ntlm->inputBuffer.cbBuffer = header->auth_length;
rpc->ntlm->inputBuffer.pvBuffer = malloc(header->auth_length);
if (auth_data == NULL)
{
free(pdu);
return -1;
}
p = (pdu + (header.frag_length - header.auth_length));
memcpy(auth_data, p, header.auth_length);
rpc->ntlm->inputBuffer.pvBuffer = auth_data;
rpc->ntlm->inputBuffer.cbBuffer = header.auth_length;
auth_data = rpc->buffer + (header->frag_length - header->auth_length);
CopyMemory(rpc->ntlm->inputBuffer.pvBuffer, auth_data, header->auth_length);
ntlm_authenticate(rpc->ntlm);
}
free(pdu);
return status;
}
@ -702,75 +750,106 @@ BOOL rpc_send_rpc_auth_3_pdu(rdpRpc* rpc)
return TRUE;
}
//if (rpc->VirtualConnection->DefaultOutChannel->ReceiverAvailableWindow < 0x00008FFF) /* Just a simple workaround */
// rts_send_flow_control_ack_pdu(rpc); /* Send FlowControlAck every time AvailableWindow reaches the half */
int rpc_out_read(rdpRpc* rpc, BYTE* data, int length)
{
BYTE* pdu;
int status;
int content_length;
RPC_PDU_HEADER header;
if (rpc->VirtualConnection->DefaultOutChannel->ReceiverAvailableWindow < 0x00008FFF) /* Just a simple workaround */
rts_send_flow_control_ack_pdu(rpc); /* Send FlowControlAck every time AvailableWindow reaches the half */
pdu = malloc(0xFFFF);
if (pdu == NULL)
{
printf("rpc_out_read error: memory allocation failed") ;
return -1;
}
RPC_PDU_HEADER* header;
/* read first 20 bytes to get RPC PDU Header */
status = tls_read(rpc->tls_out, pdu, 20);
status = tls_read(rpc->tls_out, data, 20);
if (status <= 0)
{
free(pdu);
return status;
}
CopyMemory(&header, pdu, 20);
header = (RPC_PDU_HEADER*) data;
content_length = header.frag_length - 20;
status = tls_read(rpc->tls_out, pdu + 20, content_length);
rpc_pdu_header_print(header);
status = tls_read(rpc->tls_out, &data[20], header->frag_length - 20);
if (status < 0)
{
free(pdu);
return status;
}
if (header.ptype == PTYPE_RTS) /* RTS PDU */
if (header->ptype == PTYPE_RTS) /* RTS PDU */
{
printf("rpc_out_read error: Unexpected RTS PDU\n");
free(pdu);
return -1;
}
else
{
/* RTS PDUs are not subject to flow control */
rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header.frag_length;
rpc->VirtualConnection->DefaultOutChannel->ReceiverAvailableWindow -= header.frag_length;
rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header->frag_length;
rpc->VirtualConnection->DefaultOutChannel->ReceiverAvailableWindow -= header->frag_length;
}
if (length < header.frag_length)
if (length < header->frag_length)
{
printf("rpc_out_read error! receive buffer is not large enough\n");
free(pdu);
printf("rpc_out_read error! receive buffer is not large enough: %d < %d\n", length, header->frag_length);
return -1;
}
memcpy(data, pdu, header.frag_length);
#ifdef WITH_DEBUG_RPC
printf("rpc_out_read(): length: %d\n", header.frag_length);
freerdp_hexdump(data, header.frag_length);
printf("rpc_out_read(): length: %d\n", header->frag_length);
freerdp_hexdump(data, header->frag_length);
printf("\n");
#endif
free(pdu);
return header->frag_length;
}
return header.frag_length;
int rpc_recv_pdu(rdpRpc* rpc)
{
int status;
RPC_PDU_HEADER* header;
/* read first 20 bytes to get RPC PDU Header */
status = tls_read(rpc->tls_out, rpc->buffer, 20);
if (status <= 0)
{
printf("rpc_recv_pdu: error reading header\n");
return status;
}
header = (RPC_PDU_HEADER*) rpc->buffer;
rpc_pdu_header_print(header);
if (header->frag_length > rpc->length)
{
rpc->length = header->frag_length;
rpc->buffer = (BYTE*) realloc(rpc->buffer, rpc->length);
}
status = tls_read(rpc->tls_out, &rpc->buffer[20], header->frag_length - 20);
if (status < 0)
{
printf("rpc_recv_pdu: error reading fragment\n");
return status;
}
if (header->ptype == PTYPE_RTS) /* RTS PDU */
{
printf("rpc_recv_pdu error: Unexpected RTS PDU\n");
return -1;
}
else
{
/* RTS PDUs are not subject to flow control */
rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header->frag_length;
rpc->VirtualConnection->DefaultOutChannel->ReceiverAvailableWindow -= header->frag_length;
}
#ifdef WITH_DEBUG_RPC
printf("rpc_recv_pdu: length: %d\n", header->frag_length);
freerdp_hexdump(rpc->buffer, rpc->length);
printf("\n");
#endif
return header->frag_length;
}
int rpc_tsg_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
@ -883,90 +962,17 @@ int rpc_tsg_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
int rpc_read(rdpRpc* rpc, BYTE* data, int length)
{
int status;
int read = 0;
int data_length;
UINT16 frag_length;
UINT16 auth_length;
BYTE auth_pad_length;
UINT32 call_id = -1;
int rpc_length = length + 0xFF;
BYTE* rpc_data = malloc(rpc_length);
RPC_PDU_HEADER header;
if (rpc_data == NULL)
status = rpc_out_read(rpc, data, length);
if (status > 0)
{
printf("rpc_read error: memory allocation failed\n") ;
return -1;
CopyMemory(&header, data, 20);
status = rpc_out_read(rpc, data, header.frag_length - 20);
}
if (rpc->read_buffer_len > 0)
{
if (rpc->read_buffer_len > (UINT32) length)
{
printf("rpc_read error: receiving buffer is not large enough\n");
free(rpc_data);
return -1;
}
memcpy(data, rpc->read_buffer, rpc->read_buffer_len);
read += rpc->read_buffer_len;
free(rpc->read_buffer);
rpc->read_buffer_len = 0;
}
while (TRUE)
{
status = rpc_out_read(rpc, rpc_data, rpc_length);
if (status == 0)
{
free(rpc_data);
return read;
}
else if (status < 0)
{
printf("Error! rpc_out_read() returned negative value. BytesSent: %d, BytesReceived: %d\n",
rpc->VirtualConnection->DefaultInChannel->BytesSent,
rpc->VirtualConnection->DefaultOutChannel->BytesReceived);
free(rpc_data);
return status;
}
frag_length = *(UINT16*)(rpc_data + 8);
auth_length = *(UINT16*)(rpc_data + 10);
call_id = *(UINT32*)(rpc_data + 12);
status = *(UINT32*)(rpc_data + 16); /* alloc_hint */
auth_pad_length = *(rpc_data + frag_length - auth_length - 6); /* -6 = -8 + 2 (sec_trailer + 2) */
/* data_length must be calculated because alloc_hint carries size of more than one pdu */
data_length = frag_length - auth_length - 24 - 8 - auth_pad_length; /* 24 is header; 8 is sec_trailer */
if (status == 4)
continue;
if (read + data_length > length) /* if read data is greater then given buffer */
{
rpc->read_buffer_len = read + data_length - length;
rpc->read_buffer = malloc(rpc->read_buffer_len);
data_length -= rpc->read_buffer_len;
memcpy(rpc->read_buffer, rpc_data + 24 + data_length, rpc->read_buffer_len);
}
memcpy(data + read, rpc_data + 24, data_length);
read += data_length;
if (status > data_length && read < length)
continue;
break;
}
free(rpc_data);
return read;
return status;
}
BOOL rpc_connect(rdpRpc* rpc)
@ -1105,10 +1111,8 @@ rdpRpc* rpc_new(rdpTransport* transport)
rpc_ntlm_http_init_channel(rpc, rpc->ntlm_http_in, TSG_CHANNEL_IN);
rpc_ntlm_http_init_channel(rpc, rpc->ntlm_http_out, TSG_CHANNEL_OUT);
rpc->read_buffer = NULL;
rpc->write_buffer = NULL;
rpc->read_buffer_len = 0;
rpc->write_buffer_len = 0;
rpc->length = 20;
rpc->buffer = (BYTE*) malloc(rpc->length);
rpc->rpc_vers = 5;
rpc->rpc_vers_minor = 0;

View File

@ -516,14 +516,12 @@ struct rdp_rpc
rdpSettings* settings;
rdpTransport* transport;
BYTE* write_buffer;
UINT32 write_buffer_len;
BYTE* read_buffer;
UINT32 read_buffer_len;
UINT32 call_id;
UINT32 pipe_call_id;
BYTE* buffer;
UINT32 length;
BYTE rpc_vers;
BYTE rpc_vers_minor;
BYTE packed_drep[4];
@ -551,6 +549,7 @@ void rpc_pdu_header_read(STREAM* s, RPC_PDU_HEADER* header);
int rpc_out_write(rdpRpc* rpc, BYTE* data, int length);
int rpc_in_write(rdpRpc* rpc, BYTE* data, int length);
int rpc_recv_pdu(rdpRpc* rpc);
int rpc_out_read(rdpRpc* rpc, BYTE* data, int length);
int rpc_tsg_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum);

View File

@ -489,10 +489,7 @@ BOOL tsg_proxy_create_tunnel(rdpTsg* tsg)
free(buffer);
length = 0x8FFF;
buffer = malloc(length);
status = rpc_read(rpc, buffer, length);
status = rpc_recv_pdu(rpc);
if (status <= 0)
{
@ -500,7 +497,7 @@ BOOL tsg_proxy_create_tunnel(rdpTsg* tsg)
return FALSE;
}
memcpy(tsg->TunnelContext, buffer + (status - 24), 16);
CopyMemory(tsg->TunnelContext, rpc->buffer + (status - 24), 16);
#ifdef WITH_DEBUG_TSG
printf("TSG TunnelContext:\n");
@ -508,8 +505,6 @@ BOOL tsg_proxy_create_tunnel(rdpTsg* tsg)
printf("\n");
#endif
free(buffer);
return TRUE;
}
@ -533,11 +528,10 @@ BOOL tsg_proxy_authorize_tunnel(rdpTsg* tsg)
DEBUG_TSG("TsProxyAuthorizeTunnel");
memcpy(tsg_packet2 + 4, tsg->TunnelContext, 16);
length = sizeof(tsg_packet2);
buffer = (BYTE*) malloc(length);
CopyMemory(buffer, tsg_packet2, length);
CopyMemory(buffer + 4, tsg->TunnelContext, 16);
status = rpc_tsg_write(rpc, buffer, length, TsProxyAuthorizeTunnelOpnum);
@ -549,10 +543,7 @@ BOOL tsg_proxy_authorize_tunnel(rdpTsg* tsg)
free(buffer);
length = 0x8FFF;
buffer = malloc(length);
status = rpc_read(rpc, buffer, length);
status = rpc_recv_pdu(rpc);
if (status <= 0)
{
@ -560,8 +551,6 @@ BOOL tsg_proxy_authorize_tunnel(rdpTsg* tsg)
return FALSE;
}
free(buffer);
return TRUE;
}
@ -585,11 +574,10 @@ BOOL tsg_proxy_make_tunnel_call(rdpTsg* tsg)
DEBUG_TSG("TsProxyMakeTunnelCall");
memcpy(tsg_packet3 + 4, tsg->TunnelContext, 16);
length = sizeof(tsg_packet3);
buffer = (BYTE*) malloc(length);
CopyMemory(buffer, tsg_packet3, length);
CopyMemory(buffer + 4, tsg->TunnelContext, 16);
status = rpc_tsg_write(rpc, buffer, length, TsProxyMakeTunnelCallOpnum);
@ -608,11 +596,11 @@ BOOL tsg_proxy_make_tunnel_call(rdpTsg* tsg)
BOOL tsg_proxy_create_channel(rdpTsg* tsg)
{
STREAM* s;
int status;
UINT32 count;
BYTE* buffer;
UINT32 length;
UINT32 offset;
rdpRpc* rpc = tsg->rpc;
/**
@ -628,19 +616,21 @@ BOOL tsg_proxy_create_channel(rdpTsg* tsg)
DEBUG_TSG("TsProxyCreateChannel");
offset = 0;
count = _wcslen(tsg->hostname) + 1;
memcpy(tsg_packet4 + 4, tsg->TunnelContext, 16);
memcpy(tsg_packet4 + 38, &tsg->port, 2);
length = 48 + 12 + (count * 2);
buffer = (BYTE*) malloc(length);
CopyMemory(buffer, tsg_packet4, 48);
CopyMemory(buffer + 4, tsg->TunnelContext, 16);
CopyMemory(buffer + 38, &tsg->port, 2);
s = stream_new(60 + (count * 2));
stream_write(s, tsg_packet4, 48);
stream_write_UINT32(s, count); /* MaximumCount */
stream_write_UINT32(s, 0x00000000); /* Offset */
stream_write_UINT32(s, count); /* ActualCount */
stream_write(s, tsg->hostname, count);
CopyMemory(&buffer[48], &count, 4); /* MaximumCount */
CopyMemory(&buffer[52], &offset, 4); /* Offset */
CopyMemory(&buffer[56], &count, 4); /* ActualCount */
CopyMemory(&buffer[60], &tsg->hostname, count);
status = rpc_tsg_write(rpc, s->data, s->size, TsProxyCreateChannelOpnum);
status = rpc_tsg_write(rpc, buffer, length, TsProxyCreateChannelOpnum);
if (status <= 0)
{
@ -648,12 +638,9 @@ BOOL tsg_proxy_create_channel(rdpTsg* tsg)
return FALSE;
}
//free(buffer);
free(buffer);
length = 0x8FFF;
buffer = malloc(length);
status = rpc_read(rpc, buffer, length);
status = rpc_recv_pdu(rpc);
if (status < 0)
{
@ -661,7 +648,7 @@ BOOL tsg_proxy_create_channel(rdpTsg* tsg)
return FALSE;
}
memcpy(tsg->ChannelContext, buffer + 4, 16);
CopyMemory(tsg->ChannelContext, rpc->buffer + 4, 16);
#ifdef WITH_DEBUG_TSG
printf("TSG ChannelContext:\n");
@ -669,8 +656,6 @@ BOOL tsg_proxy_create_channel(rdpTsg* tsg)
printf("\n");
#endif
free(buffer);
return TRUE;
}
@ -707,6 +692,8 @@ BOOL tsg_proxy_setup_receive_pipe(rdpTsg* tsg)
free(buffer);
/* read? */
return TRUE;
}
@ -741,10 +728,37 @@ BOOL tsg_connect(rdpTsg* tsg, const char* hostname, UINT16 port)
int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length)
{
int status;
RPC_PDU_HEADER* header;
rdpRpc* rpc = tsg->rpc;
status = rpc_read(tsg->rpc, data, length);
printf("tsg_read: %d, pending: %d\n", length, tsg->pendingPdu);
return status;
if (tsg->pendingPdu)
{
header = (RPC_PDU_HEADER*) rpc->buffer;
CopyMemory(data, &rpc->buffer[tsg->bytesRead], length);
tsg->bytesAvailable -= length;
tsg->bytesRead += length;
if (tsg->bytesAvailable < 1)
tsg->pendingPdu = FALSE;
}
else
{
status = rpc_recv_pdu(rpc);
tsg->pendingPdu = TRUE;
header = (RPC_PDU_HEADER*) rpc->buffer;
tsg->bytesAvailable = header->frag_length;
tsg->bytesRead = 0;
CopyMemory(data, &rpc->buffer[tsg->bytesRead], length);
tsg->bytesAvailable -= length;
tsg->bytesRead += length;
}
return length;
}
int tsg_write(rdpTsg* tsg, BYTE* data, UINT32 length)
@ -763,6 +777,7 @@ rdpTsg* tsg_new(rdpTransport* transport)
tsg->transport = transport;
tsg->settings = transport->settings;
tsg->rpc = rpc_new(tsg->transport);
tsg->pendingPdu = FALSE;
}
return tsg;

View File

@ -42,6 +42,9 @@ struct rdp_tsg
rdpRpc* rpc;
UINT16 port;
LPWSTR hostname;
BOOL pendingPdu;
BOOL bytesRead;
BOOL bytesAvailable;
rdpSettings* settings;
rdpTransport* transport;
BYTE TunnelContext[16];