diff --git a/libfreerdp/core/capabilities.c b/libfreerdp/core/capabilities.c index 1152989a8..2934e5e17 100644 --- a/libfreerdp/core/capabilities.c +++ b/libfreerdp/core/capabilities.c @@ -4469,8 +4469,6 @@ fail: BOOL rdp_recv_get_active_header(rdpRdp* rdp, wStream* s, UINT16* pChannelId, UINT16* length) { - UINT16 securityFlags = 0; - WINPR_ASSERT(rdp); WINPR_ASSERT(rdp->context); @@ -4480,18 +4478,6 @@ BOOL rdp_recv_get_active_header(rdpRdp* rdp, wStream* s, UINT16* pChannelId, UIN if (freerdp_shall_disconnect_context(rdp->context)) return TRUE; - if (rdp->settings->UseRdpSecurityLayer) - { - if (!rdp_read_security_header(rdp, s, &securityFlags, length)) - return FALSE; - - if (securityFlags & SEC_ENCRYPT) - { - if (!rdp_decrypt(rdp, s, length, securityFlags)) - return FALSE; - } - } - if (*pChannelId != MCS_GLOBAL_CHANNEL_ID) { UINT16 mcsMessageChannelId = rdp->mcs->messageChannelId; diff --git a/libfreerdp/core/connection.c b/libfreerdp/core/connection.c index a84de180f..8506b797d 100644 --- a/libfreerdp/core/connection.c +++ b/libfreerdp/core/connection.c @@ -1128,48 +1128,74 @@ BOOL rdp_client_connect_mcs_channel_join_confirm(rdpRdp* rdp, wStream* s) return TRUE; } +BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, UINT16 length) +{ + WINPR_ASSERT(rdp); + WINPR_ASSERT(rdp->mcs); + + if (!rdp->mcs->messageChannelJoined) + { + WLog_Print(rdp->log, WLOG_WARN, "MCS message channel not joined!"); + return FALSE; + } + const UINT16 messageChannelId = rdp->mcs->messageChannelId; + if (messageChannelId == 0) + { + WLog_Print(rdp->log, WLOG_WARN, "MCS message channel id == 0"); + return FALSE; + } + + if (channelId != messageChannelId) + { + WLog_Print(rdp->log, WLOG_WARN, "MCS message channel expected id=%" PRIu16 ", got %" PRIu16, + messageChannelId, channelId); + return FALSE; + } + + UINT16 securityFlags = 0; + if (!rdp_read_security_header(rdp, s, &securityFlags, &length)) + return FALSE; + + if (securityFlags & SEC_ENCRYPT) + { + if (!rdp_decrypt(rdp, s, &length, securityFlags)) + return FALSE; + } + + if (rdp_recv_message_channel_pdu(rdp, s, securityFlags) != STATE_RUN_SUCCESS) + return FALSE; + + return tpkt_ensure_stream_consumed(s, length); +} + BOOL rdp_client_connect_auto_detect(rdpRdp* rdp, wStream* s) { WINPR_ASSERT(rdp); WINPR_ASSERT(rdp->mcs); - const UINT16 messageChannelId = rdp->mcs->messageChannelId; - /* If the MCS message channel has been joined... */ - if (messageChannelId != 0) + const size_t pos = Stream_GetPosition(s); + UINT16 length = 0; + UINT16 channelId = 0; + + if (rdp_read_header(rdp, s, &length, &channelId)) { + const UINT16 messageChannelId = rdp->mcs->messageChannelId; + /* If the MCS message channel has been joined... */ + /* Process any MCS message channel PDUs. */ - const size_t pos = Stream_GetPosition(s); - UINT16 length = 0; - UINT16 channelId = 0; - - if (rdp_read_header(rdp, s, &length, &channelId)) + if (rdp->mcs->messageChannelJoined && (channelId == messageChannelId)) { - if (channelId == messageChannelId) - { - UINT16 securityFlags = 0; - - if (!rdp_read_security_header(rdp, s, &securityFlags, &length)) - return FALSE; - - if (securityFlags & SEC_ENCRYPT) - { - if (!rdp_decrypt(rdp, s, &length, securityFlags)) - return FALSE; - } - - if (rdp_recv_message_channel_pdu(rdp, s, securityFlags) == STATE_RUN_SUCCESS) - return tpkt_ensure_stream_consumed(s, length); - } + if (rdp_handle_message_channel(rdp, s, channelId, length)) + return TRUE; } else + { WLog_WARN(TAG, "expected messageChannelId=%" PRIu16 ", got %" PRIu16, messageChannelId, channelId); - - Stream_SetPosition(s, pos); + } } - else - WLog_WARN(TAG, "messageChannelId == 0"); + Stream_SetPosition(s, pos); return FALSE; } @@ -1185,6 +1211,17 @@ state_run_t rdp_client_connect_license(rdpRdp* rdp, wStream* s) if (!rdp_read_header(rdp, s, &length, &channelId)) return STATE_RUN_FAILED; + /* there might be autodetect messages mixed in between licensing messages. + * that has been observed with 2k12 R2 and 2k19 + */ + const UINT16 messageChannelId = rdp->mcs->messageChannelId; + if (rdp->mcs->messageChannelJoined && (channelId == messageChannelId)) + { + if (!rdp_handle_message_channel(rdp, s, channelId, length)) + return STATE_RUN_FAILED; + return STATE_RUN_SUCCESS; + } + if (!rdp_read_security_header(rdp, s, &securityFlags, &length)) return STATE_RUN_FAILED; @@ -1194,15 +1231,6 @@ state_run_t rdp_client_connect_license(rdpRdp* rdp, wStream* s) return STATE_RUN_FAILED; } - /* there might be autodetect messages mixed in between licensing messages. - * that has been observed with 2k12 R2 and 2k19 - */ - const UINT16 messageChannelId = rdp->mcs->messageChannelId; - if ((channelId == messageChannelId) || (securityFlags & SEC_AUTODETECT_REQ)) - { - return rdp_recv_message_channel_pdu(rdp, s, securityFlags); - } - if (channelId != MCS_GLOBAL_CHANNEL_ID) WLog_WARN(TAG, "unexpected message for channel %u, expected %u", channelId, MCS_GLOBAL_CHANNEL_ID); @@ -1266,19 +1294,15 @@ state_run_t rdp_client_connect_demand_active(rdpRdp* rdp, wStream* s) if (rdp->mcs->messageChannelId && (channelId == rdp->mcs->messageChannelId)) { - UINT16 securityFlags = 0; - if (!rdp_read_security_header(rdp, s, &securityFlags, NULL)) - return STATE_RUN_FAILED; - - if (securityFlags & SEC_ENCRYPT) - { - if (!rdp_decrypt(rdp, s, &length, securityFlags)) - return STATE_RUN_FAILED; - } rdp->inPackets++; - return rdp_recv_message_channel_pdu(rdp, s, securityFlags); + if (!rdp_handle_message_channel(rdp, s, channelId, length)) + return STATE_RUN_FAILED; + return STATE_RUN_SUCCESS; } + if (!rdp_handle_optional_rdp_decryption(rdp, s, &length, NULL)) + return STATE_RUN_FAILED; + if (!rdp_read_share_control_header(rdp, s, NULL, NULL, &pduType, &pduSource)) return STATE_RUN_FAILED; @@ -2154,3 +2178,31 @@ state_run_t rdp_client_connect_confirm_active(rdpRdp* rdp, wStream* s) status = STATE_RUN_FAILED; return status; } + +BOOL rdp_handle_optional_rdp_decryption(rdpRdp* rdp, wStream* s, UINT16* length, + UINT16* pSecurityFlags) +{ + BOOL rc = FALSE; + WINPR_ASSERT(rdp); + WINPR_ASSERT(rdp->settings); + + UINT16 securityFlags = 0; + if (rdp->settings->UseRdpSecurityLayer) + { + if (!rdp_read_security_header(rdp, s, &securityFlags, length)) + goto fail; + + if (securityFlags & SEC_ENCRYPT) + { + if (!rdp_decrypt(rdp, s, length, securityFlags)) + goto fail; + } + } + + rc = TRUE; + +fail: + if (pSecurityFlags) + *pSecurityFlags = securityFlags; + return rc; +} diff --git a/libfreerdp/core/connection.h b/libfreerdp/core/connection.h index 291ea262d..224eb72eb 100644 --- a/libfreerdp/core/connection.h +++ b/libfreerdp/core/connection.h @@ -47,6 +47,7 @@ FREERDP_LOCAL BOOL rdp_client_redirect(rdpRdp* rdp); FREERDP_LOCAL BOOL rdp_client_skip_mcs_channel_join(rdpRdp* rdp); FREERDP_LOCAL BOOL rdp_client_connect_mcs_channel_join_confirm(rdpRdp* rdp, wStream* s); FREERDP_LOCAL BOOL rdp_client_connect_auto_detect(rdpRdp* rdp, wStream* s); + FREERDP_LOCAL state_run_t rdp_client_connect_license(rdpRdp* rdp, wStream* s); FREERDP_LOCAL state_run_t rdp_client_connect_demand_active(rdpRdp* rdp, wStream* s); FREERDP_LOCAL state_run_t rdp_client_connect_confirm_active(rdpRdp* rdp, wStream* s); @@ -72,4 +73,9 @@ FREERDP_LOCAL const char* rdp_client_connection_state_string(int state); FREERDP_LOCAL BOOL rdp_channels_from_mcs(rdpSettings* settings, const rdpRdp* rdp); +FREERDP_LOCAL BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, + UINT16 length); +FREERDP_LOCAL BOOL rdp_handle_optional_rdp_decryption(rdpRdp* rdp, wStream* s, UINT16* length, + UINT16* pSecurityFlags); + #endif /* FREERDP_LIB_CORE_CONNECTION_H */ diff --git a/libfreerdp/core/peer.c b/libfreerdp/core/peer.c index 4bb33d58c..46c22bb7a 100644 --- a/libfreerdp/core/peer.c +++ b/libfreerdp/core/peer.c @@ -449,27 +449,13 @@ static state_run_t peer_recv_tpkt_pdu(freerdp_peer* client, wStream* s) if (rdp_get_state(rdp) <= CONNECTION_STATE_LICENSING) { - if (!rdp_read_security_header(rdp, s, &securityFlags, &length)) + if (!rdp_handle_message_channel(rdp, s, channelId, length)) return STATE_RUN_FAILED; - if (securityFlags & SEC_ENCRYPT) - { - if (!rdp_decrypt(rdp, s, &length, securityFlags)) - return STATE_RUN_FAILED; - } - return rdp_recv_message_channel_pdu(rdp, s, securityFlags); + return STATE_RUN_SUCCESS; } - if (settings->UseRdpSecurityLayer) - { - if (!rdp_read_security_header(rdp, s, &securityFlags, &length)) - return STATE_RUN_FAILED; - - if (securityFlags & SEC_ENCRYPT) - { - if (!rdp_decrypt(rdp, s, &length, securityFlags)) - return STATE_RUN_FAILED; - } - } + if (!rdp_handle_optional_rdp_decryption(rdp, s, &length, &securityFlags)) + return STATE_RUN_FAILED; if (channelId == MCS_GLOBAL_CHANNEL_ID) { diff --git a/libfreerdp/core/rdp.c b/libfreerdp/core/rdp.c index 922cb3346..03a48f10f 100644 --- a/libfreerdp/core/rdp.c +++ b/libfreerdp/core/rdp.c @@ -1620,6 +1620,14 @@ static state_run_t rdp_recv_tpkt_pdu(rdpRdp* rdp, wStream* s) rdp->autodetect->bandwidthMeasureByteCount += length; } + if (rdp->mcs->messageChannelId && (channelId == rdp->mcs->messageChannelId)) + { + rdp->inPackets++; + if (!rdp_handle_message_channel(rdp, s, channelId, length)) + return STATE_RUN_FAILED; + return STATE_RUN_SUCCESS; + } + if (rdp->settings->UseRdpSecurityLayer) { if (!rdp_read_security_header(rdp, s, &securityFlags, &length)) @@ -1714,14 +1722,6 @@ static state_run_t rdp_recv_tpkt_pdu(rdpRdp* rdp, wStream* s) } } } - else if (rdp->mcs->messageChannelId && (channelId == rdp->mcs->messageChannelId)) - { - if (!rdp->settings->UseRdpSecurityLayer) - if (!rdp_read_security_header(rdp, s, &securityFlags, NULL)) - return STATE_RUN_FAILED; - rdp->inPackets++; - rc = rdp_recv_message_channel_pdu(rdp, s, securityFlags); - } else { rdp->inPackets++;