[core,rdstls] add state transition checks and logs

This commit is contained in:
akallabeth 2023-03-08 12:24:31 +01:00 committed by akallabeth
parent adbecf71c6
commit bc1d291b44
1 changed files with 87 additions and 11 deletions

View File

@ -99,9 +99,6 @@ rdpRdstls* rdstls_new(rdpContext* context, rdpTransport* transport)
void rdstls_free(rdpRdstls* rdstls)
{
if (!rdstls)
return;
free(rdstls);
}
@ -130,14 +127,92 @@ static RDSTLS_STATE rdstls_get_state(rdpRdstls* rdstls)
return rdstls->state;
}
static BOOL check_transition(wLog* log, RDSTLS_STATE current, RDSTLS_STATE expected,
RDSTLS_STATE requested)
{
if (requested != expected)
{
WLog_Print(log, WLOG_ERROR,
"Unexpected rdstls state transition from %s [%d] to %s [%d], expected %s [%d]",
rdstls_get_state_str(current), current, rdstls_get_state_str(requested),
requested, rdstls_get_state_str(expected), expected);
return FALSE;
}
return TRUE;
}
static BOOL rdstls_set_client_state(rdpRdstls* rdstls, RDSTLS_STATE state)
{
BOOL rc = FALSE;
WINPR_ASSERT(rdstls);
switch (rdstls->state)
{
case RDSTLS_STATE_CAPABILITIES:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_REQ, state);
break;
case RDSTLS_STATE_AUTH_REQ:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_RSP, state);
break;
case RDSTLS_STATE_AUTH_RSP:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_COMPLETED, state);
break;
case RDSTLS_STATE_COMPLETED:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state);
break;
default:
WLog_Print(rdstls->log, WLOG_ERROR,
"Invalid rdstls state %s [%d], requested transition to %s [%d]",
rdstls_get_state_str(rdstls->state), rdstls->state,
rdstls_get_state_str(state), state);
break;
}
if (rc)
rdstls->state = state;
return rc;
}
static BOOL rdstls_set_server_state(rdpRdstls* rdstls, RDSTLS_STATE state)
{
BOOL rc = FALSE;
WINPR_ASSERT(rdstls);
switch (rdstls->state)
{
case RDSTLS_STATE_CAPABILITIES:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_REQ, state);
break;
case RDSTLS_STATE_AUTH_REQ:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_RSP, state);
break;
case RDSTLS_STATE_AUTH_RSP:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_COMPLETED, state);
break;
case RDSTLS_STATE_COMPLETED:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state);
break;
default:
WLog_Print(rdstls->log, WLOG_ERROR,
"Invalid rdstls state %s [%d], requested transition to %s [%d]",
rdstls_get_state_str(rdstls->state), rdstls->state,
rdstls_get_state_str(state), state);
break;
}
if (rc)
rdstls->state = state;
return rc;
}
static BOOL rdstls_set_state(rdpRdstls* rdstls, RDSTLS_STATE state)
{
WINPR_ASSERT(rdstls);
WLog_Print(rdstls->log, WLOG_DEBUG, "-- %s\t--> %s", rdstls_get_state_str(rdstls->state),
rdstls_get_state_str(state));
rdstls->state = state;
return TRUE;
if (rdstls->server)
return rdstls_set_server_state(rdstls, state);
else
return rdstls_set_client_state(rdstls, state);
}
static BOOL rdstls_write_capabilities(rdpRdstls* rdstls, wStream* s)
@ -522,7 +597,8 @@ static BOOL rdstls_send(rdpTransport* transport, wStream* s, void* extra)
Stream_Write_UINT16(s, RDSTLS_VERSION_1);
switch (rdstls_get_state(rdstls))
const RDSTLS_STATE state = rdstls_get_state(rdstls);
switch (state)
{
case RDSTLS_STATE_CAPABILITIES:
if (!rdstls_write_capabilities(rdstls, s))
@ -551,9 +627,9 @@ static BOOL rdstls_send(rdpTransport* transport, wStream* s, void* extra)
return FALSE;
break;
default:
WLog_ERR(TAG, "RDSTLS in invalid receive state %s",
rdstls_get_state_str(rdstls_get_state(rdstls)));
return -1;
WLog_Print(rdstls->log, WLOG_ERROR, "Invalid rdstls state %s [%d]",
rdstls_get_state_str(state), state);
return FALSE;
}
if (transport_write(rdstls->transport, s) < 0)
@ -564,7 +640,6 @@ static BOOL rdstls_send(rdpTransport* transport, wStream* s, void* extra)
static int rdstls_recv(rdpTransport* transport, wStream* s, void* extra)
{
UINT16 length;
UINT16 version;
UINT16 pduType;
rdpRdstls* rdstls = (rdpRdstls*)extra;
@ -601,7 +676,8 @@ static int rdstls_recv(rdpTransport* transport, wStream* s, void* extra)
return -1;
break;
default:
WLog_Print(rdstls->log, WLOG_ERROR, "unknown RDSTLS PDU type");
WLog_Print(rdstls->log, WLOG_ERROR, "unknown RDSTLS PDU type [0x%04" PRIx16 "]",
pduType);
return -1;
}