[core,rdstls] log state checks

when checking expected states print a proper log message when the
requirement is not met
This commit is contained in:
akallabeth 2023-03-08 12:34:49 +01:00 committed by akallabeth
parent bc1d291b44
commit 9a51f3b77b

View File

@ -141,78 +141,42 @@ static BOOL check_transition(wLog* log, RDSTLS_STATE current, RDSTLS_STATE expec
return TRUE;
}
static BOOL rdstls_set_client_state(rdpRdstls* rdstls, RDSTLS_STATE state)
{
BOOL rc = FALSE;
WINPR_ASSERT(rdstls);
switch (rdstls->state)
{
case RDSTLS_STATE_CAPABILITIES:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_REQ, state);
break;
case RDSTLS_STATE_AUTH_REQ:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_RSP, state);
break;
case RDSTLS_STATE_AUTH_RSP:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_COMPLETED, state);
break;
case RDSTLS_STATE_COMPLETED:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state);
break;
default:
WLog_Print(rdstls->log, WLOG_ERROR,
"Invalid rdstls state %s [%d], requested transition to %s [%d]",
rdstls_get_state_str(rdstls->state), rdstls->state,
rdstls_get_state_str(state), state);
break;
}
if (rc)
rdstls->state = state;
return rc;
}
static BOOL rdstls_set_server_state(rdpRdstls* rdstls, RDSTLS_STATE state)
{
BOOL rc = FALSE;
WINPR_ASSERT(rdstls);
switch (rdstls->state)
{
case RDSTLS_STATE_CAPABILITIES:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_REQ, state);
break;
case RDSTLS_STATE_AUTH_REQ:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_RSP, state);
break;
case RDSTLS_STATE_AUTH_RSP:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_COMPLETED, state);
break;
case RDSTLS_STATE_COMPLETED:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state);
break;
default:
WLog_Print(rdstls->log, WLOG_ERROR,
"Invalid rdstls state %s [%d], requested transition to %s [%d]",
rdstls_get_state_str(rdstls->state), rdstls->state,
rdstls_get_state_str(state), state);
break;
}
if (rc)
rdstls->state = state;
return rc;
}
static BOOL rdstls_set_state(rdpRdstls* rdstls, RDSTLS_STATE state)
{
BOOL rc = FALSE;
WINPR_ASSERT(rdstls);
WLog_Print(rdstls->log, WLOG_DEBUG, "-- %s\t--> %s", rdstls_get_state_str(rdstls->state),
rdstls_get_state_str(state));
if (rdstls->server)
return rdstls_set_server_state(rdstls, state);
else
return rdstls_set_client_state(rdstls, state);
switch (rdstls->state)
{
case RDSTLS_STATE_INITIAL:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state);
break;
case RDSTLS_STATE_CAPABILITIES:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_REQ, state);
break;
case RDSTLS_STATE_AUTH_REQ:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_RSP, state);
break;
case RDSTLS_STATE_AUTH_RSP:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_FINAL, state);
break;
case RDSTLS_STATE_FINAL:
rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state);
break;
default:
WLog_Print(rdstls->log, WLOG_ERROR,
"Invalid rdstls state %s [%d], requested transition to %s [%d]",
rdstls_get_state_str(rdstls->state), rdstls->state,
rdstls_get_state_str(state), state);
break;
}
if (rc)
rdstls->state = state;
return rc;
}
static BOOL rdstls_write_capabilities(rdpRdstls* rdstls, wStream* s)
@ -684,14 +648,31 @@ static int rdstls_recv(rdpTransport* transport, wStream* s, void* extra)
return 1;
}
#define rdstls_check_state_requirements(rdstls, expected) \
rdstls_check_state_requirements_((rdstls), (expected), __FILE__, __FUNCTION__, __LINE__)
static BOOL rdstls_check_state_requirements_(rdpRdstls* rdstls, RDSTLS_STATE expected,
const char* file, const char* fkt, size_t line)
{
const RDSTLS_STATE current = rdstls_get_state(rdstls);
if (current == expected)
return TRUE;
const DWORD log_level = WLOG_ERROR;
if (WLog_IsLevelActive(rdstls->log, log_level))
WLog_PrintMessage(rdstls->log, WLOG_MESSAGE_TEXT, log_level, line, file, fkt,
"Unexpected rdstls state %s [%d], expected %s [%d]",
rdstls_get_state_str(current), current, rdstls_get_state_str(expected),
expected);
return FALSE;
}
static BOOL rdstls_send_capabilities(rdpRdstls* rdstls)
{
BOOL rc = FALSE;
wStream* s;
WINPR_ASSERT(rdstls);
if (rdstls_get_state(rdstls) != RDSTLS_STATE_CAPABILITIES)
if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_CAPABILITIES))
goto fail;
s = Stream_New(NULL, 512);
@ -713,9 +694,7 @@ static BOOL rdstls_recv_authentication_request(rdpRdstls* rdstls)
int status;
wStream* s;
WINPR_ASSERT(rdstls);
if (rdstls_get_state(rdstls) != RDSTLS_STATE_AUTH_REQ)
if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_REQ))
goto fail;
s = Stream_New(NULL, 4096);
@ -743,9 +722,7 @@ static BOOL rdstls_send_authentication_response(rdpRdstls* rdstls)
BOOL rc = FALSE;
wStream* s;
WINPR_ASSERT(rdstls);
if (rdstls_get_state(rdstls) != RDSTLS_STATE_AUTH_RSP)
if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_RSP))
goto fail;
s = Stream_New(NULL, 512);
@ -767,9 +744,7 @@ static BOOL rdstls_recv_capabilities(rdpRdstls* rdstls)
int status;
wStream* s;
WINPR_ASSERT(rdstls);
if (rdstls_get_state(rdstls) != RDSTLS_STATE_CAPABILITIES)
if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_CAPABILITIES))
goto fail;
s = Stream_New(NULL, 512);
@ -797,9 +772,7 @@ static BOOL rdstls_send_authentication_request(rdpRdstls* rdstls)
BOOL rc = FALSE;
wStream* s;
WINPR_ASSERT(rdstls);
if (rdstls_get_state(rdstls) != RDSTLS_STATE_AUTH_REQ)
if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_REQ))
goto fail;
s = Stream_New(NULL, 4096);
@ -823,7 +796,7 @@ static BOOL rdstls_recv_authentication_response(rdpRdstls* rdstls)
WINPR_ASSERT(rdstls);
if (rdstls_get_state(rdstls) != RDSTLS_STATE_AUTH_RSP)
if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_RSP))
goto fail;
s = Stream_New(NULL, 512);