Cleanup of client RDP state machine

* Use enum for most common return types
* Add success/failed check functions
* Add a function creating a string from the return value
This commit is contained in:
akallabeth 2022-11-11 11:38:11 +01:00 committed by Martin Fleisz
parent 872f52c014
commit 4ef72bbe14

View File

@ -1617,19 +1617,74 @@ static int rdp_recv_pdu(rdpRdp* rdp, wStream* s)
return rc;
}
/* TODO: Need to properly return:
*
* -24 ... Rerun state machine, ignore wStream* s reset
* -23 ... Rerun state machine, reset wStream* s
* -1 ... Failure
* 0 ... Success
* 2 ... State ACTIVE
*/
typedef enum
{
STATE_RUN_ACTIVE = 2,
STATE_RUN_REDIRECT = 1,
STATE_RUN_SUCCESS = 0,
STATE_RUN_FAILED = -1,
STATE_RUN_TRY_AGAIN = -23,
STATE_RUN_CONTINUE = -24
} state_run_t;
static BOOL state_run_failed(int status)
{
switch (status)
{
case STATE_RUN_CONTINUE:
case STATE_RUN_TRY_AGAIN:
return FALSE;
default:
break;
}
if (status < STATE_RUN_SUCCESS)
return TRUE;
return FALSE;
}
static BOOL state_run_success(int status)
{
return status >= STATE_RUN_SUCCESS;
}
static const char* state_run_result_string(int status, char* buffer, size_t buffersize)
{
const char* name;
switch (status)
{
case STATE_RUN_ACTIVE:
name = "STATE_RUN_ACTIVE";
break;
case STATE_RUN_REDIRECT:
name = "STATE_RUN_REDIRECT";
break;
case STATE_RUN_SUCCESS:
name = "STATE_RUN_SUCCESS";
break;
case STATE_RUN_FAILED:
name = "STATE_RUN_FAILED";
break;
case STATE_RUN_TRY_AGAIN:
name = "STATE_RUN_TRY_AGAIN";
break;
case STATE_RUN_CONTINUE:
name = "STATE_RUN_CONTINUE";
break;
default:
name = "STATE_RUN_UNKNOWN";
break;
}
_snprintf(buffer, buffersize, "%s [%d]", name, status);
return buffer;
}
static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extra)
{
const UINT32 mask = FINALIZE_SC_SYNCHRONIZE_PDU | FINALIZE_SC_CONTROL_COOPERATE_PDU |
FINALIZE_SC_CONTROL_GRANTED_PDU | FINALIZE_SC_FONT_MAP_PDU;
int status = 0;
int status = STATE_RUN_SUCCESS;
rdpRdp* rdp = (rdpRdp*)extra;
WINPR_ASSERT(transport);
@ -1640,7 +1695,7 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
{
case CONNECTION_STATE_NEGO:
rdp_client_transition_to_state(rdp, CONNECTION_STATE_MCS_CREATE_REQUEST);
status = -24;
status = STATE_RUN_CONTINUE;
break;
case CONNECTION_STATE_NLA:
if (nla_get_state(rdp->nla) < NLA_STATE_AUTH_INFO)
@ -1649,7 +1704,7 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
{
WLog_ERR(TAG, "%s: %s - nla_recv_pdu() fail", __FUNCTION__,
rdp_get_state_string(rdp));
status = -1;
status = STATE_RUN_FAILED;
}
}
else if (nla_get_state(rdp->nla) == NLA_STATE_POST_NEGO)
@ -1660,13 +1715,13 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
{
WLog_ERR(TAG, "%s: %s - nego_recv() fail", __FUNCTION__,
rdp_get_state_string(rdp));
status = -1;
status = STATE_RUN_FAILED;
}
else if (!nla_set_state(rdp->nla, NLA_STATE_FINAL))
status = -1;
status = STATE_RUN_FAILED;
}
if (status >= 0)
if (state_run_success(status))
{
if (nla_get_state(rdp->nla) == NLA_STATE_AUTH_INFO)
{
@ -1675,32 +1730,32 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
if (rdp->settings->VmConnectMode)
{
if (!nego_set_state(rdp->nego, NEGO_STATE_NLA))
status = -1;
status = STATE_RUN_FAILED;
else if (!nego_set_requested_protocols(rdp->nego,
PROTOCOL_HYBRID | PROTOCOL_SSL))
status = -1;
status = STATE_RUN_FAILED;
else
{
if (!nego_send_negotiation_request(rdp->nego))
status = -1;
status = STATE_RUN_FAILED;
else if (!nla_set_state(rdp->nla, NLA_STATE_POST_NEGO))
status = -1;
status = STATE_RUN_FAILED;
}
}
else
{
if (!nla_set_state(rdp->nla, NLA_STATE_FINAL))
status = -1;
status = STATE_RUN_FAILED;
}
}
}
if (status >= 0)
if (state_run_success(status))
{
if (nla_get_state(rdp->nla) == NLA_STATE_FINAL)
{
rdp_client_transition_to_state(rdp, CONNECTION_STATE_MCS_CREATE_REQUEST);
status = -23;
status = STATE_RUN_TRY_AGAIN;
}
}
break;
@ -1710,7 +1765,7 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
{
WLog_ERR(TAG, "%s: %s - mcs_client_begin() fail", __FUNCTION__,
rdp_get_state_string(rdp));
status = -1;
status = STATE_RUN_FAILED;
}
else
{
@ -1722,7 +1777,7 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
if (!mcs_recv_connect_response(rdp->mcs, s))
{
WLog_ERR(TAG, "mcs_recv_connect_response failure");
status = -1;
status = STATE_RUN_FAILED;
}
else
{
@ -1730,7 +1785,7 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
if (!mcs_send_erect_domain_request(rdp->mcs))
{
WLog_ERR(TAG, "mcs_send_erect_domain_request failure");
status = -1;
status = STATE_RUN_FAILED;
}
else
{
@ -1738,7 +1793,7 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
if (!mcs_send_attach_user_request(rdp->mcs))
{
WLog_ERR(TAG, "mcs_send_attach_user_request failure");
status = -1;
status = STATE_RUN_FAILED;
}
else
rdp_client_transition_to_state(rdp,
@ -1751,7 +1806,7 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
if (!mcs_recv_attach_user_confirm(rdp->mcs, s))
{
WLog_ERR(TAG, "mcs_recv_attach_user_confirm failure");
status = -1;
status = STATE_RUN_FAILED;
}
else
{
@ -1759,7 +1814,7 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
if (!mcs_send_channel_join_request(rdp->mcs, rdp->mcs->userId))
{
WLog_ERR(TAG, "mcs_send_channel_join_request failure");
status = -1;
status = STATE_RUN_FAILED;
}
else
rdp_client_transition_to_state(rdp, CONNECTION_STATE_MCS_CHANNEL_JOIN_RESPONSE);
@ -1773,7 +1828,7 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
"%s: %s - "
"rdp_client_connect_mcs_channel_join_confirm() fail",
__FUNCTION__, rdp_get_state_string(rdp));
status = -1;
status = STATE_RUN_FAILED;
}
break;
@ -1782,16 +1837,20 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
if (!rdp_client_connect_auto_detect(rdp, s))
{
rdp_client_transition_to_state(rdp, CONNECTION_STATE_LICENSING);
status = -23;
status = STATE_RUN_TRY_AGAIN;
}
break;
case CONNECTION_STATE_LICENSING:
status = rdp_client_connect_license(rdp, s);
if (status < 0)
WLog_DBG(TAG, "%s: %s - rdp_client_connect_license() - %i", __FUNCTION__,
rdp_get_state_string(rdp), status);
if (state_run_failed(status))
{
char buffer[64] = { 0 };
WLog_DBG(TAG, "%s: %s - rdp_client_connect_license() - %s", __FUNCTION__,
rdp_get_state_string(rdp),
state_run_result_string(status, buffer, ARRAYSIZE(buffer)));
}
break;
@ -1800,27 +1859,29 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
{
rdp_client_transition_to_state(
rdp, CONNECTION_STATE_CAPABILITIES_EXCHANGE_DEMAND_ACTIVE);
status = -23;
status = STATE_RUN_TRY_AGAIN;
}
break;
case CONNECTION_STATE_CAPABILITIES_EXCHANGE_DEMAND_ACTIVE:
status = rdp_client_connect_demand_active(rdp, s);
if (status < 0)
if (state_run_failed(status))
{
char buffer[64] = { 0 };
WLog_DBG(TAG,
"%s: %s - "
"rdp_client_connect_demand_active() - %i",
__FUNCTION__, rdp_get_state_string(rdp), status);
else if (status == 1)
status = 1;
else
"rdp_client_connect_demand_active() - %s",
__FUNCTION__, rdp_get_state_string(rdp),
state_run_result_string(status, buffer, ARRAYSIZE(buffer)));
}
else if (status != STATE_RUN_REDIRECT)
{
if (!rdp->settings->SupportMonitorLayoutPdu)
{
rdp_client_transition_to_state(
rdp, CONNECTION_STATE_CAPABILITIES_EXCHANGE_CONFIRM_ACTIVE);
status = -23;
status = STATE_RUN_TRY_AGAIN;
}
else
{
@ -1832,9 +1893,9 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
case CONNECTION_STATE_CAPABILITIES_EXCHANGE_MONITOR_LAYOUT:
status = rdp_recv_pdu(rdp, s);
if (status >= 0)
if (state_run_success(status))
{
status = -23;
status = STATE_RUN_TRY_AGAIN;
rdp_client_transition_to_state(
rdp, CONNECTION_STATE_CAPABILITIES_EXCHANGE_CONFIRM_ACTIVE);
}
@ -1848,14 +1909,14 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
{
const UINT32 flags = rdp->finalize_sc_pdus & mask;
status = rdp_recv_pdu(rdp, s);
if (status >= 0)
if (state_run_success(status))
{
const UINT32 uflags = rdp->finalize_sc_pdus & mask;
if (flags != uflags)
rdp_client_transition_to_state(rdp,
CONNECTION_STATE_FINALIZATION_CLIENT_COOPERATE);
else
status = -1;
status = STATE_RUN_FAILED;
}
}
break;
@ -1863,14 +1924,14 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
{
const UINT32 flags = rdp->finalize_sc_pdus & mask;
status = rdp_recv_pdu(rdp, s);
if (status >= 0)
if (state_run_success(status))
{
const UINT32 uflags = rdp->finalize_sc_pdus & mask;
if (flags != uflags)
rdp_client_transition_to_state(
rdp, CONNECTION_STATE_FINALIZATION_CLIENT_GRANTED_CONTROL);
else
status = -1;
status = STATE_RUN_FAILED;
}
}
break;
@ -1878,14 +1939,14 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
{
const UINT32 flags = rdp->finalize_sc_pdus & mask;
status = rdp_recv_pdu(rdp, s);
if (status >= 0)
if (state_run_success(status))
{
const UINT32 uflags = rdp->finalize_sc_pdus & mask;
if (flags != uflags)
rdp_client_transition_to_state(rdp,
CONNECTION_STATE_FINALIZATION_CLIENT_FONT_MAP);
else
status = -1;
status = STATE_RUN_FAILED;
}
}
break;
@ -1893,7 +1954,7 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
{
const UINT32 flags = rdp->finalize_sc_pdus & mask;
status = rdp_recv_pdu(rdp, s);
if (status >= 0)
if (state_run_success(status))
{
const UINT32 uflags = rdp->finalize_sc_pdus & mask;
if (flags == uflags)
@ -1901,49 +1962,58 @@ static int rdp_recv_callback_int(rdpTransport* transport, wStream* s, void* extr
{
rdp_client_transition_to_state(rdp, CONNECTION_STATE_ACTIVE);
status = 2;
status = STATE_RUN_ACTIVE;
}
}
if (status < 0)
WLog_DBG(TAG, "%s: %s - rdp_recv_pdu() - %i", __FUNCTION__,
rdp_get_state_string(rdp), status);
if (state_run_failed(status))
{
char buffer[64] = { 0 };
WLog_DBG(TAG, "%s: %s - rdp_recv_pdu() - %s", __FUNCTION__,
rdp_get_state_string(rdp),
state_run_result_string(status, buffer, ARRAYSIZE(buffer)));
}
}
break;
case CONNECTION_STATE_ACTIVE:
status = rdp_recv_pdu(rdp, s);
if (status < 0)
WLog_DBG(TAG, "%s: %s - rdp_recv_pdu() - %i", __FUNCTION__,
rdp_get_state_string(rdp), status);
if (state_run_failed(status))
{
char buffer[64] = { 0 };
WLog_DBG(TAG, "%s: %s - rdp_recv_pdu() - %s", __FUNCTION__,
rdp_get_state_string(rdp),
state_run_result_string(status, buffer, ARRAYSIZE(buffer)));
}
break;
default:
WLog_ERR(TAG, "%s: %s state %d", __FUNCTION__, rdp_get_state_string(rdp),
rdp_get_state(rdp));
status = -1;
status = STATE_RUN_FAILED;
break;
}
if (status < 0 && status > -23)
if (state_run_failed(status))
{
WLog_ERR(TAG, "%s: %s status %d", __FUNCTION__, rdp_get_state_string(rdp), status);
char buffer[64] = { 0 };
WLog_ERR(TAG, "%s: %s status %s", __FUNCTION__, rdp_get_state_string(rdp),
state_run_result_string(status, buffer, ARRAYSIZE(buffer)));
}
return status;
}
int rdp_recv_callback(rdpTransport* transport, wStream* s, void* extra)
{
int rc = -1;
state_run_t rc = STATE_RUN_FAILED;
const size_t start = Stream_GetPosition(s);
do
{
if (rc == -23)
if (rc == STATE_RUN_TRY_AGAIN)
Stream_SetPosition(s, start);
rc = rdp_recv_callback_int(transport, s, extra);
} while (rc <= -23);
} while ((rc == STATE_RUN_TRY_AGAIN) || (rc == STATE_RUN_CONTINUE));
return rc;
}