libfreerdp-core: enforce checking of NLA packets in transport only when expecting NLA

This commit is contained in:
Marc-André Moreau 2014-03-27 14:24:15 -04:00
parent 75302e2cc2
commit 3f07157637
2 changed files with 103 additions and 73 deletions

View File

@ -317,6 +317,7 @@ BOOL transport_connect_nla(rdpTransport* transport)
if (!transport->credssp) if (!transport->credssp)
{ {
transport->credssp = credssp_new(instance, transport, settings); transport->credssp = credssp_new(instance, transport, settings);
transport_set_nla_mode(transport, TRUE);
if (settings->AuthenticationServiceClass) if (settings->AuthenticationServiceClass)
{ {
@ -338,11 +339,14 @@ BOOL transport_connect_nla(rdpTransport* transport)
fprintf(stderr, "Authentication failure, check credentials.\n" fprintf(stderr, "Authentication failure, check credentials.\n"
"If credentials are valid, the NTLMSSP implementation may be to blame.\n"); "If credentials are valid, the NTLMSSP implementation may be to blame.\n");
transport_set_nla_mode(transport, FALSE);
credssp_free(transport->credssp); credssp_free(transport->credssp);
transport->credssp = NULL; transport->credssp = NULL;
return FALSE; return FALSE;
} }
transport_set_nla_mode(transport, FALSE);
credssp_free(transport->credssp); credssp_free(transport->credssp);
transport->credssp = NULL; transport->credssp = NULL;
@ -481,11 +485,16 @@ BOOL transport_accept_nla(rdpTransport* transport)
return TRUE; return TRUE;
if (!transport->credssp) if (!transport->credssp)
{
transport->credssp = credssp_new(instance, transport, settings); transport->credssp = credssp_new(instance, transport, settings);
transport_set_nla_mode(transport, TRUE);
}
if (credssp_authenticate(transport->credssp) < 0) if (credssp_authenticate(transport->credssp) < 0)
{ {
fprintf(stderr, "client authentication failure\n"); fprintf(stderr, "client authentication failure\n");
transport_set_nla_mode(transport, FALSE);
credssp_free(transport->credssp); credssp_free(transport->credssp);
transport->credssp = NULL; transport->credssp = NULL;
@ -495,6 +504,7 @@ BOOL transport_accept_nla(rdpTransport* transport)
} }
/* don't free credssp module yet, we need to copy the credentials from it first */ /* don't free credssp module yet, we need to copy the credentials from it first */
transport_set_nla_mode(transport, FALSE);
return TRUE; return TRUE;
} }
@ -643,49 +653,56 @@ int transport_read(rdpTransport* transport, wStream* s)
CopyMemory(header, Stream_Buffer(s), 4); /* peek at first 4 bytes */ CopyMemory(header, Stream_Buffer(s), 4); /* peek at first 4 bytes */
/* if header is present, read in exactly one PDU */ /* if header is present, read exactly one PDU */
if (header[0] == 0x03)
{
/* TPKT header */
pduLength = (header[2] << 8) | header[3]; if (transport->NlaMode)
}
else if (header[0] == 0x30)
{ {
/* TSRequest (NLA) */ if (header[0] == 0x30)
if (header[1] & 0x80)
{ {
if ((header[1] & ~(0x80)) == 1) /* TSRequest (NLA) */
if (header[1] & 0x80)
{ {
pduLength = header[2]; if ((header[1] & ~(0x80)) == 1)
pduLength += 3; {
} pduLength = header[2];
else if ((header[1] & ~(0x80)) == 2) pduLength += 3;
{ }
pduLength = (header[2] << 8) | header[3]; else if ((header[1] & ~(0x80)) == 2)
pduLength += 4; {
pduLength = (header[2] << 8) | header[3];
pduLength += 4;
}
else
{
fprintf(stderr, "Error reading TSRequest!\n");
return -1;
}
} }
else else
{ {
fprintf(stderr, "Error reading TSRequest!\n"); pduLength = header[1];
return -1; pduLength += 2;
} }
} }
else
{
pduLength = header[1];
pduLength += 2;
}
} }
else else
{ {
/* Fast-Path Header */ if (header[0] == 0x03)
{
/* TPKT header */
if (header[1] & 0x80) pduLength = (header[2] << 8) | header[3];
pduLength = ((header[1] & 0x7F) << 8) | header[2]; }
else else
pduLength = header[1]; {
/* Fast-Path Header */
if (header[1] & 0x80)
pduLength = ((header[1] & 0x7F) << 8) | header[2];
else
pduLength = header[1];
}
} }
status = transport_read_layer(transport, Stream_Buffer(s) + position, pduLength - position); status = transport_read_layer(transport, Stream_Buffer(s) + position, pduLength - position);
@ -889,7 +906,7 @@ int transport_check_fds(rdpTransport* transport)
* Loop through and read all available PDUs. Since multiple * Loop through and read all available PDUs. Since multiple
* PDUs can exist, it's important to deliver them all before * PDUs can exist, it's important to deliver them all before
* returning. Otherwise we run the risk of having a thread * returning. Otherwise we run the risk of having a thread
* wait for a socket to get signalled that data is available * wait for a socket to get signaled that data is available
* (which may never happen). * (which may never happen).
*/ */
for (;;) for (;;)
@ -903,58 +920,64 @@ int transport_check_fds(rdpTransport* transport)
{ {
Stream_SetPosition(transport->ReceiveBuffer, 0); Stream_SetPosition(transport->ReceiveBuffer, 0);
if (tpkt_verify_header(transport->ReceiveBuffer)) /* TPKT */ if (transport->NlaMode)
{ {
/* Ensure the TPKT header is available. */ if (nla_verify_header(transport->ReceiveBuffer))
if (pos <= 4)
{ {
Stream_SetPosition(transport->ReceiveBuffer, pos); /* TSRequest */
return 0;
}
length = tpkt_read_header(transport->ReceiveBuffer); /* Ensure the TSRequest header is available. */
if (pos <= 4)
{
Stream_SetPosition(transport->ReceiveBuffer, pos);
return 0;
}
/* TSRequest header can be 2, 3 or 4 bytes long */
length = nla_header_length(transport->ReceiveBuffer);
if (pos < length)
{
Stream_SetPosition(transport->ReceiveBuffer, pos);
return 0;
}
length = nla_read_header(transport->ReceiveBuffer);
}
} }
else if (nla_verify_header(transport->ReceiveBuffer)) else
{ {
/* TSRequest */ if (tpkt_verify_header(transport->ReceiveBuffer)) /* TPKT */
/* Ensure the TSRequest header is available. */
if (pos <= 4)
{ {
Stream_SetPosition(transport->ReceiveBuffer, pos); /* Ensure the TPKT header is available. */
return 0; if (pos <= 4)
{
Stream_SetPosition(transport->ReceiveBuffer, pos);
return 0;
}
length = tpkt_read_header(transport->ReceiveBuffer);
} }
else /* Fast Path */
/* TSRequest header can be 2, 3 or 4 bytes long */
length = nla_header_length(transport->ReceiveBuffer);
if (pos < length)
{ {
Stream_SetPosition(transport->ReceiveBuffer, pos); /* Ensure the Fast Path header is available. */
return 0; if (pos <= 2)
{
Stream_SetPosition(transport->ReceiveBuffer, pos);
return 0;
}
/* Fastpath header can be two or three bytes long. */
length = fastpath_header_length(transport->ReceiveBuffer);
if (pos < length)
{
Stream_SetPosition(transport->ReceiveBuffer, pos);
return 0;
}
length = fastpath_read_header(NULL, transport->ReceiveBuffer);
} }
length = nla_read_header(transport->ReceiveBuffer);
}
else /* Fast Path */
{
/* Ensure the Fast Path header is available. */
if (pos <= 2)
{
Stream_SetPosition(transport->ReceiveBuffer, pos);
return 0;
}
/* Fastpath header can be two or three bytes long. */
length = fastpath_header_length(transport->ReceiveBuffer);
if (pos < length)
{
Stream_SetPosition(transport->ReceiveBuffer, pos);
return 0;
}
length = fastpath_read_header(NULL, transport->ReceiveBuffer);
} }
if (length == 0) if (length == 0)
@ -1039,6 +1062,11 @@ void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled)
transport->GatewayEnabled = GatewayEnabled; transport->GatewayEnabled = GatewayEnabled;
} }
void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode)
{
transport->NlaMode = NlaMode;
}
static void* transport_client_thread(void* arg) static void* transport_client_thread(void* arg)
{ {
DWORD status; DWORD status;

View File

@ -75,6 +75,7 @@ struct rdp_transport
HANDLE stopEvent; HANDLE stopEvent;
HANDLE thread; HANDLE thread;
BOOL async; BOOL async;
BOOL NlaMode;
BOOL GatewayEnabled; BOOL GatewayEnabled;
CRITICAL_SECTION ReadLock; CRITICAL_SECTION ReadLock;
CRITICAL_SECTION WriteLock; CRITICAL_SECTION WriteLock;
@ -99,6 +100,7 @@ void transport_get_fds(rdpTransport* transport, void** rfds, int* rcount);
int transport_check_fds(rdpTransport* transport); int transport_check_fds(rdpTransport* transport);
BOOL transport_set_blocking_mode(rdpTransport* transport, BOOL blocking); BOOL transport_set_blocking_mode(rdpTransport* transport, BOOL blocking);
void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled); void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled);
void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode);
void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD* count); void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD* count);
wStream* transport_receive_pool_take(rdpTransport* transport); wStream* transport_receive_pool_take(rdpTransport* transport);