libfreerdp-core: add better state machine transitions to ts gateway

This commit is contained in:
Marc-André Moreau 2015-02-02 18:50:26 -05:00
parent e0b0c77ecb
commit 5aea07d401
10 changed files with 229 additions and 75 deletions

View File

@ -26,13 +26,15 @@
#include <winpr/stream.h>
#include <winpr/string.h>
#include <freerdp/log.h>
#ifdef HAVE_VALGRIND_MEMCHECK_H
#include <valgrind/memcheck.h>
#endif
#include "http.h"
#define TAG "gateway"
#define TAG FREERDP_TAG("core.gateway.http")
static char* string_strnstr(const char* str1, const char* str2, size_t slen)
{

View File

@ -31,7 +31,7 @@
#include <openssl/rand.h>
#define TAG FREERDP_TAG("core.gateway")
#define TAG FREERDP_TAG("core.gateway.ntlm")
wStream* rpc_ntlm_http_request(rdpRpc* rpc, SecBuffer* ntlm_token, int content_length, TSG_CHANNEL channel)
{

View File

@ -46,7 +46,7 @@
#include "rpc.h"
#define TAG FREERDP_TAG("core.gateway")
#define TAG FREERDP_TAG("core.gateway.rpc")
/* Security Verification Trailer Signature */
@ -465,6 +465,45 @@ out_free_pdu:
return -1;
}
int rpc_client_virtual_connection_transition_to_state(rdpRpc* rpc,
RpcVirtualConnection* connection, VIRTUAL_CONNECTION_STATE state)
{
int status = 1;
const char* str = "VIRTUAL_CONNECTION_STATE_UNKNOWN";
switch (state)
{
case VIRTUAL_CONNECTION_STATE_INITIAL:
str = "VIRTUAL_CONNECTION_STATE_INITIAL";
break;
case VIRTUAL_CONNECTION_STATE_OUT_CHANNEL_WAIT:
str = "VIRTUAL_CONNECTION_STATE_OUT_CHANNEL_WAIT";
break;
case VIRTUAL_CONNECTION_STATE_WAIT_A3W:
str = "VIRTUAL_CONNECTION_STATE_WAIT_A3W";
break;
case VIRTUAL_CONNECTION_STATE_WAIT_C2:
str = "VIRTUAL_CONNECTION_STATE_WAIT_C2";
break;
case VIRTUAL_CONNECTION_STATE_OPENED:
str = "VIRTUAL_CONNECTION_STATE_OPENED";
break;
case VIRTUAL_CONNECTION_STATE_FINAL:
str = "VIRTUAL_CONNECTION_STATE_FINAL";
break;
}
connection->State = state;
WLog_DBG(TAG, "%s", str);
return status;
}
void rpc_client_virtual_connection_init(rdpRpc* rpc, RpcVirtualConnection* connection)
{
connection->DefaultInChannel->State = CLIENT_IN_CHANNEL_STATE_INITIAL;

View File

@ -773,6 +773,9 @@ int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length);
BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* header, UINT32* offset, UINT32* length);
int rpc_client_virtual_connection_transition_to_state(rdpRpc* rpc,
RpcVirtualConnection* connection, VIRTUAL_CONNECTION_STATE state);
int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum);
rdpRpc* rpc_new(rdpTransport* transport);

View File

@ -29,7 +29,7 @@
#include "rpc_bind.h"
#define TAG FREERDP_TAG("core.gateway")
#define TAG FREERDP_TAG("core.gateway.rpc")
/**
* Connection-Oriented RPC Protocol Client Details:
@ -118,7 +118,7 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
BOOL promptPassword = FALSE;
freerdp* instance = (freerdp*) settings->instance;
WLog_DBG(TAG, "Sending bind PDU");
WLog_DBG(TAG, "Sending Bind PDU");
ntlm_free(rpc->ntlm);
rpc->ntlm = ntlm_new();
@ -310,6 +310,8 @@ int rpc_recv_bind_ack_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
header = (rpcconn_hdr_t*) buffer;
WLog_DBG(TAG, "Receiving BindAck PDU");
rpc->max_recv_frag = header->bind_ack.max_xmit_frag;
rpc->max_xmit_frag = header->bind_ack.max_recv_frag;
@ -343,7 +345,7 @@ int rpc_send_rpc_auth_3_pdu(rdpRpc* rpc)
RpcClientCall* clientCall;
rpcconn_rpc_auth_3_hdr_t* auth_3_pdu;
WLog_DBG(TAG, "Sending rpc_auth_3 PDU");
WLog_DBG(TAG, "Sending RpcAuth3 PDU");
auth_3_pdu = (rpcconn_rpc_auth_3_hdr_t*) calloc(1, sizeof(rpcconn_rpc_auth_3_hdr_t));

View File

@ -34,9 +34,7 @@
#include "rpc_client.h"
#include "../rdp.h"
#define TAG FREERDP_TAG("core.gateway")
#define SYNCHRONOUS_TIMEOUT 5000
#define TAG FREERDP_TAG("core.gateway.rpc")
static void rpc_pdu_reset(RPC_PDU* pdu)
{
@ -124,8 +122,55 @@ int rpc_client_receive_pipe_read(rdpRpc* rpc, BYTE* buffer, size_t length)
return status;
}
int rpc_client_transition_to_state(rdpRpc* rpc, RPC_CLIENT_STATE state)
{
int status = 1;
const char* str = "RPC_CLIENT_STATE_UNKNOWN";
switch (state)
{
case RPC_CLIENT_STATE_INITIAL:
str = "RPC_CLIENT_STATE_INITIAL";
break;
case RPC_CLIENT_STATE_ESTABLISHED:
str = "RPC_CLIENT_STATE_ESTABLISHED";
break;
case RPC_CLIENT_STATE_WAIT_SECURE_BIND_ACK:
str = "RPC_CLIENT_STATE_WAIT_SECURE_BIND_ACK";
break;
case RPC_CLIENT_STATE_WAIT_UNSECURE_BIND_ACK:
str = "RPC_CLIENT_STATE_WAIT_UNSECURE_BIND_ACK";
break;
case RPC_CLIENT_STATE_WAIT_SECURE_ALTER_CONTEXT_RESPONSE:
str = "RPC_CLIENT_STATE_WAIT_SECURE_ALTER_CONTEXT_RESPONSE";
break;
case RPC_CLIENT_STATE_CONTEXT_NEGOTIATED:
str = "RPC_CLIENT_STATE_CONTEXT_NEGOTIATED";
break;
case RPC_CLIENT_STATE_WAIT_RESPONSE:
str = "RPC_CLIENT_STATE_WAIT_RESPONSE";
break;
case RPC_CLIENT_STATE_FINAL:
str = "RPC_CLIENT_STATE_FINAL";
break;
}
rpc->State = state;
WLog_DBG(TAG, "%s", str);
return status;
}
int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu)
{
int status = -1;
rpcconn_rts_hdr_t* rts;
rdpTsg* tsg = rpc->transport->tsg;
@ -151,8 +196,10 @@ int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu)
rts_recv_CONN_A3_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s));
rpc->VirtualConnection->State = VIRTUAL_CONNECTION_STATE_WAIT_C2;
WLog_DBG(TAG, "VIRTUAL_CONNECTION_STATE_WAIT_C2");
rpc_client_virtual_connection_transition_to_state(rpc,
rpc->VirtualConnection, VIRTUAL_CONNECTION_STATE_WAIT_C2);
status = 1;
break;
@ -168,10 +215,10 @@ int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu)
rts_recv_CONN_C2_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s));
rpc->VirtualConnection->State = VIRTUAL_CONNECTION_STATE_OPENED;
WLog_DBG(TAG, "VIRTUAL_CONNECTION_STATE_OPENED");
rpc_client_virtual_connection_transition_to_state(rpc,
rpc->VirtualConnection, VIRTUAL_CONNECTION_STATE_OPENED);
rpc->State = RPC_CLIENT_STATE_ESTABLISHED;
rpc_client_transition_to_state(rpc, RPC_CLIENT_STATE_ESTABLISHED);
if (rpc_send_bind_pdu(rpc) < 0)
{
@ -179,7 +226,9 @@ int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu)
return -1;
}
rpc->State = RPC_CLIENT_STATE_WAIT_SECURE_BIND_ACK;
rpc_client_transition_to_state(rpc, RPC_CLIENT_STATE_WAIT_SECURE_BIND_ACK);
status = 1;
break;
@ -189,19 +238,24 @@ int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu)
case VIRTUAL_CONNECTION_STATE_FINAL:
break;
}
return 1;
}
if (rpc->State < RPC_CLIENT_STATE_CONTEXT_NEGOTIATED)
else if (rpc->State < RPC_CLIENT_STATE_CONTEXT_NEGOTIATED)
{
if (rpc->State == RPC_CLIENT_STATE_WAIT_SECURE_BIND_ACK)
{
if (pdu->Type == PTYPE_BIND_ACK)
{
if (rpc_recv_bind_ack_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s)) <= 0)
{
WLog_ERR(TAG, "rpc_recv_bind_ack_pdu failure");
return -1;
}
}
else
{
WLog_ERR(TAG, "RPC_CLIENT_STATE_WAIT_SECURE_BIND_ACK unexpected pdu type: 0x%04X", pdu->Type);
return -1;
}
if (rpc_send_rpc_auth_3_pdu(rpc) <= 0)
{
@ -209,7 +263,7 @@ int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu)
return -1;
}
rpc->State = RPC_CLIENT_STATE_CONTEXT_NEGOTIATED;
rpc_client_transition_to_state(rpc, RPC_CLIENT_STATE_CONTEXT_NEGOTIATED);
if (!TsProxyCreateTunnel(tsg, NULL, NULL, NULL, NULL))
{
@ -218,23 +272,24 @@ int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu)
return -1;
}
tsg->state = TSG_STATE_INITIAL;
tsg_transition_to_state(tsg, TSG_STATE_INITIAL);
status = 1;
}
else
{
WLog_ERR(TAG, "rpc_client_recv_pdu: invalid rpc->State: %d", rpc->State);
return -1;
}
return 1;
}
else if (rpc->State >= RPC_CLIENT_STATE_CONTEXT_NEGOTIATED)
{
if (tsg->state != TSG_STATE_PIPE_CREATED)
{
return tsg_recv_pdu(tsg, pdu);
status = tsg_recv_pdu(tsg, pdu);
}
}
return 1;
return status;
}
int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment)
@ -384,18 +439,18 @@ int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment)
int rpc_client_recv(rdpRpc* rpc)
{
int position;
int status = -1;
wStream* fragment;
rpcconn_common_hdr_t* header;
fragment = rpc->client->ReceiveFragment;
while (1)
{
position = Stream_GetPosition(rpc->client->ReceiveFragment);
while (Stream_GetPosition(rpc->client->ReceiveFragment) < RPC_COMMON_FIELDS_LENGTH)
while (Stream_GetPosition(fragment) < RPC_COMMON_FIELDS_LENGTH)
{
status = rpc_out_read(rpc, Stream_Pointer(rpc->client->ReceiveFragment),
RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->ReceiveFragment));
status = rpc_out_read(rpc, Stream_Pointer(fragment),
RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(fragment));
if (status < 0)
return -1;
@ -403,26 +458,26 @@ int rpc_client_recv(rdpRpc* rpc)
if (!status)
return 0;
Stream_Seek(rpc->client->ReceiveFragment, status);
Stream_Seek(fragment, status);
}
if (Stream_GetPosition(rpc->client->ReceiveFragment) < RPC_COMMON_FIELDS_LENGTH)
if (Stream_GetPosition(fragment) < RPC_COMMON_FIELDS_LENGTH)
return status;
header = (rpcconn_common_hdr_t*) Stream_Buffer(rpc->client->ReceiveFragment);
header = (rpcconn_common_hdr_t*) Stream_Buffer(fragment);
if (header->frag_length > rpc->max_recv_frag)
{
WLog_ERR(TAG, "rpc_client_frag_read: invalid fragment size: %d (max: %d)",
WLog_ERR(TAG, "rpc_client_recv: invalid fragment size: %d (max: %d)",
header->frag_length, rpc->max_recv_frag);
winpr_HexDump(TAG, WLOG_ERROR, Stream_Buffer(rpc->client->ReceiveFragment), Stream_GetPosition(rpc->client->ReceiveFragment));
winpr_HexDump(TAG, WLOG_ERROR, Stream_Buffer(fragment), Stream_GetPosition(fragment));
return -1;
}
while (Stream_GetPosition(rpc->client->ReceiveFragment) < header->frag_length)
while (Stream_GetPosition(fragment) < header->frag_length)
{
status = rpc_out_read(rpc, Stream_Pointer(rpc->client->ReceiveFragment),
header->frag_length - Stream_GetPosition(rpc->client->ReceiveFragment));
status = rpc_out_read(rpc, Stream_Pointer(fragment),
header->frag_length - Stream_GetPosition(fragment));
if (status < 0)
{
@ -433,27 +488,25 @@ int rpc_client_recv(rdpRpc* rpc)
if (!status)
return 0;
Stream_Seek(rpc->client->ReceiveFragment, status);
Stream_Seek(fragment, status);
}
if (status < 0)
return -1;
status = Stream_GetPosition(rpc->client->ReceiveFragment) - position;
if (Stream_GetPosition(rpc->client->ReceiveFragment) >= header->frag_length)
if (Stream_GetPosition(fragment) >= header->frag_length)
{
/* complete fragment received */
Stream_SealLength(rpc->client->ReceiveFragment);
Stream_SetPosition(rpc->client->ReceiveFragment, 0);
Stream_SealLength(fragment);
Stream_SetPosition(fragment, 0);
status = rpc_client_recv_fragment(rpc, rpc->client->ReceiveFragment);
status = rpc_client_recv_fragment(rpc, fragment);
if (status < 0)
return status;
Stream_SetPosition(rpc->client->ReceiveFragment, 0);
Stream_SetPosition(fragment, 0);
}
}

View File

@ -31,7 +31,7 @@
#include "rts.h"
#define TAG FREERDP_TAG("core.gateway")
#define TAG FREERDP_TAG("core.gateway.rts")
/**
* [MS-RPCH]: Remote Procedure Call over HTTP Protocol Specification:
@ -66,8 +66,8 @@ BOOL rts_connect(rdpRpc* rpc)
freerdp* instance = (freerdp*) rpc->settings->instance;
rdpContext* context = instance->context;
rpc->VirtualConnection->State = VIRTUAL_CONNECTION_STATE_INITIAL;
WLog_DBG(TAG, "VIRTUAL_CONNECTION_STATE_INITIAL");
rpc_client_virtual_connection_transition_to_state(rpc,
rpc->VirtualConnection, VIRTUAL_CONNECTION_STATE_INITIAL);
if (!rpc_ntlm_http_out_connect(rpc))
{
@ -93,8 +93,8 @@ BOOL rts_connect(rdpRpc* rpc)
return FALSE;
}
rpc->VirtualConnection->State = VIRTUAL_CONNECTION_STATE_OUT_CHANNEL_WAIT;
WLog_DBG(TAG, "VIRTUAL_CONNECTION_STATE_OUT_CHANNEL_WAIT");
rpc_client_virtual_connection_transition_to_state(rpc,
rpc->VirtualConnection, VIRTUAL_CONNECTION_STATE_OUT_CHANNEL_WAIT);
response = http_response_recv(rpc->TlsOut);
@ -126,15 +126,10 @@ BOOL rts_connect(rdpRpc* rpc)
return FALSE;
}
WLog_DBG(TAG, "HTTP Body (%d):", response->BodyLength);
if (response->BodyLength)
winpr_HexDump(TAG, WLOG_DEBUG, response->BodyContent, response->BodyLength);
http_response_free(response);
rpc->VirtualConnection->State = VIRTUAL_CONNECTION_STATE_WAIT_A3W;
WLog_DBG(TAG, "VIRTUAL_CONNECTION_STATE_WAIT_A3W");
rpc_client_virtual_connection_transition_to_state(rpc,
rpc->VirtualConnection, VIRTUAL_CONNECTION_STATE_WAIT_A3W);
return TRUE;
}

View File

@ -21,7 +21,7 @@
#include "rts_signature.h"
#define TAG FREERDP_TAG("core.gateway")
#define TAG FREERDP_TAG("core.gateway.rts")
RtsPduSignature RTS_PDU_CONN_A1_SIGNATURE = { RTS_FLAG_NONE, 4,
{ RTS_CMD_VERSION, RTS_CMD_COOKIE, RTS_CMD_COOKIE, RTS_CMD_RECEIVE_WINDOW_SIZE, 0, 0, 0, 0 } };

View File

@ -1200,8 +1200,54 @@ BOOL TsProxySetupReceivePipe(handle_t IDL_handle, BYTE* pRpcMessage)
return TRUE;
}
int tsg_transition_to_state(rdpTsg* tsg, TSG_STATE state)
{
const char* str = "TSG_STATE_UNKNOWN";
switch (tsg->state)
{
case TSG_STATE_INITIAL:
str = "TSG_STATE_INITIAL";
break;
case TSG_STATE_CONNECTED:
str = "TSG_STATE_CONNECTED";
break;
case TSG_STATE_AUTHORIZED:
str = "TSG_STATE_AUTHORIZED";
break;
case TSG_STATE_CHANNEL_CREATED:
str = "TSG_STATE_CHANNEL_CREATED";
break;
case TSG_STATE_PIPE_CREATED:
str = "TSG_STATE_PIPE_CREATED";
break;
case TSG_STATE_TUNNEL_CLOSE_PENDING:
str = "TSG_STATE_TUNNEL_CLOSE_PENDING";
break;
case TSG_STATE_CHANNEL_CLOSE_PENDING:
str = "TSG_STATE_CHANNEL_CLOSE_PENDING";
break;
case TSG_STATE_FINAL:
str = "TSG_STATE_FINAL";
break;
}
tsg->state = state;
WLog_DBG(TAG, "%s", str);
return 1;
}
int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
{
int status = -1;
RpcClientCall* call;
rdpRpc* rpc = tsg->rpc;
@ -1215,15 +1261,16 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
return -1;
}
tsg->state = TSG_STATE_CONNECTED;
tsg_transition_to_state(tsg, TSG_STATE_CONNECTED);
if (!TsProxyAuthorizeTunnel(tsg, &tsg->TunnelContext, NULL, NULL))
{
WLog_ERR(TAG, "TsProxyAuthorizeTunnel failure");
tsg->state = TSG_STATE_TUNNEL_CLOSE_PENDING;
return -1;
}
status = 1;
break;
case TSG_STATE_CONNECTED:
@ -1234,7 +1281,7 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
return -1;
}
tsg->state = TSG_STATE_AUTHORIZED;
tsg_transition_to_state(tsg, TSG_STATE_AUTHORIZED);
if (!TsProxyMakeTunnelCall(tsg, &tsg->TunnelContext, TSG_TUNNEL_CALL_ASYNC_MSG_REQUEST, NULL, NULL))
{
@ -1248,6 +1295,8 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
return -1;
}
status = 1;
break;
case TSG_STATE_AUTHORIZED:
@ -1261,6 +1310,8 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
WLog_ERR(TAG, "TsProxyMakeTunnelCallReadResponse failure");
return -1;
}
status = 1;
}
else if (call->OpNum == TsProxyCreateChannelOpnum)
{
@ -1270,7 +1321,7 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
return -1;
}
tsg->state = TSG_STATE_CHANNEL_CREATED;
tsg_transition_to_state(tsg, TSG_STATE_CHANNEL_CREATED);
if (!TsProxySetupReceivePipe((handle_t) tsg, NULL))
{
@ -1278,7 +1329,9 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
return -1;
}
tsg->state = TSG_STATE_PIPE_CREATED;
tsg_transition_to_state(tsg, TSG_STATE_PIPE_CREATED);
status = 1;
}
else
{
@ -1301,7 +1354,7 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
return FALSE;
}
tsg->state = TSG_STATE_CHANNEL_CLOSE_PENDING;
tsg_transition_to_state(tsg, TSG_STATE_CHANNEL_CLOSE_PENDING);
if (!TsProxyCloseChannelWriteRequest(tsg, NULL))
{
@ -1315,6 +1368,8 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
return FALSE;
}
status = 1;
break;
case TSG_STATE_CHANNEL_CLOSE_PENDING:
@ -1325,7 +1380,9 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
return FALSE;
}
tsg->state = TSG_STATE_FINAL;
tsg_transition_to_state(tsg, TSG_STATE_FINAL);
status = 1;
break;
@ -1333,7 +1390,7 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
break;
}
return 1;
return status;
}
int tsg_check(rdpTsg* tsg)
@ -1379,8 +1436,9 @@ BOOL tsg_connect(rdpTsg* tsg, const char* hostname, UINT16 port)
{
if (rpc_client_recv(rpc) < 0)
{
WLog_ERR(TAG, "tsg_connect: rpc_client_recv failure");
rpc->transport->layer = TRANSPORT_LAYER_CLOSED;
break;
return FALSE;
}
}
}

View File

@ -304,6 +304,8 @@ BOOL TsProxyCreateTunnel(rdpTsg* tsg, PTSG_PACKET tsgPacket, PTSG_PACKET* tsgPac
DWORD TsProxySendToServer(handle_t IDL_handle, BYTE pRpcMessage[], UINT32 count, UINT32* lengths);
int tsg_transition_to_state(rdpTsg* tsg, TSG_STATE state);
BOOL tsg_connect(rdpTsg* tsg, const char* hostname, UINT16 port);
BOOL tsg_disconnect(rdpTsg* tsg);