From 9a51f3b77bcea57f6dac1fcad7a99fde183c899a Mon Sep 17 00:00:00 2001 From: akallabeth Date: Wed, 8 Mar 2023 12:34:49 +0100 Subject: [PATCH] [core,rdstls] log state checks when checking expected states print a proper log message when the requirement is not met --- libfreerdp/core/rdstls.c | 137 ++++++++++++++++----------------------- 1 file changed, 55 insertions(+), 82 deletions(-) diff --git a/libfreerdp/core/rdstls.c b/libfreerdp/core/rdstls.c index 038603297..321251d13 100644 --- a/libfreerdp/core/rdstls.c +++ b/libfreerdp/core/rdstls.c @@ -141,78 +141,42 @@ static BOOL check_transition(wLog* log, RDSTLS_STATE current, RDSTLS_STATE expec return TRUE; } -static BOOL rdstls_set_client_state(rdpRdstls* rdstls, RDSTLS_STATE state) -{ - BOOL rc = FALSE; - WINPR_ASSERT(rdstls); - switch (rdstls->state) - { - case RDSTLS_STATE_CAPABILITIES: - rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_REQ, state); - break; - case RDSTLS_STATE_AUTH_REQ: - rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_RSP, state); - break; - case RDSTLS_STATE_AUTH_RSP: - rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_COMPLETED, state); - break; - case RDSTLS_STATE_COMPLETED: - rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state); - break; - default: - WLog_Print(rdstls->log, WLOG_ERROR, - "Invalid rdstls state %s [%d], requested transition to %s [%d]", - rdstls_get_state_str(rdstls->state), rdstls->state, - rdstls_get_state_str(state), state); - break; - } - if (rc) - rdstls->state = state; - - return rc; -} - -static BOOL rdstls_set_server_state(rdpRdstls* rdstls, RDSTLS_STATE state) -{ - BOOL rc = FALSE; - WINPR_ASSERT(rdstls); - switch (rdstls->state) - { - case RDSTLS_STATE_CAPABILITIES: - rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_REQ, state); - break; - case RDSTLS_STATE_AUTH_REQ: - rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_RSP, state); - break; - case RDSTLS_STATE_AUTH_RSP: - rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_COMPLETED, state); - break; - case RDSTLS_STATE_COMPLETED: - rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state); - break; - default: - WLog_Print(rdstls->log, WLOG_ERROR, - "Invalid rdstls state %s [%d], requested transition to %s [%d]", - rdstls_get_state_str(rdstls->state), rdstls->state, - rdstls_get_state_str(state), state); - break; - } - if (rc) - rdstls->state = state; - - return rc; -} - static BOOL rdstls_set_state(rdpRdstls* rdstls, RDSTLS_STATE state) { + BOOL rc = FALSE; WINPR_ASSERT(rdstls); WLog_Print(rdstls->log, WLOG_DEBUG, "-- %s\t--> %s", rdstls_get_state_str(rdstls->state), rdstls_get_state_str(state)); - if (rdstls->server) - return rdstls_set_server_state(rdstls, state); - else - return rdstls_set_client_state(rdstls, state); + + switch (rdstls->state) + { + case RDSTLS_STATE_INITIAL: + rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state); + break; + case RDSTLS_STATE_CAPABILITIES: + rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_REQ, state); + break; + case RDSTLS_STATE_AUTH_REQ: + rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_RSP, state); + break; + case RDSTLS_STATE_AUTH_RSP: + rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_FINAL, state); + break; + case RDSTLS_STATE_FINAL: + rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state); + break; + default: + WLog_Print(rdstls->log, WLOG_ERROR, + "Invalid rdstls state %s [%d], requested transition to %s [%d]", + rdstls_get_state_str(rdstls->state), rdstls->state, + rdstls_get_state_str(state), state); + break; + } + if (rc) + rdstls->state = state; + + return rc; } static BOOL rdstls_write_capabilities(rdpRdstls* rdstls, wStream* s) @@ -684,14 +648,31 @@ static int rdstls_recv(rdpTransport* transport, wStream* s, void* extra) return 1; } +#define rdstls_check_state_requirements(rdstls, expected) \ + rdstls_check_state_requirements_((rdstls), (expected), __FILE__, __FUNCTION__, __LINE__) +static BOOL rdstls_check_state_requirements_(rdpRdstls* rdstls, RDSTLS_STATE expected, + const char* file, const char* fkt, size_t line) +{ + const RDSTLS_STATE current = rdstls_get_state(rdstls); + if (current == expected) + return TRUE; + + const DWORD log_level = WLOG_ERROR; + if (WLog_IsLevelActive(rdstls->log, log_level)) + WLog_PrintMessage(rdstls->log, WLOG_MESSAGE_TEXT, log_level, line, file, fkt, + "Unexpected rdstls state %s [%d], expected %s [%d]", + rdstls_get_state_str(current), current, rdstls_get_state_str(expected), + expected); + + return FALSE; +} + static BOOL rdstls_send_capabilities(rdpRdstls* rdstls) { BOOL rc = FALSE; wStream* s; - WINPR_ASSERT(rdstls); - - if (rdstls_get_state(rdstls) != RDSTLS_STATE_CAPABILITIES) + if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_CAPABILITIES)) goto fail; s = Stream_New(NULL, 512); @@ -713,9 +694,7 @@ static BOOL rdstls_recv_authentication_request(rdpRdstls* rdstls) int status; wStream* s; - WINPR_ASSERT(rdstls); - - if (rdstls_get_state(rdstls) != RDSTLS_STATE_AUTH_REQ) + if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_REQ)) goto fail; s = Stream_New(NULL, 4096); @@ -743,9 +722,7 @@ static BOOL rdstls_send_authentication_response(rdpRdstls* rdstls) BOOL rc = FALSE; wStream* s; - WINPR_ASSERT(rdstls); - - if (rdstls_get_state(rdstls) != RDSTLS_STATE_AUTH_RSP) + if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_RSP)) goto fail; s = Stream_New(NULL, 512); @@ -767,9 +744,7 @@ static BOOL rdstls_recv_capabilities(rdpRdstls* rdstls) int status; wStream* s; - WINPR_ASSERT(rdstls); - - if (rdstls_get_state(rdstls) != RDSTLS_STATE_CAPABILITIES) + if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_CAPABILITIES)) goto fail; s = Stream_New(NULL, 512); @@ -797,9 +772,7 @@ static BOOL rdstls_send_authentication_request(rdpRdstls* rdstls) BOOL rc = FALSE; wStream* s; - WINPR_ASSERT(rdstls); - - if (rdstls_get_state(rdstls) != RDSTLS_STATE_AUTH_REQ) + if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_REQ)) goto fail; s = Stream_New(NULL, 4096); @@ -823,7 +796,7 @@ static BOOL rdstls_recv_authentication_response(rdpRdstls* rdstls) WINPR_ASSERT(rdstls); - if (rdstls_get_state(rdstls) != RDSTLS_STATE_AUTH_RSP) + if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_RSP)) goto fail; s = Stream_New(NULL, 512);