[core,mcs] unify message channel handling

This commit is contained in:
akallabeth 2024-09-26 14:44:39 +02:00
parent 2fe0435e79
commit ab88e79a36
No known key found for this signature in database
GPG Key ID: A49454A3FC909FD5
5 changed files with 117 additions and 87 deletions

View File

@ -4469,8 +4469,6 @@ fail:
BOOL rdp_recv_get_active_header(rdpRdp* rdp, wStream* s, UINT16* pChannelId, UINT16* length)
{
UINT16 securityFlags = 0;
WINPR_ASSERT(rdp);
WINPR_ASSERT(rdp->context);
@ -4480,18 +4478,6 @@ BOOL rdp_recv_get_active_header(rdpRdp* rdp, wStream* s, UINT16* pChannelId, UIN
if (freerdp_shall_disconnect_context(rdp->context))
return TRUE;
if (rdp->settings->UseRdpSecurityLayer)
{
if (!rdp_read_security_header(rdp, s, &securityFlags, length))
return FALSE;
if (securityFlags & SEC_ENCRYPT)
{
if (!rdp_decrypt(rdp, s, length, securityFlags))
return FALSE;
}
}
if (*pChannelId != MCS_GLOBAL_CHANNEL_ID)
{
UINT16 mcsMessageChannelId = rdp->mcs->messageChannelId;

View File

@ -1128,26 +1128,31 @@ BOOL rdp_client_connect_mcs_channel_join_confirm(rdpRdp* rdp, wStream* s)
return TRUE;
}
BOOL rdp_client_connect_auto_detect(rdpRdp* rdp, wStream* s)
BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, UINT16 length)
{
WINPR_ASSERT(rdp);
WINPR_ASSERT(rdp->mcs);
if (!rdp->mcs->messageChannelJoined)
{
WLog_Print(rdp->log, WLOG_WARN, "MCS message channel not joined!");
return FALSE;
}
const UINT16 messageChannelId = rdp->mcs->messageChannelId;
/* If the MCS message channel has been joined... */
if (messageChannelId != 0)
if (messageChannelId == 0)
{
/* Process any MCS message channel PDUs. */
const size_t pos = Stream_GetPosition(s);
UINT16 length = 0;
UINT16 channelId = 0;
WLog_Print(rdp->log, WLOG_WARN, "MCS message channel id == 0");
return FALSE;
}
if (rdp_read_header(rdp, s, &length, &channelId))
{
if (channelId == messageChannelId)
if (channelId != messageChannelId)
{
WLog_Print(rdp->log, WLOG_WARN, "MCS message channel expected id=%" PRIu16 ", got %" PRIu16,
messageChannelId, channelId);
return FALSE;
}
UINT16 securityFlags = 0;
if (!rdp_read_security_header(rdp, s, &securityFlags, &length))
return FALSE;
@ -1157,19 +1162,40 @@ BOOL rdp_client_connect_auto_detect(rdpRdp* rdp, wStream* s)
return FALSE;
}
if (rdp_recv_message_channel_pdu(rdp, s, securityFlags) == STATE_RUN_SUCCESS)
if (rdp_recv_message_channel_pdu(rdp, s, securityFlags) != STATE_RUN_SUCCESS)
return FALSE;
return tpkt_ensure_stream_consumed(s, length);
}
}
BOOL rdp_client_connect_auto_detect(rdpRdp* rdp, wStream* s)
{
WINPR_ASSERT(rdp);
WINPR_ASSERT(rdp->mcs);
const size_t pos = Stream_GetPosition(s);
UINT16 length = 0;
UINT16 channelId = 0;
if (rdp_read_header(rdp, s, &length, &channelId))
{
const UINT16 messageChannelId = rdp->mcs->messageChannelId;
/* If the MCS message channel has been joined... */
/* Process any MCS message channel PDUs. */
if (rdp->mcs->messageChannelJoined && (channelId == messageChannelId))
{
if (rdp_handle_message_channel(rdp, s, channelId, length))
return TRUE;
}
else
{
WLog_WARN(TAG, "expected messageChannelId=%" PRIu16 ", got %" PRIu16, messageChannelId,
channelId);
}
}
Stream_SetPosition(s, pos);
}
else
WLog_WARN(TAG, "messageChannelId == 0");
return FALSE;
}
@ -1185,6 +1211,17 @@ state_run_t rdp_client_connect_license(rdpRdp* rdp, wStream* s)
if (!rdp_read_header(rdp, s, &length, &channelId))
return STATE_RUN_FAILED;
/* there might be autodetect messages mixed in between licensing messages.
* that has been observed with 2k12 R2 and 2k19
*/
const UINT16 messageChannelId = rdp->mcs->messageChannelId;
if (rdp->mcs->messageChannelJoined && (channelId == messageChannelId))
{
if (!rdp_handle_message_channel(rdp, s, channelId, length))
return STATE_RUN_FAILED;
return STATE_RUN_SUCCESS;
}
if (!rdp_read_security_header(rdp, s, &securityFlags, &length))
return STATE_RUN_FAILED;
@ -1194,15 +1231,6 @@ state_run_t rdp_client_connect_license(rdpRdp* rdp, wStream* s)
return STATE_RUN_FAILED;
}
/* there might be autodetect messages mixed in between licensing messages.
* that has been observed with 2k12 R2 and 2k19
*/
const UINT16 messageChannelId = rdp->mcs->messageChannelId;
if ((channelId == messageChannelId) || (securityFlags & SEC_AUTODETECT_REQ))
{
return rdp_recv_message_channel_pdu(rdp, s, securityFlags);
}
if (channelId != MCS_GLOBAL_CHANNEL_ID)
WLog_WARN(TAG, "unexpected message for channel %u, expected %u", channelId,
MCS_GLOBAL_CHANNEL_ID);
@ -1266,19 +1294,15 @@ state_run_t rdp_client_connect_demand_active(rdpRdp* rdp, wStream* s)
if (rdp->mcs->messageChannelId && (channelId == rdp->mcs->messageChannelId))
{
UINT16 securityFlags = 0;
if (!rdp_read_security_header(rdp, s, &securityFlags, NULL))
return STATE_RUN_FAILED;
if (securityFlags & SEC_ENCRYPT)
{
if (!rdp_decrypt(rdp, s, &length, securityFlags))
return STATE_RUN_FAILED;
}
rdp->inPackets++;
return rdp_recv_message_channel_pdu(rdp, s, securityFlags);
if (!rdp_handle_message_channel(rdp, s, channelId, length))
return STATE_RUN_FAILED;
return STATE_RUN_SUCCESS;
}
if (!rdp_handle_optional_rdp_decryption(rdp, s, &length, NULL))
return STATE_RUN_FAILED;
if (!rdp_read_share_control_header(rdp, s, NULL, NULL, &pduType, &pduSource))
return STATE_RUN_FAILED;
@ -2154,3 +2178,31 @@ state_run_t rdp_client_connect_confirm_active(rdpRdp* rdp, wStream* s)
status = STATE_RUN_FAILED;
return status;
}
BOOL rdp_handle_optional_rdp_decryption(rdpRdp* rdp, wStream* s, UINT16* length,
UINT16* pSecurityFlags)
{
BOOL rc = FALSE;
WINPR_ASSERT(rdp);
WINPR_ASSERT(rdp->settings);
UINT16 securityFlags = 0;
if (rdp->settings->UseRdpSecurityLayer)
{
if (!rdp_read_security_header(rdp, s, &securityFlags, length))
goto fail;
if (securityFlags & SEC_ENCRYPT)
{
if (!rdp_decrypt(rdp, s, length, securityFlags))
goto fail;
}
}
rc = TRUE;
fail:
if (pSecurityFlags)
*pSecurityFlags = securityFlags;
return rc;
}

View File

@ -47,6 +47,7 @@ FREERDP_LOCAL BOOL rdp_client_redirect(rdpRdp* rdp);
FREERDP_LOCAL BOOL rdp_client_skip_mcs_channel_join(rdpRdp* rdp);
FREERDP_LOCAL BOOL rdp_client_connect_mcs_channel_join_confirm(rdpRdp* rdp, wStream* s);
FREERDP_LOCAL BOOL rdp_client_connect_auto_detect(rdpRdp* rdp, wStream* s);
FREERDP_LOCAL state_run_t rdp_client_connect_license(rdpRdp* rdp, wStream* s);
FREERDP_LOCAL state_run_t rdp_client_connect_demand_active(rdpRdp* rdp, wStream* s);
FREERDP_LOCAL state_run_t rdp_client_connect_confirm_active(rdpRdp* rdp, wStream* s);
@ -72,4 +73,9 @@ FREERDP_LOCAL const char* rdp_client_connection_state_string(int state);
FREERDP_LOCAL BOOL rdp_channels_from_mcs(rdpSettings* settings, const rdpRdp* rdp);
FREERDP_LOCAL BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId,
UINT16 length);
FREERDP_LOCAL BOOL rdp_handle_optional_rdp_decryption(rdpRdp* rdp, wStream* s, UINT16* length,
UINT16* pSecurityFlags);
#endif /* FREERDP_LIB_CORE_CONNECTION_H */

View File

@ -449,28 +449,14 @@ static state_run_t peer_recv_tpkt_pdu(freerdp_peer* client, wStream* s)
if (rdp_get_state(rdp) <= CONNECTION_STATE_LICENSING)
{
if (!rdp_read_security_header(rdp, s, &securityFlags, &length))
if (!rdp_handle_message_channel(rdp, s, channelId, length))
return STATE_RUN_FAILED;
if (securityFlags & SEC_ENCRYPT)
{
if (!rdp_decrypt(rdp, s, &length, securityFlags))
return STATE_RUN_FAILED;
}
return rdp_recv_message_channel_pdu(rdp, s, securityFlags);
return STATE_RUN_SUCCESS;
}
if (settings->UseRdpSecurityLayer)
{
if (!rdp_read_security_header(rdp, s, &securityFlags, &length))
if (!rdp_handle_optional_rdp_decryption(rdp, s, &length, &securityFlags))
return STATE_RUN_FAILED;
if (securityFlags & SEC_ENCRYPT)
{
if (!rdp_decrypt(rdp, s, &length, securityFlags))
return STATE_RUN_FAILED;
}
}
if (channelId == MCS_GLOBAL_CHANNEL_ID)
{
char buffer[256] = { 0 };

View File

@ -1620,6 +1620,14 @@ static state_run_t rdp_recv_tpkt_pdu(rdpRdp* rdp, wStream* s)
rdp->autodetect->bandwidthMeasureByteCount += length;
}
if (rdp->mcs->messageChannelId && (channelId == rdp->mcs->messageChannelId))
{
rdp->inPackets++;
if (!rdp_handle_message_channel(rdp, s, channelId, length))
return STATE_RUN_FAILED;
return STATE_RUN_SUCCESS;
}
if (rdp->settings->UseRdpSecurityLayer)
{
if (!rdp_read_security_header(rdp, s, &securityFlags, &length))
@ -1714,14 +1722,6 @@ static state_run_t rdp_recv_tpkt_pdu(rdpRdp* rdp, wStream* s)
}
}
}
else if (rdp->mcs->messageChannelId && (channelId == rdp->mcs->messageChannelId))
{
if (!rdp->settings->UseRdpSecurityLayer)
if (!rdp_read_security_header(rdp, s, &securityFlags, NULL))
return STATE_RUN_FAILED;
rdp->inPackets++;
rc = rdp_recv_message_channel_pdu(rdp, s, securityFlags);
}
else
{
rdp->inPackets++;